├── LICENSE ├── OTVM-teaser.jpg ├── README.md ├── config.py ├── dataset.py ├── demo └── dove │ ├── frames │ ├── 00000.jpg │ ├── 00001.jpg │ ├── 00002.jpg │ ├── 00003.jpg │ ├── 00004.jpg │ ├── 00005.jpg │ ├── 00006.jpg │ ├── 00007.jpg │ ├── 00008.jpg │ ├── 00009.jpg │ └── 00010.jpg │ └── trimap │ └── 00000.png ├── eval.py ├── helpers.py ├── models ├── __init__.py ├── alpha │ ├── FBA │ │ ├── layers_WS.py │ │ ├── models.py │ │ ├── resnet_GN_WS.py │ │ └── resnet_bn.py │ ├── __init__.py │ ├── common.py │ └── model.py └── trimap │ ├── STM.py │ ├── __init__.py │ └── model.py ├── scripts ├── eval_s4.sh ├── eval_s4_demo.sh ├── train_s1_alpha.sh ├── train_s1_trimap.sh ├── train_s2_alpha.sh ├── train_s3.sh └── train_s4.sh ├── train.py ├── train_s1_trimap.py └── utils ├── loss_func.py ├── optimizer.py ├── tmp ├── __init__.py ├── augmentation.py ├── closed_form_matting │ ├── .travis.yml │ ├── LICENSE │ ├── README.md │ ├── closed_form_matting │ │ ├── __init__.py │ │ ├── closed_form_matting.py │ │ └── solve_foreground_background.py │ ├── requirements.txt │ ├── setup.py │ ├── test_matting.py │ └── testdata │ │ ├── matlab_alpha.png │ │ ├── matlab_background.png │ │ ├── matlab_foreground.png │ │ ├── output_alpha.png │ │ ├── output_background.png │ │ ├── output_foreground.png │ │ ├── scribbles.png │ │ ├── source.png │ │ └── trimap.png ├── group_weight.py └── metric.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | 418 | ======================================================================= 419 | 420 | Creative Commons is not a party to its public 421 | licenses. Notwithstanding, Creative Commons may elect to apply one of 422 | its public licenses to material it publishes and in those instances 423 | will be considered the “Licensor.” The text of the Creative Commons 424 | public licenses is dedicated to the public domain under the CC0 Public 425 | Domain Dedication. Except for the limited purpose of indicating that 426 | material is shared under a Creative Commons public license or as 427 | otherwise permitted by the Creative Commons policies published at 428 | creativecommons.org/policies, Creative Commons does not authorize the 429 | use of the trademark "Creative Commons" or any other trademark or logo 430 | of Creative Commons without its prior written consent including, 431 | without limitation, in connection with any unauthorized modifications 432 | to any of its public licenses or any other arrangements, 433 | understandings, or agreements concerning use of licensed material. For 434 | the avoidance of doubt, this paragraph does not form part of the 435 | public licenses. 436 | 437 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /OTVM-teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/OTVM-teaser.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## One-Trimap Video Matting (ECCV 2022)
Hongje Seong, Seoung Wug Oh, Brian Price, Euntai Kim, Joon-Young Lee 2 | 3 | [[Paper]](https://arxiv.org/abs/2207.13353) [[Demo video]](https://youtu.be/qkda4fHSyQE) 4 | 5 | Official Pytorch implementation of the ECCV 2022 paper, "One-Trimap Video Matting". 6 | 7 | ![Teaser image](./OTVM-teaser.jpg) 8 | 9 | 10 | ## Environments 11 | - Ubuntu 18.04 12 | - python 3.8 13 | - pytorch 1.8.2 14 | - CUDA 10.2 15 | 16 | ### Environment setting 17 | ```bash 18 | conda create -n otvm python=3.8 19 | conda activate otvm 20 | conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-lts 21 | pip install opencv-contrib-python scikit-image scipy tqdm imgaug yacs albumentations 22 | ``` 23 | 24 | ## Dataset 25 | To train OTVM, you need to prepare [AIM](https://sites.google.com/view/deepimagematting) and [VideoMatting108](https://github.com/yunkezhang/TCVOM) datasets 26 | ``` 27 | PATH/TO/DATASET 28 | ├── Combined_Dataset 29 | │   ├── Adobe Deep Image Mattng Dataset License Agreement.pdf 30 | │   ├── README.txt 31 | │   ├── Test_set 32 | │ │   ├── Adobe-licensed images 33 | │ │   └── ... 34 | │   └── Training_set 35 | │    ├── Adobe-licensed images 36 | │    └── ... 37 | └── VideoMatting108 38 |    ├── BG_done2 39 | │   ├── airport 40 | │   └── ... 41 |    ├── FG_done 42 | │   ├── animal_still 43 | │   └── ... 44 |    ├── flow_png_val 45 | │   ├── animal_still 46 | │   └── ... 47 |    ├── frame_corr.json 48 |    ├── train_videos_subset.txt 49 |    ├── train_videos.txt 50 |    ├── val_videos_subset.txt 51 |    └── val_videos.txt 52 | 53 | ``` 54 | 55 | ## Training 56 | ### Download pre-trained weights 57 | Download the pre-trained weights from [here](https://drive.google.com/drive/folders/1La53_oYZjhmcd2pfPPlnibLBPE12mc6b) and put them in the `weight/` directory. 58 | ```bash 59 | mkdir weights 60 | mv STM_weights.pth weights/ 61 | mv FBA.pth weights/ 62 | mv s1_OTVM_trimap.pth weights/ 63 | mv s1_OTVM_alpha.pth weights/ 64 | mv s2_OTVM_alpha.pth weights/ 65 | mv s3_OTVM.pth weights/ 66 | mv s4_OTVM.pth weights/ 67 | ``` 68 | Note: Initial weights of the trimap propagation and alpha prediction networks were taken from [STM](https://github.com/seoungwugoh/STM) and [FBA](https://github.com/MarcoForte/FBA_Matting), respectively. 69 | 79 | ### Change DATASET.PATH in config.py 80 | ```bash 81 | vim config.py 82 | 83 | # Change below path 84 | _C.DATASET.PATH = 'PATH/TO/DATASET' 85 | ``` 86 | 87 | ### Stage-wise Training 88 | ```bash 89 | # options: scripts/train_XXX.sh [GPUs] 90 | bash scripts/train_s1_trimap.sh 0,1,2,3 91 | bash scripts/train_s1_alpha.sh 0,1,2,3 92 | bash scripts/train_s2_alpha.sh 0,1,2,3 93 | bash scripts/train_s3.sh 0,1,2,3 94 | bash scripts/train_s4.sh 0,1,2,3 95 | ``` 96 | 97 | ## Inference (VideoMatting108 dataset) 98 | ```bash 99 | # options: scripts/eval_s4.sh [GPU] 100 | bash scripts/eval_s4.sh 0 101 | ``` 102 | 103 | ## Inference (custom dataset) 104 | ```bash 105 | # options: scripts/eval_s4_demo.sh [GPU] 106 | # The results will be generated in: ./demo_results 107 | bash scripts/eval_s4_demo.sh 0 108 | ``` 109 | 110 | ## Bibtex 111 | ``` 112 | @inproceedings{seong2022one, 113 | title={One-Trimap Video Matting}, 114 | author={Seong, Hongje and Oh, Seoung Wug and Price, Brian and Kim, Euntai and Lee, Joon-Young}, 115 | booktitle={European Conference on Computer Vision}, 116 | year={2022} 117 | } 118 | ``` 119 | 120 | 121 | ## Terms of Use 122 | This software is for non-commercial use only. 123 | The source code is released under the Attribution-NonCommercial-ShareAlike (CC BY-NC-SA) Licence 124 | (see [this](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for details) 125 | 126 | [![Creative Commons License](https://i.creativecommons.org/l/by-nc-sa/4.0/88x31.png)](http://creativecommons.org/licenses/by-nc-sa/4.0/) 127 | 128 | ## Acknowledgments 129 | This code is based on TCVOM (ACM MM 2021): [[link](https://github.com/yunkezhang/TCVOM)] 130 | 131 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from pickle import FALSE, TRUE 2 | from yacs.config import CfgNode as CN 3 | 4 | _C = CN() 5 | _C.SYSTEM = CN() 6 | # Number of workers for doing things 7 | _C.SYSTEM.NUM_WORKERS = 8 8 | # Specific random seed, -1 for random. 9 | _C.SYSTEM.RANDOM_SEED = 111 10 | _C.SYSTEM.OUTDIR = 'train_log' 11 | _C.SYSTEM.CUDNN_BENCHMARK = True 12 | _C.SYSTEM.CUDNN_DETERMINISTIC = False 13 | _C.SYSTEM.CUDNN_ENABLED = True 14 | _C.SYSTEM.TESTMODE = False 15 | 16 | _C.DATASET = CN() 17 | # dataset path 18 | _C.DATASET.PATH = 'PATH/TO/DATASET' 19 | _C.DATASET.MIN_EDGE_LENGTH = 1088 20 | 21 | _C.TEST = CN() 22 | _C.TEST.MEMORY_MAX_NUM = 5 # 2: First&Prev, 0: First, 1: Prev, 3~: Multiple 23 | _C.TEST.MEMORY_SKIP_FRAME = 10 24 | 25 | _C.TRAIN = CN() 26 | _C.TRAIN.STAGE = 1 27 | _C.TRAIN.BATCH_SIZE = 4 28 | _C.TRAIN.BASE_LR = 1e-5 29 | _C.TRAIN.LR_STRATEGY = 'stair' # 'poly', 'const' or 'stair' 30 | _C.TRAIN.WEIGHT_DECAY = 1e-4 31 | _C.TRAIN.TRAIN_INPUT_SIZE = (320,320) 32 | _C.TRAIN.FRAME_NUM = 3 33 | _C.TRAIN.FREEZE_BN = True 34 | 35 | # optimizer type 36 | _C.TRAIN.OPTIMIZER = 'radam' #adam, radam 37 | _C.TRAIN.TOTAL_EPOCHS = 200 38 | _C.TRAIN.IMAGE_FREQ = -1 39 | _C.TRAIN.SAVE_EVERY_EPOCH = 20 40 | 41 | _C.ALPHA = CN() 42 | _C.ALPHA.MODEL = 'fba' 43 | 44 | 45 | def get_cfg_defaults(): 46 | """Get a yacs CfgNode object with default values for my_project.""" 47 | # Return a clone so that the defaults will not be altered 48 | # This is for the "local variable" use pattern 49 | return _C.clone() 50 | 51 | # Alternatively, provide a way to import the defaults as 52 | # a global singleton: 53 | # cfg = _C # users can `from config import cfg` -------------------------------------------------------------------------------- /demo/dove/frames/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00000.jpg -------------------------------------------------------------------------------- /demo/dove/frames/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00001.jpg -------------------------------------------------------------------------------- /demo/dove/frames/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00002.jpg -------------------------------------------------------------------------------- /demo/dove/frames/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00003.jpg -------------------------------------------------------------------------------- /demo/dove/frames/00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00004.jpg -------------------------------------------------------------------------------- /demo/dove/frames/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00005.jpg -------------------------------------------------------------------------------- /demo/dove/frames/00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00006.jpg -------------------------------------------------------------------------------- /demo/dove/frames/00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00007.jpg -------------------------------------------------------------------------------- /demo/dove/frames/00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00008.jpg -------------------------------------------------------------------------------- /demo/dove/frames/00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00009.jpg -------------------------------------------------------------------------------- /demo/dove/frames/00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00010.jpg -------------------------------------------------------------------------------- /demo/dove/trimap/00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/trimap/00000.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import timeit 5 | import cv2 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn.functional as F 9 | from torch import nn 10 | from torchvision.utils import save_image 11 | import tqdm 12 | 13 | from config import get_cfg_defaults 14 | from dataset import EvalDataset, VideoMatting108_Test, Demo_Test 15 | from helpers import * 16 | 17 | torch.set_grad_enabled(False) 18 | 19 | EPS = 0 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Train network') 23 | 24 | parser.add_argument("--gpu", type=str, default='0') 25 | parser.add_argument('--trimap', default='medium', choices=['narrow', 'medium', 'wide']) 26 | parser.add_argument("--viz", action='store_true') 27 | parser.add_argument("--demo", action='store_true') 28 | 29 | args = parser.parse_args() 30 | 31 | cfg = get_cfg_defaults() 32 | cfg.TRAIN.STAGE = 4 33 | 34 | if args.demo: 35 | cfg.SYSTEM.OUTDIR = './demo_results' 36 | cfg.DATASET.PATH = './demo' 37 | 38 | return args, cfg 39 | 40 | 41 | def main(cfg, args, GPU): 42 | os.environ['CUDA_VISIBLE_DEVICES'] = GPU 43 | if torch.cuda.is_available(): 44 | print('using Cuda devices, num:', torch.cuda.device_count()) 45 | 46 | MODEL = get_model_name(cfg) 47 | random_seed = cfg.SYSTEM.RANDOM_SEED 48 | output_dir = os.path.join(cfg.SYSTEM.OUTDIR, 'alpha') 49 | start = timeit.default_timer() 50 | cudnn.benchmark = False 51 | cudnn.deterministic = cfg.SYSTEM.CUDNN_DETERMINISTIC 52 | cudnn.enabled = cfg.SYSTEM.CUDNN_ENABLED 53 | if random_seed > 0: 54 | import random 55 | print('Seeding with', random_seed) 56 | random.seed(random_seed) 57 | torch.manual_seed(random_seed) 58 | 59 | if args.demo: 60 | outdir_tail = MODEL 61 | else: 62 | outdir_tail = os.path.join(args.trimap, MODEL) 63 | alpha_outdir = os.path.join(output_dir, 'test', outdir_tail) 64 | viz_outdir_img = os.path.join(output_dir, 'viz', 'img', outdir_tail) 65 | viz_outdir_vid = os.path.join(output_dir, 'viz', 'vid', outdir_tail) 66 | 67 | if args.trimap == 'narrow': 68 | dilate_kernel = 5 # width: 11 69 | elif args.trimap == 'medium': 70 | dilate_kernel = 12 # width: 25 71 | elif args.trimap == 'wide': 72 | dilate_kernel = 20 # width: 41 73 | 74 | model_trimap = get_model_trimap(cfg, mode='Test', dilate_kernel=dilate_kernel) 75 | model = get_model_alpha(cfg, model_trimap, mode='Test', dilate_kernel=dilate_kernel) 76 | 77 | load_ckpt = os.path.join('weights', '{:s}.pth'.format(MODEL)) 78 | dct = torch.load(load_ckpt, map_location=torch.device('cpu')) 79 | model.load_state_dict(dct) 80 | model = nn.DataParallel(model.cuda()) 81 | 82 | 83 | if args.demo: 84 | valid_dataset = Demo_Test(data_root=cfg.DATASET.PATH) 85 | else: 86 | valid_dataset = VideoMatting108_Test( 87 | data_root=cfg.DATASET.PATH, 88 | mode='val', 89 | ) 90 | with torch.no_grad(): 91 | eval(args, cfg, valid_dataset, model, alpha_outdir, viz_outdir_img, viz_outdir_vid, args.viz) 92 | 93 | end = timeit.default_timer() 94 | print('done | Total time: {}'.format(format_time(end-start))) 95 | 96 | def write_image(outdir, out, filename, max_batch=4): 97 | with torch.no_grad(): 98 | scaled_imgs, tri_pred, tri_gt, alphas, scaled_gts, comps = out 99 | b, s, _, h, w = scaled_imgs.shape 100 | alphas = alphas.expand(-1,-1,3,-1,-1) 101 | scaled_gts = scaled_gts.expand(-1,-1,3,-1,-1) 102 | 103 | b = max_batch if b > max_batch else b 104 | img_list = list() 105 | img_list.append(scaled_imgs[:max_batch].reshape(b*s, 3, h, w)) 106 | img_list.append(comps[:max_batch].reshape(b*s, 3, h, w)) 107 | img_list.append(tri_gt[:max_batch].reshape(b*s, 3, h, w)) 108 | img_list.append(scaled_gts[:max_batch].reshape(b*s, 3, h, w)) 109 | img_list.append(tri_pred[:max_batch].reshape(b*s, 3, h, w)) 110 | img_list.append(alphas[:max_batch].reshape(b*s, 3, h, w)) 111 | imgs = torch.cat(img_list, dim=0).reshape(-1, 3, h, w) 112 | 113 | imgs = F.interpolate(imgs, size=(h//2, w//2), mode='bilinear', align_corners=False) 114 | 115 | save_image(imgs, outdir%(filename), nrow=int(s*b*2)) 116 | 117 | def eval(args, cfg, valid_dataset, model, alpha_outdir, viz_outdir_img, viz_outdir_vid, VIZ): 118 | model.eval() 119 | 120 | for i_iter, (data_name, data_root, FG, BG, a, tri, seq_name) in enumerate(valid_dataset): 121 | if cfg.SYSTEM.TESTMODE: 122 | if i_iter not in [0, len(valid_dataset)-1]: 123 | continue 124 | torch.cuda.empty_cache() 125 | num_frames = 1 126 | eval_sequence = EvalDataset( 127 | data_name=data_name, 128 | data_root=data_root, 129 | FG=FG, 130 | BG=BG, 131 | a=a, 132 | tri_gt=tri, # GT trimap 133 | trimap=None, 134 | num_frames=num_frames, 135 | ) 136 | eval_loader = torch.utils.data.DataLoader( 137 | eval_sequence, 138 | batch_size=1, 139 | # num_workers=cfg.SYSTEM.NUM_WORKERS, 140 | num_workers=0, 141 | pin_memory=False, 142 | drop_last=False, 143 | shuffle=False, 144 | sampler=None) 145 | 146 | print('[{}/{}] Set FIXED dilate of unknown region: [{}]'.format(i_iter, len(valid_dataset), args.trimap)) 147 | 148 | save_path = os.path.join(alpha_outdir, 'pred', seq_name) 149 | os.makedirs(save_path, exist_ok=True) 150 | if VIZ: 151 | visualization_path_img = os.path.join(viz_outdir_img, 'viz', seq_name) 152 | visualization_path_vid = os.path.join(viz_outdir_vid, 'viz') 153 | os.makedirs(visualization_path_img, exist_ok=True) 154 | os.makedirs(visualization_path_vid, exist_ok=True) 155 | 156 | iterations = tqdm.tqdm(eval_loader) 157 | for i_seq, dp in enumerate(iterations): 158 | if cfg.SYSTEM.TESTMODE: 159 | if i_seq > 10: 160 | break 161 | 162 | def handle_batch(dp, first_frame, last_frame, memorize, max_memory_num, large_input): 163 | fg, bg, a, eps, tri_gt, tri, _, filename = dp # [B, 3, 3 or 1, H, W] 164 | 165 | if tri.dim() == 1: 166 | tri = None 167 | if tri_gt.dim() == 1: 168 | tri_gt = None 169 | 170 | out = model(a, fg, bg, tri=tri, tri_gt=tri_gt, 171 | first_frame=first_frame, 172 | last_frame=last_frame, 173 | memorize=memorize, 174 | max_memory_num=max_memory_num, 175 | large_input=large_input,) 176 | return out, filename[0] 177 | 178 | first_frame = (i_seq==0) 179 | last_frame = (i_seq==(len(iterations)-1)) 180 | memorize = False 181 | MEMORY_SKIP_FRAME = cfg.TEST.MEMORY_SKIP_FRAME 182 | MEMORY_MAX_NUM = cfg.TEST.MEMORY_MAX_NUM 183 | large_input = False 184 | if min(dp[0].shape[-2:]) > 1100: 185 | MEMORY_SKIP_FRAME = int(MEMORY_SKIP_FRAME * 2) 186 | MEMORY_MAX_NUM = int(MEMORY_MAX_NUM / 2) 187 | large_input = True 188 | if MEMORY_SKIP_FRAME > 2: 189 | memorize = (i_seq % MEMORY_SKIP_FRAME) == 0 190 | max_memory_num = MEMORY_MAX_NUM 191 | 192 | if first_frame: 193 | print('[{}/{}] {} | {} | Large input: {}'.format(i_iter, len(valid_dataset), seq_name, dp[0].shape[-2:], large_input)) 194 | 195 | torch.cuda.synchronize() 196 | out, filename = handle_batch(dp, first_frame, last_frame, memorize, max_memory_num, large_input,) 197 | torch.cuda.synchronize() 198 | 199 | scaled_imgs, tri_pred, tri_gt, alphas, scaled_gts = out 200 | 201 | green_bg = torch.zeros_like(scaled_imgs) 202 | green_bg[:,:,1] = 1. 203 | comps = scaled_imgs * alphas + green_bg * (1. - alphas) 204 | 205 | if VIZ: 206 | frame_path = os.path.join(visualization_path_img, 'f%d.jpg') 207 | else: 208 | frame_path = None 209 | alpha_pred_img = (alphas*255).byte().cpu().squeeze(0).squeeze(0).squeeze(0).numpy() 210 | filename_for_save = os.path.splitext(filename)[0]+'.png' 211 | 212 | def write_result_images(alpha_pred_img, path, VIZ, frame_path, vis_out, i_seq): 213 | if VIZ: 214 | write_image(frame_path, 215 | vis_out, 216 | i_seq) 217 | cv2.imwrite(path, alpha_pred_img) 218 | 219 | write_result_images(alpha_pred_img, 220 | os.path.join(save_path, filename_for_save), 221 | VIZ, 222 | frame_path, 223 | # [scaled_imgs, tri_pred, tri_gt, alphas, scaled_gts, comps], 224 | [scaled_imgs.cpu(), tri_pred.cpu(), tri_gt.cpu(), alphas.cpu(), scaled_gts.cpu(), comps.cpu()], 225 | i_seq) 226 | 227 | 228 | torch.cuda.synchronize() 229 | 230 | if VIZ: 231 | if '/' in seq_name: 232 | vid_name = seq_name.split('/') 233 | vid_name = '_'.join(vid_name) 234 | else: 235 | vid_name = seq_name 236 | vid_path = os.path.join(visualization_path_vid, '{}.mp4'.format(vid_name)) 237 | 238 | def make_viz_video(frame_path, vid_path): 239 | os.system('ffmpeg -framerate 10 -i {} {} -nostats -loglevel 0 -y'.format(frame_path, vid_path)) 240 | time.sleep(10) # wait 10 seconds 241 | 242 | make_viz_video(frame_path, vid_path) 243 | 244 | if __name__ == "__main__": 245 | args, cfg = parse_args() 246 | main(cfg, args, args.gpu) 247 | -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | #torch 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.distributed as torch_dist 6 | 7 | import numpy as np 8 | import time 9 | import os 10 | import logging 11 | from pathlib import Path 12 | from importlib import reload 13 | import sys 14 | 15 | def ToCuda(xs): 16 | if torch.cuda.is_available(): 17 | if isinstance(xs, list) or isinstance(xs, tuple): 18 | return [x.cuda() for x in xs] 19 | else: 20 | return xs.cuda() 21 | else: 22 | return xs 23 | 24 | 25 | def pad_divide_by(in_list, d, in_size): 26 | out_list = [] 27 | h, w = in_size 28 | if h % d > 0: 29 | new_h = h + d - h % d 30 | else: 31 | new_h = h 32 | if w % d > 0: 33 | new_w = w + d - w % d 34 | else: 35 | new_w = w 36 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) 37 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) 38 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 39 | for inp in in_list: 40 | out_list.append(F.pad(inp, pad_array)) 41 | return out_list, pad_array 42 | 43 | 44 | 45 | def overlay_davis(image,mask,colors=[255,0,0],cscale=2,alpha=0.4): 46 | """ Overlay segmentation on top of RGB image. from davis official""" 47 | # import skimage 48 | from scipy.ndimage.morphology import binary_erosion, binary_dilation 49 | 50 | colors = np.reshape(colors, (-1, 3)) 51 | colors = np.atleast_2d(colors) * cscale 52 | 53 | im_overlay = image.copy() 54 | object_ids = np.unique(mask) 55 | 56 | for object_id in object_ids[1:]: 57 | # Overlay color on binary mask 58 | foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) 59 | binary_mask = mask == object_id 60 | 61 | # Compose image 62 | im_overlay[binary_mask] = foreground[binary_mask] 63 | 64 | # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask 65 | countours = binary_dilation(binary_mask) ^ binary_mask 66 | # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask 67 | im_overlay[countours,:] = 0 68 | 69 | return im_overlay.astype(image.dtype) 70 | 71 | 72 | def torch_barrier(): 73 | if torch_dist.is_available() and torch_dist.is_initialized(): 74 | torch_dist.barrier() 75 | 76 | def reduce_tensor(inp): 77 | """ 78 | Reduce the loss from all processes so that 79 | ALL PROCESSES has the averaged results. 80 | """ 81 | if torch_dist.is_initialized(): 82 | world_size = torch_dist.get_world_size() 83 | if world_size < 2: 84 | return inp 85 | with torch.no_grad(): 86 | reduced_inp = inp 87 | torch.distributed.all_reduce(reduced_inp) 88 | torch.distributed.barrier() 89 | return reduced_inp / world_size 90 | return inp 91 | 92 | def print_loss_dict(loss, save=None): 93 | s = '' 94 | for key in sorted(loss.keys()): 95 | s += '{}: {:.6f}\n'.format(key, loss[key]) 96 | print (s) 97 | if save is not None: 98 | with open(save, 'w') as f: 99 | f.write(s) 100 | 101 | class AverageMeter(object): 102 | """Computes and stores the average and current value""" 103 | 104 | def __init__(self): 105 | self.initialized = False 106 | self.val = None 107 | self.avg = None 108 | self.sum = None 109 | self.count = None 110 | 111 | def initialize(self, val, weight): 112 | self.val = val 113 | self.avg = val 114 | self.sum = val * weight 115 | self.count = weight 116 | self.initialized = True 117 | 118 | def update(self, val, weight=1): 119 | if not self.initialized: 120 | self.initialize(val, weight) 121 | else: 122 | self.add(val, weight) 123 | 124 | def add(self, val, weight): 125 | self.val = val 126 | self.sum += val * weight 127 | self.count += weight 128 | self.avg = self.sum / self.count 129 | 130 | def value(self): 131 | return self.val 132 | 133 | def average(self): 134 | return self.avg 135 | 136 | def create_logger(output_dir, cfg_name, phase='train'): 137 | root_output_dir = Path(output_dir) 138 | # set up logger 139 | if not root_output_dir.exists(): 140 | print('=> creating {}'.format(root_output_dir)) 141 | root_output_dir.mkdir() 142 | 143 | final_output_dir = root_output_dir / cfg_name 144 | 145 | print('=> creating {}'.format(final_output_dir)) 146 | final_output_dir.mkdir(parents=True, exist_ok=True) 147 | 148 | time_str = time.strftime('%Y-%m-%d-%H-%M') 149 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase) 150 | final_log_file = final_output_dir / log_file 151 | head = '%(asctime)-15s %(message)s' 152 | # reset logging 153 | logging.shutdown() 154 | reload(logging) 155 | logging.basicConfig(filename=str(final_log_file), 156 | format=head) 157 | logger = logging.getLogger() 158 | logger.setLevel(logging.INFO) 159 | console = logging.StreamHandler() 160 | logging.getLogger('').addHandler(console) 161 | 162 | return logger, str(final_output_dir) 163 | 164 | def poly_lr(optimizer, base_lr, max_iters, cur_iters, power=0.9): 165 | lr = base_lr*((1-float(cur_iters)/max_iters)**(power)) 166 | # optimizer.param_groups[0]['lr'] = lr 167 | for param_group in optimizer.param_groups: 168 | if 'lr_ratio' in param_group: 169 | param_group['lr'] = lr * param_group['lr_ratio'] 170 | else: 171 | param_group['lr'] = lr 172 | return lr 173 | 174 | def const_lr(optimizer, base_lr, max_iters, cur_iters): 175 | # optimizer.param_groups[0]['lr'] = base_lr 176 | for param_group in optimizer.param_groups: 177 | if 'lr_ratio' in param_group: 178 | param_group['lr'] = base_lr * param_group['lr_ratio'] 179 | else: 180 | param_group['lr'] = base_lr 181 | return base_lr 182 | 183 | def stair_lr(optimizer, base_lr, max_iters, cur_iters): 184 | # 0, 180 185 | ratios = [1, 0.1] 186 | progress = cur_iters / float(max_iters) 187 | if progress < 0.9: 188 | ratio = ratios[0] 189 | else: 190 | ratio = ratios[-1] 191 | lr = base_lr * ratio 192 | # optimizer.param_groups[0]['lr'] = lr 193 | for param_group in optimizer.param_groups: 194 | if 'lr_ratio' in param_group: 195 | param_group['lr'] = lr * param_group['lr_ratio'] 196 | else: 197 | param_group['lr'] = lr 198 | return lr 199 | 200 | def worker_init_fn(worker_id): 201 | np.random.seed(np.random.get_state()[1][0] + worker_id) 202 | 203 | STR_DICT = { 204 | 'poly': poly_lr, 205 | 'const': const_lr, 206 | 'stair': stair_lr 207 | } 208 | 209 | 210 | 211 | _, term_width = os.popen('stty size', 'r').read().split() 212 | term_width = int(term_width) 213 | 214 | TOTAL_BAR_LENGTH = 20. 215 | last_time = time.time() 216 | begin_time = last_time 217 | 218 | code_begin_time = time.time() 219 | memorize_iter_time = list() 220 | memorize_iter_time.append(code_begin_time) 221 | 222 | def progress_bar(current, total, current_epoch, start_epoch, end_epoch, mode=None, msg=None): 223 | # global last_time, begin_time, code_begin_time, runing_weight 224 | global last_time, begin_time, memorize_iter_time 225 | if current == 0: 226 | begin_time = time.time() # Reset for new bar. 227 | 228 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 229 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 230 | 231 | sys.stdout.write(' [') 232 | for i in range(cur_len): 233 | sys.stdout.write('=') 234 | sys.stdout.write('>') 235 | for i in range(rest_len): 236 | sys.stdout.write('.') 237 | sys.stdout.write(']') 238 | 239 | cur_time = time.time() 240 | step_time = cur_time - last_time 241 | last_time = cur_time 242 | tot_time = cur_time - begin_time 243 | 244 | L = [] 245 | L.append(' E: %d' % current_epoch) 246 | L.append(' | Step: %s' % format_time(step_time)) 247 | L.append(' | Tot: %s' % format_time(tot_time)) 248 | if mode: 249 | memorize_iter_num = 1000 250 | total_time_from_code_begin = time.time() 251 | memorize_iter_time.append(total_time_from_code_begin) 252 | if len(memorize_iter_time) > memorize_iter_num: 253 | memorize_iter_time.pop(0) 254 | remain_iters = ((end_epoch-current_epoch)*total) - (current+1) 255 | eta = (memorize_iter_time[-1] - memorize_iter_time[0]) / (len(memorize_iter_time) - 1) * remain_iters 256 | L.append(' | ETA: %s' % format_time(eta)) 257 | if msg: 258 | L.append(' | ' + msg) 259 | 260 | msg = ''.join(L) 261 | sys.stdout.write(msg) 262 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 263 | sys.stdout.write(' ') 264 | 265 | # Go back to the center of the bar. 266 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 267 | sys.stdout.write('\b') 268 | sys.stdout.write(' %d/%d ' % (current+1, total)) 269 | 270 | if current < total-1: 271 | sys.stdout.write('\r') 272 | else: 273 | sys.stdout.write('\n') 274 | sys.stdout.flush() 275 | 276 | def format_time(seconds): 277 | days = int(seconds / 3600/24) 278 | seconds = seconds - days*3600*24 279 | hours = int(seconds / 3600) 280 | seconds = seconds - hours*3600 281 | minutes = int(seconds / 60) 282 | seconds = seconds - minutes*60 283 | secondsf = int(seconds) 284 | seconds = seconds - secondsf 285 | millis = int(seconds*1000) 286 | 287 | f = '' 288 | i = 1 289 | if days > 0: 290 | f += str(days) + 'D' 291 | i += 1 292 | if hours > 0 and i <= 2: 293 | f += str(hours) + 'h' 294 | i += 1 295 | if minutes > 0 and i <= 2: 296 | f += str(minutes) + 'm' 297 | i += 1 298 | if secondsf > 0 and i <= 2: 299 | f += str(secondsf) + 's' 300 | i += 1 301 | if millis > 0 and i <= 2: 302 | f += str(millis) + 'ms' 303 | i += 1 304 | if f == '': 305 | f = '0ms' 306 | return f 307 | 308 | 309 | def load_NoPrefix(path, length): 310 | # load dataparallel wrapped model properly 311 | state_dict = torch.load(path, map_location=torch.device('cpu')) 312 | if 'state_dict' in state_dict.keys(): 313 | state_dict = state_dict['state_dict'] 314 | # create new OrderedDict that does not contain `module.` 315 | from collections import OrderedDict 316 | new_state_dict = OrderedDict() 317 | for k, v in state_dict.items(): 318 | name = k[length:] # remove `Scale.` 319 | new_state_dict[name] = v 320 | return new_state_dict 321 | 322 | 323 | def get_model_name(cfg): 324 | names = {1: 's1_OTVM_alpha', 325 | 2: 's2_OTVM_alpha', 326 | 3: 's3_OTVM', 327 | 4: 's4_OTVM'} 328 | return names[cfg.TRAIN.STAGE] 329 | 330 | 331 | 332 | def get_model_trimap(cfg, mode='Test', dilate_kernel=None): 333 | import models.trimap.model as model_trimap 334 | if mode == 'Train': 335 | model = model_trimap.FullModel 336 | elif mode == 'Test': 337 | model = model_trimap.FullModel_eval 338 | 339 | hdim = 16 340 | 341 | model_loded = model(eps=0, 342 | stage=cfg.TRAIN.STAGE, 343 | dilate_kernel=dilate_kernel, 344 | hdim=hdim,) 345 | 346 | return model_loded 347 | 348 | def get_model_alpha(cfg, model_trimap, mode='Test', dilate_kernel=None): 349 | import models.alpha.model as model_alpha 350 | if cfg.TRAIN.STAGE == 1: 351 | model_trimap = None 352 | 353 | if mode == 'Train': 354 | model = model_alpha.FullModel 355 | elif mode == 'Test': 356 | model = model_alpha.EvalModel 357 | 358 | model_loded = model(dilate_kernel=dilate_kernel, 359 | trimap=model_trimap, 360 | stage=cfg.TRAIN.STAGE,) 361 | 362 | return model_loded 363 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/models/__init__.py -------------------------------------------------------------------------------- /models/alpha/FBA/layers_WS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class Conv2d(nn.Conv2d): 7 | 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 9 | padding=0, dilation=1, groups=1, bias=True): 10 | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, 11 | padding, dilation, groups, bias) 12 | 13 | def forward(self, x): 14 | # return super(Conv2d, self).forward(x) 15 | weight = self.weight 16 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, 17 | keepdim=True).mean(dim=3, keepdim=True) 18 | weight = weight - weight_mean 19 | # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 20 | std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1) + 1e-5 21 | weight = weight / std.expand_as(weight) 22 | return F.conv2d(x, weight, self.bias, self.stride, 23 | self.padding, self.dilation, self.groups) 24 | 25 | 26 | def BatchNorm2d(num_features): 27 | return nn.GroupNorm(num_channels=num_features, num_groups=32) 28 | -------------------------------------------------------------------------------- /models/alpha/FBA/models.py: -------------------------------------------------------------------------------- 1 | from numpy import not_equal 2 | import torch 3 | import torch.nn as nn 4 | from . import resnet_GN_WS 5 | from . import layers_WS as L 6 | from . import resnet_bn 7 | 8 | FEAT_DIM = 2048 9 | DEC_DIM = 256 10 | 11 | def FBA(refinement): 12 | builder = ModelBuilder() 13 | net_encoder = builder.build_encoder(arch='resnet50_GN_WS') 14 | net_decoder = builder.build_decoder(arch="fba_decoder", batch_norm=False) 15 | 16 | model = MattingModule(net_encoder, net_decoder, refinement) 17 | 18 | return model 19 | 20 | 21 | class MattingModule(nn.Module): 22 | def __init__(self, net_enc, net_dec, refinement): 23 | super(MattingModule, self).__init__() 24 | self.encoder = net_enc 25 | self.decoder = net_dec 26 | self.refinement = refinement 27 | if refinement: 28 | self.refine = RefinementModule() 29 | else: 30 | self.refine = None 31 | 32 | def forward(self, x, extras): 33 | image, two_chan_trimap = extras 34 | conv_out, indices = self.encoder(x) 35 | 36 | hid, output, x_dec = self.decoder(conv_out, image, indices, two_chan_trimap) 37 | pred_alpha = output[:, :1] 38 | 39 | if self.refine is not None: 40 | hid, refine_output, refine_trimap = self.refine(x_dec, image, two_chan_trimap, pred_alpha) 41 | else: 42 | refine_output = None 43 | refine_trimap = None 44 | 45 | return output, hid, refine_output, refine_trimap 46 | 47 | 48 | class ModelBuilder(): 49 | def build_encoder(self, arch='resnet50_GN', num_channels_additional=None): 50 | if arch == 'resnet50_GN_WS': 51 | orig_resnet = resnet_GN_WS.__dict__['l_resnet50']() 52 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8, num_channels_additional=num_channels_additional) 53 | elif arch == 'resnet50_BN': 54 | orig_resnet = resnet_bn.__dict__['l_resnet50']() 55 | net_encoder = ResnetDilatedBN(orig_resnet, dilate_scale=8, num_channels_additional=num_channels_additional) 56 | elif arch == 'resnet18_GN_WS': 57 | orig_resnet = resnet_GN_WS.__dict__['l_resnet18']() 58 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8, num_channels_additional=num_channels_additional) 59 | elif arch == 'resnet34_GN_WS': 60 | orig_resnet = resnet_GN_WS.__dict__['l_resnet34']() 61 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8, num_channels_additional=num_channels_additional) 62 | 63 | else: 64 | raise Exception('Architecture undefined!') 65 | 66 | num_channels = 3 + 6 + 2 67 | 68 | if(num_channels > 3): 69 | print(f'modifying input layer to accept {num_channels} channels') 70 | net_encoder_sd = net_encoder.state_dict() 71 | conv1_weights = net_encoder_sd['conv1.weight'] 72 | 73 | c_out, c_in, h, w = conv1_weights.size() 74 | conv1_mod = torch.zeros(c_out, num_channels, h, w) 75 | conv1_mod[:, :3, :, :] = conv1_weights 76 | 77 | conv1 = net_encoder.conv1 78 | conv1.in_channels = num_channels 79 | conv1.weight = torch.nn.Parameter(conv1_mod) 80 | 81 | net_encoder.conv1 = conv1 82 | 83 | net_encoder_sd['conv1.weight'] = conv1_mod 84 | 85 | net_encoder.load_state_dict(net_encoder_sd) 86 | return net_encoder 87 | 88 | def build_decoder(self, arch='fba_decoder', batch_norm=False, memory_decoder=False): 89 | if arch == 'fba_decoder': 90 | net_decoder = fba_decoder(batch_norm=batch_norm, memory_decoder=memory_decoder) 91 | 92 | return net_decoder 93 | 94 | 95 | class ResnetDilatedBN(nn.Module): 96 | def __init__(self, orig_resnet, dilate_scale=8, num_channels_additional=None): 97 | super(ResnetDilatedBN, self).__init__() 98 | from functools import partial 99 | 100 | if dilate_scale == 8: 101 | orig_resnet.layer3.apply( 102 | partial(self._nostride_dilate, dilate=2)) 103 | orig_resnet.layer4.apply( 104 | partial(self._nostride_dilate, dilate=4)) 105 | elif dilate_scale == 16: 106 | orig_resnet.layer4.apply( 107 | partial(self._nostride_dilate, dilate=2)) 108 | 109 | # take pretrained resnet, except AvgPool and FC 110 | self.conv1 = orig_resnet.conv1 111 | self.bn1 = orig_resnet.bn1 112 | self.relu1 = orig_resnet.relu1 113 | self.conv2 = orig_resnet.conv2 114 | self.bn2 = orig_resnet.bn2 115 | self.relu2 = orig_resnet.relu2 116 | self.conv3 = orig_resnet.conv3 117 | self.bn3 = orig_resnet.bn3 118 | self.relu3 = orig_resnet.relu3 119 | self.maxpool = orig_resnet.maxpool 120 | self.layer1 = orig_resnet.layer1 121 | self.layer2 = orig_resnet.layer2 122 | self.layer3 = orig_resnet.layer3 123 | self.layer4 = orig_resnet.layer4 124 | 125 | self.num_channels_additional = num_channels_additional 126 | if self.num_channels_additional is not None: 127 | self.conv1_a = resnet_bn.conv3x3(self.num_channels_additional, 64, stride=2) 128 | 129 | def _nostride_dilate(self, m, dilate): 130 | classname = m.__class__.__name__ 131 | if classname.find('Conv') != -1: 132 | # the convolution with stride 133 | if m.stride == (2, 2): 134 | m.stride = (1, 1) 135 | if m.kernel_size == (3, 3): 136 | m.dilation = (dilate // 2, dilate // 2) 137 | m.padding = (dilate // 2, dilate // 2) 138 | # other convoluions 139 | else: 140 | if m.kernel_size == (3, 3): 141 | m.dilation = (dilate, dilate) 142 | m.padding = (dilate, dilate) 143 | 144 | def forward(self, x, return_feature_maps=False): 145 | conv_out = [x] 146 | x = self.relu1(self.bn1(self.conv1(x))) 147 | x = self.relu2(self.bn2(self.conv2(x))) 148 | x = self.relu3(self.bn3(self.conv3(x))) 149 | conv_out.append(x) 150 | x, indices = self.maxpool(x) 151 | x = self.layer1(x) 152 | conv_out.append(x) 153 | x = self.layer2(x) 154 | conv_out.append(x) 155 | x = self.layer3(x) 156 | conv_out.append(x) 157 | x = self.layer4(x) 158 | conv_out.append(x) 159 | 160 | if return_feature_maps: 161 | return conv_out, indices 162 | return [x] 163 | 164 | 165 | class Resnet(nn.Module): 166 | def __init__(self, orig_resnet): 167 | super(Resnet, self).__init__() 168 | 169 | # take pretrained resnet, except AvgPool and FC 170 | self.conv1 = orig_resnet.conv1 171 | self.bn1 = orig_resnet.bn1 172 | self.relu1 = orig_resnet.relu1 173 | self.conv2 = orig_resnet.conv2 174 | self.bn2 = orig_resnet.bn2 175 | self.relu2 = orig_resnet.relu2 176 | self.conv3 = orig_resnet.conv3 177 | self.bn3 = orig_resnet.bn3 178 | self.relu3 = orig_resnet.relu3 179 | self.maxpool = orig_resnet.maxpool 180 | self.layer1 = orig_resnet.layer1 181 | self.layer2 = orig_resnet.layer2 182 | self.layer3 = orig_resnet.layer3 183 | self.layer4 = orig_resnet.layer4 184 | 185 | def forward(self, x, return_feature_maps=False): 186 | conv_out = [] 187 | 188 | x = self.relu1(self.bn1(self.conv1(x))) 189 | x = self.relu2(self.bn2(self.conv2(x))) 190 | x = self.relu3(self.bn3(self.conv3(x))) 191 | conv_out.append(x) 192 | x, indices = self.maxpool(x) 193 | 194 | x = self.layer1(x) 195 | conv_out.append(x) 196 | x = self.layer2(x) 197 | conv_out.append(x) 198 | x = self.layer3(x) 199 | conv_out.append(x) 200 | x = self.layer4(x) 201 | conv_out.append(x) 202 | 203 | if return_feature_maps: 204 | return conv_out 205 | return [x] 206 | 207 | 208 | class ResnetDilated(nn.Module): 209 | def __init__(self, orig_resnet, dilate_scale=8, num_channels_additional=None): 210 | super(ResnetDilated, self).__init__() 211 | from functools import partial 212 | 213 | if dilate_scale == 8: 214 | orig_resnet.layer3.apply( 215 | partial(self._nostride_dilate, dilate=2)) 216 | orig_resnet.layer4.apply( 217 | partial(self._nostride_dilate, dilate=4)) 218 | elif dilate_scale == 16: 219 | orig_resnet.layer4.apply( 220 | partial(self._nostride_dilate, dilate=2)) 221 | 222 | # take pretrained resnet, except AvgPool and FC 223 | self.conv1 = orig_resnet.conv1 224 | self.bn1 = orig_resnet.bn1 225 | self.relu = orig_resnet.relu 226 | self.maxpool = orig_resnet.maxpool 227 | self.layer1 = orig_resnet.layer1 228 | self.layer2 = orig_resnet.layer2 229 | self.layer3 = orig_resnet.layer3 230 | self.layer4 = orig_resnet.layer4 231 | 232 | self.num_channels_additional = num_channels_additional 233 | if self.num_channels_additional is not None: 234 | self.conv1_a = resnet_GN_WS.L.Conv2d(self.num_channels_additional, 64, kernel_size=7, stride=2, padding=3, bias=False) 235 | 236 | def _nostride_dilate(self, m, dilate): 237 | classname = m.__class__.__name__ 238 | if classname.find('Conv') != -1: 239 | # the convolution with stride 240 | if m.stride == (2, 2): 241 | m.stride = (1, 1) 242 | if m.kernel_size == (3, 3): 243 | m.dilation = (dilate // 2, dilate // 2) 244 | m.padding = (dilate // 2, dilate // 2) 245 | # other convoluions 246 | else: 247 | if m.kernel_size == (3, 3): 248 | m.dilation = (dilate, dilate) 249 | m.padding = (dilate, dilate) 250 | 251 | def forward(self, x, x_a=None): 252 | conv_out = [x] # OS=1 253 | if self.num_channels_additional is None: 254 | x = self.relu(self.bn1(self.conv1(x))) 255 | else: 256 | x = self.conv1(x) + self.conv1_a(x_a) 257 | x = self.relu(self.bn1(x)) 258 | conv_out.append(x) # OS=2 259 | x, indices = self.maxpool(x) 260 | x = self.layer1(x) 261 | conv_out.append(x) # OS=4 262 | x = self.layer2(x) 263 | conv_out.append(x) # OS=8 264 | x = self.layer3(x) 265 | conv_out.append(x) 266 | x = self.layer4(x) 267 | conv_out.append(x) 268 | 269 | return conv_out, indices 270 | 271 | 272 | def norm(dim, bn=False): 273 | if(bn is False): 274 | return nn.GroupNorm(32, dim) 275 | else: 276 | return nn.BatchNorm2d(dim) 277 | 278 | 279 | def fba_fusion(alpha, img, F, B): 280 | F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B)) 281 | B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F) 282 | 283 | F = torch.clamp(F, 0, 1) 284 | B = torch.clamp(B, 0, 1) 285 | la = 0.1 286 | alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la) 287 | alpha = torch.clamp(alpha, 0, 1) 288 | return alpha, F, B 289 | 290 | 291 | class fba_decoder(nn.Module): 292 | def __init__(self, batch_norm=False, memory_decoder=False): 293 | super(fba_decoder, self).__init__() 294 | pool_scales = (1, 2, 3, 6) 295 | self.batch_norm = batch_norm 296 | self.memory_decoder = memory_decoder 297 | 298 | self.ppm = [] 299 | 300 | for scale in pool_scales: 301 | self.ppm.append(nn.Sequential( 302 | nn.AdaptiveAvgPool2d(scale), 303 | L.Conv2d(FEAT_DIM, DEC_DIM, kernel_size=1, bias=True), 304 | norm(DEC_DIM, self.batch_norm), 305 | nn.LeakyReLU() 306 | )) 307 | self.ppm = nn.ModuleList(self.ppm) 308 | 309 | self.conv_up1 = nn.Sequential( 310 | L.Conv2d(FEAT_DIM + len(pool_scales) * DEC_DIM, DEC_DIM, 311 | kernel_size=3, padding=1, bias=True), 312 | 313 | norm(DEC_DIM, self.batch_norm), 314 | nn.LeakyReLU(), 315 | L.Conv2d(DEC_DIM, DEC_DIM, kernel_size=3, padding=1), 316 | norm(DEC_DIM, self.batch_norm), 317 | nn.LeakyReLU() 318 | ) 319 | 320 | # if not self.memory_decoder: 321 | self.conv_up2 = nn.Sequential( 322 | L.Conv2d((FEAT_DIM//8) + DEC_DIM, DEC_DIM, 323 | kernel_size=3, padding=1, bias=True), 324 | norm(DEC_DIM, self.batch_norm), 325 | nn.LeakyReLU() 326 | ) 327 | if(self.batch_norm): 328 | d_up3 = 128 329 | else: 330 | d_up3 = 64 331 | self.conv_up3 = nn.Sequential( 332 | L.Conv2d(DEC_DIM + d_up3, 64, 333 | kernel_size=3, padding=1, bias=True), 334 | norm(64, self.batch_norm), 335 | nn.LeakyReLU() 336 | ) 337 | 338 | self.unpool = nn.MaxUnpool2d(2, stride=2) 339 | 340 | self.conv_up4 = nn.Sequential( 341 | nn.Conv2d(64 + 3 + 3 + 2, 32, 342 | kernel_size=3, padding=1, bias=True), 343 | nn.LeakyReLU(), 344 | nn.Conv2d(32, 16, 345 | kernel_size=3, padding=1, bias=True), 346 | 347 | nn.LeakyReLU(), 348 | nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True) 349 | ) 350 | 351 | def forward(self, conv_out, img, indices, two_chan_trimap, extract_feature=False, x=None): 352 | # if extract_feature: 353 | conv5 = conv_out[-1] 354 | 355 | input_size = conv5.size() 356 | ppm_out = [conv5] 357 | for pool_scale in self.ppm: 358 | ppm_out.append(nn.functional.interpolate( 359 | pool_scale(conv5), 360 | (input_size[2], input_size[3]), 361 | mode='bilinear', align_corners=False)) 362 | ppm_out = torch.cat(ppm_out, 1) 363 | x = self.conv_up1(ppm_out) 364 | # return x 365 | # else: 366 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 367 | 368 | x = torch.cat((x, conv_out[-4]), 1) 369 | 370 | x = self.conv_up2(x) 371 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 372 | 373 | x = torch.cat((x, conv_out[-5]), 1) 374 | x = self.conv_up3(x) 375 | 376 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 377 | x = torch.cat((x, conv_out[-6][:, :3], img), 1) 378 | x2 = torch.cat((x, two_chan_trimap), 1) 379 | 380 | hid = self.conv_up4[:-1](x2) 381 | output = self.conv_up4[-1:](hid) 382 | 383 | alpha = torch.clamp(output[:, 0][:, None], 0, 1) 384 | F = torch.sigmoid(output[:, 1:4]) 385 | B = torch.sigmoid(output[:, 4:7]) 386 | 387 | # FBA Fusion 388 | alpha, F, B = fba_fusion(alpha, img, F, B) 389 | 390 | output = torch.cat((alpha, F, B), 1) 391 | 392 | return hid, output, x 393 | 394 | 395 | class RefinementModule(nn.Module): 396 | def __init__(self, batch_norm=False): 397 | super(RefinementModule, self).__init__() 398 | self.batch_norm = batch_norm 399 | self.conv1 = nn.Sequential( 400 | L.Conv2d((64 + 3 + 3) + 2 + 1, 64, 401 | kernel_size=3, padding=1, bias=True), 402 | norm(64, self.batch_norm), 403 | nn.LeakyReLU() 404 | ) 405 | self.layer1 = resnet_GN_WS.BasicBlock(64, 64) 406 | self.layer2 = resnet_GN_WS.BasicBlock(64, 64) 407 | outdim = 10 408 | self.pred = nn.Sequential( 409 | nn.Conv2d(64, 32, 410 | kernel_size=3, padding=1, bias=True), 411 | nn.LeakyReLU(), 412 | nn.Conv2d(32, 16, 413 | kernel_size=3, padding=1, bias=True), 414 | nn.LeakyReLU(), 415 | nn.Conv2d(16, outdim, kernel_size=1, padding=0, bias=True) 416 | ) 417 | def forward(self, x, img, two_chan_trimap, pred_alpha): 418 | x = torch.cat((x, two_chan_trimap, pred_alpha), 1) 419 | x = self.conv1(x) 420 | x = self.layer1(x) 421 | x = self.layer2(x) 422 | x = self.pred[:-1](x) 423 | output = self.pred[-1](x) 424 | 425 | a = output[:, :7] 426 | alpha = torch.clamp(a[:, 0][:, None], 0, 1) 427 | F = torch.sigmoid(a[:, 1:4]) 428 | B = torch.sigmoid(a[:, 4:7]) 429 | # FBA Fusion 430 | alpha, F, B = fba_fusion(alpha, img, F, B) 431 | alpha = torch.cat((alpha, F, B), 1) 432 | 433 | trimap = output[:, -3:] 434 | 435 | return x, alpha, trimap 436 | -------------------------------------------------------------------------------- /models/alpha/FBA/resnet_GN_WS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import layers_WS as L 4 | 5 | __all__ = ['ResNet', 'l_resnet18', 'l_resnet34', 'l_resnet50'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | return L.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | def conv1x1(in_planes, out_planes, stride=1): 15 | """1x1 convolution""" 16 | return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = L.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = L.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | identity = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | identity = self.downsample(x) 44 | 45 | out += identity 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = conv1x1(inplanes, planes) 57 | self.bn1 = L.BatchNorm2d(planes) 58 | self.conv2 = conv3x3(planes, planes, stride) 59 | self.bn2 = L.BatchNorm2d(planes) 60 | self.conv3 = conv1x1(planes, planes * self.expansion) 61 | self.bn3 = L.BatchNorm2d(planes * self.expansion) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | identity = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv3(out) 78 | out = self.bn3(out) 79 | 80 | if self.downsample is not None: 81 | identity = self.downsample(x) 82 | 83 | out += identity 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class ResNet(nn.Module): 90 | 91 | def __init__(self, block, layers, num_classes=1000): 92 | super(ResNet, self).__init__() 93 | self.inplanes = 64 94 | self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 95 | bias=False) 96 | self.bn1 = L.BatchNorm2d(64) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) 99 | self.layer1 = self._make_layer(block, 64, layers[0]) 100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 102 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 103 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 104 | self.fc = nn.Linear(512 * block.expansion, num_classes) 105 | 106 | def _make_layer(self, block, planes, blocks, stride=1): 107 | downsample = None 108 | if stride != 1 or self.inplanes != planes * block.expansion: 109 | downsample = nn.Sequential( 110 | conv1x1(self.inplanes, planes * block.expansion, stride), 111 | L.BatchNorm2d(planes * block.expansion), 112 | ) 113 | 114 | layers = [] 115 | layers.append(block(self.inplanes, planes, stride, downsample)) 116 | self.inplanes = planes * block.expansion 117 | for _ in range(1, blocks): 118 | layers.append(block(self.inplanes, planes)) 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | x = self.conv1(x) 124 | x = self.bn1(x) 125 | x = self.relu(x) 126 | x = self.maxpool(x) 127 | 128 | x = self.layer1(x) 129 | x = self.layer2(x) 130 | x = self.layer3(x) 131 | x = self.layer4(x) 132 | 133 | x = self.avgpool(x) 134 | x = x.view(x.size(0), -1) 135 | x = self.fc(x) 136 | 137 | return x 138 | 139 | 140 | def load_NoPrefix(path, length): 141 | # load dataparallel wrapped model properly 142 | try: 143 | state_dict = torch.load(path, map_location='cpu') 144 | except: 145 | state_dict = torch.load(path, map_location='cpu')['state_dict'] 146 | # create new OrderedDict that does not contain `module.` 147 | from collections import OrderedDict 148 | new_state_dict = OrderedDict() 149 | for k, v in state_dict.items(): 150 | name = k[length:] # remove `Scale.` 151 | new_state_dict[name] = v 152 | return new_state_dict 153 | 154 | 155 | def my_load_state_dict(model, state_dict): 156 | # version 2: support tensor same name but size is different 157 | 158 | own_state = model.state_dict() 159 | for name, param in state_dict.items(): 160 | if name in own_state: 161 | if isinstance(param, nn.Parameter): 162 | # backwards compatibility for serialized parameters 163 | param = param.data 164 | try: 165 | own_state[name].copy_(param) 166 | except: 167 | print('While copying the parameter named {}, whose dimensions in the model are {} and whose dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size())) 168 | else: 169 | print('[Warning] Found key "{}" in file, but not in current model'.format(name)) 170 | 171 | missing = set(own_state.keys()) - set(state_dict.keys()) 172 | if len(missing) > 0: 173 | print('[Warning] Cant find keys "{}" in file'.format(missing)) 174 | 175 | 176 | def l_resnet18(**kwargs): 177 | """Constructs a ResNet-18 model. 178 | """ 179 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 180 | 181 | return model 182 | 183 | 184 | def l_resnet34(**kwargs): 185 | """Constructs a ResNet-34 model. 186 | """ 187 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 188 | 189 | return model 190 | 191 | 192 | def l_resnet50(**kwargs): 193 | """Constructs a ResNet-50 model. 194 | """ 195 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 196 | 197 | return model 198 | -------------------------------------------------------------------------------- /models/alpha/FBA/resnet_bn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from torch.nn import BatchNorm2d 4 | 5 | __all__ = ['ResNet'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn1 = BatchNorm2d(planes) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.conv2 = conv3x3(planes, planes) 23 | self.bn2 = BatchNorm2d(planes) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None): 50 | super(Bottleneck, self).__init__() 51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 52 | self.bn1 = BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 54 | padding=1, bias=False) 55 | self.bn2 = BatchNorm2d(planes, momentum=0.01) 56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 57 | self.bn3 = BatchNorm2d(planes * 4) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv3(out) 74 | out = self.bn3(out) 75 | 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class ResNet(nn.Module): 86 | 87 | def __init__(self, block, layers, num_classes=1000): 88 | self.inplanes = 128 89 | super(ResNet, self).__init__() 90 | self.conv1 = conv3x3(3, 64, stride=2) 91 | self.bn1 = BatchNorm2d(64) 92 | self.relu1 = nn.ReLU(inplace=True) 93 | self.conv2 = conv3x3(64, 64) 94 | self.bn2 = BatchNorm2d(64) 95 | self.relu2 = nn.ReLU(inplace=True) 96 | self.conv3 = conv3x3(64, 128) 97 | self.bn3 = BatchNorm2d(128) 98 | self.relu3 = nn.ReLU(inplace=True) 99 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) 100 | 101 | self.layer1 = self._make_layer(block, 64, layers[0]) 102 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 103 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 104 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 105 | self.avgpool = nn.AvgPool2d(7, stride=1) 106 | self.fc = nn.Linear(512 * block.expansion, num_classes) 107 | 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 111 | m.weight.data.normal_(0, math.sqrt(2. / n)) 112 | elif isinstance(m, BatchNorm2d): 113 | m.weight.data.fill_(1) 114 | m.bias.data.zero_() 115 | 116 | def _make_layer(self, block, planes, blocks, stride=1): 117 | downsample = None 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | downsample = nn.Sequential( 120 | nn.Conv2d(self.inplanes, planes * block.expansion, 121 | kernel_size=1, stride=stride, bias=False), 122 | BatchNorm2d(planes * block.expansion), 123 | ) 124 | 125 | layers = [] 126 | layers.append(block(self.inplanes, planes, stride, downsample)) 127 | self.inplanes = planes * block.expansion 128 | for i in range(1, blocks): 129 | layers.append(block(self.inplanes, planes)) 130 | 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | x = self.relu1(self.bn1(self.conv1(x))) 135 | x = self.relu2(self.bn2(self.conv2(x))) 136 | x = self.relu3(self.bn3(self.conv3(x))) 137 | x, indices = self.maxpool(x) 138 | 139 | x = self.layer1(x) 140 | x = self.layer2(x) 141 | x = self.layer3(x) 142 | x = self.layer4(x) 143 | 144 | x = self.avgpool(x) 145 | x = x.view(x.size(0), -1) 146 | x = self.fc(x) 147 | return x 148 | 149 | 150 | def l_resnet50(): 151 | """Constructs a ResNet-50 model. 152 | Args: 153 | pretrained (bool): If True, returns a model pre-trained on ImageNet 154 | """ 155 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 156 | return model 157 | -------------------------------------------------------------------------------- /models/alpha/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/models/alpha/__init__.py -------------------------------------------------------------------------------- /models/alpha/common.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def pad_divide_by(in_list, d, in_size, padval=0.): 7 | out_list = [] 8 | h, w = in_size 9 | if h % d > 0: 10 | new_h = h + d - h % d 11 | else: 12 | new_h = h 13 | if w % d > 0: 14 | new_w = w + d - w % d 15 | else: 16 | new_w = w 17 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) 18 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) 19 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 20 | if sum(pad_array)>0: 21 | for inp in in_list: 22 | out_list.append(F.pad(inp, pad_array, value=padval)) 23 | else: 24 | out_list = in_list 25 | if len(in_list) == 1: 26 | out_list = out_list[0] 27 | return out_list, pad_array 28 | -------------------------------------------------------------------------------- /models/trimap/STM.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import models 6 | from helpers import ToCuda, pad_divide_by 7 | import math 8 | 9 | class ResBlock(nn.Module): 10 | def __init__(self, indim, outdim=None, stride=1): 11 | super(ResBlock, self).__init__() 12 | if outdim == None: 13 | outdim = indim 14 | if indim == outdim and stride==1: 15 | self.downsample = None 16 | else: 17 | self.downsample = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride) 18 | 19 | self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride) 20 | self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1) 21 | 22 | 23 | def forward(self, x): 24 | r = self.conv1(F.relu(x)) 25 | r = self.conv2(F.relu(r)) 26 | 27 | if self.downsample is not None: 28 | x = self.downsample(x) 29 | 30 | return x + r 31 | 32 | class Encoder_M(nn.Module): 33 | def __init__(self, hdim=32): 34 | super(Encoder_M, self).__init__() 35 | self.hdim = hdim 36 | 37 | self.conv1_m = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 38 | self.conv1_o = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 39 | if self.hdim > 0: 40 | self.conv1_a = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 41 | self.conv1_h = nn.Conv2d(hdim, 64, kernel_size=7, stride=2, padding=3, bias=False) 42 | 43 | resnet = models.resnet50(pretrained=True) 44 | self.conv1 = resnet.conv1 45 | self.bn1 = resnet.bn1 46 | self.relu = resnet.relu # 1/2, 64 47 | self.maxpool = resnet.maxpool 48 | 49 | self.res2 = resnet.layer1 # 1/4, 256 50 | self.res3 = resnet.layer2 # 1/8, 512 51 | self.res4 = resnet.layer3 # 1/8, 1024 52 | 53 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) 54 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) 55 | 56 | def forward(self, in_f, in_m, in_o, in_a, in_h): 57 | f = (in_f - self.mean) / self.std 58 | m = torch.unsqueeze(in_m, dim=1).float() # add channel dim 59 | o = torch.unsqueeze(in_o, dim=1).float() # add channel dim 60 | if self.hdim > 0: 61 | a = torch.unsqueeze(in_a, dim=1).float() # add channel dim 62 | h = in_h.float() # add channel dim 63 | x = self.conv1_m(m) + self.conv1_o(o) + self.conv1_a(a) + self.conv1_h(h) 64 | else: 65 | x = self.conv1_m(m) + self.conv1_o(o) 66 | 67 | x = self.conv1(f) + x 68 | x = self.bn1(x) 69 | c1 = self.relu(x) # 1/2, 64 70 | x = self.maxpool(c1) # 1/4, 64 71 | r2 = self.res2(x) # 1/4, 256 72 | r3 = self.res3(r2) # 1/8, 512 73 | r4 = self.res4(r3) # 1/8, 1024 74 | return r4, r3, r2, c1, f 75 | 76 | class Encoder_Q(nn.Module): 77 | def __init__(self): 78 | super(Encoder_Q, self).__init__() 79 | resnet = models.resnet50(pretrained=True) 80 | self.conv1 = resnet.conv1 81 | self.bn1 = resnet.bn1 82 | self.relu = resnet.relu # 1/2, 64 83 | self.maxpool = resnet.maxpool 84 | 85 | self.res2 = resnet.layer1 # 1/4, 256 86 | self.res3 = resnet.layer2 # 1/8, 512 87 | self.res4 = resnet.layer3 # 1/8, 1024 88 | 89 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) 90 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) 91 | 92 | def forward(self, in_f): 93 | f = (in_f - self.mean) / self.std 94 | 95 | x = self.conv1(f) 96 | x = self.bn1(x) 97 | c1 = self.relu(x) # 1/2, 64 98 | x = self.maxpool(c1) # 1/4, 64 99 | r2 = self.res2(x) # 1/4, 256 100 | r3 = self.res3(r2) # 1/8, 512 101 | r4 = self.res4(r3) # 1/8, 1024 102 | return r4, r3, r2, c1, f 103 | 104 | 105 | class Refine(nn.Module): 106 | def __init__(self, inplanes, planes, scale_factor=2): 107 | super(Refine, self).__init__() 108 | self.convFS = nn.Conv2d(inplanes, planes, kernel_size=(3,3), padding=(1,1), stride=1) 109 | self.ResFS = ResBlock(planes, planes) 110 | self.ResMM = ResBlock(planes, planes) 111 | self.scale_factor = scale_factor 112 | 113 | def forward(self, f, pm): 114 | s = self.ResFS(self.convFS(f)) 115 | m = s + F.interpolate(pm, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) 116 | m = self.ResMM(m) 117 | return m 118 | 119 | class Decoder(nn.Module): 120 | def __init__(self, mdim): 121 | super(Decoder, self).__init__() 122 | self.convFM = nn.Conv2d(1024, mdim, kernel_size=(3,3), padding=(1,1), stride=1) 123 | self.ResMM = ResBlock(mdim, mdim) 124 | self.RF3 = Refine(512, mdim) # 1/8 -> 1/4 125 | self.RF2 = Refine(256, mdim) # 1/4 -> 1 126 | 127 | self.pred = nn.Conv2d(mdim, 3, kernel_size=(3,3), padding=(1,1), stride=1) 128 | 129 | def forward(self, r4, r3, r2, VOS_mode=False): 130 | m4 = self.ResMM(self.convFM(r4)) 131 | m3 = self.RF3(r3, m4) # out: 1/8, 256 132 | m2 = self.RF2(r2, m3) # out: 1/4, 256 133 | 134 | p2 = self.pred(F.relu(m2)) 135 | 136 | p = F.interpolate(p2, scale_factor=4, mode='bilinear', align_corners=False) 137 | return p 138 | 139 | 140 | class Memory(nn.Module): 141 | def __init__(self): 142 | super(Memory, self).__init__() 143 | 144 | def forward(self, m_in, m_out, q_in, q_out): # m_in: o,c,t,h,w 145 | B, D_e, T, H, W = m_in.size() 146 | _, D_o, _, _, _ = m_out.size() 147 | 148 | mi = m_in.view(B, D_e, T*H*W) 149 | mi = torch.transpose(mi, 1, 2) # b, THW, emb 150 | 151 | qi = q_in.view(B, D_e, H*W) # b, emb, HW 152 | 153 | p = torch.bmm(mi, qi) # b, THW, HW 154 | p = p / math.sqrt(D_e) 155 | p = F.softmax(p, dim=1) # b, THW, HW 156 | 157 | mo = m_out.view(B, D_o, T*H*W) 158 | mem = torch.bmm(mo, p) # Weighted-sum B, D_o, HW 159 | mem = mem.view(B, D_o, H, W) 160 | 161 | mem_out = torch.cat([mem, q_out], dim=1) 162 | 163 | return mem_out 164 | 165 | 166 | class KeyValue(nn.Module): 167 | # Not using location 168 | def __init__(self, indim, keydim, valdim): 169 | super(KeyValue, self).__init__() 170 | self.Key = nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1) 171 | self.Value = nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1) 172 | 173 | def forward(self, x): 174 | return self.Key(x), self.Value(x) 175 | 176 | 177 | 178 | 179 | class STM(nn.Module): 180 | def __init__(self, hdim=-1): 181 | super(STM, self).__init__() 182 | self.hdim = hdim 183 | 184 | self.Encoder_M = Encoder_M(hdim=self.hdim) 185 | self.Encoder_Q = Encoder_Q() 186 | 187 | self.KV_M_r4 = KeyValue(1024, keydim=128, valdim=512) 188 | self.KV_Q_r4 = KeyValue(1024, keydim=128, valdim=512) 189 | 190 | self.Memory = Memory() 191 | self.Decoder = Decoder(256) 192 | 193 | def Pad_memory(self, mems, num_objects): 194 | pad_mems = [] 195 | for mem in mems: 196 | batch_and_numobj, C, H, W = mem.shape 197 | batch_size = batch_and_numobj//num_objects 198 | pad_mems.append(mem.view(num_objects, batch_size, C, 1, H, W).transpose(1,0)) 199 | return pad_mems 200 | 201 | def memorize(self, frame, masks, num_objects): 202 | # memorize a frame 203 | num_objects = num_objects[0].item() 204 | 205 | (frame, masks), pad = pad_divide_by([frame, masks], 16, (frame.size()[2], frame.size()[3])) 206 | 207 | # make batch arg list 208 | B_list = {'f':[], 'm':[], 'o':[], 'a':[], 'h':[]} 209 | for o in range(1, num_objects+1): # 1 - no 210 | B_list['f'].append(frame) 211 | B_list['m'].append(masks[:,1]) # Unkown region 212 | B_list['o'].append(masks[:,2]) # Foreground region 213 | if self.hdim > 0: 214 | B_list['a'].append(masks[:,3]) # Alpha matte 215 | B_list['h'].append(masks[:,4:]) # hidden layer 216 | 217 | # make Batch 218 | B_ = {} 219 | B_['a'] = None 220 | B_['h'] = None 221 | for arg in B_list.keys(): 222 | if len(B_list[arg]) > 0: 223 | B_[arg] = torch.cat(B_list[arg], dim=0) 224 | 225 | r4, _, _, _, _ = self.Encoder_M(B_['f'], B_['m'], B_['o'], B_['a'], B_['h']) 226 | k4, v4 = self.KV_M_r4(r4) # num_objects, 128 and 512, H/16, W/16 227 | k4, v4 = self.Pad_memory([k4, v4], num_objects=num_objects) 228 | return k4, v4 229 | 230 | def Soft_aggregation(self, ps, K): 231 | num_objects, H, W = ps.shape 232 | em = ToCuda(torch.zeros(1, K, H, W)) 233 | em[0,0] = torch.prod(1-ps, dim=0) # bg prob 234 | em[0,1:num_objects+1] = ps # obj prob 235 | em = torch.clamp(em, 1e-7, 1-1e-7) 236 | logit = torch.log((em /(1-em))) 237 | return logit 238 | 239 | def segment(self, frame, keys, values, num_objects): 240 | # pad 241 | [frame], pad = pad_divide_by([frame], 16, (frame.size()[2], frame.size()[3])) 242 | 243 | r4, r3, r2, _, _ = self.Encoder_Q(frame) 244 | k4, v4 = self.KV_Q_r4(r4) # 1, dim, H/16, W/16 245 | 246 | # memory select kv:(1, K, C, T, H, W) 247 | m4 = self.Memory(keys.squeeze(1), values.squeeze(1), k4, v4) 248 | logits = self.Decoder(m4, r3, r2) 249 | 250 | logit = logits 251 | 252 | if pad[2]+pad[3] > 0: 253 | logit = logit[:,:,pad[2]:-pad[3],:] 254 | if pad[0]+pad[1] > 0: 255 | logit = logit[:,:,:,pad[0]:-pad[1]] 256 | 257 | return logit 258 | 259 | def forward(self, *args, **kwargs): 260 | if args[1].dim() > 4: # keys 261 | return self.segment(*args, **kwargs) 262 | else: 263 | return self.memorize(*args, **kwargs) 264 | 265 | -------------------------------------------------------------------------------- /models/trimap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/models/trimap/__init__.py -------------------------------------------------------------------------------- /models/trimap/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # general libs 7 | import sys 8 | 9 | sys.path.insert(0, '../') 10 | from helpers import * 11 | 12 | from .STM import STM 13 | 14 | 15 | class FullModel(nn.Module): 16 | def __init__(self, dilate_kernel=None, eps=0, ignore_label=255, 17 | stage=1, 18 | hdim=-1,): 19 | super(FullModel, self).__init__() 20 | self.DILATION_KERNEL = dilate_kernel 21 | self.EPS = eps 22 | self.IMG_SCALE = 1./255 23 | self.register_buffer('IMG_MEAN', torch.tensor([0.485, 0.456, 0.406]).reshape([1, 1, 3, 1, 1]).float()) 24 | self.register_buffer('IMG_STD', torch.tensor([0.229, 0.224, 0.225]).reshape([1, 1, 3, 1, 1]).float()) 25 | 26 | self.stage = stage 27 | self.hdim = hdim if self.stage > 2 else -1 28 | self.memory_update = False 29 | 30 | self.model = STM(hdim=self.hdim) 31 | 32 | self.num_object = 1 33 | 34 | self.ignore_label = ignore_label 35 | self.LOSS = nn.CrossEntropyLoss(weight=torch.tensor([1, 1, 1]).float(), ignore_index=ignore_label) 36 | 37 | def make_trimap(self, alpha, ignore_region): 38 | b = alpha.shape[0] 39 | alpha = torch.where(alpha < self.EPS, torch.zeros_like(alpha), alpha) 40 | alpha = torch.where(alpha > 1 - self.EPS, torch.ones_like(alpha), alpha) 41 | trimasks = ((alpha > 0) & (alpha < 1.)).float().split(1) 42 | trimaps = [None] * b 43 | for i in range(b): 44 | # trimap width: 1 - 51 45 | kernel_rad = int(torch.randint(0, 26, size=())) \ 46 | if self.DILATION_KERNEL is None else self.DILATION_KERNEL 47 | trimaps[i] = F.max_pool2d(trimasks[i].squeeze(0), kernel_size=kernel_rad*2+1, stride=1, padding=kernel_rad) 48 | trimap = torch.stack(trimaps) 49 | # 0: bg, 1: un, 2: fg 50 | trimap1 = torch.where(trimap > 0.5, torch.ones_like(alpha), 2 * alpha).long() 51 | if ignore_region is not None: 52 | trimap1[ignore_region] = 0 53 | trimap3 = F.one_hot(trimap1.squeeze(2), num_classes=3).permute(0, 1, 4, 2, 3) 54 | return trimap3.float() 55 | 56 | def preprocess(self, a, fg, bg, ignore_region=None, tri=None): 57 | # Data preprocess 58 | with torch.no_grad(): 59 | scaled_gts = a 60 | scaled_fgs = fg.flip([2]) * self.IMG_SCALE 61 | if bg is None: 62 | scaled_bgs = scaled_fgs 63 | scaled_imgs = scaled_fgs 64 | else: 65 | scaled_bgs = bg.flip([2]) * self.IMG_SCALE 66 | scaled_imgs = scaled_fgs * scaled_gts + scaled_bgs * (1. - scaled_gts) 67 | 68 | if tri is None: 69 | scaled_tris = self.make_trimap(scaled_gts, ignore_region) 70 | else: 71 | scaled_tris = tri 72 | imgs = scaled_imgs 73 | return scaled_imgs, scaled_fgs, scaled_bgs, scaled_gts, scaled_tris, imgs 74 | 75 | def _forward(self, imgs, tris, alpha, masks=None, og_shape=None): 76 | if self.stage == 1: 77 | batch_size, sample_length = imgs.shape[:2] 78 | num_object = torch.tensor([self.num_object]).to(torch.cuda.current_device()) 79 | GT = tris.split(1, dim=0) # [1, S, C, H, W] 80 | FG = imgs.split(1, dim=0) # [1, S, C, H, W] 81 | 82 | if masks is not None: 83 | M = masks.squeeze(2).split(1, dim=1) 84 | E = [] 85 | E_logits = [] 86 | # we split batch here since the original code only supports b=1 87 | for b in range(batch_size): 88 | Fs = FG[b].split(1, dim=1) # [1, 1, C, H, W] 89 | GTs = GT[b].split(1, dim=1) # [1, 1, C, H, W] 90 | Es = [GTs[0].squeeze(1)] + [None] * (sample_length - 1) # [1, C, H, W] 91 | ELs = [] 92 | for t in range(1, sample_length): 93 | input_Es = Es[t-1] 94 | # memorize 95 | prev_key, prev_value = self.model(Fs[t-1].squeeze(1), input_Es, num_object) 96 | 97 | if t-1 == 0: # 98 | this_keys, this_values = prev_key, prev_value # only prev memory 99 | else: 100 | this_keys = torch.cat([keys, prev_key], dim=3) 101 | this_values = torch.cat([values, prev_value], dim=3) 102 | 103 | # segment 104 | logit = self.model(Fs[t].squeeze(1), this_keys, this_values, num_object) 105 | ELs.append(logit) 106 | Es[t] = F.softmax(logit, dim=1) 107 | 108 | # update 109 | keys, values = this_keys, this_values 110 | E.append(torch.cat(Es, dim=0)) # cat t 111 | E_logits.append(torch.cat(ELs, dim=0)) 112 | 113 | pred = torch.stack(E, dim=0) # stack b 114 | E_logits = [None] + list(torch.stack(E_logits).split(1, dim=1)) 115 | GT = torch.argmax(tris, dim=2) 116 | # Loss & Vis 117 | losses = [] 118 | for t in range(1, sample_length): 119 | gt = GT[:,t].squeeze(1) 120 | p = E_logits[t].squeeze(1) 121 | if og_shape is not None: 122 | for b in range(batch_size): 123 | h, w = og_shape[b] 124 | gt[b, h:] = self.ignore_label 125 | gt[b, :, w:] = self.ignore_label 126 | if masks is not None: 127 | mask = M[t].squeeze(1) 128 | gt = torch.where(mask == 0, torch.ones_like(gt) * self.ignore_label, gt) 129 | losses.append(self.LOSS(p, gt)) 130 | loss = sum(losses) / float(len(losses)) 131 | return pred, loss 132 | 133 | def _forward_single_step(self, img_q, img, tri, alpha, hid, memories=None): 134 | num_object = torch.tensor([self.num_object]).to(torch.cuda.current_device()) 135 | # we split batch here since the original code only supports b=1 136 | if self.hdim > 0: 137 | Es = torch.cat([tri, alpha, hid], dim=1) 138 | else: 139 | Es = tri 140 | # memorize 141 | prev_key, prev_value = self.model(img, Es, num_object) 142 | 143 | # update 144 | if memories is None: 145 | memories = dict() 146 | memories['key'] = prev_key 147 | memories['val'] = prev_value 148 | else: 149 | memories['key'] = torch.cat([memories['key'], prev_key], dim=3) 150 | memories['val'] = torch.cat([memories['val'], prev_value], dim=3) 151 | 152 | # segment 153 | logit = self.model(img_q, memories['key'], memories['val'], num_object) 154 | return logit, memories 155 | 156 | def forward(self, a, fg, bg, ignore_region=None, tri=None, og_shape=None, 157 | single_step=False, hid=None, memories=None): 158 | if single_step: 159 | # fg: query frame (normalized between 0~1) [B, 3, H, W] 160 | # bg: prev frame (normalized between 0~1) [B, 3, H, W] 161 | # tri: prev trimap (normalized between 0~1) [B, 3, H, W] 162 | # a: prev alpha (normalized between 0~1) [B, 1, H, W] 163 | logit, memories = self._forward_single_step(fg, bg, tri, a, hid, memories=memories) 164 | return logit, memories 165 | else: 166 | scaled_imgs, _, _, scaled_gts, tris, imgs = self.preprocess(a, fg, bg, ignore_region=ignore_region, tri=tri) 167 | 168 | pred, loss = self._forward(imgs, tris, scaled_gts, og_shape=og_shape) 169 | 170 | return [loss, scaled_imgs, pred, tris, scaled_gts] 171 | 172 | 173 | class FullModel_eval(FullModel): 174 | def _forward(self, imgs, tris, first_frame=False, masks=None, og_shape=None, save_memory=False, max_memory_num=2, memorize_gt=False): 175 | if self.stage == 1: 176 | num_object = torch.tensor([self.num_object]).to(torch.cuda.current_device()) 177 | 178 | Fs = imgs 179 | 180 | if first_frame: 181 | Es = tris 182 | pred = Es 183 | else: 184 | logit = self.model(Fs, self.this_keys, self.this_values, num_object, memory_update=self.memory_update,) 185 | Es = F.softmax(logit, dim=1) 186 | pred = Es 187 | 188 | if save_memory and memorize_gt: 189 | Es = tris 190 | pred = tris 191 | prev_key, prev_value = self.model(Fs, Es, num_object) 192 | 193 | if max_memory_num == 0: 194 | if first_frame: 195 | self.this_keys = prev_key 196 | self.this_values = prev_value 197 | elif max_memory_num == 1: 198 | self.this_keys = prev_key 199 | self.this_values = prev_value 200 | else: 201 | if first_frame: 202 | self.this_keys = prev_key 203 | self.this_values = prev_value 204 | elif save_memory: 205 | self.this_keys = torch.cat([self.this_keys, prev_key], dim=3) 206 | self.this_values = torch.cat([self.this_values, prev_value], dim=3) 207 | else: 208 | if self.this_keys.size(3) == 1: 209 | self.this_keys = torch.cat([self.this_keys, prev_key], dim=3) 210 | self.this_values = torch.cat([self.this_values, prev_value], dim=3) 211 | else: 212 | self.this_keys = torch.cat([self.this_keys[:,:,:,:-1], prev_key], dim=3) 213 | self.this_values = torch.cat([self.this_values[:,:,:,:-1], prev_value], dim=3) 214 | 215 | if self.this_keys.size(3) > max_memory_num: 216 | if memorize_gt: 217 | self.this_keys = self.this_keys[:,:,:,1:] 218 | self.this_values = self.this_values[:,:,:,1:] 219 | else: 220 | self.this_keys = torch.cat([self.this_keys[:,:,:,:1], self.this_keys[:,:,:,2:]], dim=3) 221 | self.this_values = torch.cat([self.this_values[:,:,:,:1], self.this_values[:,:,:,2:]], dim=3) 222 | 223 | self.memory_update = save_memory 224 | 225 | return pred.unsqueeze(1), 0 226 | 227 | def _forward_memorize(self, img, tri, alpha, hid): 228 | num_object = torch.tensor([self.num_object]).to(torch.cuda.current_device()) 229 | # we split batch here since the original code only supports b=1 230 | if self.hdim > 0: 231 | Es = torch.cat([tri, alpha, hid], dim=1) 232 | else: 233 | Es = tri 234 | # memorize 235 | prev_key, prev_value = self.model(img, Es, num_object) 236 | memories = {'key': prev_key, 237 | 'val': prev_value, 238 | } 239 | return memories 240 | 241 | def _forward_segment(self, img_q, memories=None, memory_update=False): 242 | num_object = torch.tensor([self.num_object]).to(torch.cuda.current_device()) 243 | # segment 244 | logit = self.model(img_q, memories['key'], memories['val'], num_object) 245 | return logit 246 | 247 | def forward(self, a, fg, bg, tri=None, first_frame=False, og_shape=None, 248 | memorize=False, segment=False, memories=None, hid=None, 249 | save_memory=False, max_memory_num=2, memory_update=False, 250 | memorize_gt=False,): 251 | if memorize: 252 | # fg: query frame (normalized between 0~1) [B, 3, H, W] 253 | # bg: prev frame (normalized between 0~1) [B, 3, H, W] 254 | # tri: prev trimap (normalized between 0~1) [B, 3, H, W] 255 | # a: prev alpha (normalized between 0~1) [B, 1, H, W] 256 | memories = self._forward_memorize(bg, tri, a, hid) 257 | return memories 258 | elif segment: 259 | # fg: query frame (normalized between 0~1) [B, 3, H, W] 260 | # bg: prev frame (normalized between 0~1) [B, 3, H, W] 261 | # tri: prev trimap (normalized between 0~1) [B, 3, H, W] 262 | # a: prev alpha (normalized between 0~1) [B, 1, H, W] 263 | logit = self._forward_segment(fg, memories=memories) 264 | return logit 265 | else: 266 | scaled_imgs, _, _, scaled_gts, tris, imgs = self.preprocess(a, fg, bg) 267 | if tri is not None: 268 | tris = tri 269 | imgs_fw_HR = imgs.squeeze(0) 270 | tris_fw = tris.squeeze(0) 271 | _, _, H, W = imgs_fw_HR.shape 272 | 273 | imgs_fw = imgs_fw_HR 274 | 275 | pred, loss = self._forward(imgs_fw, tris_fw, first_frame=first_frame, og_shape=og_shape, save_memory=save_memory, max_memory_num=max_memory_num, memorize_gt=memorize_gt,) 276 | 277 | if first_frame: 278 | pred = tris 279 | 280 | return [loss, 281 | scaled_imgs, pred, tris, scaled_gts] 282 | -------------------------------------------------------------------------------- /scripts/eval_s4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPU=$1 4 | 5 | python eval.py --gpu $GPU -------------------------------------------------------------------------------- /scripts/eval_s4_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPU=$1 4 | 5 | python eval.py --gpu $GPU --demo -------------------------------------------------------------------------------- /scripts/train_s1_alpha.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUS=$1 3 | GPUS_ARRAY=($(echo $GPUS | tr ',' "\n")) 4 | NUMBER_OF_CUDA_DEVICES=${#GPUS_ARRAY[@]} 5 | if [ $NUMBER_OF_CUDA_DEVICES -gt 1 ]; then 6 | echo "Training with multiple GPUs: $GPUS" 7 | PY_CMD="-m torch.distributed.launch --nproc_per_node=$NUMBER_OF_CUDA_DEVICES --master_port $((RANDOM + 66000))" 8 | else 9 | echo "Training with a single GPU: $GPUS" 10 | PY_CMD="" 11 | fi 12 | 13 | python $PY_CMD train.py --stage 1 --gpu $GPUS -------------------------------------------------------------------------------- /scripts/train_s1_trimap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUS=$1 3 | 4 | python train_s1_trimap.py --gpu $GPUS 5 | -------------------------------------------------------------------------------- /scripts/train_s2_alpha.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUS=$1 3 | GPUS_ARRAY=($(echo $GPUS | tr ',' "\n")) 4 | NUMBER_OF_CUDA_DEVICES=${#GPUS_ARRAY[@]} 5 | if [ $NUMBER_OF_CUDA_DEVICES -gt 1 ]; then 6 | echo "Training with multiple GPUs: $GPUS" 7 | PY_CMD="-m torch.distributed.launch --nproc_per_node=$NUMBER_OF_CUDA_DEVICES --master_port $((RANDOM + 66000))" 8 | else 9 | echo "Training with a single GPU: $GPUS" 10 | PY_CMD="" 11 | fi 12 | 13 | python $PY_CMD train.py --stage 2 --gpu $GPUS 14 | -------------------------------------------------------------------------------- /scripts/train_s3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUS=$1 3 | GPUS_ARRAY=($(echo $GPUS | tr ',' "\n")) 4 | NUMBER_OF_CUDA_DEVICES=${#GPUS_ARRAY[@]} 5 | if [ $NUMBER_OF_CUDA_DEVICES -gt 1 ]; then 6 | echo "Training with multiple GPUs: $GPUS" 7 | PY_CMD="-m torch.distributed.launch --nproc_per_node=$NUMBER_OF_CUDA_DEVICES --master_port $((RANDOM + 66000))" 8 | else 9 | echo "Training with a single GPU: $GPUS" 10 | PY_CMD="" 11 | fi 12 | 13 | python $PY_CMD train.py --stage 3 --gpu $GPUS 14 | -------------------------------------------------------------------------------- /scripts/train_s4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUS=$1 3 | GPUS_ARRAY=($(echo $GPUS | tr ',' "\n")) 4 | NUMBER_OF_CUDA_DEVICES=${#GPUS_ARRAY[@]} 5 | if [ $NUMBER_OF_CUDA_DEVICES -gt 1 ]; then 6 | echo "Training with multiple GPUs: $GPUS" 7 | PY_CMD="-m torch.distributed.launch --nproc_per_node=$NUMBER_OF_CUDA_DEVICES --master_port $((RANDOM + 66000))" 8 | else 9 | echo "Training with a single GPU: $GPUS" 10 | PY_CMD="" 11 | fi 12 | 13 | python $PY_CMD train.py --stage 4 --gpu $GPUS 14 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import UnsupportedOperation 3 | import logging 4 | import os 5 | import shutil 6 | import time 7 | import timeit 8 | import shutil 9 | 10 | import numpy as np 11 | import cv2 as cv 12 | import torch 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as torch_dist 15 | import torch.nn.functional as F 16 | from torch import nn 17 | from torch.utils import data 18 | from torchvision.utils import save_image 19 | 20 | from config import get_cfg_defaults 21 | from dataset import DIM_Train, VideoMatting108_Train 22 | from helpers import * 23 | from utils.optimizer import RAdam 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Train network') 27 | parser.add_argument("--stage", type=int, default=1) 28 | parser.add_argument("--gpu", type=str, default='0,1,2,3') 29 | parser.add_argument("--local_rank", type=int, default=-1) 30 | 31 | args = parser.parse_args() 32 | 33 | cfg = get_cfg_defaults() 34 | cfg.TRAIN.STAGE = args.stage 35 | cfg.freeze() 36 | 37 | return args, cfg 38 | 39 | def main(args, cfg): 40 | MODEL = get_model_name(cfg) 41 | random_seed = cfg.SYSTEM.RANDOM_SEED 42 | base_lr = cfg.TRAIN.BASE_LR 43 | 44 | weight_decay = cfg.TRAIN.WEIGHT_DECAY 45 | output_dir = os.path.join(cfg.SYSTEM.OUTDIR, 'checkpoint') 46 | if args.local_rank <= 0: 47 | os.makedirs(output_dir, exist_ok=True) 48 | start = timeit.default_timer() 49 | # cudnn related setting 50 | cudnn.benchmark = cfg.SYSTEM.CUDNN_BENCHMARK 51 | cudnn.deterministic = cfg.SYSTEM.CUDNN_DETERMINISTIC 52 | cudnn.enabled = cfg.SYSTEM.CUDNN_ENABLED 53 | if random_seed > 0: 54 | import random 55 | if args.local_rank <= 0: 56 | print('Seeding with', random_seed) 57 | random.seed(random_seed+args.local_rank) 58 | torch.manual_seed(random_seed+args.local_rank) 59 | 60 | args.world_size = 1 61 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 62 | if args.local_rank >= 0: 63 | device = torch.device('cuda:{}'.format(args.local_rank)) 64 | torch.cuda.set_device(device) 65 | torch.distributed.init_process_group( 66 | backend="nccl", init_method="env://", 67 | ) 68 | args.world_size = torch.distributed.get_world_size() 69 | else: 70 | if torch.cuda.is_available(): 71 | print('using Cuda devices, num:', torch.cuda.device_count()) 72 | 73 | if args.local_rank <= 0: 74 | logger, final_output_dir = create_logger(output_dir, MODEL, 'train') 75 | print(cfg) 76 | with open(os.path.join(final_output_dir, 'config.yaml'), 'w') as f: 77 | f.write(str(cfg)) 78 | image_outdir = os.path.join(final_output_dir, 'training_images') 79 | os.makedirs(os.path.join(final_output_dir, 'training_images'), exist_ok=True) 80 | else: 81 | image_outdir = None 82 | 83 | if cfg.TRAIN.STAGE == 1: 84 | model_trimap = None 85 | else: 86 | model_trimap = get_model_trimap(cfg, mode='Train') 87 | model = get_model_alpha(cfg, model_trimap, mode='Train') 88 | 89 | 90 | if cfg.TRAIN.STAGE == 1: 91 | load_ckpt = './weights/FBA.pth' 92 | dct = torch.load(load_ckpt, map_location=torch.device('cpu')) 93 | if 'state_dict' in dct.keys(): 94 | dct = dct['state_dict'] 95 | missing_keys, unexpected_keys = model.NET.load_state_dict(dct, strict=False) 96 | if args.local_rank <= 0: 97 | logger.info('Missing keys: ' + str(sorted(missing_keys))) 98 | logger.info('Unexpected keys: ' + str(sorted(unexpected_keys))) 99 | logger.info("=> loaded checkpoint from Image Matting Weight: {}".format(load_ckpt)) 100 | elif cfg.TRAIN.STAGE in [2,3]: 101 | load_ckpt = './weights/s1_OTVM_trimap.pth' 102 | dct = torch.load(load_ckpt, map_location=torch.device('cpu')) 103 | missing_keys, unexpected_keys = model.trimap.model.load_state_dict(dct, strict=False) 104 | if args.local_rank <= 0: 105 | logger.info('Missing keys: ' + str(sorted(missing_keys))) 106 | logger.info('Unexpected keys: ' + str(sorted(unexpected_keys))) 107 | logger.info("=> loaded checkpoint from Pretrained STM Weight: {}".format(load_ckpt)) 108 | 109 | if cfg.TRAIN.STAGE == 2: 110 | load_ckpt = './weights/s1_OTVM_alpha.pth' 111 | elif cfg.TRAIN.STAGE == 3: 112 | load_ckpt = './weights/s2_OTVM_alpha.pth' 113 | dct = torch.load(load_ckpt, map_location=torch.device('cpu')) 114 | missing_keys, unexpected_keys = model.NET.load_state_dict(dct, strict=False) 115 | if args.local_rank <= 0: 116 | logger.info('Missing keys: ' + str(sorted(missing_keys))) 117 | logger.info('Unexpected keys: ' + str(sorted(unexpected_keys))) 118 | elif cfg.TRAIN.STAGE == 4: 119 | load_ckpt = './weights/s3_OTVM.pth' 120 | dct = torch.load(load_ckpt, map_location=torch.device('cpu')) 121 | model.load_state_dict(dct) 122 | 123 | torch_barrier() 124 | 125 | ADDITIONAL_INPUTS = dict() 126 | 127 | start_epoch = 0 128 | 129 | if args.local_rank >= 0: 130 | # FBA particularly uses batch_size == 1, thus no syncbn here 131 | if (not cfg.ALPHA.MODEL.endswith('fba')) and (not cfg.TRAIN.FREEZE_BN): 132 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 133 | model = model.to(device) 134 | find_unused_parameters = False 135 | if cfg.TRAIN.STAGE == 2: 136 | find_unused_parameters = True 137 | model = torch.nn.parallel.DistributedDataParallel( 138 | model, 139 | find_unused_parameters=find_unused_parameters, 140 | device_ids=[args.local_rank], 141 | output_device=args.local_rank, 142 | ) 143 | else: 144 | model = torch.nn.DataParallel(model).cuda() 145 | 146 | if cfg.TRAIN.STAGE in [2,3]: 147 | params = list() 148 | for k, v in model.named_parameters(): 149 | if v.requires_grad: 150 | _k = k[7:] # remove 'module.' 151 | if _k.startswith('NET.'): 152 | if cfg.TRAIN.STAGE == 3: 153 | if args.local_rank <= 0: 154 | logging.info('do NOT train parameter: %s'%(k)) 155 | pass 156 | else: 157 | params.append({'params': v, 'lr': base_lr}) 158 | elif _k.startswith('trimap.'): 159 | if cfg.TRAIN.STAGE == 2: 160 | if args.local_rank <= 0: 161 | logging.info('do NOT train parameter: %s'%(k)) 162 | pass 163 | else: 164 | params.append({'params': v, 'lr': base_lr}) 165 | else: 166 | if args.local_rank <= 0: 167 | logging.info('%s: Undefined parameter'%(k)) 168 | params.append({'params': v, 'lr': base_lr}) 169 | else: 170 | params_dict = {k: v for k, v in model.named_parameters() if v.requires_grad} 171 | params = [{'params': list(params_dict.values()), 'lr': base_lr}] 172 | 173 | params_count = 0 174 | if args.local_rank <= 0: 175 | logging.info('=> Parameters needs to be optimized:') 176 | for param in params: 177 | _param = param['params'] 178 | if type(_param) is list: 179 | for _p in _param: 180 | params_count += _p.shape.numel() 181 | else: 182 | params_count += _param.shape.numel() 183 | logging.info('=> Total Parameters: {}'.format(params_count)) 184 | 185 | 186 | if cfg.TRAIN.OPTIMIZER == 'adam': 187 | optimizer = torch.optim.Adam(params, lr=base_lr) 188 | elif cfg.TRAIN.OPTIMIZER == 'radam': 189 | optimizer = RAdam(params, lr=base_lr, weight_decay=weight_decay) 190 | 191 | if cfg.TRAIN.LR_STRATEGY == 'stair': 192 | adjust_lr = stair_lr 193 | elif cfg.TRAIN.LR_STRATEGY == 'poly': 194 | adjust_lr = poly_lr 195 | elif cfg.TRAIN.LR_STRATEGY == 'const': 196 | adjust_lr = const_lr 197 | else: 198 | raise NotImplementedError('[%s] is not supported in cfg.TRAIN.LR_STRATEGY'%(cfg.TRAIN.LR_STRATEGY)) 199 | 200 | total_epochs = cfg.TRAIN.TOTAL_EPOCHS 201 | 202 | sample_length = cfg.TRAIN.FRAME_NUM 203 | if cfg.TRAIN.STAGE == 1: 204 | sample_length = 1 205 | if cfg.TRAIN.STAGE in [1,2,3]: 206 | train_dataset = DIM_Train( 207 | data_root=cfg.DATASET.PATH, 208 | image_shape=cfg.TRAIN.TRAIN_INPUT_SIZE, 209 | mode='train', 210 | sample_length=sample_length, 211 | ) 212 | else: 213 | train_dataset = VideoMatting108_Train( 214 | data_root=cfg.DATASET.PATH, 215 | image_shape=cfg.TRAIN.TRAIN_INPUT_SIZE, 216 | mode='train', 217 | sample_length=sample_length, 218 | max_skip=15, 219 | do_affine=0.5, 220 | do_time_flip=0.5, 221 | ) 222 | 223 | if cfg.SYSTEM.TESTMODE: 224 | start_epoch = max(start_epoch, total_epochs - 1) 225 | for epoch in range(start_epoch, total_epochs): 226 | train(epoch, cfg, args, train_dataset, base_lr, start_epoch, total_epochs, 227 | optimizer, model, adjust_lr, image_outdir, MODEL, 228 | ADDITIONAL_INPUTS) 229 | if args.local_rank <= 0: 230 | if (((epoch+1) % cfg.TRAIN.SAVE_EVERY_EPOCH) == 0) or ((epoch+1) == total_epochs): 231 | weight_fn = os.path.join(final_output_dir, 'checkpoint_{}.pth'.format(epoch+1)) 232 | logger.info('=> saving checkpoint to {}'.format(weight_fn)) 233 | if cfg.TRAIN.STAGE in [1,2]: 234 | torch.save(model.module.NET.state_dict(), weight_fn) 235 | else: 236 | torch.save(model.module.state_dict(), weight_fn) 237 | optim_fn = os.path.join(final_output_dir, 'optim_{}.pth'.format(epoch+1)) 238 | torch.save(optimizer.state_dict(), optim_fn) 239 | 240 | if args.local_rank <= 0: 241 | weight_fn = os.path.join('weights', '{:s}.pth'.format(MODEL)) 242 | logger.info('=> saving checkpoint to {}'.format(weight_fn)) 243 | if cfg.TRAIN.STAGE in [1,2]: 244 | torch.save(model.module.NET.state_dict(), weight_fn) 245 | else: 246 | torch.save(model.module.state_dict(), weight_fn) 247 | 248 | end = timeit.default_timer() 249 | if args.local_rank <= 0: 250 | logger.info('Time: %d sec.' % np.int32((end-start))) 251 | logger.info('Done') 252 | 253 | 254 | 255 | def write_image(outdir, out, step, max_batch=1, trimap=False): 256 | with torch.no_grad(): 257 | scaled_imgs, scaled_tris, alphas, comps, gts, fgs, bgs = out[:7] 258 | if trimap: 259 | pred_tris = out[7] 260 | b, s, _, h, w = scaled_imgs.shape 261 | b = max_batch if b > max_batch else b 262 | img_list = list() 263 | img_list.append(scaled_imgs[:max_batch].reshape(b*s, 3, h, w)) 264 | img_list.append(scaled_tris[:max_batch].reshape(b*s, 1, h, w).expand(-1, 3, -1, -1)) 265 | img_list.append(gts[:max_batch].reshape(b*s, 1, h, w).expand(-1, 3, -1, -1)) 266 | img_list.append(alphas[:max_batch].reshape(b*s, 1, h, w).expand(-1, 3, -1, -1)) 267 | if trimap: 268 | img_list.append(pred_tris[:max_batch].reshape(b*s, 3, h, w)) 269 | img_list.append(comps[:max_batch].reshape(b*s, 3, h, w)) 270 | img_list.append(fgs[:max_batch].reshape(b*s, 3, h, w)) 271 | img_list.append(bgs[:max_batch].reshape(b*s, 3, h, w)) 272 | imgs = torch.cat(img_list, dim=0).reshape(-1, 3, h, w) 273 | if h > 320: 274 | imgs = F.interpolate(imgs, scale_factor=320/h) 275 | save_image(imgs, os.path.join(outdir, '{}.png'.format(step)), nrow=int(s*b)) 276 | 277 | def train(epoch, cfg, args, train_dataset, base_lr, start_epoch, total_epochs, 278 | optimizer, model, adjust_learning_rate, image_outdir, MODEL, 279 | ADDITIONAL_INPUTS): 280 | # Training 281 | torch.cuda.empty_cache() 282 | if cfg.TRAIN.STAGE in [1,2,3]: 283 | train_dataset_concat = [train_dataset] * 20 284 | else: 285 | if epoch < 100: 286 | SKIP = min(1+(epoch//5), 25) 287 | else: 288 | SKIP = max(44-(epoch//5), 10) 289 | train_dataset.max_skip = SKIP 290 | train_dataset_concat = [train_dataset] * 20 291 | 292 | train_dataset = data.ConcatDataset(train_dataset_concat) 293 | train_sampler = get_sampler(train_dataset) 294 | trainloader = torch.utils.data.DataLoader( 295 | train_dataset, 296 | batch_size=int(cfg.TRAIN.BATCH_SIZE // args.world_size), 297 | num_workers=cfg.SYSTEM.NUM_WORKERS, 298 | pin_memory=True, 299 | drop_last=True, 300 | shuffle=True if train_sampler is None else False, 301 | sampler=train_sampler) 302 | 303 | if args.local_rank >= 0: 304 | train_sampler.set_epoch(epoch) 305 | 306 | iters_per_epoch = len(trainloader) 307 | image_freq = cfg.TRAIN.IMAGE_FREQ if cfg.TRAIN.IMAGE_FREQ > 0 else 1e+8 308 | image_freq = min(image_freq, iters_per_epoch) 309 | 310 | # STM DISABLES BN DURING TRAINING 311 | model.train() 312 | if cfg.TRAIN.STAGE > 1: 313 | for m in model.module.trimap.modules(): 314 | if isinstance(m, nn.BatchNorm2d): 315 | m.eval() # turn-off BN 316 | if cfg.TRAIN.FREEZE_BN: 317 | for m in model.modules(): 318 | if isinstance(m, nn.BatchNorm2d): 319 | m.eval() # turn-off BN 320 | if cfg.TRAIN.STAGE == 2: 321 | model.module.trimap.eval() 322 | if args.local_rank <= 0: 323 | logging.info('Set trimap model to eval mode') 324 | if cfg.TRAIN.STAGE == 3: 325 | model.module.NET.eval() 326 | if args.local_rank <= 0: 327 | logging.info('Set alpha model to eval mode') 328 | 329 | sub_losses = ['L_alpha', 'L_comp', 'L_grad'] if not cfg.ALPHA.MODEL.endswith('fba') else \ 330 | ['L_alpha_comp', 'L_lap', 'L_grad'] 331 | 332 | data_time = AverageMeter() 333 | losses = AverageMeter() 334 | sub_losses_avg = [AverageMeter() for _ in range(len(sub_losses))] 335 | tic = time.time() 336 | cur_iters = epoch*iters_per_epoch 337 | 338 | prefetcher = data_prefetcher(trainloader) 339 | dp = prefetcher.next() 340 | i_iter = 0 341 | while dp[0] is not None: 342 | if cfg.SYSTEM.TESTMODE: 343 | if i_iter > 20: 344 | print() 345 | break 346 | def step(i_iter, dp, tic): 347 | data_time.update(time.time() - tic) 348 | 349 | def handle_batch(): 350 | fg, bg, a, ir, tri, _ = dp # [B, 3, 3 or 1, H, W] 351 | 352 | bg = bg if bg.dim() > 1 else None 353 | a = a if a.dim() > 1 else None 354 | ir = ir if ir.dim() > 1 else None 355 | 356 | out = model(a, fg, bg, ignore_region=ir, tri=tri) 357 | L_alpha = out[0].mean() 358 | L_comp = out[1].mean() 359 | L_grad = out[2].mean() 360 | vis_alpha = L_alpha.detach()#.item() 361 | vis_comp = L_comp.detach()#.item() 362 | vis_grad = L_grad.detach()#.item() 363 | if cfg.TRAIN.STAGE == 1: 364 | loss = L_alpha + L_comp + L_grad 365 | batch_out = [loss.detach(), vis_alpha, vis_comp, vis_grad, out[4:-1]] 366 | else: 367 | L_tri = out[3].mean() 368 | loss = L_alpha + L_comp + L_grad + L_tri 369 | batch_out = [loss.detach(), vis_alpha, vis_comp, vis_grad, out[4:]] 370 | 371 | model.zero_grad() 372 | loss.backward() 373 | optimizer.step() 374 | 375 | return batch_out 376 | 377 | loss, vis_alpha, vis_comp, vis_grad, vis_images = handle_batch() 378 | 379 | reduced_loss = reduce_tensor(loss) 380 | reduced_sub_losses = [reduce_tensor(vis_alpha), reduce_tensor(vis_comp), reduce_tensor(vis_grad)] 381 | 382 | # update average loss 383 | losses.update(reduced_loss.item()) 384 | sub_losses_avg[0].update(reduced_sub_losses[0].item()) 385 | sub_losses_avg[1].update(reduced_sub_losses[1].item()) 386 | sub_losses_avg[2].update(reduced_sub_losses[2].item()) 387 | 388 | torch_barrier() 389 | 390 | current_lr = adjust_learning_rate(optimizer, 391 | base_lr, 392 | total_epochs * iters_per_epoch, 393 | i_iter+cur_iters) 394 | 395 | if args.local_rank <= 0: 396 | progress_bar(i_iter, iters_per_epoch, epoch, start_epoch, total_epochs, 'finetuning', 397 | 'Data: {data_time} | ' 398 | 'Loss: {loss.val:.4f} ({loss.avg:.4f}) | ' 399 | '{sub_losses[0]}: {sub_losses_avg[0].val:.4f} ({sub_losses_avg[0].avg:.4f})'.format( 400 | data_time=format_time(data_time.sum), 401 | loss=losses, 402 | sub_losses=sub_losses, 403 | sub_losses_avg=sub_losses_avg)) 404 | 405 | if i_iter % image_freq == 0 and args.local_rank <= 0: 406 | write_image(image_outdir, vis_images, i_iter+cur_iters, trimap=(cfg.TRAIN.STAGE > 1)) 407 | return current_lr 408 | 409 | current_lr = step(i_iter, dp, tic) 410 | tic = time.time() 411 | 412 | dp = prefetcher.next() 413 | i_iter += 1 414 | 415 | if args.local_rank <= 0: 416 | logger_str = '{:s} | E [{:d}] | I [{:d}] | LR [{:.1e}] | Total Loss:{: 4.6f}'.format( 417 | MODEL, epoch+1, i_iter+1, current_lr, losses.avg) 418 | logger_str += ' | {} [{: 4.6f}] | {} [{: 4.6f}] | {} [{: 4.6f}]'.format( 419 | sub_losses[0], sub_losses_avg[0].avg, 420 | sub_losses[1], sub_losses_avg[1].avg, 421 | sub_losses[2], sub_losses_avg[2].avg) 422 | logging.info(logger_str) 423 | 424 | class data_prefetcher(): 425 | def __init__(self, loader): 426 | self.loader = iter(loader) 427 | self.stream = torch.cuda.Stream() 428 | self.preload() 429 | 430 | def preload(self): 431 | try: 432 | self.next_fg, self.next_bg, self.next_a, self.next_ir, self.next_tri, self.next_idx = next(self.loader) 433 | except StopIteration: 434 | self.next_fg = None 435 | self.next_bg = None 436 | self.next_a = None 437 | self.next_ir = None 438 | self.next_tri = None 439 | self.next_idx = None 440 | return 441 | with torch.cuda.stream(self.stream): 442 | self.next_fg = self.next_fg.cuda(non_blocking=True) 443 | self.next_bg = self.next_bg.cuda(non_blocking=True) 444 | self.next_a = self.next_a.cuda(non_blocking=True) 445 | self.next_ir = self.next_ir.cuda(non_blocking=True) 446 | self.next_tri = self.next_tri.cuda(non_blocking=True) 447 | self.next_idx = self.next_idx.cuda(non_blocking=True) 448 | 449 | def next(self): 450 | torch.cuda.current_stream().wait_stream(self.stream) 451 | fg = self.next_fg 452 | bg = self.next_bg 453 | a = self.next_a 454 | ir = self.next_ir 455 | tri = self.next_tri 456 | idx = self.next_idx 457 | if fg is not None: 458 | fg.record_stream(torch.cuda.current_stream()) 459 | if bg is not None: 460 | bg.record_stream(torch.cuda.current_stream()) 461 | if a is not None: 462 | a.record_stream(torch.cuda.current_stream()) 463 | if ir is not None: 464 | ir.record_stream(torch.cuda.current_stream()) 465 | if tri is not None: 466 | tri.record_stream(torch.cuda.current_stream()) 467 | if idx is not None: 468 | idx.record_stream(torch.cuda.current_stream()) 469 | self.preload() 470 | return fg, bg, a, ir, tri, idx 471 | 472 | 473 | 474 | 475 | def get_sampler(dataset, shuffle=True): 476 | if torch_dist.is_available() and torch_dist.is_initialized(): 477 | from torch.utils.data.distributed import DistributedSampler 478 | return DistributedSampler(dataset, shuffle=shuffle) 479 | else: 480 | return None 481 | 482 | 483 | def IoU(pred, true): 484 | _, _, n_class, _, _ = pred.shape 485 | 486 | _, xx = torch.max(pred, dim=2) 487 | _, yy = torch.max(true, dim=2) 488 | iou = list() 489 | for n in range(n_class): 490 | x = (xx == n).float() 491 | y = (yy == n).float() 492 | 493 | i = torch.sum(torch.sum(x*y, dim=-1), dim=-1) # sum over spatial dims 494 | u = torch.sum(torch.sum((x+y)-(x*y), dim=-1), dim=-1) 495 | 496 | iou.append(((i + 1e-4) / (u + 1e-4)).mean().item() * 100.) # b 497 | 498 | # mean over mini-batch 499 | return sum(iou)/n_class, iou 500 | 501 | 502 | if __name__ == "__main__": 503 | args, cfg = parse_args() 504 | main(args, cfg) 505 | -------------------------------------------------------------------------------- /train_s1_trimap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import time 5 | import timeit 6 | import shutil 7 | 8 | import numpy as np 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn.functional as F 12 | from torch import nn, optim 13 | from torch.utils import data 14 | from torchvision.utils import save_image 15 | 16 | from config import get_cfg_defaults 17 | from dataset import DIM_Train 18 | from helpers import * 19 | from utils.optimizer import RAdam 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Train network') 23 | parser.add_argument("--gpu", type=str, default='0,1,2,3') 24 | 25 | args = parser.parse_args() 26 | 27 | cfg = get_cfg_defaults() 28 | cfg.freeze() 29 | 30 | return args, cfg 31 | 32 | 33 | def main(args, cfg): 34 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 35 | if torch.cuda.is_available(): 36 | print('using Cuda devices, num:', torch.cuda.device_count()) 37 | 38 | MODEL = 's1_OTVM_trimap' 39 | random_seed = cfg.SYSTEM.RANDOM_SEED 40 | base_lr = cfg.TRAIN.BASE_LR 41 | weight_decay = cfg.TRAIN.WEIGHT_DECAY 42 | output_dir = os.path.join(cfg.SYSTEM.OUTDIR, 'checkpoint') 43 | os.makedirs(output_dir, exist_ok=True) 44 | start = timeit.default_timer() 45 | # cudnn related setting 46 | cudnn.benchmark = cfg.SYSTEM.CUDNN_BENCHMARK 47 | cudnn.deterministic = cfg.SYSTEM.CUDNN_DETERMINISTIC 48 | cudnn.enabled = cfg.SYSTEM.CUDNN_ENABLED 49 | if random_seed > 0: 50 | import random 51 | print('Seeding with', random_seed) 52 | random.seed(random_seed) 53 | torch.manual_seed(random_seed) 54 | 55 | logger, final_output_dir = create_logger(output_dir, MODEL, 'train') 56 | print(cfg) 57 | with open(os.path.join(final_output_dir, 'config.yaml'), 'w') as f: 58 | f.write(str(cfg)) 59 | image_outdir = os.path.join(final_output_dir, 'training_images') 60 | os.makedirs(os.path.join(final_output_dir, 'training_images'), exist_ok=True) 61 | 62 | model = get_model_trimap(cfg, mode='Train') 63 | torch_barrier() 64 | 65 | start_epoch = 0 66 | 67 | load_ckpt = './weights/STM_weights.pth' 68 | dct = load_NoPrefix(load_ckpt, 7) 69 | missing_keys, unexpected_keys = model.model.load_state_dict(dct, strict=False) 70 | logger.info('Missing keys: ' + str(sorted(missing_keys))) 71 | logger.info('Unexpected keys: ' + str(sorted(unexpected_keys))) 72 | logger.info("=> loaded checkpoint from {}".format(load_ckpt)) 73 | 74 | model = torch.nn.DataParallel(model).cuda() 75 | 76 | # optimizer 77 | params_dict = {k: v for k, v in model.named_parameters() if v.requires_grad} 78 | 79 | params_count = 0 80 | logging.info('=> Parameters needs to be optimized:') 81 | for k in sorted(params_dict): 82 | params_count += params_dict[k].shape.numel() 83 | logging.info('=> Total Parameters: {}'.format(params_count)) 84 | 85 | params = [{'params': list(params_dict.values()), 'lr': base_lr}] 86 | if cfg.TRAIN.OPTIMIZER == 'adam': 87 | optimizer = torch.optim.Adam(params, lr=base_lr) 88 | elif cfg.TRAIN.OPTIMIZER == 'radam': 89 | optimizer = RAdam(params, lr=base_lr, weight_decay=weight_decay) 90 | 91 | if cfg.TRAIN.LR_STRATEGY == 'stair': 92 | adjust_lr = stair_lr 93 | elif cfg.TRAIN.LR_STRATEGY == 'poly': 94 | adjust_lr = poly_lr 95 | elif cfg.TRAIN.LR_STRATEGY == 'const': 96 | adjust_lr = const_lr 97 | else: 98 | raise NotImplementedError('[%s] is not supported in cfg.TRAIN.LR_STRATEGY'%(cfg.TRAIN.LR_STRATEGY)) 99 | 100 | total_epochs = cfg.TRAIN.TOTAL_EPOCHS 101 | 102 | train_dataset = DIM_Train( 103 | data_root=cfg.DATASET.PATH, 104 | image_shape=cfg.TRAIN.TRAIN_INPUT_SIZE, 105 | mode='train', 106 | sample_length=3, 107 | ) 108 | train_dataset = [train_dataset] * 20 109 | 110 | train_dataset = data.ConcatDataset(train_dataset) 111 | trainloader = torch.utils.data.DataLoader( 112 | train_dataset, 113 | batch_size=cfg.TRAIN.BATCH_SIZE, 114 | num_workers=cfg.SYSTEM.NUM_WORKERS, 115 | pin_memory=False, 116 | drop_last=True, 117 | shuffle=True) 118 | 119 | if cfg.SYSTEM.TESTMODE: 120 | start_epoch += 199 121 | for epoch in range(start_epoch, total_epochs): 122 | train(epoch, cfg, trainloader, base_lr, start_epoch, total_epochs, 123 | optimizer, model, adjust_lr, image_outdir, MODEL) 124 | 125 | if (((epoch+1) % cfg.TRAIN.SAVE_EVERY_EPOCH) == 0) or ((epoch+1) == total_epochs): 126 | weight_fn = os.path.join(final_output_dir, 'checkpoint_{}.pth'.format(epoch+1)) 127 | logger.info('=> saving checkpoint to {}'.format(weight_fn)) 128 | torch.save(model.module.model.state_dict(), weight_fn) 129 | optim_fn = os.path.join(final_output_dir, 'optim_{}.pth'.format(epoch+1)) 130 | torch.save(optimizer.state_dict(), optim_fn) 131 | 132 | weight_fn = os.path.join('weights', '{:s}.pth'.format(MODEL)) 133 | logger.info('=> saving checkpoint to {}'.format(weight_fn)) 134 | torch.save(model.module.model.state_dict(), weight_fn) 135 | end = timeit.default_timer() 136 | logger.info('Time: %d sec.' % np.int32((end-start))) 137 | logger.info('Done') 138 | 139 | 140 | 141 | def write_image(outdir, out, step, max_batch=1): 142 | with torch.no_grad(): 143 | scaled_imgs, pred, tris, scaled_gts = out 144 | b, s, _, h, w = scaled_imgs.shape 145 | b = max_batch if b > max_batch else b 146 | img_list = list() 147 | img_list.append(scaled_imgs[:max_batch].reshape(b*s, 3, h, w)) 148 | img_list.append(tris[:max_batch].reshape(b*s, 3, h, w)) 149 | img_list.append(pred[:max_batch].reshape(b*s, 3, h, w)) 150 | imgs = torch.cat(img_list, dim=0).reshape(-1, 3, h, w) 151 | if h > 320: 152 | imgs = F.interpolate(imgs, scale_factor=320/h) 153 | save_image(imgs, os.path.join(outdir, '{}.png'.format(step)), nrow=int(s*b)) 154 | 155 | def train(epoch, cfg, trainloader, base_lr, start_epoch, total_epochs, 156 | optimizer, model, adjust_learning_rate, image_outdir, MODEL): 157 | # Training 158 | iters_per_epoch = len(trainloader) 159 | image_freq = cfg.TRAIN.IMAGE_FREQ if cfg.TRAIN.IMAGE_FREQ > 0 else 1e+8 160 | image_freq = min(image_freq, iters_per_epoch) 161 | 162 | # STM DISABLES BN DURING TRAINING 163 | model.train() 164 | for m in model.modules(): 165 | if isinstance(m, nn.BatchNorm2d): 166 | m.eval() # turn-off BN 167 | 168 | data_time = AverageMeter() 169 | losses = AverageMeter() 170 | IOU = AverageMeter() 171 | tic = time.time() 172 | cur_iters = epoch*iters_per_epoch 173 | 174 | prefetcher = data_prefetcher(trainloader) 175 | dp = prefetcher.next() 176 | i_iter = 0 177 | while dp[0] is not None: 178 | data_time.update(time.time() - tic) 179 | if cfg.SYSTEM.TESTMODE: 180 | if i_iter > 20: 181 | print() 182 | break 183 | 184 | def handle_batch(): 185 | fg, bg, a, ir, tri, _ = dp # [B, 3, 3 or 1, H, W] 186 | 187 | bg = bg if bg.dim() > 1 else None 188 | a = a if a.dim() > 1 else None 189 | ir = ir if ir.dim() > 1 else None 190 | 191 | out = model(a, fg, bg, ignore_region=ir, tri=tri) 192 | loss = out[0].mean() 193 | 194 | 195 | model.zero_grad() 196 | loss.backward() 197 | optimizer.step() 198 | return loss.detach(), out[1:] 199 | 200 | loss, vis_out = handle_batch() 201 | 202 | reduced_loss = reduce_tensor(loss) 203 | 204 | # update average loss 205 | losses.update(reduced_loss.item()) 206 | 207 | tri_pred = vis_out[1] 208 | tri_gt = vis_out[2] 209 | mIoU, _ = IoU(tri_pred, tri_gt) 210 | IOU.update(mIoU) 211 | torch_barrier() 212 | 213 | current_lr = adjust_learning_rate(optimizer, 214 | base_lr, 215 | total_epochs * iters_per_epoch, 216 | i_iter+cur_iters) 217 | 218 | tic = time.time() 219 | progress_bar(i_iter, iters_per_epoch, epoch, start_epoch, total_epochs, 'finetuning', 220 | 'Data: {data_time} | ' 221 | 'Loss: {loss.val:.4f} ({loss.avg:.4f}) | ' 222 | 'IOU: {IOU.val:.4f} ({IOU.avg:.4f})'.format( 223 | data_time=format_time(data_time.sum), 224 | loss=losses, 225 | IOU=IOU)) 226 | 227 | if i_iter % image_freq == 0: 228 | write_image(image_outdir, vis_out, i_iter+cur_iters) 229 | 230 | dp = prefetcher.next() 231 | i_iter += 1 232 | 233 | logger_str = '{:s} | E [{:d}] | I [{:d}] | LR [{:.1e}] | CE:{: 4.6f} | mIoU:{: 4.6f}' 234 | logger_format = [MODEL, epoch+1, i_iter+1, current_lr, losses.avg, IOU.avg] 235 | logging.info(logger_str.format(*logger_format)) 236 | 237 | class data_prefetcher(): 238 | def __init__(self, loader): 239 | self.loader = iter(loader) 240 | self.stream = torch.cuda.Stream() 241 | self.preload() 242 | 243 | def preload(self): 244 | try: 245 | self.next_fg, self.next_bg, self.next_a, self.next_ir, self.next_tri, self.next_idx = next(self.loader) 246 | except StopIteration: 247 | self.next_fg = None 248 | self.next_bg = None 249 | self.next_a = None 250 | self.next_ir = None 251 | self.next_tri = None 252 | self.next_idx = None 253 | return 254 | with torch.cuda.stream(self.stream): 255 | self.next_fg = self.next_fg.cuda(non_blocking=True) 256 | self.next_bg = self.next_bg.cuda(non_blocking=True) 257 | self.next_a = self.next_a.cuda(non_blocking=True) 258 | self.next_ir = self.next_ir.cuda(non_blocking=True) 259 | self.next_tri = self.next_tri.cuda(non_blocking=True) 260 | self.next_idx = self.next_idx.cuda(non_blocking=True) 261 | 262 | def next(self): 263 | torch.cuda.current_stream().wait_stream(self.stream) 264 | fg = self.next_fg 265 | bg = self.next_bg 266 | a = self.next_a 267 | ir = self.next_ir 268 | tri = self.next_tri 269 | idx = self.next_idx 270 | if fg is not None: 271 | fg.record_stream(torch.cuda.current_stream()) 272 | if bg is not None: 273 | bg.record_stream(torch.cuda.current_stream()) 274 | if a is not None: 275 | a.record_stream(torch.cuda.current_stream()) 276 | if ir is not None: 277 | ir.record_stream(torch.cuda.current_stream()) 278 | if tri is not None: 279 | tri.record_stream(torch.cuda.current_stream()) 280 | if idx is not None: 281 | idx.record_stream(torch.cuda.current_stream()) 282 | self.preload() 283 | return fg, bg, a, ir, tri, idx 284 | 285 | 286 | 287 | def IoU(pred, true): 288 | _, _, n_class, _, _ = pred.shape 289 | 290 | _, xx = torch.max(pred, dim=2) 291 | _, yy = torch.max(true, dim=2) 292 | iou = list() 293 | for n in range(n_class): 294 | x = (xx == n).float() 295 | y = (yy == n).float() 296 | 297 | i = torch.sum(torch.sum(x*y, dim=-1), dim=-1) # sum over spatial dims 298 | u = torch.sum(torch.sum((x+y)-(x*y), dim=-1), dim=-1) 299 | 300 | iou.append(((i + 1e-4) / (u + 1e-4)).mean().item() * 100.) # b 301 | 302 | # mean over mini-batch 303 | return sum(iou)/n_class, iou 304 | 305 | 306 | if __name__ == "__main__": 307 | args, cfg = parse_args() 308 | main(args, cfg) 309 | -------------------------------------------------------------------------------- /utils/loss_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def L1_mask(x, y, mask=None, epsilon=1.001e-5, normalize=True): 5 | res = torch.abs(x - y) 6 | b,c,h,w = y.shape 7 | if mask is not None: 8 | res = res * mask 9 | if normalize: 10 | _safe = torch.sum((mask > epsilon).float()).clamp(epsilon, b*c*h*w+1) 11 | return torch.sum(res) / _safe 12 | else: 13 | return torch.sum(res) 14 | if normalize: 15 | return torch.mean(res) 16 | else: 17 | return torch.sum(res) 18 | 19 | 20 | def L1_mask_hard_mining(x, y, mask): 21 | input_size = x.size() 22 | res = torch.sum(torch.abs(x - y), dim=1, keepdim=True) 23 | with torch.no_grad(): 24 | idx = mask > 0.5 25 | res_sort = [torch.sort(res[i, idx[i, ...]])[0] for i in range(idx.shape[0])] 26 | res_sort = [i[int(i.shape[0] * 0.5)].item() for i in res_sort] 27 | new_mask = mask.clone() 28 | for i in range(res.shape[0]): 29 | new_mask[i, ...] = ((mask[i, ...] > 0.5) & (res[i, ...] > res_sort[i])).float() 30 | 31 | res = res * new_mask 32 | final_res = torch.sum(res) / torch.sum(new_mask) 33 | return final_res, new_mask 34 | 35 | def get_gradient(image): 36 | b, c, h, w = image.shape 37 | dy = image[:, :, 1:, :] - image[:, :, :-1, :] 38 | dx = image[:, :, :, 1:] - image[:, :, :, :-1] 39 | 40 | dy = F.pad(dy, (0, 0, 0, 1)) 41 | dx = F.pad(dx, (0, 1, 0, 0)) 42 | return dx, dy 43 | 44 | def L1_grad(pred, gt, mask=None, epsilon=1.001e-5, normalize=True): 45 | fake_grad_x, fake_grad_y = get_gradient(pred) 46 | true_grad_x, true_grad_y = get_gradient(gt) 47 | 48 | mag_fake = torch.sqrt(fake_grad_x ** 2 + fake_grad_y ** 2 + epsilon) 49 | mag_true = torch.sqrt(true_grad_x ** 2 + true_grad_y ** 2 + epsilon) 50 | 51 | return L1_mask(mag_fake, mag_true, mask=mask, normalize=normalize) 52 | 53 | ''' 54 | Ported from https://github.com/ceciliavision/perceptual-reflection-removal/blob/master/main.py 55 | ''' 56 | def exclusion_loss(img1, img2, level, epsilon=1.001e-5, normalize=True): 57 | gradx_loss=[] 58 | grady_loss=[] 59 | for l in range(level): 60 | gradx1, grady1 = get_gradient(img1) 61 | gradx2, grady2 = get_gradient(img2) 62 | 63 | alphax=2.0*torch.mean(torch.abs(gradx1))/(torch.mean(torch.abs(gradx2)) + epsilon) 64 | alphay=2.0*torch.mean(torch.abs(grady1))/(torch.mean(torch.abs(grady2)) + epsilon) 65 | 66 | gradx1_s=(torch.sigmoid(gradx1)*2)-1 67 | grady1_s=(torch.sigmoid(grady1)*2)-1 68 | gradx2_s=(torch.sigmoid(gradx2*alphax)*2)-1 69 | grady2_s=(torch.sigmoid(grady2*alphay)*2)-1 70 | 71 | safe_x = torch.mean((gradx1_s ** 2) * (gradx2_s ** 2), dim=(1,2,3)) + epsilon 72 | safe_y = torch.mean((grady1_s ** 2) * (grady2_s ** 2), dim=(1,2,3)) + epsilon 73 | gradx_loss.append(safe_x ** 0.25) 74 | grady_loss.append(safe_y ** 0.25) 75 | 76 | img1 = F.avg_pool2d(img1, kernel_size=2, stride=2) 77 | img2 = F.avg_pool2d(img2, kernel_size=2, stride=2) 78 | 79 | if normalize: 80 | return torch.mean(sum(gradx_loss) / float(level)) + torch.mean(sum(grady_loss) / float(level)) 81 | else: 82 | return torch.sum(sum(gradx_loss) / float(level)) + torch.sum(sum(grady_loss) / float(level)) 83 | 84 | def sparsity_loss(prediction, trimask, eps=1e-5, gamma=0.9): 85 | mask = trimask > 0.5 86 | pred = prediction[mask] 87 | loss = torch.sum(torch.pow(pred+eps, gamma) + torch.pow(1.-pred+eps, gamma) - 1.) 88 | return loss 89 | 90 | ''' 91 | Borrowed from https://gist.github.com/alper111/b9c6d80e2dba1ee0bfac15eb7dad09c8 92 | It directly follows OpenCV's image pyramid implementation pyrDown() and pyrUp(). 93 | Reference: https://docs.opencv.org/4.4.0/d4/d86/group__imgproc__filter.html#gaf9bba239dfca11654cb7f50f889fc2ff 94 | ''' 95 | class LapLoss(torch.nn.Module): 96 | def __init__(self, max_levels=5): 97 | super(LapLoss, self).__init__() 98 | self.max_levels = max_levels 99 | kernel = torch.tensor([[1., 4., 6., 4., 1], 100 | [4., 16., 24., 16., 4.], 101 | [6., 24., 36., 24., 6.], 102 | [4., 16., 24., 16., 4.], 103 | [1., 4., 6., 4., 1.]]) 104 | kernel /= 256. 105 | self.register_buffer('KERNEL', kernel.float()) 106 | 107 | def downsample(self, x): 108 | # rejecting even rows and columns 109 | return x[:, :, ::2, ::2] 110 | 111 | def upsample(self, x): 112 | # Padding zeros interleaved in x (similar to unpooling where indices are always at top-left corner) 113 | # Original code only works when x.shape[2] == x.shape[3] because it uses the wrong indice order 114 | # after the first permute 115 | cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3) 116 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]) 117 | cc = cc.permute(0,1,3,2) 118 | cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2, device=x.device)], dim=3) 119 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2) 120 | x_up = cc.permute(0,1,3,2) 121 | return self.conv_gauss(x_up, 4*self.KERNEL.repeat(x.shape[1], 1, 1, 1)) 122 | 123 | def conv_gauss(self, img, kernel): 124 | img = F.pad(img, (2, 2, 2, 2), mode='reflect') 125 | out = F.conv2d(img, kernel, groups=img.shape[1]) 126 | return out 127 | 128 | def laplacian_pyramid(self, img): 129 | current = img 130 | pyr = [] 131 | for level in range(self.max_levels): 132 | filtered = self.conv_gauss(current, \ 133 | self.KERNEL.repeat(img.shape[1], 1, 1, 1)) 134 | down = self.downsample(filtered) 135 | up = self.upsample(down) 136 | diff = current-up 137 | pyr.append(diff) 138 | current = down 139 | return pyr 140 | 141 | def forward(self, img, tgt, mask=None, normalize=True): 142 | (img, tgt), pad = self.pad_divide_by([img, tgt], 32, (img.size()[2], img.size()[3])) 143 | 144 | pyr_input = self.laplacian_pyramid(img) 145 | pyr_target = self.laplacian_pyramid(tgt) 146 | loss = sum((2 ** level) * L1_mask(ab[0], ab[1], mask=mask, normalize=False) \ 147 | for level, ab in enumerate(zip(pyr_input, pyr_target))) 148 | if normalize: 149 | b,c,h,w = tgt.shape 150 | if mask is not None: 151 | _safe = torch.sum((mask > 1e-6).float()).clamp(epsilon, b*c*h*w+1) 152 | else: 153 | _safe = b*c*h*w 154 | return loss / _safe 155 | return loss 156 | 157 | def pad_divide_by(self, in_list, d, in_size): 158 | out_list = [] 159 | h, w = in_size 160 | if h % d > 0: 161 | new_h = h + d - h % d 162 | else: 163 | new_h = h 164 | if w % d > 0: 165 | new_w = w + d - w % d 166 | else: 167 | new_w = w 168 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) 169 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) 170 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 171 | for inp in in_list: 172 | out_list.append(F.pad(inp, pad_array)) 173 | return out_list, pad_array -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class RAdam(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False): 8 | if not 0.0 <= lr: 9 | raise ValueError("Invalid learning rate: {}".format(lr)) 10 | if not 0.0 <= eps: 11 | raise ValueError("Invalid epsilon value: {}".format(eps)) 12 | if not 0.0 <= betas[0] < 1.0: 13 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 14 | if not 0.0 <= betas[1] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 16 | 17 | self.degenerated_to_sgd = degenerated_to_sgd 18 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 19 | for param in params: 20 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 21 | param['buffer'] = [[None, None, None] for _ in range(10)] 22 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 23 | super(RAdam, self).__init__(params, defaults) 24 | 25 | def __setstate__(self, state): 26 | super(RAdam, self).__setstate__(state) 27 | 28 | def step(self, closure=None): 29 | 30 | loss = None 31 | if closure is not None: 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | 36 | for p in group['params']: 37 | if p.grad is None: 38 | continue 39 | grad = p.grad.data.float() 40 | if grad.is_sparse: 41 | raise RuntimeError('RAdam does not support sparse gradients') 42 | 43 | p_data_fp32 = p.data.float() 44 | 45 | state = self.state[p] 46 | 47 | if len(state) == 0: 48 | state['step'] = 0 49 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 50 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 51 | else: 52 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 53 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 54 | 55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 56 | beta1, beta2 = group['betas'] 57 | 58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 59 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 60 | 61 | state['step'] += 1 62 | buffered = group['buffer'][int(state['step'] % 10)] 63 | if state['step'] == buffered[0]: 64 | N_sma, step_size = buffered[1], buffered[2] 65 | else: 66 | buffered[0] = state['step'] 67 | beta2_t = beta2 ** state['step'] 68 | N_sma_max = 2 / (1 - beta2) - 1 69 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 70 | buffered[1] = N_sma 71 | 72 | # more conservative since it's an approximated value 73 | if N_sma >= 5: 74 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 75 | elif self.degenerated_to_sgd: 76 | step_size = 1.0 / (1 - beta1 ** state['step']) 77 | else: 78 | step_size = -1 79 | buffered[2] = step_size 80 | 81 | # more conservative since it's an approximated value 82 | if N_sma >= 5: 83 | if group['weight_decay'] != 0: 84 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 85 | denom = exp_avg_sq.sqrt().add_(group['eps']) 86 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 87 | p.data.copy_(p_data_fp32) 88 | elif step_size > 0: 89 | if group['weight_decay'] != 0: 90 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 91 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 92 | p.data.copy_(p_data_fp32) 93 | 94 | return loss 95 | 96 | class PlainRAdam(Optimizer): 97 | 98 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False): 99 | if not 0.0 <= lr: 100 | raise ValueError("Invalid learning rate: {}".format(lr)) 101 | if not 0.0 <= eps: 102 | raise ValueError("Invalid epsilon value: {}".format(eps)) 103 | if not 0.0 <= betas[0] < 1.0: 104 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 105 | if not 0.0 <= betas[1] < 1.0: 106 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 107 | 108 | self.degenerated_to_sgd = degenerated_to_sgd 109 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 110 | 111 | super(PlainRAdam, self).__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super(PlainRAdam, self).__setstate__(state) 115 | 116 | def step(self, closure=None): 117 | 118 | loss = None 119 | if closure is not None: 120 | loss = closure() 121 | 122 | for group in self.param_groups: 123 | 124 | for p in group['params']: 125 | if p.grad is None: 126 | continue 127 | grad = p.grad.data.float() 128 | if grad.is_sparse: 129 | raise RuntimeError('RAdam does not support sparse gradients') 130 | 131 | p_data_fp32 = p.data.float() 132 | 133 | state = self.state[p] 134 | 135 | if len(state) == 0: 136 | state['step'] = 0 137 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 138 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 139 | else: 140 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 141 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 142 | 143 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 144 | beta1, beta2 = group['betas'] 145 | 146 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 147 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 148 | 149 | state['step'] += 1 150 | beta2_t = beta2 ** state['step'] 151 | N_sma_max = 2 / (1 - beta2) - 1 152 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 153 | 154 | 155 | # more conservative since it's an approximated value 156 | if N_sma >= 5: 157 | if group['weight_decay'] != 0: 158 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 159 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 160 | denom = exp_avg_sq.sqrt().add_(group['eps']) 161 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 162 | p.data.copy_(p_data_fp32) 163 | elif self.degenerated_to_sgd: 164 | if group['weight_decay'] != 0: 165 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 166 | step_size = group['lr'] / (1 - beta1 ** state['step']) 167 | p_data_fp32.add_(-step_size, exp_avg) 168 | p.data.copy_(p_data_fp32) 169 | 170 | return loss 171 | 172 | 173 | class AdamW(Optimizer): 174 | 175 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 176 | if not 0.0 <= lr: 177 | raise ValueError("Invalid learning rate: {}".format(lr)) 178 | if not 0.0 <= eps: 179 | raise ValueError("Invalid epsilon value: {}".format(eps)) 180 | if not 0.0 <= betas[0] < 1.0: 181 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 182 | if not 0.0 <= betas[1] < 1.0: 183 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 184 | 185 | defaults = dict(lr=lr, betas=betas, eps=eps, 186 | weight_decay=weight_decay, warmup = warmup) 187 | super(AdamW, self).__init__(params, defaults) 188 | 189 | def __setstate__(self, state): 190 | super(AdamW, self).__setstate__(state) 191 | 192 | def step(self, closure=None): 193 | loss = None 194 | if closure is not None: 195 | loss = closure() 196 | 197 | for group in self.param_groups: 198 | 199 | for p in group['params']: 200 | if p.grad is None: 201 | continue 202 | grad = p.grad.data.float() 203 | if grad.is_sparse: 204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 205 | 206 | p_data_fp32 = p.data.float() 207 | 208 | state = self.state[p] 209 | 210 | if len(state) == 0: 211 | state['step'] = 0 212 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 213 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 214 | else: 215 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 216 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 217 | 218 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 219 | beta1, beta2 = group['betas'] 220 | 221 | state['step'] += 1 222 | 223 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 224 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 225 | 226 | denom = exp_avg_sq.sqrt().add_(group['eps']) 227 | bias_correction1 = 1 - beta1 ** state['step'] 228 | bias_correction2 = 1 - beta2 ** state['step'] 229 | 230 | if group['warmup'] > state['step']: 231 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 232 | else: 233 | scheduled_lr = group['lr'] 234 | 235 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 236 | 237 | if group['weight_decay'] != 0: 238 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 239 | 240 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 241 | 242 | p.data.copy_(p_data_fp32) 243 | 244 | return loss 245 | -------------------------------------------------------------------------------- /utils/tmp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/__init__.py -------------------------------------------------------------------------------- /utils/tmp/augmentation.py: -------------------------------------------------------------------------------- 1 | import easing_functions as ef 2 | import random 3 | import torch 4 | from torchvision import transforms 5 | from torchvision.transforms import functional as F 6 | 7 | 8 | class MotionAugmentation: 9 | def __init__(self, 10 | size, 11 | prob_fgr_affine, 12 | prob_bgr_affine, 13 | prob_noise, 14 | prob_color_jitter, 15 | prob_grayscale, 16 | prob_sharpness, 17 | prob_blur, 18 | prob_hflip, 19 | prob_pause, 20 | static_affine=True, 21 | aspect_ratio_range=(0.9, 1.1)): 22 | self.size = size 23 | self.prob_fgr_affine = prob_fgr_affine 24 | self.prob_bgr_affine = prob_bgr_affine 25 | self.prob_noise = prob_noise 26 | self.prob_color_jitter = prob_color_jitter 27 | self.prob_grayscale = prob_grayscale 28 | self.prob_sharpness = prob_sharpness 29 | self.prob_blur = prob_blur 30 | self.prob_hflip = prob_hflip 31 | self.prob_pause = prob_pause 32 | self.static_affine = static_affine 33 | self.aspect_ratio_range = aspect_ratio_range 34 | 35 | def __call__(self, fgrs, phas, bgrs): 36 | # Foreground affine 37 | if random.random() < self.prob_fgr_affine: 38 | fgrs, phas = self._motion_affine(fgrs, phas) 39 | 40 | # Background affine 41 | if random.random() < self.prob_bgr_affine / 2: 42 | bgrs = self._motion_affine(bgrs) 43 | if random.random() < self.prob_bgr_affine / 2: 44 | fgrs, phas, bgrs = self._motion_affine(fgrs, phas, bgrs) 45 | 46 | # Still Affine 47 | if self.static_affine: 48 | fgrs, phas = self._static_affine(fgrs, phas, scale_ranges=(0.5, 1)) 49 | bgrs = self._static_affine(bgrs, scale_ranges=(1, 1.5)) 50 | 51 | # To tensor 52 | fgrs = torch.stack([F.to_tensor(fgr) for fgr in fgrs]) 53 | phas = torch.stack([F.to_tensor(pha) for pha in phas]) 54 | bgrs = torch.stack([F.to_tensor(bgr) for bgr in bgrs]) 55 | 56 | # Resize 57 | params = transforms.RandomResizedCrop.get_params(fgrs, scale=(1, 1), ratio=self.aspect_ratio_range) 58 | fgrs = F.resized_crop(fgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) 59 | phas = F.resized_crop(phas, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) 60 | params = transforms.RandomResizedCrop.get_params(bgrs, scale=(1, 1), ratio=self.aspect_ratio_range) 61 | bgrs = F.resized_crop(bgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) 62 | 63 | # Horizontal flip 64 | if random.random() < self.prob_hflip: 65 | fgrs = F.hflip(fgrs) 66 | phas = F.hflip(phas) 67 | if random.random() < self.prob_hflip: 68 | bgrs = F.hflip(bgrs) 69 | 70 | # Noise 71 | if random.random() < self.prob_noise: 72 | fgrs, bgrs = self._motion_noise(fgrs, bgrs) 73 | 74 | # Color jitter 75 | if random.random() < self.prob_color_jitter: 76 | fgrs = self._motion_color_jitter(fgrs) 77 | if random.random() < self.prob_color_jitter: 78 | bgrs = self._motion_color_jitter(bgrs) 79 | 80 | # Grayscale 81 | if random.random() < self.prob_grayscale: 82 | fgrs = F.rgb_to_grayscale(fgrs, num_output_channels=3).contiguous() 83 | bgrs = F.rgb_to_grayscale(bgrs, num_output_channels=3).contiguous() 84 | 85 | # Sharpen 86 | if random.random() < self.prob_sharpness: 87 | sharpness = random.random() * 8 88 | fgrs = F.adjust_sharpness(fgrs, sharpness) 89 | phas = F.adjust_sharpness(phas, sharpness) 90 | bgrs = F.adjust_sharpness(bgrs, sharpness) 91 | 92 | # Blur 93 | if random.random() < self.prob_blur / 3: 94 | fgrs, phas = self._motion_blur(fgrs, phas) 95 | if random.random() < self.prob_blur / 3: 96 | bgrs = self._motion_blur(bgrs) 97 | if random.random() < self.prob_blur / 3: 98 | fgrs, phas, bgrs = self._motion_blur(fgrs, phas, bgrs) 99 | 100 | # Pause 101 | if random.random() < self.prob_pause: 102 | fgrs, phas, bgrs = self._motion_pause(fgrs, phas, bgrs) 103 | 104 | return fgrs, phas, bgrs 105 | 106 | def _static_affine(self, *imgs, scale_ranges): 107 | params = transforms.RandomAffine.get_params( 108 | degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=scale_ranges, 109 | shears=(-5, 5), img_size=imgs[0][0].size) 110 | imgs = [[F.affine(t, *params, F.InterpolationMode.BILINEAR) for t in img] for img in imgs] 111 | return imgs if len(imgs) > 1 else imgs[0] 112 | 113 | def _motion_affine(self, *imgs): 114 | config = dict(degrees=(-10, 10), translate=(0.1, 0.1), 115 | scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size) 116 | angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config) 117 | angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config) 118 | 119 | T = len(imgs[0]) 120 | easing = random_easing_fn() 121 | for t in range(T): 122 | percentage = easing(t / (T - 1)) 123 | angle = lerp(angleA, angleB, percentage) 124 | transX = lerp(transXA, transXB, percentage) 125 | transY = lerp(transYA, transYB, percentage) 126 | scale = lerp(scaleA, scaleB, percentage) 127 | shearX = lerp(shearXA, shearXB, percentage) 128 | shearY = lerp(shearYA, shearYB, percentage) 129 | for img in imgs: 130 | img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR) 131 | return imgs if len(imgs) > 1 else imgs[0] 132 | 133 | def _motion_noise(self, *imgs): 134 | grain_size = random.random() * 3 + 1 # range 1 ~ 4 135 | monochrome = random.random() < 0.5 136 | for img in imgs: 137 | T, C, H, W = img.shape 138 | noise = torch.randn((T, 1 if monochrome else C, round(H / grain_size), round(W / grain_size))) 139 | noise.mul_(random.random() * 0.2 / grain_size) 140 | if grain_size != 1: 141 | noise = F.resize(noise, (H, W)) 142 | img.add_(noise).clamp_(0, 1) 143 | return imgs if len(imgs) > 1 else imgs[0] 144 | 145 | def _motion_color_jitter(self, *imgs): 146 | brightnessA, brightnessB, contrastA, contrastB, saturationA, saturationB, hueA, hueB \ 147 | = torch.randn(8).mul(0.1).tolist() 148 | strength = random.random() * 0.2 149 | easing = random_easing_fn() 150 | T = len(imgs[0]) 151 | for t in range(T): 152 | percentage = easing(t / (T - 1)) * strength 153 | for img in imgs: 154 | img[t] = F.adjust_brightness(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) 155 | img[t] = F.adjust_contrast(img[t], max(1 + lerp(contrastA, contrastB, percentage), 0.1)) 156 | img[t] = F.adjust_saturation(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) 157 | img[t] = F.adjust_hue(img[t], min(0.5, max(-0.5, lerp(hueA, hueB, percentage) * 0.1))) 158 | return imgs if len(imgs) > 1 else imgs[0] 159 | 160 | def _motion_blur(self, *imgs): 161 | blurA = random.random() * 10 162 | blurB = random.random() * 10 163 | 164 | T = len(imgs[0]) 165 | easing = random_easing_fn() 166 | for t in range(T): 167 | percentage = easing(t / (T - 1)) 168 | blur = max(lerp(blurA, blurB, percentage), 0) 169 | if blur != 0: 170 | kernel_size = int(blur * 2) 171 | if kernel_size % 2 == 0: 172 | kernel_size += 1 # Make kernel_size odd 173 | for img in imgs: 174 | img[t] = F.gaussian_blur(img[t], kernel_size, sigma=blur) 175 | 176 | return imgs if len(imgs) > 1 else imgs[0] 177 | 178 | def _motion_pause(self, *imgs): 179 | T = len(imgs[0]) 180 | pause_frame = random.choice(range(T - 1)) 181 | pause_length = random.choice(range(T - pause_frame)) 182 | for img in imgs: 183 | img[pause_frame + 1 : pause_frame + pause_length] = img[pause_frame] 184 | return imgs if len(imgs) > 1 else imgs[0] 185 | 186 | 187 | def lerp(a, b, percentage): 188 | return a * (1 - percentage) + b * percentage 189 | 190 | 191 | def random_easing_fn(): 192 | if random.random() < 0.2: 193 | return ef.LinearInOut() 194 | else: 195 | return random.choice([ 196 | ef.BackEaseIn, 197 | ef.BackEaseOut, 198 | ef.BackEaseInOut, 199 | ef.BounceEaseIn, 200 | ef.BounceEaseOut, 201 | ef.BounceEaseInOut, 202 | ef.CircularEaseIn, 203 | ef.CircularEaseOut, 204 | ef.CircularEaseInOut, 205 | ef.CubicEaseIn, 206 | ef.CubicEaseOut, 207 | ef.CubicEaseInOut, 208 | ef.ExponentialEaseIn, 209 | ef.ExponentialEaseOut, 210 | ef.ExponentialEaseInOut, 211 | ef.ElasticEaseIn, 212 | ef.ElasticEaseOut, 213 | ef.ElasticEaseInOut, 214 | ef.QuadEaseIn, 215 | ef.QuadEaseOut, 216 | ef.QuadEaseInOut, 217 | ef.QuarticEaseIn, 218 | ef.QuarticEaseOut, 219 | ef.QuarticEaseInOut, 220 | ef.QuinticEaseIn, 221 | ef.QuinticEaseOut, 222 | ef.QuinticEaseInOut, 223 | ef.SineEaseIn, 224 | ef.SineEaseOut, 225 | ef.SineEaseInOut, 226 | Step, 227 | ])() 228 | 229 | class Step: # Custom easing function for sudden change. 230 | def __call__(self, value): 231 | return 0 if value < 0.5 else 1 232 | 233 | 234 | # ---------------------------- Frame Sampler ---------------------------- 235 | 236 | 237 | class TrainFrameSampler: 238 | def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]): 239 | self.speed = speed 240 | 241 | def __call__(self, seq_length): 242 | frames = list(range(seq_length)) 243 | 244 | # Speed up 245 | speed = random.choice(self.speed) 246 | frames = [int(f * speed) for f in frames] 247 | 248 | # Shift 249 | shift = random.choice(range(seq_length)) 250 | frames = [f + shift for f in frames] 251 | 252 | # Reverse 253 | if random.random() < 0.5: 254 | frames = frames[::-1] 255 | 256 | return frames 257 | 258 | class ValidFrameSampler: 259 | def __call__(self, seq_length): 260 | return range(seq_length) -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | - "3.6" 5 | install: 6 | - pip install -r requirements.txt 7 | script: 8 | - pytest 9 | -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Marco Forte 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/README.md: -------------------------------------------------------------------------------- 1 | # Closed-Form Matting 2 | [![Build Status](https://travis-ci.org/MarcoForte/closed-form-matting.svg?branch=master)](https://travis-ci.org/MarcoForte/closed-form-matting) 3 | 4 | 5 | Python implementation of image matting method proposed in A. Levin D. Lischinski and Y. Weiss. A Closed Form Solution to Natural Image Matting. IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), June 2006, New York 6 | 7 | The repository also contains implementation of background/foreground reconstruction method proposed in Levin, Anat, Dani Lischinski, and Yair Weiss. "A closed-form solution to natural image matting." IEEE Transactions on Pattern Analysis and Machine Intelligence 30.2 (2008): 228-242. 8 | 9 | ## Requirements 10 | - python 3.5+ (Though it should run on 2.7) 11 | - scipy 12 | - numpy 13 | - opencv-python 14 | 15 | ## Installation 16 | 17 | Clone this repository and install the closed-form-matting package via pip. 18 | 19 | ```bash 20 | git clone https://github.com/MarcoForte/closed-form-matting.git 21 | cd closed-form-matting/ 22 | pip install . 23 | ``` 24 | 25 | ## Usage 26 | 27 | ### Closed-Form matting 28 | CLI inerface: 29 | 30 | ```bash 31 | # Scribbles input 32 | closed-form-matting ./testdata/source.png -s ./testdata/scribbles.png -o output_alpha.png 33 | 34 | # Trimap input 35 | closed-form-matting ./testdata/source.png -t ./testdata/trimap.png -o output_alpha.png 36 | 37 | # Add flag --solve-fg to compute foreground color and output RGBA image instead 38 | # of alpha. 39 | ``` 40 | 41 | 42 | Python interface: 43 | 44 | ```python 45 | import closed_form_matting 46 | ... 47 | # For scribles input 48 | alpha = closed_form_matting.closed_form_matting_with_scribbles(image, scribbles) 49 | 50 | # For trimap input 51 | alpha = closed_form_matting.closed_form_matting_with_trimap(image, trimap) 52 | 53 | # For prior with confidence 54 | alpha = closed_form_matting.closed_form_matting_with_prior( 55 | image, prior, prior_confidence, optional_const_mask) 56 | 57 | # To get Matting Laplacian for image 58 | laplacian = closed_form_matting.compute_laplacian(image, optional_const_mask) 59 | ``` 60 | 61 | ### Foreground and Background Reconstruction 62 | CLI interface (requires opencv-python): 63 | 64 | ```bash 65 | solve-foreground-background image.png alpha.png foreground.png background.png 66 | ``` 67 | 68 | Python interface: 69 | 70 | ```python 71 | from closed_form_matting import solve_foreground_background 72 | ... 73 | foreground, background = solve_foreground_background(image, alpha) 74 | ``` 75 | 76 | ## Results 77 | | Original image | Scribbled image | Output alpha | Output foreground | 78 | |------------------|-----------------|--------------|-------------------| 79 | | ![Original image](testdata/source.png) | ![Scribbled image](testdata/scribbles.png) | ![Output alpha](testdata/output_alpha.png) | ![Output foreground](testdata/output_foreground.png) | 80 | 81 | 82 | ## More Information 83 | The computation is generally faster than the matlab version thanks to more vectorization. 84 | Note. The computed laplacian is slightly different due to array ordering in numpy being different than in matlab. To get same laplacian as in matlab change, 85 | 86 | `indsM = np.arange(h*w).reshape((h, w))` 87 | `ravelImg = img.reshape(h*w, d)` 88 | to 89 | `indsM = np.arange(h*w).reshape((h, w), order='F')` 90 | `ravelImg = img.reshape(h*w, d, , order='F')`. 91 | Again note that this will result in incorrect alpha if the `D_s, b_s` orderings are not also changed to `order='F'F`. 92 | 93 | For more information see the original paper http://www.wisdom.weizmann.ac.il/~levina/papers/Matting-Levin-Lischinski-Weiss-CVPR06.pdf 94 | The original matlab code is here http://www.wisdom.weizmann.ac.il/~levina/matting.tar.gz 95 | 96 | ## Disclaimer 97 | 98 | The code is free for academic/research purpose. Use at your own risk and we are not responsible for any loss resulting from this code. Feel free to submit pull request for bug fixes. 99 | -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/closed_form_matting/__init__.py: -------------------------------------------------------------------------------- 1 | # #!/usr/bin/env python 2 | # """Init script when importing closed-form-matting package""" 3 | 4 | # from closed_form_matting.closed_form_matting import ( 5 | # compute_laplacian, 6 | # closed_form_matting_with_prior, 7 | # closed_form_matting_with_trimap, 8 | # closed_form_matting_with_scribbles, 9 | # ) 10 | # from closed_form_matting.solve_foreground_background import ( 11 | # solve_foreground_background 12 | # ) 13 | 14 | # __version__ = '1.0.0' 15 | # __all__ = [ 16 | # 'compute_laplacian', 17 | # 'closed_form_matting_with_prior', 18 | # 'closed_form_matting_with_trimap', 19 | # 'closed_form_matting_with_scribbles', 20 | # 'solve_foreground_background', 21 | # ] 22 | -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/closed_form_matting/closed_form_matting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Implementation of Closed-Form Matting. 3 | 4 | This module implements natural image matting method described in: 5 | Levin, Anat, Dani Lischinski, and Yair Weiss. "A closed-form solution to natural image matting." 6 | IEEE Transactions on Pattern Analysis and Machine Intelligence 30.2 (2008): 228-242. 7 | 8 | The code can be used in two ways: 9 | 1. By importing solve_foregound_background in your code: 10 | ``` 11 | import closed_form_matting 12 | ... 13 | # For scribles input 14 | alpha = closed_form_matting.closed_form_matting_with_scribbles(image, scribbles) 15 | 16 | # For trimap input 17 | alpha = closed_form_matting.closed_form_matting_with_trimap(image, trimap) 18 | 19 | # For prior with confidence 20 | alpha = closed_form_matting.closed_form_matting_with_prior( 21 | image, prior, prior_confidence, optional_const_mask) 22 | 23 | # To get Matting Laplacian for image 24 | laplacian = compute_laplacian(image, optional_const_mask) 25 | ``` 26 | 2. From command line: 27 | ``` 28 | # Scribbles input 29 | ./closed_form_matting.py input_image.png -s scribbles_image.png -o output_alpha.png 30 | 31 | # Trimap input 32 | ./closed_form_matting.py input_image.png -t scribbles_image.png -o output_alpha.png 33 | 34 | # Add flag --solve-fg to compute foreground color and output RGBA image instead 35 | # of alpha. 36 | ``` 37 | """ 38 | 39 | from __future__ import division 40 | 41 | import logging 42 | 43 | import cv2 44 | import numpy as np 45 | from numpy.lib.stride_tricks import as_strided 46 | import scipy.sparse 47 | import scipy.sparse.linalg 48 | 49 | 50 | def _rolling_block(A, block=(3, 3)): 51 | """Applies sliding window to given matrix.""" 52 | shape = (A.shape[0] - block[0] + 1, A.shape[1] - block[1] + 1) + block 53 | strides = (A.strides[0], A.strides[1]) + A.strides 54 | return as_strided(A, shape=shape, strides=strides) 55 | 56 | 57 | def compute_laplacian(img, mask=None, eps=10**(-7), win_rad=1): 58 | """Computes Matting Laplacian for a given image. 59 | 60 | Args: 61 | img: 3-dim numpy matrix with input image 62 | mask: mask of pixels for which Laplacian will be computed. 63 | If not set Laplacian will be computed for all pixels. 64 | eps: regularization parameter controlling alpha smoothness 65 | from Eq. 12 of the original paper. Defaults to 1e-7. 66 | win_rad: radius of window used to build Matting Laplacian (i.e. 67 | radius of omega_k in Eq. 12). 68 | Returns: sparse matrix holding Matting Laplacian. 69 | """ 70 | 71 | win_size = (win_rad * 2 + 1) ** 2 72 | h, w, d = img.shape 73 | # Number of window centre indices in h, w axes 74 | c_h, c_w = h - 2 * win_rad, w - 2 * win_rad 75 | win_diam = win_rad * 2 + 1 76 | 77 | indsM = np.arange(h * w).reshape((h, w)) 78 | ravelImg = img.reshape(h * w, d) 79 | win_inds = _rolling_block(indsM, block=(win_diam, win_diam)) 80 | 81 | win_inds = win_inds.reshape(c_h, c_w, win_size) 82 | if mask is not None: 83 | mask = cv2.dilate( 84 | mask.astype(np.uint8), 85 | np.ones((win_diam, win_diam), np.uint8) 86 | ).astype(np.bool) 87 | win_mask = np.sum(mask.ravel()[win_inds], axis=2) 88 | win_inds = win_inds[win_mask > 0, :] 89 | else: 90 | win_inds = win_inds.reshape(-1, win_size) 91 | 92 | 93 | winI = ravelImg[win_inds] 94 | 95 | win_mu = np.mean(winI, axis=1, keepdims=True) 96 | win_var = np.einsum('...ji,...jk ->...ik', winI, winI) / win_size - np.einsum('...ji,...jk ->...ik', win_mu, win_mu) 97 | 98 | inv = np.linalg.inv(win_var + (eps/win_size)*np.eye(3)) 99 | 100 | X = np.einsum('...ij,...jk->...ik', winI - win_mu, inv) 101 | vals = np.eye(win_size) - (1.0/win_size)*(1 + np.einsum('...ij,...kj->...ik', X, winI - win_mu)) 102 | 103 | nz_indsCol = np.tile(win_inds, win_size).ravel() 104 | nz_indsRow = np.repeat(win_inds, win_size).ravel() 105 | nz_indsVal = vals.ravel() 106 | L = scipy.sparse.coo_matrix((nz_indsVal, (nz_indsRow, nz_indsCol)), shape=(h*w, h*w)) 107 | return L 108 | 109 | 110 | def closed_form_matting_with_prior(image, prior, prior_confidence, consts_map=None): 111 | """Applies closed form matting with prior alpha map to image. 112 | 113 | Args: 114 | image: 3-dim numpy matrix with input image. 115 | prior: matrix of same width and height as input image holding apriori alpha map. 116 | prior_confidence: matrix of the same shape as prior hodling confidence of prior alpha. 117 | consts_map: binary mask of pixels that aren't expected to change due to high 118 | prior confidence. 119 | 120 | Returns: 2-dim matrix holding computed alpha map. 121 | """ 122 | 123 | assert image.shape[:2] == prior.shape, ('prior must be 2D matrix with height and width equal ' 124 | 'to image.') 125 | assert image.shape[:2] == prior_confidence.shape, ('prior_confidence must be 2D matrix with ' 126 | 'height and width equal to image.') 127 | assert (consts_map is None) or image.shape[:2] == consts_map.shape, ( 128 | 'consts_map must be 2D matrix with height and width equal to image.') 129 | 130 | logging.info('Computing Matting Laplacian.') 131 | laplacian = compute_laplacian(image, ~consts_map if consts_map is not None else None) 132 | 133 | confidence = scipy.sparse.diags(prior_confidence.ravel()) 134 | logging.info('Solving for alpha.') 135 | solution = scipy.sparse.linalg.spsolve( 136 | laplacian + confidence, 137 | prior.ravel() * prior_confidence.ravel() 138 | ) 139 | alpha = np.minimum(np.maximum(solution.reshape(prior.shape), 0), 1) 140 | return alpha 141 | 142 | 143 | def closed_form_matting_with_trimap(image, trimap, trimap_confidence=100.0): 144 | """Apply Closed-Form matting to given image using trimap.""" 145 | 146 | assert image.shape[:2] == trimap.shape, ('trimap must be 2D matrix with height and width equal ' 147 | 'to image.') 148 | consts_map = (trimap < 0.1) | (trimap > 0.9) 149 | return closed_form_matting_with_prior(image, trimap, trimap_confidence * consts_map, consts_map) 150 | 151 | 152 | def closed_form_matting_with_scribbles(image, scribbles, scribbles_confidence=100.0): 153 | """Apply Closed-Form matting to given image using scribbles image.""" 154 | 155 | assert image.shape == scribbles.shape, 'scribbles must have exactly same shape as image.' 156 | prior = np.sign(np.sum(scribbles - image, axis=2)) / 2 + 0.5 157 | consts_map = prior != 0.5 158 | return closed_form_matting_with_prior( 159 | image, 160 | prior, 161 | scribbles_confidence * consts_map, 162 | consts_map 163 | ) 164 | 165 | 166 | closed_form_matting = closed_form_matting_with_trimap 167 | 168 | def main(): 169 | import argparse 170 | 171 | logging.basicConfig(level=logging.INFO) 172 | arg_parser = argparse.ArgumentParser(description=__doc__) 173 | arg_parser.add_argument('image', type=str, help='input image') 174 | 175 | arg_parser.add_argument('-t', '--trimap', type=str, help='input trimap') 176 | arg_parser.add_argument('-s', '--scribbles', type=str, help='input scribbles') 177 | arg_parser.add_argument('-o', '--output', type=str, required=True, help='output image') 178 | arg_parser.add_argument( 179 | '--solve-fg', dest='solve_fg', action='store_true', 180 | help='compute foreground color and output RGBA image' 181 | ) 182 | args = arg_parser.parse_args() 183 | 184 | image = cv2.imread(args.image, cv2.IMREAD_COLOR) / 255.0 185 | 186 | if args.scribbles: 187 | scribbles = cv2.imread(args.scribbles, cv2.IMREAD_COLOR) / 255.0 188 | alpha = closed_form_matting_with_scribbles(image, scribbles) 189 | elif args.trimap: 190 | trimap = cv2.imread(args.trimap, cv2.IMREAD_GRAYSCALE) / 255.0 191 | alpha = closed_form_matting_with_trimap(image, trimap) 192 | else: 193 | logging.error('Either trimap or scribbles must be specified.') 194 | arg_parser.print_help() 195 | exit(-1) 196 | 197 | if args.solve_fg: 198 | from closed_form_matting.solve_foreground_background import solve_foreground_background 199 | foreground, _ = solve_foreground_background(image, alpha) 200 | output = np.concatenate((foreground, alpha[:, :, np.newaxis]), axis=2) 201 | else: 202 | output = alpha 203 | 204 | cv2.imwrite(args.output, output * 255.0) 205 | 206 | 207 | if __name__ == "__main__": 208 | main() 209 | -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/closed_form_matting/solve_foreground_background.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Computes foreground and background images given source image and transparency map. 3 | 4 | This module implements foreground and background reconstruction method described in Section 7 of: 5 | Levin, Anat, Dani Lischinski, and Yair Weiss. "A closed-form solution to natural image 6 | matting." IEEE Transactions on Pattern Analysis and Machine Intelligence 30.2 (2008): 228-242. 7 | 8 | Please note, that the cost-function optimized by this code doesn't perfectly match Eq. 19 of the 9 | paper, since our implementation mimics `solveFB.m` Matlab function provided by the authors of the 10 | original paper (this implementation is 11 | availale at http://people.csail.mit.edu/alevin/matting.tar.gz). 12 | 13 | The code can be used in two ways: 14 | 1. By importing solve_foregound_background in your code: 15 | ``` 16 | from solve_foregound_background import solve_foregound_background 17 | ... 18 | foreground, background = solve_foregound_background(image, alpha) 19 | ``` 20 | 2. From command line: 21 | ``` 22 | ./solve_foregound_background.py image.png alpha.png foreground.png background.png 23 | ``` 24 | 25 | Authors: Mikhail Erofeev, Yury Gitman. 26 | """ 27 | 28 | import numpy as np 29 | import scipy.sparse 30 | import scipy.sparse.linalg 31 | 32 | # CONST_ALPHA_MARGIN = 0.02 33 | CONST_ALPHA_MARGIN = 0. 34 | 35 | 36 | def __spdiagonal(diag): 37 | """Produces sparse matrix with given vector on its main diagonal.""" 38 | return scipy.sparse.spdiags(diag, (0,), len(diag), len(diag)) 39 | 40 | 41 | def get_grad_operator(mask): 42 | """Returns sparse matrix computing horizontal, vertical, and two diagonal gradients.""" 43 | horizontal_left = np.ravel_multi_index(np.nonzero(mask[:, :-1] | mask[:, 1:]), mask.shape) 44 | horizontal_right = horizontal_left + 1 45 | 46 | vertical_top = np.ravel_multi_index(np.nonzero(mask[:-1, :] | mask[1:, :]), mask.shape) 47 | vertical_bottom = vertical_top + mask.shape[1] 48 | 49 | diag_main_1 = np.ravel_multi_index(np.nonzero(mask[:-1, :-1] | mask[1:, 1:]), mask.shape) 50 | diag_main_2 = diag_main_1 + mask.shape[1] + 1 51 | 52 | diag_sub_1 = np.ravel_multi_index(np.nonzero(mask[:-1, 1:] | mask[1:, :-1]), mask.shape) + 1 53 | diag_sub_2 = diag_sub_1 + mask.shape[1] - 1 54 | 55 | indices = np.stack(( 56 | np.concatenate((horizontal_left, vertical_top, diag_main_1, diag_sub_1)), 57 | np.concatenate((horizontal_right, vertical_bottom, diag_main_2, diag_sub_2)) 58 | ), axis=-1) 59 | return scipy.sparse.coo_matrix( 60 | (np.tile([-1, 1], len(indices)), (np.arange(indices.size) // 2, indices.flatten())), 61 | shape=(len(indices), mask.size)) 62 | 63 | 64 | def get_const_conditions(image, alpha): 65 | """Returns sparse diagonal matrix and vector encoding color prior conditions.""" 66 | falpha = alpha.flatten() 67 | weights = ( 68 | (falpha < CONST_ALPHA_MARGIN) * 100.0 + 69 | 0.03 * (1.0 - falpha) * (falpha < 0.3) + 70 | 0.01 * (falpha > 1.0 - CONST_ALPHA_MARGIN) 71 | ) 72 | conditions = __spdiagonal(weights) 73 | 74 | mask = falpha < 1.0 - CONST_ALPHA_MARGIN 75 | right_hand = (weights * mask)[:, np.newaxis] * image.reshape((alpha.size, -1)) 76 | return conditions, right_hand 77 | 78 | 79 | def solve_foreground_background(image, alpha): 80 | """Compute foreground and background image given source image and transparency map.""" 81 | 82 | consts = (alpha < CONST_ALPHA_MARGIN) | (alpha > 1.0 - CONST_ALPHA_MARGIN) 83 | grad = get_grad_operator(~consts) 84 | grad_weights = np.power(np.abs(grad * alpha.flatten()), 0.5) 85 | 86 | grad_only_positive = grad.maximum(0) 87 | grad_weights_f = grad_weights + 0.003 * grad_only_positive * (1.0 - alpha.flatten()) 88 | grad_weights_b = grad_weights + 0.003 * grad_only_positive * alpha.flatten() 89 | 90 | grad_pad = scipy.sparse.coo_matrix(grad.shape) 91 | 92 | smoothness_conditions = scipy.sparse.vstack(( 93 | scipy.sparse.hstack((__spdiagonal(grad_weights_f) * grad, grad_pad)), 94 | scipy.sparse.hstack((grad_pad, __spdiagonal(grad_weights_b) * grad)) 95 | )) 96 | 97 | composite_conditions = scipy.sparse.hstack(( 98 | __spdiagonal(alpha.flatten()), 99 | __spdiagonal(1.0 - alpha.flatten()) 100 | )) 101 | 102 | const_conditions_f, b_const_f = get_const_conditions(image, 1.0 - alpha) 103 | const_conditions_b, b_const_b = get_const_conditions(image, alpha) 104 | 105 | non_zero_conditions = scipy.sparse.vstack(( 106 | composite_conditions, 107 | scipy.sparse.hstack(( 108 | const_conditions_f, 109 | scipy.sparse.coo_matrix(const_conditions_f.shape) 110 | )), 111 | scipy.sparse.hstack(( 112 | scipy.sparse.coo_matrix(const_conditions_b.shape), 113 | const_conditions_b 114 | )) 115 | )) 116 | 117 | b_composite = image.reshape(alpha.size, -1) 118 | 119 | right_hand = non_zero_conditions.transpose() * np.concatenate((b_composite, 120 | b_const_f, 121 | b_const_b)) 122 | 123 | conditons = scipy.sparse.vstack(( 124 | non_zero_conditions, 125 | smoothness_conditions 126 | )) 127 | left_hand = conditons.transpose() * conditons 128 | 129 | solution = scipy.sparse.linalg.spsolve(left_hand, right_hand).reshape(2, *image.shape) 130 | foreground = solution[0, :, :, :].reshape(*image.shape) 131 | background = solution[1, :, :, :].reshape(*image.shape) 132 | return foreground, background 133 | 134 | 135 | def main(): 136 | """Parse command line arguments and apply solve_foregound_background.""" 137 | 138 | import argparse 139 | import cv2 140 | arg_parser = argparse.ArgumentParser(description=__doc__) 141 | arg_parser.add_argument('image', type=str) 142 | arg_parser.add_argument('alpha', type=str) 143 | arg_parser.add_argument('foreground', type=str) 144 | arg_parser.add_argument('background', type=str, default=None, nargs='?') 145 | args = arg_parser.parse_args() 146 | 147 | image = cv2.imread(args.image) / 255.0 148 | alpha = cv2.imread(args.alpha, 0) / 255.0 149 | foreground, background = solve_foreground_background(image, alpha) 150 | cv2.imwrite(args.foreground, foreground * 255.0) 151 | if args.background: 152 | cv2.imwrite(args.background, background * 255.0) 153 | 154 | 155 | if __name__ == "__main__": 156 | main() 157 | -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | opencv_python 4 | -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Setting up closed-form-matting package during pip installation""" 3 | 4 | import os 5 | import re 6 | 7 | import setuptools 8 | 9 | # Project root directory 10 | root_dir = os.path.dirname(__file__) 11 | 12 | # Get version string from __init__.py in the package 13 | with open(os.path.join(root_dir, 'closed_form_matting', '__init__.py')) as f: 14 | version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1) 15 | 16 | # Get dependency list from requirements.txt 17 | with open(os.path.join(root_dir, 'requirements.txt')) as f: 18 | requirements = f.read().split() 19 | 20 | setuptools.setup( 21 | name='closed-form-matting', 22 | version=version, 23 | author='Marco Forte', 24 | author_email='fortemarco.irl@gmail.com', 25 | maintainer='Marco Forte', 26 | maintainer_email='fortemarco.irl@gmail.com', 27 | url='https://github.com/MarcoForte/closed-form-matting', 28 | description='A closed-form solution to natural image matting', 29 | long_description=open(os.path.join(root_dir, 'README.md')).read(), 30 | long_description_content_type='text/markdown', 31 | packages=setuptools.find_packages(), 32 | classifiers=[ 33 | 'Development Status :: 4 - Beta', 34 | 'Intended Audience :: Science/Research', 35 | 'Topic :: Scientific/Engineering :: Image Processing', 36 | 'License :: OSI Approved :: MIT License', 37 | 'Programming Language :: Python :: 3', 38 | ], 39 | keywords=['closed-form matting', 'image matting', 'image processing'], 40 | license='MIT', 41 | python_requires='>=3.5', 42 | install_requires=requirements, 43 | entry_points={ 44 | 'console_scripts': [ 45 | 'closed-form-matting=closed_form_matting.closed_form_matting:main', 46 | 'solve-foreground-background=closed_form_matting.solve_foreground_background:main', 47 | ], 48 | }, 49 | ) 50 | -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/test_matting.py: -------------------------------------------------------------------------------- 1 | """Tests for Closed-Form matting and foreground/background solver.""" 2 | import unittest 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | import closed_form_matting 8 | 9 | class TestMatting(unittest.TestCase): 10 | def test_solution_close_to_original_implementation(self): 11 | image = cv2.imread('testdata/source.png', cv2.IMREAD_COLOR) / 255.0 12 | scribles = cv2.imread('testdata/scribbles.png', cv2.IMREAD_COLOR) / 255.0 13 | 14 | alpha = closed_form_matting.closed_form_matting_with_scribbles(image, scribles) 15 | foreground, background = closed_form_matting.solve_foreground_background(image, alpha) 16 | 17 | matlab_alpha = cv2.imread('testdata/matlab_alpha.png', cv2.IMREAD_GRAYSCALE) / 255.0 18 | matlab_foreground = cv2.imread('testdata/matlab_foreground.png', cv2.IMREAD_COLOR) / 255.0 19 | matlab_background = cv2.imread('testdata/matlab_background.png', cv2.IMREAD_COLOR) / 255.0 20 | 21 | sad_alpha = np.mean(np.abs(alpha - matlab_alpha)) 22 | sad_foreground = np.mean(np.abs(foreground - matlab_foreground)) 23 | sad_background = np.mean(np.abs(background - matlab_background)) 24 | 25 | self.assertLess(sad_alpha, 1e-2) 26 | self.assertLess(sad_foreground, 1e-2) 27 | self.assertLess(sad_background, 1e-2) 28 | 29 | def test_matting_with_trimap(self): 30 | image = cv2.imread('testdata/source.png', cv2.IMREAD_COLOR) / 255.0 31 | trimap = cv2.imread('testdata/trimap.png', cv2.IMREAD_GRAYSCALE) / 255.0 32 | 33 | alpha = closed_form_matting.closed_form_matting_with_trimap(image, trimap) 34 | 35 | reference_alpha = cv2.imread('testdata/output_alpha.png', cv2.IMREAD_GRAYSCALE) / 255.0 36 | 37 | sad_alpha = np.mean(np.abs(alpha - reference_alpha)) 38 | self.assertLess(sad_alpha, 1e-3) 39 | -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/testdata/matlab_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/matlab_alpha.png -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/testdata/matlab_background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/matlab_background.png -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/testdata/matlab_foreground.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/matlab_foreground.png -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/testdata/output_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/output_alpha.png -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/testdata/output_background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/output_background.png -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/testdata/output_foreground.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/output_foreground.png -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/testdata/scribbles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/scribbles.png -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/testdata/source.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/source.png -------------------------------------------------------------------------------- /utils/tmp/closed_form_matting/testdata/trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/trimap.png -------------------------------------------------------------------------------- /utils/tmp/group_weight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.models_TCVOM.FBA import layers_WS 4 | 5 | def group_weight(module, lr_encoder, lr_decoder, WD): 6 | group_decay = { 'encoder': [], 'decoder':[]} 7 | group_bias = { 'encoder': [], 'decoder':[]} 8 | group_GN = { 'encoder': [], 'decoder':[]} 9 | 10 | 11 | for name, m in module.named_modules(): 12 | # if hasattr(m, 'requires_grad'): 13 | # if m.requires_grad: 14 | # continue 15 | 16 | part = 'decoder' 17 | if('encoder' in name): 18 | part = 'encoder' 19 | 20 | if isinstance(m, nn.Linear): 21 | group_decay[part].append(m.weight) 22 | if m.bias is not None: 23 | group_bias[part].append(m.bias) 24 | 25 | elif isinstance(m, nn.Conv2d) and m.weight.requires_grad: 26 | group_decay[part].append(m.weight) 27 | if m.bias is not None: 28 | group_bias[part].append(m.bias) 29 | elif isinstance(m, layers_WS.Conv2d) and m.weight.requires_grad: 30 | group_decay[part].append(m.weight) 31 | if m.bias is not None: 32 | group_bias[part].append(m.bias) 33 | 34 | elif isinstance(m, nn.GroupNorm): 35 | if m.weight is not None: 36 | group_GN[part].append(m.weight) 37 | if m.bias is not None: 38 | group_GN[part].append(m.bias) 39 | 40 | 41 | print(len(list(module.parameters())), len(group_decay['encoder']) + len(group_bias['encoder']) + len(group_GN['encoder']) + len(group_decay['decoder']) + len(group_bias['decoder']) + len(group_GN['decoder']) , len(list(module.modules()))) 42 | # assert len(list(module.parameters())) == len(group_decay) + len(group_bias) + len(group_GN) 43 | groups = [dict(params=group_decay['decoder'], lr =lr_decoder, weight_decay=WD), dict(params=group_bias['decoder'], lr=2*lr_decoder, weight_decay=0.0), dict(params=group_GN['decoder'], lr=lr_decoder, weight_decay=1e-5), 44 | dict(params=group_decay['encoder'], lr=lr_encoder, weight_decay=WD), dict(params=group_bias['encoder'], lr=2*lr_encoder, weight_decay=0.0), dict(params=group_GN['encoder'], lr=lr_encoder, weight_decay=1e-5)] 45 | 46 | # groups= [dict(params=module.decoder.conv_pred.parameters(), lr=lr, weight_decay=0.0)] 47 | return groups -------------------------------------------------------------------------------- /utils/tmp/metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | import time 8 | import skimage.measure 9 | 10 | from PIL import Image 11 | from scipy import ndimage 12 | from scipy.ndimage.morphology import distance_transform_edt 13 | from multiprocessing import Pool 14 | 15 | 16 | def findMaxConnectedRegion(x): 17 | assert len(x.shape) == 2 18 | cc, num = skimage.measure.label(x, connectivity=1, return_num=True) 19 | omega = np.zeros_like(x) 20 | if num > 0: 21 | # find the largest connected region 22 | max_id = np.argmax(np.bincount(cc.flatten())[1:]) + 1 23 | omega[cc == max_id] = 1 24 | return omega 25 | 26 | def genGaussKernel(sigma, q=2): 27 | pi = math.pi 28 | eps = 1e-2 29 | 30 | def gauss(x, sigma): 31 | return np.exp(-np.power(x,2)/(2*np.power(sigma,2))) / (sigma*np.sqrt(2*pi)) 32 | 33 | def dgauss(x, sigma): 34 | return -x * gauss(x,sigma) / np.power(sigma, 2) 35 | 36 | hsize = int(np.ceil(sigma*np.sqrt(-2*np.log(np.sqrt(2*pi)*sigma*eps)))) 37 | size = 2 * hsize + 1 38 | hx = np.zeros([size, size], dtype=np.float32) 39 | for i in range(size): 40 | for j in range(size): 41 | u, v = i-hsize, j-hsize 42 | hx[i,j] = gauss(u,sigma) * dgauss(v,sigma) 43 | 44 | hx = hx / np.sqrt(np.sum(np.power(np.abs(hx), 2))) 45 | hy = hx.transpose(1, 0) 46 | return hx, hy, size 47 | 48 | def calcOpticalFlow(frames): 49 | prev, curr = frames 50 | flow = cv2.calcOpticalFlowFarneback(prev.astype(np.uint8), curr.astype(np.uint8), None, 51 | 0.5, 5, 10, 2, 7, 1.5, 52 | cv2.OPTFLOW_FARNEBACK_GAUSSIAN) 53 | return flow 54 | 55 | 56 | class ImageFilter(nn.Module): 57 | def __init__(self, chn, kernel_size, weight, device): 58 | super(ImageFilter, self).__init__() 59 | self.kernel_size = kernel_size 60 | assert kernel_size == weight.size(-1) 61 | self.filter = nn.Conv2d(chn, chn, kernel_size, padding=0, bias=False) 62 | self.filter.weight = nn.Parameter(weight) 63 | self.device = device 64 | 65 | def pad(self, x): 66 | assert len(x.shape) == 3 67 | x = x.unsqueeze(-1).permute((0,3,1,2)) 68 | b, c, h, w = x.shape 69 | pad = self.kernel_size // 2 70 | y = torch.zeros([b, c, h+pad*2, w+pad*2]).to(self.device) 71 | y[:,:,0:pad,0:pad] = x[:,:,0:1,0:1].repeat(1,1,pad,pad) 72 | y[:,:,0:pad,w+pad:] = x[:,:,0:1,-1:].repeat(1,1,pad,pad) 73 | y[:,:,h+pad:,0:pad] = x[:,:,-1:,0:1].repeat(1,1,pad,pad) 74 | y[:,:,h+pad:,w+pad:] = x[:,:,-1:,-1:].repeat(1,1,pad,pad) 75 | 76 | y[:,:,0:pad,pad:w+pad] = x[:,:,0:1,:].repeat(1,1,pad,1) 77 | y[:,:,pad:h+pad,0:pad] = x[:,:,:,0:1].repeat(1,1,1,pad) 78 | y[:,:,h+pad:,pad:w+pad] = x[:,:,-1:,:].repeat(1,1,pad,1) 79 | y[:,:,pad:h+pad,w+pad:] = x[:,:,:,-1:].repeat(1,1,1,pad) 80 | 81 | y[:,:,pad:h+pad, pad:w+pad] = x 82 | return y 83 | 84 | def forward(self, x): 85 | y = self.filter(self.pad(x)) 86 | return y 87 | 88 | 89 | class BatchMetric(object): 90 | def __init__(self, device, grad_sigma=1.4, grad_q=2, 91 | conn_step=0.1, conn_thresh=0.5, conn_theta=0.15, conn_p=1): 92 | # parameters for connectivity 93 | self.conn_step = conn_step 94 | self.conn_thresh = conn_thresh 95 | self.conn_theta = conn_theta 96 | self.conn_p = conn_p 97 | self.device = device 98 | 99 | hx, hy, size = genGaussKernel(grad_sigma, grad_q) 100 | self.hx = hx 101 | self.hy = hy 102 | self.kernel_size = size 103 | kx = self.hx[::-1, ::-1].copy() 104 | ky = self.hy[::-1, ::-1].copy() 105 | kernel_x = torch.from_numpy(kx).unsqueeze(0).unsqueeze(0) 106 | kernel_y = torch.from_numpy(ky).unsqueeze(0).unsqueeze(0) 107 | self.fx = ImageFilter(1, self.kernel_size, kernel_x, self.device).cuda(self.device) 108 | self.fy = ImageFilter(1, self.kernel_size, kernel_y, self.device).cuda(self.device) 109 | 110 | def run(self, input, target, mask=None): 111 | torch.cuda.empty_cache() 112 | input_t = torch.from_numpy(input.astype(np.float32)).to(self.device) 113 | target_t = torch.from_numpy(target.astype(np.float32)).to(self.device) 114 | if mask is None: 115 | mask = torch.zeros_like(target_t).to(self.device) 116 | mask[(target_t>0) * (target_t<255)] = 1 117 | else: 118 | mask = torch.from_numpy(mask.astype(np.float32)).to(self.device) 119 | mask = (mask == 128).float() 120 | sad = self.BatchSAD(input_t, target_t, mask) 121 | mse = self.BatchMSE(input_t, target_t, mask) 122 | grad = self.BatchGradient(input_t, target_t, mask) 123 | conn = self.BatchConnectivity(input_t, target_t, mask) 124 | return sad, mse, grad, conn 125 | 126 | def run_video(self, input, target, mask=None): 127 | torch.cuda.empty_cache() 128 | input_t = torch.from_numpy(input.astype(np.float32)).to(self.device) 129 | target_t = torch.from_numpy(target.astype(np.float32)).to(self.device) 130 | if mask is None: 131 | mask = torch.zeros_like(target_t).to(self.device) 132 | mask[(target_t>0) * (target_t<255)] = 1 133 | else: 134 | mask = torch.from_numpy(mask.astype(np.float32)).to(self.device) 135 | mask = (mask == 128).float() 136 | errs, nums = [], [] 137 | err, n = self.SSDA(input_t, target_t, mask) 138 | errs.append(err) 139 | nums.append(n) 140 | err, n = self.dtSSD(input_t, target_t, mask) 141 | errs.append(err) 142 | nums.append(n) 143 | err, n = self.MESSDdt(input_t, target_t, mask) 144 | errs.append(err) 145 | nums.append(n) 146 | return errs, nums 147 | 148 | def run_metric(self, metric, input, target, mask=None): 149 | torch.cuda.empty_cache() 150 | input_t = torch.from_numpy(input.astype(np.float32)).to(self.device) 151 | target_t = torch.from_numpy(target.astype(np.float32)).to(self.device) 152 | if mask is None: 153 | mask = torch.zeros_like(target_t).to(self.device) 154 | mask[(target_t>0) * (target_t<255)] = 1 155 | else: 156 | mask = torch.from_numpy(mask.astype(np.float32)).to(self.device) 157 | mask = (mask == 128).float() 158 | 159 | if metric == 'sad': 160 | ret = self.BatchSAD(input_t, target_t, mask) 161 | elif metric == 'mse': 162 | ret = self.BatchMSE(input_t, target_t, mask) 163 | elif metric == 'grad': 164 | ret = self.BatchGradient(input_t, target_t, mask) 165 | elif metric == 'conn': 166 | ret = self.BatchConnectivity(input_t, target_t, mask) 167 | elif metric == 'ssda': 168 | ret = self.SSDA(input_t, target_t, mask) 169 | elif metric == 'dtssd': 170 | ret = self.dtSSD(input_t, target_t, mask) 171 | elif metric == 'messddt': 172 | ret = self.MESSDdt(input_t, target_t, mask) 173 | else: 174 | raise NotImplementedError 175 | return ret 176 | 177 | def BatchSAD(self, pred, target, mask): 178 | B = target.size(0) 179 | error_map = (pred - target).abs() / 255. 180 | batch_loss = (error_map * mask).view(B, -1).sum(dim=-1) 181 | batch_loss = batch_loss / 1000. 182 | return batch_loss.data.cpu().numpy() 183 | 184 | def BatchMSE(self, pred, target, mask): 185 | B = target.size(0) 186 | error_map = (pred-target) / 255. 187 | batch_loss = (error_map.pow(2) * mask).view(B, -1).sum(dim=-1) 188 | batch_loss = batch_loss / (mask.view(B, -1).sum(dim=-1) + 1.) 189 | return batch_loss.data.cpu().numpy() 190 | 191 | def BatchGradient(self, pred, target, mask): 192 | B = target.size(0) 193 | pred = pred / 255. 194 | target = target / 255. 195 | 196 | pred_x_t = self.fx(pred).squeeze(1) 197 | pred_y_t = self.fy(pred).squeeze(1) 198 | target_x_t = self.fx(target).squeeze(1) 199 | target_y_t = self.fy(target).squeeze(1) 200 | pred_amp = (pred_x_t.pow(2) + pred_y_t.pow(2)).sqrt() 201 | target_amp = (target_x_t.pow(2) + target_y_t.pow(2)).sqrt() 202 | error_map = (pred_amp - target_amp).pow(2) 203 | batch_loss = (error_map * mask).view(B, -1).sum(dim=-1) 204 | return batch_loss.data.cpu().numpy() 205 | 206 | def BatchConnectivity(self, pred, target, mask): 207 | step = self.conn_step 208 | theta = self.conn_theta 209 | 210 | pred = pred / 255. 211 | target = target / 255. 212 | B, dimy, dimx = pred.shape 213 | thresh_steps = torch.arange(0, 1+step, step).to(self.device) 214 | l_map = torch.ones_like(pred).to(self.device)*(-1) 215 | pool = Pool(B) 216 | for i in range(1, len(thresh_steps)): 217 | pred_alpha_thresh = pred>=thresh_steps[i] 218 | target_alpha_thresh = target>=thresh_steps[i] 219 | mask_i = pred_alpha_thresh * target_alpha_thresh 220 | omegas = [] 221 | items = [mask_ij.data.cpu().numpy() for mask_ij in mask_i] 222 | for omega in pool.imap(findMaxConnectedRegion, items): 223 | omegas.append(omega) 224 | omegas = torch.from_numpy(np.array(omegas)).to(self.device) 225 | flag = (l_map==-1) * (omegas==0) 226 | l_map[flag==1] = thresh_steps[i-1] 227 | l_map[l_map==-1] = 1 228 | pred_d = pred - l_map 229 | target_d = target - l_map 230 | pred_phi = 1 - pred_d*(pred_d>=theta).float() 231 | target_phi = 1 - target_d*(target_d>=theta).float() 232 | batch_loss = ((pred_phi-target_phi).abs()*mask).view([B, -1]).sum(-1) 233 | pool.close() 234 | return batch_loss.data.cpu().numpy() 235 | 236 | def GaussianGradient(self, mat): 237 | gx = np.zeros_like(mat) 238 | gy = np.zeros_like(mat) 239 | for i in range(mat.shape[0]): 240 | gx[i, ...] = ndimage.filters.convolve(mat[i], self.hx, mode='nearest') 241 | gy[i, ...] = ndimage.filters.convolve(mat[i], self.hy, mode='nearest') 242 | return gx, gy 243 | 244 | def SSDA(self, pred, target, mask=None): 245 | B, h, w = target.shape 246 | pred = pred / 255. 247 | target = target / 255. 248 | error = ((pred-target).pow(2) * mask).view(B, -1).sum(dim=1).sqrt() 249 | num = mask.view(B, -1).sum(dim=1) + 1. 250 | return error.data.cpu().numpy(), num.data.cpu().numpy() 251 | 252 | def dtSSD(self, pred, target, mask=None): 253 | B, h, w = target.shape 254 | pred = pred / 255. 255 | target = target / 255. 256 | pred_0 = pred[:-1, ...] 257 | pred_1 = pred[1:, ...] 258 | target_0 = target[:-1, ...] 259 | target_1 = target[1:, ...] 260 | mask_0 = mask[:-1, ...] 261 | error_map = ((pred_1-pred_0) - (target_1-target_0)).pow(2) 262 | error = (error_map * mask_0).view(mask_0.shape[0], -1).sum(dim=1).sqrt() 263 | num = mask_0.view(mask_0.shape[0], -1).sum(dim=1) + 1. 264 | return error.data.cpu().numpy(), num.data.cpu().numpy() 265 | 266 | def MESSDdt(self, pred, target, mask=None): 267 | B, h, w = target.shape 268 | 269 | pool = Pool(B) 270 | flows = [] 271 | items = [t for t in target.data.cpu().numpy()] 272 | for flow in pool.imap(calcOpticalFlow, zip(items[:-1], items[1:])): 273 | flows.append(flow) 274 | flow = torch.from_numpy(np.rint(np.array(flows)).astype(np.int64)).to(self.device) 275 | pool.close() 276 | 277 | pred = pred / 255. 278 | target = target / 255. 279 | pred_0 = pred[:-1, ...] 280 | pred_1 = pred[1:, ...] 281 | target_0 = target[:-1, ...] 282 | target_1 = target[1:, ...] 283 | mask_0 = mask[:-1, ...] 284 | mask_1 = mask[1:, ...] 285 | 286 | B, h, w = target_0.shape 287 | x = torch.arange(0, w).to(self.device) 288 | y = torch.arange(0, h).to(self.device) 289 | xx, yy = torch.meshgrid([y, x]) 290 | coords = torch.stack([yy, xx], dim=2).unsqueeze(0).repeat((B, 1, 1, 1)) 291 | coords_n = (coords + flow) 292 | coords_y = coords_n[..., 0].clamp(0, h-1) 293 | coords_x = coords_n[..., 1].clamp(0, w-1) 294 | indices = coords_y * w + coords_x 295 | pred_1 = torch.take(pred_1, indices) 296 | target_1 = torch.take(target_1, indices) 297 | mask_1 = torch.take(mask_1, indices) 298 | 299 | error_map = (pred_0-target_0).pow(2) * mask_0 - (pred_1-target_1).pow(2) * mask_1 300 | error = error_map.abs().view(mask_0.shape[0], -1).sum(dim=1) 301 | num = mask_0.view(mask_0.shape[0], -1).sum(dim=1) + 1. 302 | return error.data.cpu().numpy(), num.data.cpu().numpy() -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from pathlib import Path 4 | import cv2 as cv 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.distributed as torch_dist 11 | 12 | def dt(a): 13 | # a: tensor, [B, S, H, W] 14 | ac = a.cpu().numpy() 15 | b, s = a.shape[:2] 16 | z = [] 17 | for i in range(b): 18 | y = [] 19 | for j in range(s): 20 | x = ac[i,j] 21 | y.append(cv.distanceTransform((x * 255).astype(np.uint8), cv.DIST_L2, 0)) 22 | z.append(np.stack(y)) 23 | return torch.from_numpy(np.stack(z)).float().to(a.device) 24 | 25 | def trimap_transform(trimap): 26 | # trimap: tensor, [B, S, 2, H, W] 27 | b, s, _, h, w = trimap.shape 28 | 29 | clicks = torch.zeros((b, s, 6, h, w), device=trimap.device) 30 | for k in range(2): 31 | tk = trimap[:, :, k] 32 | if torch.sum(tk != 0) > 0: 33 | dt_mask = -dt(1. - tk)**2 34 | L = 320 35 | clicks[:, :, 3*k] = torch.exp(dt_mask / (2 * ((0.02 * L)**2))) 36 | clicks[:, :, 3*k+1] = torch.exp(dt_mask / (2 * ((0.08 * L)**2))) 37 | clicks[:, :, 3*k+2] = torch.exp(dt_mask / (2 * ((0.16 * L)**2))) 38 | 39 | return clicks 40 | 41 | def torch_barrier(): 42 | if torch_dist.is_initialized(): 43 | torch_dist.barrier() 44 | 45 | def reduce_tensor(inp): 46 | """ 47 | Reduce the loss from all processes so that 48 | ALL PROCESSES has the averaged results. 49 | """ 50 | if torch_dist.is_initialized(): 51 | world_size = torch_dist.get_world_size() 52 | if world_size < 2: 53 | return inp 54 | with torch.no_grad(): 55 | reduced_inp = inp 56 | torch.distributed.all_reduce(reduced_inp) 57 | torch.distributed.barrier() 58 | return reduced_inp / world_size 59 | return inp 60 | 61 | def print_loss_dict(loss, save=None): 62 | s = '' 63 | for key in sorted(loss.keys()): 64 | s += '{}: {:.6f}\n'.format(key, loss[key]) 65 | print (s) 66 | if save is not None: 67 | with open(save, 'w') as f: 68 | f.write(s) 69 | 70 | def coords_grid(batch, ht, wd): 71 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 72 | coords = torch.stack(coords[::-1], axis=0) 73 | return coords.unsqueeze(0).repeat(batch, 1, 1, 1) 74 | 75 | def grid_sampler(img, coords, mode='bilinear'): 76 | """ Wrapper for grid_sample, uses pixel coordinates 77 | img: [B, C, H, W] 78 | coords: [B, 2, H, W] 79 | """ 80 | H, W = img.shape[-2:] 81 | xgrid, ygrid = coords.split(1, dim=1) 82 | xgrid = 2*xgrid/(W-1) - 1 83 | ygrid = 2*ygrid/(H-1) - 1 84 | 85 | grid = torch.cat([xgrid, ygrid], dim=1).permute(0, 2, 3, 1) 86 | img = F.grid_sample(img, grid, mode=mode, align_corners=True) 87 | 88 | return img 89 | 90 | def flow_dt(a, ha, gt, hgt, flow, trimask, metric=False, cuda=True): 91 | ''' 92 | All tensors in [B, C, H, W] 93 | a: current prediction 94 | gt: current groundtruth 95 | ha: adjacent frame prediction 96 | hgt: adjacent frame groundtruth 97 | flow: optical flow from current frame to adjacent frame 98 | trimask: current frame trimask 99 | ''' 100 | # Warp ha back to a and hgt back to gt 101 | with torch.no_grad(): 102 | B, _, H, W = a.shape 103 | mask = torch.isnan(flow) # B, 1, H, W 104 | coords = coords_grid(B, H, W) # B, 2, H, W 105 | if cuda: 106 | coords = coords.to(torch.cuda.current_device()) 107 | flow[mask] = 0 108 | flow_coords = coords + flow 109 | mask = (~mask[:, :1, :, :]) * trimask.bool() 110 | valid = mask.sum() 111 | if valid == 0: 112 | if metric: 113 | return valid.float(), valid.float(), valid.float() 114 | else: 115 | return valid.float() 116 | 117 | pgt = grid_sampler(hgt, flow_coords) 118 | pa = grid_sampler(ha, flow_coords) 119 | error = torch.abs((a[mask] - gt[mask]) - (pa[mask] - pgt[mask])) # L1 instead of L2 120 | if metric: 121 | error2 = torch.abs((a[mask] - gt[mask]) ** 2 - (pa[mask] - pgt[mask]) ** 2) 122 | return error.sum(), error2.sum(), valid 123 | return error.mean() 124 | 125 | class AverageMeter(object): 126 | """Computes and stores the average and current value""" 127 | 128 | def __init__(self): 129 | self.initialized = False 130 | self.val = None 131 | self.avg = None 132 | self.sum = None 133 | self.count = None 134 | 135 | def initialize(self, val, weight): 136 | self.val = val 137 | self.avg = val 138 | self.sum = val * weight 139 | self.count = weight 140 | self.initialized = True 141 | 142 | def update(self, val, weight=1): 143 | if not self.initialized: 144 | self.initialize(val, weight) 145 | else: 146 | self.add(val, weight) 147 | 148 | def add(self, val, weight): 149 | self.val = val 150 | self.sum += val * weight 151 | self.count += weight 152 | self.avg = self.sum / self.count 153 | 154 | def value(self): 155 | return self.val 156 | 157 | def average(self): 158 | return self.avg 159 | 160 | def create_logger(output_dir, cfg_name, phase='train'): 161 | root_output_dir = Path(output_dir) 162 | # set up logger 163 | if not root_output_dir.exists(): 164 | print('=> creating {}'.format(root_output_dir)) 165 | root_output_dir.mkdir() 166 | 167 | final_output_dir = root_output_dir / cfg_name 168 | 169 | print('=> creating {}'.format(final_output_dir)) 170 | final_output_dir.mkdir(parents=True, exist_ok=True) 171 | 172 | time_str = time.strftime('%Y-%m-%d-%H-%M') 173 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase) 174 | final_log_file = final_output_dir / log_file 175 | head = '%(asctime)-15s %(message)s' 176 | logging.basicConfig(filename=str(final_log_file), 177 | format=head) 178 | logger = logging.getLogger() 179 | logger.setLevel(logging.INFO) 180 | console = logging.StreamHandler() 181 | logging.getLogger('').addHandler(console) 182 | 183 | return logger, str(final_output_dir) 184 | 185 | def poly_lr(optimizer, base_lr, max_iters, cur_iters, power=0.9): 186 | lr = base_lr*((1-float(cur_iters)/max_iters)**(power)) 187 | optimizer.param_groups[0]['lr'] = lr 188 | return lr 189 | 190 | def const_lr(optimizer, base_lr, max_iters, cur_iters): 191 | return base_lr 192 | 193 | OPT_DICT = { 194 | 'adam': torch.optim.Adam, 195 | 'adamw': torch.optim.AdamW, 196 | 'sgd': torch.optim.SGD, 197 | } 198 | 199 | STR_DICT = { 200 | 'poly': poly_lr, 201 | 'const': const_lr, 202 | } 203 | --------------------------------------------------------------------------------