├── LICENSE
├── OTVM-teaser.jpg
├── README.md
├── config.py
├── dataset.py
├── demo
└── dove
│ ├── frames
│ ├── 00000.jpg
│ ├── 00001.jpg
│ ├── 00002.jpg
│ ├── 00003.jpg
│ ├── 00004.jpg
│ ├── 00005.jpg
│ ├── 00006.jpg
│ ├── 00007.jpg
│ ├── 00008.jpg
│ ├── 00009.jpg
│ └── 00010.jpg
│ └── trimap
│ └── 00000.png
├── eval.py
├── helpers.py
├── models
├── __init__.py
├── alpha
│ ├── FBA
│ │ ├── layers_WS.py
│ │ ├── models.py
│ │ ├── resnet_GN_WS.py
│ │ └── resnet_bn.py
│ ├── __init__.py
│ ├── common.py
│ └── model.py
└── trimap
│ ├── STM.py
│ ├── __init__.py
│ └── model.py
├── scripts
├── eval_s4.sh
├── eval_s4_demo.sh
├── train_s1_alpha.sh
├── train_s1_trimap.sh
├── train_s2_alpha.sh
├── train_s3.sh
└── train_s4.sh
├── train.py
├── train_s1_trimap.py
└── utils
├── loss_func.py
├── optimizer.py
├── tmp
├── __init__.py
├── augmentation.py
├── closed_form_matting
│ ├── .travis.yml
│ ├── LICENSE
│ ├── README.md
│ ├── closed_form_matting
│ │ ├── __init__.py
│ │ ├── closed_form_matting.py
│ │ └── solve_foreground_background.py
│ ├── requirements.txt
│ ├── setup.py
│ ├── test_matting.py
│ └── testdata
│ │ ├── matlab_alpha.png
│ │ ├── matlab_background.png
│ │ ├── matlab_foreground.png
│ │ ├── output_alpha.png
│ │ ├── output_background.png
│ │ ├── output_foreground.png
│ │ ├── scribbles.png
│ │ ├── source.png
│ │ └── trimap.png
├── group_weight.py
└── metric.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial-ShareAlike 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58 | Public License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License
63 | ("Public License"). To the extent this Public License may be
64 | interpreted as a contract, You are granted the Licensed Rights in
65 | consideration of Your acceptance of these terms and conditions, and the
66 | Licensor grants You such rights in consideration of benefits the
67 | Licensor receives from making the Licensed Material available under
68 | these terms and conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. BY-NC-SA Compatible License means a license listed at
88 | creativecommons.org/compatiblelicenses, approved by Creative
89 | Commons as essentially the equivalent of this Public License.
90 |
91 | d. Copyright and Similar Rights means copyright and/or similar rights
92 | closely related to copyright including, without limitation,
93 | performance, broadcast, sound recording, and Sui Generis Database
94 | Rights, without regard to how the rights are labeled or
95 | categorized. For purposes of this Public License, the rights
96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
97 | Rights.
98 |
99 | e. Effective Technological Measures means those measures that, in the
100 | absence of proper authority, may not be circumvented under laws
101 | fulfilling obligations under Article 11 of the WIPO Copyright
102 | Treaty adopted on December 20, 1996, and/or similar international
103 | agreements.
104 |
105 | f. Exceptions and Limitations means fair use, fair dealing, and/or
106 | any other exception or limitation to Copyright and Similar Rights
107 | that applies to Your use of the Licensed Material.
108 |
109 | g. License Elements means the license attributes listed in the name
110 | of a Creative Commons Public License. The License Elements of this
111 | Public License are Attribution, NonCommercial, and ShareAlike.
112 |
113 | h. Licensed Material means the artistic or literary work, database,
114 | or other material to which the Licensor applied this Public
115 | License.
116 |
117 | i. Licensed Rights means the rights granted to You subject to the
118 | terms and conditions of this Public License, which are limited to
119 | all Copyright and Similar Rights that apply to Your use of the
120 | Licensed Material and that the Licensor has authority to license.
121 |
122 | j. Licensor means the individual(s) or entity(ies) granting rights
123 | under this Public License.
124 |
125 | k. NonCommercial means not primarily intended for or directed towards
126 | commercial advantage or monetary compensation. For purposes of
127 | this Public License, the exchange of the Licensed Material for
128 | other material subject to Copyright and Similar Rights by digital
129 | file-sharing or similar means is NonCommercial provided there is
130 | no payment of monetary compensation in connection with the
131 | exchange.
132 |
133 | l. Share means to provide material to the public by any means or
134 | process that requires permission under the Licensed Rights, such
135 | as reproduction, public display, public performance, distribution,
136 | dissemination, communication, or importation, and to make material
137 | available to the public including in ways that members of the
138 | public may access the material from a place and at a time
139 | individually chosen by them.
140 |
141 | m. Sui Generis Database Rights means rights other than copyright
142 | resulting from Directive 96/9/EC of the European Parliament and of
143 | the Council of 11 March 1996 on the legal protection of databases,
144 | as amended and/or succeeded, as well as other essentially
145 | equivalent rights anywhere in the world.
146 |
147 | n. You means the individual or entity exercising the Licensed Rights
148 | under this Public License. Your has a corresponding meaning.
149 |
150 |
151 | Section 2 -- Scope.
152 |
153 | a. License grant.
154 |
155 | 1. Subject to the terms and conditions of this Public License,
156 | the Licensor hereby grants You a worldwide, royalty-free,
157 | non-sublicensable, non-exclusive, irrevocable license to
158 | exercise the Licensed Rights in the Licensed Material to:
159 |
160 | a. reproduce and Share the Licensed Material, in whole or
161 | in part, for NonCommercial purposes only; and
162 |
163 | b. produce, reproduce, and Share Adapted Material for
164 | NonCommercial purposes only.
165 |
166 | 2. Exceptions and Limitations. For the avoidance of doubt, where
167 | Exceptions and Limitations apply to Your use, this Public
168 | License does not apply, and You do not need to comply with
169 | its terms and conditions.
170 |
171 | 3. Term. The term of this Public License is specified in Section
172 | 6(a).
173 |
174 | 4. Media and formats; technical modifications allowed. The
175 | Licensor authorizes You to exercise the Licensed Rights in
176 | all media and formats whether now known or hereafter created,
177 | and to make technical modifications necessary to do so. The
178 | Licensor waives and/or agrees not to assert any right or
179 | authority to forbid You from making technical modifications
180 | necessary to exercise the Licensed Rights, including
181 | technical modifications necessary to circumvent Effective
182 | Technological Measures. For purposes of this Public License,
183 | simply making modifications authorized by this Section 2(a)
184 | (4) never produces Adapted Material.
185 |
186 | 5. Downstream recipients.
187 |
188 | a. Offer from the Licensor -- Licensed Material. Every
189 | recipient of the Licensed Material automatically
190 | receives an offer from the Licensor to exercise the
191 | Licensed Rights under the terms and conditions of this
192 | Public License.
193 |
194 | b. Additional offer from the Licensor -- Adapted Material.
195 | Every recipient of Adapted Material from You
196 | automatically receives an offer from the Licensor to
197 | exercise the Licensed Rights in the Adapted Material
198 | under the conditions of the Adapter's License You apply.
199 |
200 | c. No downstream restrictions. You may not offer or impose
201 | any additional or different terms or conditions on, or
202 | apply any Effective Technological Measures to, the
203 | Licensed Material if doing so restricts exercise of the
204 | Licensed Rights by any recipient of the Licensed
205 | Material.
206 |
207 | 6. No endorsement. Nothing in this Public License constitutes or
208 | may be construed as permission to assert or imply that You
209 | are, or that Your use of the Licensed Material is, connected
210 | with, or sponsored, endorsed, or granted official status by,
211 | the Licensor or others designated to receive attribution as
212 | provided in Section 3(a)(1)(A)(i).
213 |
214 | b. Other rights.
215 |
216 | 1. Moral rights, such as the right of integrity, are not
217 | licensed under this Public License, nor are publicity,
218 | privacy, and/or other similar personality rights; however, to
219 | the extent possible, the Licensor waives and/or agrees not to
220 | assert any such rights held by the Licensor to the limited
221 | extent necessary to allow You to exercise the Licensed
222 | Rights, but not otherwise.
223 |
224 | 2. Patent and trademark rights are not licensed under this
225 | Public License.
226 |
227 | 3. To the extent possible, the Licensor waives any right to
228 | collect royalties from You for the exercise of the Licensed
229 | Rights, whether directly or through a collecting society
230 | under any voluntary or waivable statutory or compulsory
231 | licensing scheme. In all other cases the Licensor expressly
232 | reserves any right to collect such royalties, including when
233 | the Licensed Material is used other than for NonCommercial
234 | purposes.
235 |
236 |
237 | Section 3 -- License Conditions.
238 |
239 | Your exercise of the Licensed Rights is expressly made subject to the
240 | following conditions.
241 |
242 | a. Attribution.
243 |
244 | 1. If You Share the Licensed Material (including in modified
245 | form), You must:
246 |
247 | a. retain the following if it is supplied by the Licensor
248 | with the Licensed Material:
249 |
250 | i. identification of the creator(s) of the Licensed
251 | Material and any others designated to receive
252 | attribution, in any reasonable manner requested by
253 | the Licensor (including by pseudonym if
254 | designated);
255 |
256 | ii. a copyright notice;
257 |
258 | iii. a notice that refers to this Public License;
259 |
260 | iv. a notice that refers to the disclaimer of
261 | warranties;
262 |
263 | v. a URI or hyperlink to the Licensed Material to the
264 | extent reasonably practicable;
265 |
266 | b. indicate if You modified the Licensed Material and
267 | retain an indication of any previous modifications; and
268 |
269 | c. indicate the Licensed Material is licensed under this
270 | Public License, and include the text of, or the URI or
271 | hyperlink to, this Public License.
272 |
273 | 2. You may satisfy the conditions in Section 3(a)(1) in any
274 | reasonable manner based on the medium, means, and context in
275 | which You Share the Licensed Material. For example, it may be
276 | reasonable to satisfy the conditions by providing a URI or
277 | hyperlink to a resource that includes the required
278 | information.
279 | 3. If requested by the Licensor, You must remove any of the
280 | information required by Section 3(a)(1)(A) to the extent
281 | reasonably practicable.
282 |
283 | b. ShareAlike.
284 |
285 | In addition to the conditions in Section 3(a), if You Share
286 | Adapted Material You produce, the following conditions also apply.
287 |
288 | 1. The Adapter's License You apply must be a Creative Commons
289 | license with the same License Elements, this version or
290 | later, or a BY-NC-SA Compatible License.
291 |
292 | 2. You must include the text of, or the URI or hyperlink to, the
293 | Adapter's License You apply. You may satisfy this condition
294 | in any reasonable manner based on the medium, means, and
295 | context in which You Share Adapted Material.
296 |
297 | 3. You may not offer or impose any additional or different terms
298 | or conditions on, or apply any Effective Technological
299 | Measures to, Adapted Material that restrict exercise of the
300 | rights granted under the Adapter's License You apply.
301 |
302 |
303 | Section 4 -- Sui Generis Database Rights.
304 |
305 | Where the Licensed Rights include Sui Generis Database Rights that
306 | apply to Your use of the Licensed Material:
307 |
308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309 | to extract, reuse, reproduce, and Share all or a substantial
310 | portion of the contents of the database for NonCommercial purposes
311 | only;
312 |
313 | b. if You include all or a substantial portion of the database
314 | contents in a database in which You have Sui Generis Database
315 | Rights, then the database in which You have Sui Generis Database
316 | Rights (but not its individual contents) is Adapted Material,
317 | including for purposes of Section 3(b); and
318 |
319 | c. You must comply with the conditions in Section 3(a) if You Share
320 | all or a substantial portion of the contents of the database.
321 |
322 | For the avoidance of doubt, this Section 4 supplements and does not
323 | replace Your obligations under this Public License where the Licensed
324 | Rights include other Copyright and Similar Rights.
325 |
326 |
327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328 |
329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339 |
340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349 |
350 | c. The disclaimer of warranties and limitation of liability provided
351 | above shall be interpreted in a manner that, to the extent
352 | possible, most closely approximates an absolute disclaimer and
353 | waiver of all liability.
354 |
355 |
356 | Section 6 -- Term and Termination.
357 |
358 | a. This Public License applies for the term of the Copyright and
359 | Similar Rights licensed here. However, if You fail to comply with
360 | this Public License, then Your rights under this Public License
361 | terminate automatically.
362 |
363 | b. Where Your right to use the Licensed Material has terminated under
364 | Section 6(a), it reinstates:
365 |
366 | 1. automatically as of the date the violation is cured, provided
367 | it is cured within 30 days of Your discovery of the
368 | violation; or
369 |
370 | 2. upon express reinstatement by the Licensor.
371 |
372 | For the avoidance of doubt, this Section 6(b) does not affect any
373 | right the Licensor may have to seek remedies for Your violations
374 | of this Public License.
375 |
376 | c. For the avoidance of doubt, the Licensor may also offer the
377 | Licensed Material under separate terms or conditions or stop
378 | distributing the Licensed Material at any time; however, doing so
379 | will not terminate this Public License.
380 |
381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382 | License.
383 |
384 |
385 | Section 7 -- Other Terms and Conditions.
386 |
387 | a. The Licensor shall not be bound by any additional or different
388 | terms or conditions communicated by You unless expressly agreed.
389 |
390 | b. Any arrangements, understandings, or agreements regarding the
391 | Licensed Material not stated herein are separate from and
392 | independent of the terms and conditions of this Public License.
393 |
394 |
395 | Section 8 -- Interpretation.
396 |
397 | a. For the avoidance of doubt, this Public License does not, and
398 | shall not be interpreted to, reduce, limit, restrict, or impose
399 | conditions on any use of the Licensed Material that could lawfully
400 | be made without permission under this Public License.
401 |
402 | b. To the extent possible, if any provision of this Public License is
403 | deemed unenforceable, it shall be automatically reformed to the
404 | minimum extent necessary to make it enforceable. If the provision
405 | cannot be reformed, it shall be severed from this Public License
406 | without affecting the enforceability of the remaining terms and
407 | conditions.
408 |
409 | c. No term or condition of this Public License will be waived and no
410 | failure to comply consented to unless expressly agreed to by the
411 | Licensor.
412 |
413 | d. Nothing in this Public License constitutes or may be interpreted
414 | as a limitation upon, or waiver of, any privileges and immunities
415 | that apply to the Licensor or You, including from the legal
416 | processes of any jurisdiction or authority.
417 |
418 | =======================================================================
419 |
420 | Creative Commons is not a party to its public
421 | licenses. Notwithstanding, Creative Commons may elect to apply one of
422 | its public licenses to material it publishes and in those instances
423 | will be considered the “Licensor.” The text of the Creative Commons
424 | public licenses is dedicated to the public domain under the CC0 Public
425 | Domain Dedication. Except for the limited purpose of indicating that
426 | material is shared under a Creative Commons public license or as
427 | otherwise permitted by the Creative Commons policies published at
428 | creativecommons.org/policies, Creative Commons does not authorize the
429 | use of the trademark "Creative Commons" or any other trademark or logo
430 | of Creative Commons without its prior written consent including,
431 | without limitation, in connection with any unauthorized modifications
432 | to any of its public licenses or any other arrangements,
433 | understandings, or agreements concerning use of licensed material. For
434 | the avoidance of doubt, this paragraph does not form part of the
435 | public licenses.
436 |
437 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/OTVM-teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/OTVM-teaser.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## One-Trimap Video Matting (ECCV 2022)
Hongje Seong, Seoung Wug Oh, Brian Price, Euntai Kim, Joon-Young Lee
2 |
3 | [[Paper]](https://arxiv.org/abs/2207.13353) [[Demo video]](https://youtu.be/qkda4fHSyQE)
4 |
5 | Official Pytorch implementation of the ECCV 2022 paper, "One-Trimap Video Matting".
6 |
7 | 
8 |
9 |
10 | ## Environments
11 | - Ubuntu 18.04
12 | - python 3.8
13 | - pytorch 1.8.2
14 | - CUDA 10.2
15 |
16 | ### Environment setting
17 | ```bash
18 | conda create -n otvm python=3.8
19 | conda activate otvm
20 | conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-lts
21 | pip install opencv-contrib-python scikit-image scipy tqdm imgaug yacs albumentations
22 | ```
23 |
24 | ## Dataset
25 | To train OTVM, you need to prepare [AIM](https://sites.google.com/view/deepimagematting) and [VideoMatting108](https://github.com/yunkezhang/TCVOM) datasets
26 | ```
27 | PATH/TO/DATASET
28 | ├── Combined_Dataset
29 | │ ├── Adobe Deep Image Mattng Dataset License Agreement.pdf
30 | │ ├── README.txt
31 | │ ├── Test_set
32 | │ │ ├── Adobe-licensed images
33 | │ │ └── ...
34 | │ └── Training_set
35 | │ ├── Adobe-licensed images
36 | │ └── ...
37 | └── VideoMatting108
38 | ├── BG_done2
39 | │ ├── airport
40 | │ └── ...
41 | ├── FG_done
42 | │ ├── animal_still
43 | │ └── ...
44 | ├── flow_png_val
45 | │ ├── animal_still
46 | │ └── ...
47 | ├── frame_corr.json
48 | ├── train_videos_subset.txt
49 | ├── train_videos.txt
50 | ├── val_videos_subset.txt
51 | └── val_videos.txt
52 |
53 | ```
54 |
55 | ## Training
56 | ### Download pre-trained weights
57 | Download the pre-trained weights from [here](https://drive.google.com/drive/folders/1La53_oYZjhmcd2pfPPlnibLBPE12mc6b) and put them in the `weight/` directory.
58 | ```bash
59 | mkdir weights
60 | mv STM_weights.pth weights/
61 | mv FBA.pth weights/
62 | mv s1_OTVM_trimap.pth weights/
63 | mv s1_OTVM_alpha.pth weights/
64 | mv s2_OTVM_alpha.pth weights/
65 | mv s3_OTVM.pth weights/
66 | mv s4_OTVM.pth weights/
67 | ```
68 | Note: Initial weights of the trimap propagation and alpha prediction networks were taken from [STM](https://github.com/seoungwugoh/STM) and [FBA](https://github.com/MarcoForte/FBA_Matting), respectively.
69 |
79 | ### Change DATASET.PATH in config.py
80 | ```bash
81 | vim config.py
82 |
83 | # Change below path
84 | _C.DATASET.PATH = 'PATH/TO/DATASET'
85 | ```
86 |
87 | ### Stage-wise Training
88 | ```bash
89 | # options: scripts/train_XXX.sh [GPUs]
90 | bash scripts/train_s1_trimap.sh 0,1,2,3
91 | bash scripts/train_s1_alpha.sh 0,1,2,3
92 | bash scripts/train_s2_alpha.sh 0,1,2,3
93 | bash scripts/train_s3.sh 0,1,2,3
94 | bash scripts/train_s4.sh 0,1,2,3
95 | ```
96 |
97 | ## Inference (VideoMatting108 dataset)
98 | ```bash
99 | # options: scripts/eval_s4.sh [GPU]
100 | bash scripts/eval_s4.sh 0
101 | ```
102 |
103 | ## Inference (custom dataset)
104 | ```bash
105 | # options: scripts/eval_s4_demo.sh [GPU]
106 | # The results will be generated in: ./demo_results
107 | bash scripts/eval_s4_demo.sh 0
108 | ```
109 |
110 | ## Bibtex
111 | ```
112 | @inproceedings{seong2022one,
113 | title={One-Trimap Video Matting},
114 | author={Seong, Hongje and Oh, Seoung Wug and Price, Brian and Kim, Euntai and Lee, Joon-Young},
115 | booktitle={European Conference on Computer Vision},
116 | year={2022}
117 | }
118 | ```
119 |
120 |
121 | ## Terms of Use
122 | This software is for non-commercial use only.
123 | The source code is released under the Attribution-NonCommercial-ShareAlike (CC BY-NC-SA) Licence
124 | (see [this](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for details)
125 |
126 | [](http://creativecommons.org/licenses/by-nc-sa/4.0/)
127 |
128 | ## Acknowledgments
129 | This code is based on TCVOM (ACM MM 2021): [[link](https://github.com/yunkezhang/TCVOM)]
130 |
131 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from pickle import FALSE, TRUE
2 | from yacs.config import CfgNode as CN
3 |
4 | _C = CN()
5 | _C.SYSTEM = CN()
6 | # Number of workers for doing things
7 | _C.SYSTEM.NUM_WORKERS = 8
8 | # Specific random seed, -1 for random.
9 | _C.SYSTEM.RANDOM_SEED = 111
10 | _C.SYSTEM.OUTDIR = 'train_log'
11 | _C.SYSTEM.CUDNN_BENCHMARK = True
12 | _C.SYSTEM.CUDNN_DETERMINISTIC = False
13 | _C.SYSTEM.CUDNN_ENABLED = True
14 | _C.SYSTEM.TESTMODE = False
15 |
16 | _C.DATASET = CN()
17 | # dataset path
18 | _C.DATASET.PATH = 'PATH/TO/DATASET'
19 | _C.DATASET.MIN_EDGE_LENGTH = 1088
20 |
21 | _C.TEST = CN()
22 | _C.TEST.MEMORY_MAX_NUM = 5 # 2: First&Prev, 0: First, 1: Prev, 3~: Multiple
23 | _C.TEST.MEMORY_SKIP_FRAME = 10
24 |
25 | _C.TRAIN = CN()
26 | _C.TRAIN.STAGE = 1
27 | _C.TRAIN.BATCH_SIZE = 4
28 | _C.TRAIN.BASE_LR = 1e-5
29 | _C.TRAIN.LR_STRATEGY = 'stair' # 'poly', 'const' or 'stair'
30 | _C.TRAIN.WEIGHT_DECAY = 1e-4
31 | _C.TRAIN.TRAIN_INPUT_SIZE = (320,320)
32 | _C.TRAIN.FRAME_NUM = 3
33 | _C.TRAIN.FREEZE_BN = True
34 |
35 | # optimizer type
36 | _C.TRAIN.OPTIMIZER = 'radam' #adam, radam
37 | _C.TRAIN.TOTAL_EPOCHS = 200
38 | _C.TRAIN.IMAGE_FREQ = -1
39 | _C.TRAIN.SAVE_EVERY_EPOCH = 20
40 |
41 | _C.ALPHA = CN()
42 | _C.ALPHA.MODEL = 'fba'
43 |
44 |
45 | def get_cfg_defaults():
46 | """Get a yacs CfgNode object with default values for my_project."""
47 | # Return a clone so that the defaults will not be altered
48 | # This is for the "local variable" use pattern
49 | return _C.clone()
50 |
51 | # Alternatively, provide a way to import the defaults as
52 | # a global singleton:
53 | # cfg = _C # users can `from config import cfg`
--------------------------------------------------------------------------------
/demo/dove/frames/00000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00000.jpg
--------------------------------------------------------------------------------
/demo/dove/frames/00001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00001.jpg
--------------------------------------------------------------------------------
/demo/dove/frames/00002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00002.jpg
--------------------------------------------------------------------------------
/demo/dove/frames/00003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00003.jpg
--------------------------------------------------------------------------------
/demo/dove/frames/00004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00004.jpg
--------------------------------------------------------------------------------
/demo/dove/frames/00005.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00005.jpg
--------------------------------------------------------------------------------
/demo/dove/frames/00006.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00006.jpg
--------------------------------------------------------------------------------
/demo/dove/frames/00007.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00007.jpg
--------------------------------------------------------------------------------
/demo/dove/frames/00008.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00008.jpg
--------------------------------------------------------------------------------
/demo/dove/frames/00009.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00009.jpg
--------------------------------------------------------------------------------
/demo/dove/frames/00010.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/frames/00010.jpg
--------------------------------------------------------------------------------
/demo/dove/trimap/00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/demo/dove/trimap/00000.png
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import time
4 | import timeit
5 | import cv2
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import torch.nn.functional as F
9 | from torch import nn
10 | from torchvision.utils import save_image
11 | import tqdm
12 |
13 | from config import get_cfg_defaults
14 | from dataset import EvalDataset, VideoMatting108_Test, Demo_Test
15 | from helpers import *
16 |
17 | torch.set_grad_enabled(False)
18 |
19 | EPS = 0
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(description='Train network')
23 |
24 | parser.add_argument("--gpu", type=str, default='0')
25 | parser.add_argument('--trimap', default='medium', choices=['narrow', 'medium', 'wide'])
26 | parser.add_argument("--viz", action='store_true')
27 | parser.add_argument("--demo", action='store_true')
28 |
29 | args = parser.parse_args()
30 |
31 | cfg = get_cfg_defaults()
32 | cfg.TRAIN.STAGE = 4
33 |
34 | if args.demo:
35 | cfg.SYSTEM.OUTDIR = './demo_results'
36 | cfg.DATASET.PATH = './demo'
37 |
38 | return args, cfg
39 |
40 |
41 | def main(cfg, args, GPU):
42 | os.environ['CUDA_VISIBLE_DEVICES'] = GPU
43 | if torch.cuda.is_available():
44 | print('using Cuda devices, num:', torch.cuda.device_count())
45 |
46 | MODEL = get_model_name(cfg)
47 | random_seed = cfg.SYSTEM.RANDOM_SEED
48 | output_dir = os.path.join(cfg.SYSTEM.OUTDIR, 'alpha')
49 | start = timeit.default_timer()
50 | cudnn.benchmark = False
51 | cudnn.deterministic = cfg.SYSTEM.CUDNN_DETERMINISTIC
52 | cudnn.enabled = cfg.SYSTEM.CUDNN_ENABLED
53 | if random_seed > 0:
54 | import random
55 | print('Seeding with', random_seed)
56 | random.seed(random_seed)
57 | torch.manual_seed(random_seed)
58 |
59 | if args.demo:
60 | outdir_tail = MODEL
61 | else:
62 | outdir_tail = os.path.join(args.trimap, MODEL)
63 | alpha_outdir = os.path.join(output_dir, 'test', outdir_tail)
64 | viz_outdir_img = os.path.join(output_dir, 'viz', 'img', outdir_tail)
65 | viz_outdir_vid = os.path.join(output_dir, 'viz', 'vid', outdir_tail)
66 |
67 | if args.trimap == 'narrow':
68 | dilate_kernel = 5 # width: 11
69 | elif args.trimap == 'medium':
70 | dilate_kernel = 12 # width: 25
71 | elif args.trimap == 'wide':
72 | dilate_kernel = 20 # width: 41
73 |
74 | model_trimap = get_model_trimap(cfg, mode='Test', dilate_kernel=dilate_kernel)
75 | model = get_model_alpha(cfg, model_trimap, mode='Test', dilate_kernel=dilate_kernel)
76 |
77 | load_ckpt = os.path.join('weights', '{:s}.pth'.format(MODEL))
78 | dct = torch.load(load_ckpt, map_location=torch.device('cpu'))
79 | model.load_state_dict(dct)
80 | model = nn.DataParallel(model.cuda())
81 |
82 |
83 | if args.demo:
84 | valid_dataset = Demo_Test(data_root=cfg.DATASET.PATH)
85 | else:
86 | valid_dataset = VideoMatting108_Test(
87 | data_root=cfg.DATASET.PATH,
88 | mode='val',
89 | )
90 | with torch.no_grad():
91 | eval(args, cfg, valid_dataset, model, alpha_outdir, viz_outdir_img, viz_outdir_vid, args.viz)
92 |
93 | end = timeit.default_timer()
94 | print('done | Total time: {}'.format(format_time(end-start)))
95 |
96 | def write_image(outdir, out, filename, max_batch=4):
97 | with torch.no_grad():
98 | scaled_imgs, tri_pred, tri_gt, alphas, scaled_gts, comps = out
99 | b, s, _, h, w = scaled_imgs.shape
100 | alphas = alphas.expand(-1,-1,3,-1,-1)
101 | scaled_gts = scaled_gts.expand(-1,-1,3,-1,-1)
102 |
103 | b = max_batch if b > max_batch else b
104 | img_list = list()
105 | img_list.append(scaled_imgs[:max_batch].reshape(b*s, 3, h, w))
106 | img_list.append(comps[:max_batch].reshape(b*s, 3, h, w))
107 | img_list.append(tri_gt[:max_batch].reshape(b*s, 3, h, w))
108 | img_list.append(scaled_gts[:max_batch].reshape(b*s, 3, h, w))
109 | img_list.append(tri_pred[:max_batch].reshape(b*s, 3, h, w))
110 | img_list.append(alphas[:max_batch].reshape(b*s, 3, h, w))
111 | imgs = torch.cat(img_list, dim=0).reshape(-1, 3, h, w)
112 |
113 | imgs = F.interpolate(imgs, size=(h//2, w//2), mode='bilinear', align_corners=False)
114 |
115 | save_image(imgs, outdir%(filename), nrow=int(s*b*2))
116 |
117 | def eval(args, cfg, valid_dataset, model, alpha_outdir, viz_outdir_img, viz_outdir_vid, VIZ):
118 | model.eval()
119 |
120 | for i_iter, (data_name, data_root, FG, BG, a, tri, seq_name) in enumerate(valid_dataset):
121 | if cfg.SYSTEM.TESTMODE:
122 | if i_iter not in [0, len(valid_dataset)-1]:
123 | continue
124 | torch.cuda.empty_cache()
125 | num_frames = 1
126 | eval_sequence = EvalDataset(
127 | data_name=data_name,
128 | data_root=data_root,
129 | FG=FG,
130 | BG=BG,
131 | a=a,
132 | tri_gt=tri, # GT trimap
133 | trimap=None,
134 | num_frames=num_frames,
135 | )
136 | eval_loader = torch.utils.data.DataLoader(
137 | eval_sequence,
138 | batch_size=1,
139 | # num_workers=cfg.SYSTEM.NUM_WORKERS,
140 | num_workers=0,
141 | pin_memory=False,
142 | drop_last=False,
143 | shuffle=False,
144 | sampler=None)
145 |
146 | print('[{}/{}] Set FIXED dilate of unknown region: [{}]'.format(i_iter, len(valid_dataset), args.trimap))
147 |
148 | save_path = os.path.join(alpha_outdir, 'pred', seq_name)
149 | os.makedirs(save_path, exist_ok=True)
150 | if VIZ:
151 | visualization_path_img = os.path.join(viz_outdir_img, 'viz', seq_name)
152 | visualization_path_vid = os.path.join(viz_outdir_vid, 'viz')
153 | os.makedirs(visualization_path_img, exist_ok=True)
154 | os.makedirs(visualization_path_vid, exist_ok=True)
155 |
156 | iterations = tqdm.tqdm(eval_loader)
157 | for i_seq, dp in enumerate(iterations):
158 | if cfg.SYSTEM.TESTMODE:
159 | if i_seq > 10:
160 | break
161 |
162 | def handle_batch(dp, first_frame, last_frame, memorize, max_memory_num, large_input):
163 | fg, bg, a, eps, tri_gt, tri, _, filename = dp # [B, 3, 3 or 1, H, W]
164 |
165 | if tri.dim() == 1:
166 | tri = None
167 | if tri_gt.dim() == 1:
168 | tri_gt = None
169 |
170 | out = model(a, fg, bg, tri=tri, tri_gt=tri_gt,
171 | first_frame=first_frame,
172 | last_frame=last_frame,
173 | memorize=memorize,
174 | max_memory_num=max_memory_num,
175 | large_input=large_input,)
176 | return out, filename[0]
177 |
178 | first_frame = (i_seq==0)
179 | last_frame = (i_seq==(len(iterations)-1))
180 | memorize = False
181 | MEMORY_SKIP_FRAME = cfg.TEST.MEMORY_SKIP_FRAME
182 | MEMORY_MAX_NUM = cfg.TEST.MEMORY_MAX_NUM
183 | large_input = False
184 | if min(dp[0].shape[-2:]) > 1100:
185 | MEMORY_SKIP_FRAME = int(MEMORY_SKIP_FRAME * 2)
186 | MEMORY_MAX_NUM = int(MEMORY_MAX_NUM / 2)
187 | large_input = True
188 | if MEMORY_SKIP_FRAME > 2:
189 | memorize = (i_seq % MEMORY_SKIP_FRAME) == 0
190 | max_memory_num = MEMORY_MAX_NUM
191 |
192 | if first_frame:
193 | print('[{}/{}] {} | {} | Large input: {}'.format(i_iter, len(valid_dataset), seq_name, dp[0].shape[-2:], large_input))
194 |
195 | torch.cuda.synchronize()
196 | out, filename = handle_batch(dp, first_frame, last_frame, memorize, max_memory_num, large_input,)
197 | torch.cuda.synchronize()
198 |
199 | scaled_imgs, tri_pred, tri_gt, alphas, scaled_gts = out
200 |
201 | green_bg = torch.zeros_like(scaled_imgs)
202 | green_bg[:,:,1] = 1.
203 | comps = scaled_imgs * alphas + green_bg * (1. - alphas)
204 |
205 | if VIZ:
206 | frame_path = os.path.join(visualization_path_img, 'f%d.jpg')
207 | else:
208 | frame_path = None
209 | alpha_pred_img = (alphas*255).byte().cpu().squeeze(0).squeeze(0).squeeze(0).numpy()
210 | filename_for_save = os.path.splitext(filename)[0]+'.png'
211 |
212 | def write_result_images(alpha_pred_img, path, VIZ, frame_path, vis_out, i_seq):
213 | if VIZ:
214 | write_image(frame_path,
215 | vis_out,
216 | i_seq)
217 | cv2.imwrite(path, alpha_pred_img)
218 |
219 | write_result_images(alpha_pred_img,
220 | os.path.join(save_path, filename_for_save),
221 | VIZ,
222 | frame_path,
223 | # [scaled_imgs, tri_pred, tri_gt, alphas, scaled_gts, comps],
224 | [scaled_imgs.cpu(), tri_pred.cpu(), tri_gt.cpu(), alphas.cpu(), scaled_gts.cpu(), comps.cpu()],
225 | i_seq)
226 |
227 |
228 | torch.cuda.synchronize()
229 |
230 | if VIZ:
231 | if '/' in seq_name:
232 | vid_name = seq_name.split('/')
233 | vid_name = '_'.join(vid_name)
234 | else:
235 | vid_name = seq_name
236 | vid_path = os.path.join(visualization_path_vid, '{}.mp4'.format(vid_name))
237 |
238 | def make_viz_video(frame_path, vid_path):
239 | os.system('ffmpeg -framerate 10 -i {} {} -nostats -loglevel 0 -y'.format(frame_path, vid_path))
240 | time.sleep(10) # wait 10 seconds
241 |
242 | make_viz_video(frame_path, vid_path)
243 |
244 | if __name__ == "__main__":
245 | args, cfg = parse_args()
246 | main(cfg, args, args.gpu)
247 |
--------------------------------------------------------------------------------
/helpers.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | #torch
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.distributed as torch_dist
6 |
7 | import numpy as np
8 | import time
9 | import os
10 | import logging
11 | from pathlib import Path
12 | from importlib import reload
13 | import sys
14 |
15 | def ToCuda(xs):
16 | if torch.cuda.is_available():
17 | if isinstance(xs, list) or isinstance(xs, tuple):
18 | return [x.cuda() for x in xs]
19 | else:
20 | return xs.cuda()
21 | else:
22 | return xs
23 |
24 |
25 | def pad_divide_by(in_list, d, in_size):
26 | out_list = []
27 | h, w = in_size
28 | if h % d > 0:
29 | new_h = h + d - h % d
30 | else:
31 | new_h = h
32 | if w % d > 0:
33 | new_w = w + d - w % d
34 | else:
35 | new_w = w
36 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
37 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
38 | pad_array = (int(lw), int(uw), int(lh), int(uh))
39 | for inp in in_list:
40 | out_list.append(F.pad(inp, pad_array))
41 | return out_list, pad_array
42 |
43 |
44 |
45 | def overlay_davis(image,mask,colors=[255,0,0],cscale=2,alpha=0.4):
46 | """ Overlay segmentation on top of RGB image. from davis official"""
47 | # import skimage
48 | from scipy.ndimage.morphology import binary_erosion, binary_dilation
49 |
50 | colors = np.reshape(colors, (-1, 3))
51 | colors = np.atleast_2d(colors) * cscale
52 |
53 | im_overlay = image.copy()
54 | object_ids = np.unique(mask)
55 |
56 | for object_id in object_ids[1:]:
57 | # Overlay color on binary mask
58 | foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
59 | binary_mask = mask == object_id
60 |
61 | # Compose image
62 | im_overlay[binary_mask] = foreground[binary_mask]
63 |
64 | # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
65 | countours = binary_dilation(binary_mask) ^ binary_mask
66 | # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
67 | im_overlay[countours,:] = 0
68 |
69 | return im_overlay.astype(image.dtype)
70 |
71 |
72 | def torch_barrier():
73 | if torch_dist.is_available() and torch_dist.is_initialized():
74 | torch_dist.barrier()
75 |
76 | def reduce_tensor(inp):
77 | """
78 | Reduce the loss from all processes so that
79 | ALL PROCESSES has the averaged results.
80 | """
81 | if torch_dist.is_initialized():
82 | world_size = torch_dist.get_world_size()
83 | if world_size < 2:
84 | return inp
85 | with torch.no_grad():
86 | reduced_inp = inp
87 | torch.distributed.all_reduce(reduced_inp)
88 | torch.distributed.barrier()
89 | return reduced_inp / world_size
90 | return inp
91 |
92 | def print_loss_dict(loss, save=None):
93 | s = ''
94 | for key in sorted(loss.keys()):
95 | s += '{}: {:.6f}\n'.format(key, loss[key])
96 | print (s)
97 | if save is not None:
98 | with open(save, 'w') as f:
99 | f.write(s)
100 |
101 | class AverageMeter(object):
102 | """Computes and stores the average and current value"""
103 |
104 | def __init__(self):
105 | self.initialized = False
106 | self.val = None
107 | self.avg = None
108 | self.sum = None
109 | self.count = None
110 |
111 | def initialize(self, val, weight):
112 | self.val = val
113 | self.avg = val
114 | self.sum = val * weight
115 | self.count = weight
116 | self.initialized = True
117 |
118 | def update(self, val, weight=1):
119 | if not self.initialized:
120 | self.initialize(val, weight)
121 | else:
122 | self.add(val, weight)
123 |
124 | def add(self, val, weight):
125 | self.val = val
126 | self.sum += val * weight
127 | self.count += weight
128 | self.avg = self.sum / self.count
129 |
130 | def value(self):
131 | return self.val
132 |
133 | def average(self):
134 | return self.avg
135 |
136 | def create_logger(output_dir, cfg_name, phase='train'):
137 | root_output_dir = Path(output_dir)
138 | # set up logger
139 | if not root_output_dir.exists():
140 | print('=> creating {}'.format(root_output_dir))
141 | root_output_dir.mkdir()
142 |
143 | final_output_dir = root_output_dir / cfg_name
144 |
145 | print('=> creating {}'.format(final_output_dir))
146 | final_output_dir.mkdir(parents=True, exist_ok=True)
147 |
148 | time_str = time.strftime('%Y-%m-%d-%H-%M')
149 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)
150 | final_log_file = final_output_dir / log_file
151 | head = '%(asctime)-15s %(message)s'
152 | # reset logging
153 | logging.shutdown()
154 | reload(logging)
155 | logging.basicConfig(filename=str(final_log_file),
156 | format=head)
157 | logger = logging.getLogger()
158 | logger.setLevel(logging.INFO)
159 | console = logging.StreamHandler()
160 | logging.getLogger('').addHandler(console)
161 |
162 | return logger, str(final_output_dir)
163 |
164 | def poly_lr(optimizer, base_lr, max_iters, cur_iters, power=0.9):
165 | lr = base_lr*((1-float(cur_iters)/max_iters)**(power))
166 | # optimizer.param_groups[0]['lr'] = lr
167 | for param_group in optimizer.param_groups:
168 | if 'lr_ratio' in param_group:
169 | param_group['lr'] = lr * param_group['lr_ratio']
170 | else:
171 | param_group['lr'] = lr
172 | return lr
173 |
174 | def const_lr(optimizer, base_lr, max_iters, cur_iters):
175 | # optimizer.param_groups[0]['lr'] = base_lr
176 | for param_group in optimizer.param_groups:
177 | if 'lr_ratio' in param_group:
178 | param_group['lr'] = base_lr * param_group['lr_ratio']
179 | else:
180 | param_group['lr'] = base_lr
181 | return base_lr
182 |
183 | def stair_lr(optimizer, base_lr, max_iters, cur_iters):
184 | # 0, 180
185 | ratios = [1, 0.1]
186 | progress = cur_iters / float(max_iters)
187 | if progress < 0.9:
188 | ratio = ratios[0]
189 | else:
190 | ratio = ratios[-1]
191 | lr = base_lr * ratio
192 | # optimizer.param_groups[0]['lr'] = lr
193 | for param_group in optimizer.param_groups:
194 | if 'lr_ratio' in param_group:
195 | param_group['lr'] = lr * param_group['lr_ratio']
196 | else:
197 | param_group['lr'] = lr
198 | return lr
199 |
200 | def worker_init_fn(worker_id):
201 | np.random.seed(np.random.get_state()[1][0] + worker_id)
202 |
203 | STR_DICT = {
204 | 'poly': poly_lr,
205 | 'const': const_lr,
206 | 'stair': stair_lr
207 | }
208 |
209 |
210 |
211 | _, term_width = os.popen('stty size', 'r').read().split()
212 | term_width = int(term_width)
213 |
214 | TOTAL_BAR_LENGTH = 20.
215 | last_time = time.time()
216 | begin_time = last_time
217 |
218 | code_begin_time = time.time()
219 | memorize_iter_time = list()
220 | memorize_iter_time.append(code_begin_time)
221 |
222 | def progress_bar(current, total, current_epoch, start_epoch, end_epoch, mode=None, msg=None):
223 | # global last_time, begin_time, code_begin_time, runing_weight
224 | global last_time, begin_time, memorize_iter_time
225 | if current == 0:
226 | begin_time = time.time() # Reset for new bar.
227 |
228 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
229 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
230 |
231 | sys.stdout.write(' [')
232 | for i in range(cur_len):
233 | sys.stdout.write('=')
234 | sys.stdout.write('>')
235 | for i in range(rest_len):
236 | sys.stdout.write('.')
237 | sys.stdout.write(']')
238 |
239 | cur_time = time.time()
240 | step_time = cur_time - last_time
241 | last_time = cur_time
242 | tot_time = cur_time - begin_time
243 |
244 | L = []
245 | L.append(' E: %d' % current_epoch)
246 | L.append(' | Step: %s' % format_time(step_time))
247 | L.append(' | Tot: %s' % format_time(tot_time))
248 | if mode:
249 | memorize_iter_num = 1000
250 | total_time_from_code_begin = time.time()
251 | memorize_iter_time.append(total_time_from_code_begin)
252 | if len(memorize_iter_time) > memorize_iter_num:
253 | memorize_iter_time.pop(0)
254 | remain_iters = ((end_epoch-current_epoch)*total) - (current+1)
255 | eta = (memorize_iter_time[-1] - memorize_iter_time[0]) / (len(memorize_iter_time) - 1) * remain_iters
256 | L.append(' | ETA: %s' % format_time(eta))
257 | if msg:
258 | L.append(' | ' + msg)
259 |
260 | msg = ''.join(L)
261 | sys.stdout.write(msg)
262 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
263 | sys.stdout.write(' ')
264 |
265 | # Go back to the center of the bar.
266 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
267 | sys.stdout.write('\b')
268 | sys.stdout.write(' %d/%d ' % (current+1, total))
269 |
270 | if current < total-1:
271 | sys.stdout.write('\r')
272 | else:
273 | sys.stdout.write('\n')
274 | sys.stdout.flush()
275 |
276 | def format_time(seconds):
277 | days = int(seconds / 3600/24)
278 | seconds = seconds - days*3600*24
279 | hours = int(seconds / 3600)
280 | seconds = seconds - hours*3600
281 | minutes = int(seconds / 60)
282 | seconds = seconds - minutes*60
283 | secondsf = int(seconds)
284 | seconds = seconds - secondsf
285 | millis = int(seconds*1000)
286 |
287 | f = ''
288 | i = 1
289 | if days > 0:
290 | f += str(days) + 'D'
291 | i += 1
292 | if hours > 0 and i <= 2:
293 | f += str(hours) + 'h'
294 | i += 1
295 | if minutes > 0 and i <= 2:
296 | f += str(minutes) + 'm'
297 | i += 1
298 | if secondsf > 0 and i <= 2:
299 | f += str(secondsf) + 's'
300 | i += 1
301 | if millis > 0 and i <= 2:
302 | f += str(millis) + 'ms'
303 | i += 1
304 | if f == '':
305 | f = '0ms'
306 | return f
307 |
308 |
309 | def load_NoPrefix(path, length):
310 | # load dataparallel wrapped model properly
311 | state_dict = torch.load(path, map_location=torch.device('cpu'))
312 | if 'state_dict' in state_dict.keys():
313 | state_dict = state_dict['state_dict']
314 | # create new OrderedDict that does not contain `module.`
315 | from collections import OrderedDict
316 | new_state_dict = OrderedDict()
317 | for k, v in state_dict.items():
318 | name = k[length:] # remove `Scale.`
319 | new_state_dict[name] = v
320 | return new_state_dict
321 |
322 |
323 | def get_model_name(cfg):
324 | names = {1: 's1_OTVM_alpha',
325 | 2: 's2_OTVM_alpha',
326 | 3: 's3_OTVM',
327 | 4: 's4_OTVM'}
328 | return names[cfg.TRAIN.STAGE]
329 |
330 |
331 |
332 | def get_model_trimap(cfg, mode='Test', dilate_kernel=None):
333 | import models.trimap.model as model_trimap
334 | if mode == 'Train':
335 | model = model_trimap.FullModel
336 | elif mode == 'Test':
337 | model = model_trimap.FullModel_eval
338 |
339 | hdim = 16
340 |
341 | model_loded = model(eps=0,
342 | stage=cfg.TRAIN.STAGE,
343 | dilate_kernel=dilate_kernel,
344 | hdim=hdim,)
345 |
346 | return model_loded
347 |
348 | def get_model_alpha(cfg, model_trimap, mode='Test', dilate_kernel=None):
349 | import models.alpha.model as model_alpha
350 | if cfg.TRAIN.STAGE == 1:
351 | model_trimap = None
352 |
353 | if mode == 'Train':
354 | model = model_alpha.FullModel
355 | elif mode == 'Test':
356 | model = model_alpha.EvalModel
357 |
358 | model_loded = model(dilate_kernel=dilate_kernel,
359 | trimap=model_trimap,
360 | stage=cfg.TRAIN.STAGE,)
361 |
362 | return model_loded
363 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/models/__init__.py
--------------------------------------------------------------------------------
/models/alpha/FBA/layers_WS.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class Conv2d(nn.Conv2d):
7 |
8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
9 | padding=0, dilation=1, groups=1, bias=True):
10 | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
11 | padding, dilation, groups, bias)
12 |
13 | def forward(self, x):
14 | # return super(Conv2d, self).forward(x)
15 | weight = self.weight
16 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
17 | keepdim=True).mean(dim=3, keepdim=True)
18 | weight = weight - weight_mean
19 | # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
20 | std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1) + 1e-5
21 | weight = weight / std.expand_as(weight)
22 | return F.conv2d(x, weight, self.bias, self.stride,
23 | self.padding, self.dilation, self.groups)
24 |
25 |
26 | def BatchNorm2d(num_features):
27 | return nn.GroupNorm(num_channels=num_features, num_groups=32)
28 |
--------------------------------------------------------------------------------
/models/alpha/FBA/models.py:
--------------------------------------------------------------------------------
1 | from numpy import not_equal
2 | import torch
3 | import torch.nn as nn
4 | from . import resnet_GN_WS
5 | from . import layers_WS as L
6 | from . import resnet_bn
7 |
8 | FEAT_DIM = 2048
9 | DEC_DIM = 256
10 |
11 | def FBA(refinement):
12 | builder = ModelBuilder()
13 | net_encoder = builder.build_encoder(arch='resnet50_GN_WS')
14 | net_decoder = builder.build_decoder(arch="fba_decoder", batch_norm=False)
15 |
16 | model = MattingModule(net_encoder, net_decoder, refinement)
17 |
18 | return model
19 |
20 |
21 | class MattingModule(nn.Module):
22 | def __init__(self, net_enc, net_dec, refinement):
23 | super(MattingModule, self).__init__()
24 | self.encoder = net_enc
25 | self.decoder = net_dec
26 | self.refinement = refinement
27 | if refinement:
28 | self.refine = RefinementModule()
29 | else:
30 | self.refine = None
31 |
32 | def forward(self, x, extras):
33 | image, two_chan_trimap = extras
34 | conv_out, indices = self.encoder(x)
35 |
36 | hid, output, x_dec = self.decoder(conv_out, image, indices, two_chan_trimap)
37 | pred_alpha = output[:, :1]
38 |
39 | if self.refine is not None:
40 | hid, refine_output, refine_trimap = self.refine(x_dec, image, two_chan_trimap, pred_alpha)
41 | else:
42 | refine_output = None
43 | refine_trimap = None
44 |
45 | return output, hid, refine_output, refine_trimap
46 |
47 |
48 | class ModelBuilder():
49 | def build_encoder(self, arch='resnet50_GN', num_channels_additional=None):
50 | if arch == 'resnet50_GN_WS':
51 | orig_resnet = resnet_GN_WS.__dict__['l_resnet50']()
52 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8, num_channels_additional=num_channels_additional)
53 | elif arch == 'resnet50_BN':
54 | orig_resnet = resnet_bn.__dict__['l_resnet50']()
55 | net_encoder = ResnetDilatedBN(orig_resnet, dilate_scale=8, num_channels_additional=num_channels_additional)
56 | elif arch == 'resnet18_GN_WS':
57 | orig_resnet = resnet_GN_WS.__dict__['l_resnet18']()
58 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8, num_channels_additional=num_channels_additional)
59 | elif arch == 'resnet34_GN_WS':
60 | orig_resnet = resnet_GN_WS.__dict__['l_resnet34']()
61 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8, num_channels_additional=num_channels_additional)
62 |
63 | else:
64 | raise Exception('Architecture undefined!')
65 |
66 | num_channels = 3 + 6 + 2
67 |
68 | if(num_channels > 3):
69 | print(f'modifying input layer to accept {num_channels} channels')
70 | net_encoder_sd = net_encoder.state_dict()
71 | conv1_weights = net_encoder_sd['conv1.weight']
72 |
73 | c_out, c_in, h, w = conv1_weights.size()
74 | conv1_mod = torch.zeros(c_out, num_channels, h, w)
75 | conv1_mod[:, :3, :, :] = conv1_weights
76 |
77 | conv1 = net_encoder.conv1
78 | conv1.in_channels = num_channels
79 | conv1.weight = torch.nn.Parameter(conv1_mod)
80 |
81 | net_encoder.conv1 = conv1
82 |
83 | net_encoder_sd['conv1.weight'] = conv1_mod
84 |
85 | net_encoder.load_state_dict(net_encoder_sd)
86 | return net_encoder
87 |
88 | def build_decoder(self, arch='fba_decoder', batch_norm=False, memory_decoder=False):
89 | if arch == 'fba_decoder':
90 | net_decoder = fba_decoder(batch_norm=batch_norm, memory_decoder=memory_decoder)
91 |
92 | return net_decoder
93 |
94 |
95 | class ResnetDilatedBN(nn.Module):
96 | def __init__(self, orig_resnet, dilate_scale=8, num_channels_additional=None):
97 | super(ResnetDilatedBN, self).__init__()
98 | from functools import partial
99 |
100 | if dilate_scale == 8:
101 | orig_resnet.layer3.apply(
102 | partial(self._nostride_dilate, dilate=2))
103 | orig_resnet.layer4.apply(
104 | partial(self._nostride_dilate, dilate=4))
105 | elif dilate_scale == 16:
106 | orig_resnet.layer4.apply(
107 | partial(self._nostride_dilate, dilate=2))
108 |
109 | # take pretrained resnet, except AvgPool and FC
110 | self.conv1 = orig_resnet.conv1
111 | self.bn1 = orig_resnet.bn1
112 | self.relu1 = orig_resnet.relu1
113 | self.conv2 = orig_resnet.conv2
114 | self.bn2 = orig_resnet.bn2
115 | self.relu2 = orig_resnet.relu2
116 | self.conv3 = orig_resnet.conv3
117 | self.bn3 = orig_resnet.bn3
118 | self.relu3 = orig_resnet.relu3
119 | self.maxpool = orig_resnet.maxpool
120 | self.layer1 = orig_resnet.layer1
121 | self.layer2 = orig_resnet.layer2
122 | self.layer3 = orig_resnet.layer3
123 | self.layer4 = orig_resnet.layer4
124 |
125 | self.num_channels_additional = num_channels_additional
126 | if self.num_channels_additional is not None:
127 | self.conv1_a = resnet_bn.conv3x3(self.num_channels_additional, 64, stride=2)
128 |
129 | def _nostride_dilate(self, m, dilate):
130 | classname = m.__class__.__name__
131 | if classname.find('Conv') != -1:
132 | # the convolution with stride
133 | if m.stride == (2, 2):
134 | m.stride = (1, 1)
135 | if m.kernel_size == (3, 3):
136 | m.dilation = (dilate // 2, dilate // 2)
137 | m.padding = (dilate // 2, dilate // 2)
138 | # other convoluions
139 | else:
140 | if m.kernel_size == (3, 3):
141 | m.dilation = (dilate, dilate)
142 | m.padding = (dilate, dilate)
143 |
144 | def forward(self, x, return_feature_maps=False):
145 | conv_out = [x]
146 | x = self.relu1(self.bn1(self.conv1(x)))
147 | x = self.relu2(self.bn2(self.conv2(x)))
148 | x = self.relu3(self.bn3(self.conv3(x)))
149 | conv_out.append(x)
150 | x, indices = self.maxpool(x)
151 | x = self.layer1(x)
152 | conv_out.append(x)
153 | x = self.layer2(x)
154 | conv_out.append(x)
155 | x = self.layer3(x)
156 | conv_out.append(x)
157 | x = self.layer4(x)
158 | conv_out.append(x)
159 |
160 | if return_feature_maps:
161 | return conv_out, indices
162 | return [x]
163 |
164 |
165 | class Resnet(nn.Module):
166 | def __init__(self, orig_resnet):
167 | super(Resnet, self).__init__()
168 |
169 | # take pretrained resnet, except AvgPool and FC
170 | self.conv1 = orig_resnet.conv1
171 | self.bn1 = orig_resnet.bn1
172 | self.relu1 = orig_resnet.relu1
173 | self.conv2 = orig_resnet.conv2
174 | self.bn2 = orig_resnet.bn2
175 | self.relu2 = orig_resnet.relu2
176 | self.conv3 = orig_resnet.conv3
177 | self.bn3 = orig_resnet.bn3
178 | self.relu3 = orig_resnet.relu3
179 | self.maxpool = orig_resnet.maxpool
180 | self.layer1 = orig_resnet.layer1
181 | self.layer2 = orig_resnet.layer2
182 | self.layer3 = orig_resnet.layer3
183 | self.layer4 = orig_resnet.layer4
184 |
185 | def forward(self, x, return_feature_maps=False):
186 | conv_out = []
187 |
188 | x = self.relu1(self.bn1(self.conv1(x)))
189 | x = self.relu2(self.bn2(self.conv2(x)))
190 | x = self.relu3(self.bn3(self.conv3(x)))
191 | conv_out.append(x)
192 | x, indices = self.maxpool(x)
193 |
194 | x = self.layer1(x)
195 | conv_out.append(x)
196 | x = self.layer2(x)
197 | conv_out.append(x)
198 | x = self.layer3(x)
199 | conv_out.append(x)
200 | x = self.layer4(x)
201 | conv_out.append(x)
202 |
203 | if return_feature_maps:
204 | return conv_out
205 | return [x]
206 |
207 |
208 | class ResnetDilated(nn.Module):
209 | def __init__(self, orig_resnet, dilate_scale=8, num_channels_additional=None):
210 | super(ResnetDilated, self).__init__()
211 | from functools import partial
212 |
213 | if dilate_scale == 8:
214 | orig_resnet.layer3.apply(
215 | partial(self._nostride_dilate, dilate=2))
216 | orig_resnet.layer4.apply(
217 | partial(self._nostride_dilate, dilate=4))
218 | elif dilate_scale == 16:
219 | orig_resnet.layer4.apply(
220 | partial(self._nostride_dilate, dilate=2))
221 |
222 | # take pretrained resnet, except AvgPool and FC
223 | self.conv1 = orig_resnet.conv1
224 | self.bn1 = orig_resnet.bn1
225 | self.relu = orig_resnet.relu
226 | self.maxpool = orig_resnet.maxpool
227 | self.layer1 = orig_resnet.layer1
228 | self.layer2 = orig_resnet.layer2
229 | self.layer3 = orig_resnet.layer3
230 | self.layer4 = orig_resnet.layer4
231 |
232 | self.num_channels_additional = num_channels_additional
233 | if self.num_channels_additional is not None:
234 | self.conv1_a = resnet_GN_WS.L.Conv2d(self.num_channels_additional, 64, kernel_size=7, stride=2, padding=3, bias=False)
235 |
236 | def _nostride_dilate(self, m, dilate):
237 | classname = m.__class__.__name__
238 | if classname.find('Conv') != -1:
239 | # the convolution with stride
240 | if m.stride == (2, 2):
241 | m.stride = (1, 1)
242 | if m.kernel_size == (3, 3):
243 | m.dilation = (dilate // 2, dilate // 2)
244 | m.padding = (dilate // 2, dilate // 2)
245 | # other convoluions
246 | else:
247 | if m.kernel_size == (3, 3):
248 | m.dilation = (dilate, dilate)
249 | m.padding = (dilate, dilate)
250 |
251 | def forward(self, x, x_a=None):
252 | conv_out = [x] # OS=1
253 | if self.num_channels_additional is None:
254 | x = self.relu(self.bn1(self.conv1(x)))
255 | else:
256 | x = self.conv1(x) + self.conv1_a(x_a)
257 | x = self.relu(self.bn1(x))
258 | conv_out.append(x) # OS=2
259 | x, indices = self.maxpool(x)
260 | x = self.layer1(x)
261 | conv_out.append(x) # OS=4
262 | x = self.layer2(x)
263 | conv_out.append(x) # OS=8
264 | x = self.layer3(x)
265 | conv_out.append(x)
266 | x = self.layer4(x)
267 | conv_out.append(x)
268 |
269 | return conv_out, indices
270 |
271 |
272 | def norm(dim, bn=False):
273 | if(bn is False):
274 | return nn.GroupNorm(32, dim)
275 | else:
276 | return nn.BatchNorm2d(dim)
277 |
278 |
279 | def fba_fusion(alpha, img, F, B):
280 | F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B))
281 | B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F)
282 |
283 | F = torch.clamp(F, 0, 1)
284 | B = torch.clamp(B, 0, 1)
285 | la = 0.1
286 | alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la)
287 | alpha = torch.clamp(alpha, 0, 1)
288 | return alpha, F, B
289 |
290 |
291 | class fba_decoder(nn.Module):
292 | def __init__(self, batch_norm=False, memory_decoder=False):
293 | super(fba_decoder, self).__init__()
294 | pool_scales = (1, 2, 3, 6)
295 | self.batch_norm = batch_norm
296 | self.memory_decoder = memory_decoder
297 |
298 | self.ppm = []
299 |
300 | for scale in pool_scales:
301 | self.ppm.append(nn.Sequential(
302 | nn.AdaptiveAvgPool2d(scale),
303 | L.Conv2d(FEAT_DIM, DEC_DIM, kernel_size=1, bias=True),
304 | norm(DEC_DIM, self.batch_norm),
305 | nn.LeakyReLU()
306 | ))
307 | self.ppm = nn.ModuleList(self.ppm)
308 |
309 | self.conv_up1 = nn.Sequential(
310 | L.Conv2d(FEAT_DIM + len(pool_scales) * DEC_DIM, DEC_DIM,
311 | kernel_size=3, padding=1, bias=True),
312 |
313 | norm(DEC_DIM, self.batch_norm),
314 | nn.LeakyReLU(),
315 | L.Conv2d(DEC_DIM, DEC_DIM, kernel_size=3, padding=1),
316 | norm(DEC_DIM, self.batch_norm),
317 | nn.LeakyReLU()
318 | )
319 |
320 | # if not self.memory_decoder:
321 | self.conv_up2 = nn.Sequential(
322 | L.Conv2d((FEAT_DIM//8) + DEC_DIM, DEC_DIM,
323 | kernel_size=3, padding=1, bias=True),
324 | norm(DEC_DIM, self.batch_norm),
325 | nn.LeakyReLU()
326 | )
327 | if(self.batch_norm):
328 | d_up3 = 128
329 | else:
330 | d_up3 = 64
331 | self.conv_up3 = nn.Sequential(
332 | L.Conv2d(DEC_DIM + d_up3, 64,
333 | kernel_size=3, padding=1, bias=True),
334 | norm(64, self.batch_norm),
335 | nn.LeakyReLU()
336 | )
337 |
338 | self.unpool = nn.MaxUnpool2d(2, stride=2)
339 |
340 | self.conv_up4 = nn.Sequential(
341 | nn.Conv2d(64 + 3 + 3 + 2, 32,
342 | kernel_size=3, padding=1, bias=True),
343 | nn.LeakyReLU(),
344 | nn.Conv2d(32, 16,
345 | kernel_size=3, padding=1, bias=True),
346 |
347 | nn.LeakyReLU(),
348 | nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True)
349 | )
350 |
351 | def forward(self, conv_out, img, indices, two_chan_trimap, extract_feature=False, x=None):
352 | # if extract_feature:
353 | conv5 = conv_out[-1]
354 |
355 | input_size = conv5.size()
356 | ppm_out = [conv5]
357 | for pool_scale in self.ppm:
358 | ppm_out.append(nn.functional.interpolate(
359 | pool_scale(conv5),
360 | (input_size[2], input_size[3]),
361 | mode='bilinear', align_corners=False))
362 | ppm_out = torch.cat(ppm_out, 1)
363 | x = self.conv_up1(ppm_out)
364 | # return x
365 | # else:
366 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
367 |
368 | x = torch.cat((x, conv_out[-4]), 1)
369 |
370 | x = self.conv_up2(x)
371 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
372 |
373 | x = torch.cat((x, conv_out[-5]), 1)
374 | x = self.conv_up3(x)
375 |
376 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
377 | x = torch.cat((x, conv_out[-6][:, :3], img), 1)
378 | x2 = torch.cat((x, two_chan_trimap), 1)
379 |
380 | hid = self.conv_up4[:-1](x2)
381 | output = self.conv_up4[-1:](hid)
382 |
383 | alpha = torch.clamp(output[:, 0][:, None], 0, 1)
384 | F = torch.sigmoid(output[:, 1:4])
385 | B = torch.sigmoid(output[:, 4:7])
386 |
387 | # FBA Fusion
388 | alpha, F, B = fba_fusion(alpha, img, F, B)
389 |
390 | output = torch.cat((alpha, F, B), 1)
391 |
392 | return hid, output, x
393 |
394 |
395 | class RefinementModule(nn.Module):
396 | def __init__(self, batch_norm=False):
397 | super(RefinementModule, self).__init__()
398 | self.batch_norm = batch_norm
399 | self.conv1 = nn.Sequential(
400 | L.Conv2d((64 + 3 + 3) + 2 + 1, 64,
401 | kernel_size=3, padding=1, bias=True),
402 | norm(64, self.batch_norm),
403 | nn.LeakyReLU()
404 | )
405 | self.layer1 = resnet_GN_WS.BasicBlock(64, 64)
406 | self.layer2 = resnet_GN_WS.BasicBlock(64, 64)
407 | outdim = 10
408 | self.pred = nn.Sequential(
409 | nn.Conv2d(64, 32,
410 | kernel_size=3, padding=1, bias=True),
411 | nn.LeakyReLU(),
412 | nn.Conv2d(32, 16,
413 | kernel_size=3, padding=1, bias=True),
414 | nn.LeakyReLU(),
415 | nn.Conv2d(16, outdim, kernel_size=1, padding=0, bias=True)
416 | )
417 | def forward(self, x, img, two_chan_trimap, pred_alpha):
418 | x = torch.cat((x, two_chan_trimap, pred_alpha), 1)
419 | x = self.conv1(x)
420 | x = self.layer1(x)
421 | x = self.layer2(x)
422 | x = self.pred[:-1](x)
423 | output = self.pred[-1](x)
424 |
425 | a = output[:, :7]
426 | alpha = torch.clamp(a[:, 0][:, None], 0, 1)
427 | F = torch.sigmoid(a[:, 1:4])
428 | B = torch.sigmoid(a[:, 4:7])
429 | # FBA Fusion
430 | alpha, F, B = fba_fusion(alpha, img, F, B)
431 | alpha = torch.cat((alpha, F, B), 1)
432 |
433 | trimap = output[:, -3:]
434 |
435 | return x, alpha, trimap
436 |
--------------------------------------------------------------------------------
/models/alpha/FBA/resnet_GN_WS.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from . import layers_WS as L
4 |
5 | __all__ = ['ResNet', 'l_resnet18', 'l_resnet34', 'l_resnet50']
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | """3x3 convolution with padding"""
10 | return L.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11 | padding=1, bias=False)
12 |
13 |
14 | def conv1x1(in_planes, out_planes, stride=1):
15 | """1x1 convolution"""
16 | return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
17 |
18 |
19 | class BasicBlock(nn.Module):
20 | expansion = 1
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None):
23 | super(BasicBlock, self).__init__()
24 | self.conv1 = conv3x3(inplanes, planes, stride)
25 | self.bn1 = L.BatchNorm2d(planes)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv2 = conv3x3(planes, planes)
28 | self.bn2 = L.BatchNorm2d(planes)
29 | self.downsample = downsample
30 | self.stride = stride
31 |
32 | def forward(self, x):
33 | identity = x
34 |
35 | out = self.conv1(x)
36 | out = self.bn1(out)
37 | out = self.relu(out)
38 |
39 | out = self.conv2(out)
40 | out = self.bn2(out)
41 |
42 | if self.downsample is not None:
43 | identity = self.downsample(x)
44 |
45 | out += identity
46 | out = self.relu(out)
47 |
48 | return out
49 |
50 |
51 | class Bottleneck(nn.Module):
52 | expansion = 4
53 |
54 | def __init__(self, inplanes, planes, stride=1, downsample=None):
55 | super(Bottleneck, self).__init__()
56 | self.conv1 = conv1x1(inplanes, planes)
57 | self.bn1 = L.BatchNorm2d(planes)
58 | self.conv2 = conv3x3(planes, planes, stride)
59 | self.bn2 = L.BatchNorm2d(planes)
60 | self.conv3 = conv1x1(planes, planes * self.expansion)
61 | self.bn3 = L.BatchNorm2d(planes * self.expansion)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | identity = x
68 |
69 | out = self.conv1(x)
70 | out = self.bn1(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv2(out)
74 | out = self.bn2(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv3(out)
78 | out = self.bn3(out)
79 |
80 | if self.downsample is not None:
81 | identity = self.downsample(x)
82 |
83 | out += identity
84 | out = self.relu(out)
85 |
86 | return out
87 |
88 |
89 | class ResNet(nn.Module):
90 |
91 | def __init__(self, block, layers, num_classes=1000):
92 | super(ResNet, self).__init__()
93 | self.inplanes = 64
94 | self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
95 | bias=False)
96 | self.bn1 = L.BatchNorm2d(64)
97 | self.relu = nn.ReLU(inplace=True)
98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
99 | self.layer1 = self._make_layer(block, 64, layers[0])
100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
102 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
103 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
104 | self.fc = nn.Linear(512 * block.expansion, num_classes)
105 |
106 | def _make_layer(self, block, planes, blocks, stride=1):
107 | downsample = None
108 | if stride != 1 or self.inplanes != planes * block.expansion:
109 | downsample = nn.Sequential(
110 | conv1x1(self.inplanes, planes * block.expansion, stride),
111 | L.BatchNorm2d(planes * block.expansion),
112 | )
113 |
114 | layers = []
115 | layers.append(block(self.inplanes, planes, stride, downsample))
116 | self.inplanes = planes * block.expansion
117 | for _ in range(1, blocks):
118 | layers.append(block(self.inplanes, planes))
119 |
120 | return nn.Sequential(*layers)
121 |
122 | def forward(self, x):
123 | x = self.conv1(x)
124 | x = self.bn1(x)
125 | x = self.relu(x)
126 | x = self.maxpool(x)
127 |
128 | x = self.layer1(x)
129 | x = self.layer2(x)
130 | x = self.layer3(x)
131 | x = self.layer4(x)
132 |
133 | x = self.avgpool(x)
134 | x = x.view(x.size(0), -1)
135 | x = self.fc(x)
136 |
137 | return x
138 |
139 |
140 | def load_NoPrefix(path, length):
141 | # load dataparallel wrapped model properly
142 | try:
143 | state_dict = torch.load(path, map_location='cpu')
144 | except:
145 | state_dict = torch.load(path, map_location='cpu')['state_dict']
146 | # create new OrderedDict that does not contain `module.`
147 | from collections import OrderedDict
148 | new_state_dict = OrderedDict()
149 | for k, v in state_dict.items():
150 | name = k[length:] # remove `Scale.`
151 | new_state_dict[name] = v
152 | return new_state_dict
153 |
154 |
155 | def my_load_state_dict(model, state_dict):
156 | # version 2: support tensor same name but size is different
157 |
158 | own_state = model.state_dict()
159 | for name, param in state_dict.items():
160 | if name in own_state:
161 | if isinstance(param, nn.Parameter):
162 | # backwards compatibility for serialized parameters
163 | param = param.data
164 | try:
165 | own_state[name].copy_(param)
166 | except:
167 | print('While copying the parameter named {}, whose dimensions in the model are {} and whose dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
168 | else:
169 | print('[Warning] Found key "{}" in file, but not in current model'.format(name))
170 |
171 | missing = set(own_state.keys()) - set(state_dict.keys())
172 | if len(missing) > 0:
173 | print('[Warning] Cant find keys "{}" in file'.format(missing))
174 |
175 |
176 | def l_resnet18(**kwargs):
177 | """Constructs a ResNet-18 model.
178 | """
179 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
180 |
181 | return model
182 |
183 |
184 | def l_resnet34(**kwargs):
185 | """Constructs a ResNet-34 model.
186 | """
187 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
188 |
189 | return model
190 |
191 |
192 | def l_resnet50(**kwargs):
193 | """Constructs a ResNet-50 model.
194 | """
195 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
196 |
197 | return model
198 |
--------------------------------------------------------------------------------
/models/alpha/FBA/resnet_bn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | from torch.nn import BatchNorm2d
4 |
5 | __all__ = ['ResNet']
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | "3x3 convolution with padding"
10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11 | padding=1, bias=False)
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = conv3x3(inplanes, planes, stride)
20 | self.bn1 = BatchNorm2d(planes)
21 | self.relu = nn.ReLU(inplace=True)
22 | self.conv2 = conv3x3(planes, planes)
23 | self.bn2 = BatchNorm2d(planes)
24 | self.downsample = downsample
25 | self.stride = stride
26 |
27 | def forward(self, x):
28 | residual = x
29 |
30 | out = self.conv1(x)
31 | out = self.bn1(out)
32 | out = self.relu(out)
33 |
34 | out = self.conv2(out)
35 | out = self.bn2(out)
36 |
37 | if self.downsample is not None:
38 | residual = self.downsample(x)
39 |
40 | out += residual
41 | out = self.relu(out)
42 |
43 | return out
44 |
45 |
46 | class Bottleneck(nn.Module):
47 | expansion = 4
48 |
49 | def __init__(self, inplanes, planes, stride=1, downsample=None):
50 | super(Bottleneck, self).__init__()
51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
52 | self.bn1 = BatchNorm2d(planes)
53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
54 | padding=1, bias=False)
55 | self.bn2 = BatchNorm2d(planes, momentum=0.01)
56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
57 | self.bn3 = BatchNorm2d(planes * 4)
58 | self.relu = nn.ReLU(inplace=True)
59 | self.downsample = downsample
60 | self.stride = stride
61 |
62 | def forward(self, x):
63 | residual = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv3(out)
74 | out = self.bn3(out)
75 |
76 | if self.downsample is not None:
77 | residual = self.downsample(x)
78 |
79 | out += residual
80 | out = self.relu(out)
81 |
82 | return out
83 |
84 |
85 | class ResNet(nn.Module):
86 |
87 | def __init__(self, block, layers, num_classes=1000):
88 | self.inplanes = 128
89 | super(ResNet, self).__init__()
90 | self.conv1 = conv3x3(3, 64, stride=2)
91 | self.bn1 = BatchNorm2d(64)
92 | self.relu1 = nn.ReLU(inplace=True)
93 | self.conv2 = conv3x3(64, 64)
94 | self.bn2 = BatchNorm2d(64)
95 | self.relu2 = nn.ReLU(inplace=True)
96 | self.conv3 = conv3x3(64, 128)
97 | self.bn3 = BatchNorm2d(128)
98 | self.relu3 = nn.ReLU(inplace=True)
99 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
100 |
101 | self.layer1 = self._make_layer(block, 64, layers[0])
102 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
103 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
104 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
105 | self.avgpool = nn.AvgPool2d(7, stride=1)
106 | self.fc = nn.Linear(512 * block.expansion, num_classes)
107 |
108 | for m in self.modules():
109 | if isinstance(m, nn.Conv2d):
110 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
111 | m.weight.data.normal_(0, math.sqrt(2. / n))
112 | elif isinstance(m, BatchNorm2d):
113 | m.weight.data.fill_(1)
114 | m.bias.data.zero_()
115 |
116 | def _make_layer(self, block, planes, blocks, stride=1):
117 | downsample = None
118 | if stride != 1 or self.inplanes != planes * block.expansion:
119 | downsample = nn.Sequential(
120 | nn.Conv2d(self.inplanes, planes * block.expansion,
121 | kernel_size=1, stride=stride, bias=False),
122 | BatchNorm2d(planes * block.expansion),
123 | )
124 |
125 | layers = []
126 | layers.append(block(self.inplanes, planes, stride, downsample))
127 | self.inplanes = planes * block.expansion
128 | for i in range(1, blocks):
129 | layers.append(block(self.inplanes, planes))
130 |
131 | return nn.Sequential(*layers)
132 |
133 | def forward(self, x):
134 | x = self.relu1(self.bn1(self.conv1(x)))
135 | x = self.relu2(self.bn2(self.conv2(x)))
136 | x = self.relu3(self.bn3(self.conv3(x)))
137 | x, indices = self.maxpool(x)
138 |
139 | x = self.layer1(x)
140 | x = self.layer2(x)
141 | x = self.layer3(x)
142 | x = self.layer4(x)
143 |
144 | x = self.avgpool(x)
145 | x = x.view(x.size(0), -1)
146 | x = self.fc(x)
147 | return x
148 |
149 |
150 | def l_resnet50():
151 | """Constructs a ResNet-50 model.
152 | Args:
153 | pretrained (bool): If True, returns a model pre-trained on ImageNet
154 | """
155 | model = ResNet(Bottleneck, [3, 4, 6, 3])
156 | return model
157 |
--------------------------------------------------------------------------------
/models/alpha/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/models/alpha/__init__.py
--------------------------------------------------------------------------------
/models/alpha/common.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def pad_divide_by(in_list, d, in_size, padval=0.):
7 | out_list = []
8 | h, w = in_size
9 | if h % d > 0:
10 | new_h = h + d - h % d
11 | else:
12 | new_h = h
13 | if w % d > 0:
14 | new_w = w + d - w % d
15 | else:
16 | new_w = w
17 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
18 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
19 | pad_array = (int(lw), int(uw), int(lh), int(uh))
20 | if sum(pad_array)>0:
21 | for inp in in_list:
22 | out_list.append(F.pad(inp, pad_array, value=padval))
23 | else:
24 | out_list = in_list
25 | if len(in_list) == 1:
26 | out_list = out_list[0]
27 | return out_list, pad_array
28 |
--------------------------------------------------------------------------------
/models/trimap/STM.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import models
6 | from helpers import ToCuda, pad_divide_by
7 | import math
8 |
9 | class ResBlock(nn.Module):
10 | def __init__(self, indim, outdim=None, stride=1):
11 | super(ResBlock, self).__init__()
12 | if outdim == None:
13 | outdim = indim
14 | if indim == outdim and stride==1:
15 | self.downsample = None
16 | else:
17 | self.downsample = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride)
18 |
19 | self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride)
20 | self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1)
21 |
22 |
23 | def forward(self, x):
24 | r = self.conv1(F.relu(x))
25 | r = self.conv2(F.relu(r))
26 |
27 | if self.downsample is not None:
28 | x = self.downsample(x)
29 |
30 | return x + r
31 |
32 | class Encoder_M(nn.Module):
33 | def __init__(self, hdim=32):
34 | super(Encoder_M, self).__init__()
35 | self.hdim = hdim
36 |
37 | self.conv1_m = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
38 | self.conv1_o = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
39 | if self.hdim > 0:
40 | self.conv1_a = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
41 | self.conv1_h = nn.Conv2d(hdim, 64, kernel_size=7, stride=2, padding=3, bias=False)
42 |
43 | resnet = models.resnet50(pretrained=True)
44 | self.conv1 = resnet.conv1
45 | self.bn1 = resnet.bn1
46 | self.relu = resnet.relu # 1/2, 64
47 | self.maxpool = resnet.maxpool
48 |
49 | self.res2 = resnet.layer1 # 1/4, 256
50 | self.res3 = resnet.layer2 # 1/8, 512
51 | self.res4 = resnet.layer3 # 1/8, 1024
52 |
53 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
54 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
55 |
56 | def forward(self, in_f, in_m, in_o, in_a, in_h):
57 | f = (in_f - self.mean) / self.std
58 | m = torch.unsqueeze(in_m, dim=1).float() # add channel dim
59 | o = torch.unsqueeze(in_o, dim=1).float() # add channel dim
60 | if self.hdim > 0:
61 | a = torch.unsqueeze(in_a, dim=1).float() # add channel dim
62 | h = in_h.float() # add channel dim
63 | x = self.conv1_m(m) + self.conv1_o(o) + self.conv1_a(a) + self.conv1_h(h)
64 | else:
65 | x = self.conv1_m(m) + self.conv1_o(o)
66 |
67 | x = self.conv1(f) + x
68 | x = self.bn1(x)
69 | c1 = self.relu(x) # 1/2, 64
70 | x = self.maxpool(c1) # 1/4, 64
71 | r2 = self.res2(x) # 1/4, 256
72 | r3 = self.res3(r2) # 1/8, 512
73 | r4 = self.res4(r3) # 1/8, 1024
74 | return r4, r3, r2, c1, f
75 |
76 | class Encoder_Q(nn.Module):
77 | def __init__(self):
78 | super(Encoder_Q, self).__init__()
79 | resnet = models.resnet50(pretrained=True)
80 | self.conv1 = resnet.conv1
81 | self.bn1 = resnet.bn1
82 | self.relu = resnet.relu # 1/2, 64
83 | self.maxpool = resnet.maxpool
84 |
85 | self.res2 = resnet.layer1 # 1/4, 256
86 | self.res3 = resnet.layer2 # 1/8, 512
87 | self.res4 = resnet.layer3 # 1/8, 1024
88 |
89 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
90 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
91 |
92 | def forward(self, in_f):
93 | f = (in_f - self.mean) / self.std
94 |
95 | x = self.conv1(f)
96 | x = self.bn1(x)
97 | c1 = self.relu(x) # 1/2, 64
98 | x = self.maxpool(c1) # 1/4, 64
99 | r2 = self.res2(x) # 1/4, 256
100 | r3 = self.res3(r2) # 1/8, 512
101 | r4 = self.res4(r3) # 1/8, 1024
102 | return r4, r3, r2, c1, f
103 |
104 |
105 | class Refine(nn.Module):
106 | def __init__(self, inplanes, planes, scale_factor=2):
107 | super(Refine, self).__init__()
108 | self.convFS = nn.Conv2d(inplanes, planes, kernel_size=(3,3), padding=(1,1), stride=1)
109 | self.ResFS = ResBlock(planes, planes)
110 | self.ResMM = ResBlock(planes, planes)
111 | self.scale_factor = scale_factor
112 |
113 | def forward(self, f, pm):
114 | s = self.ResFS(self.convFS(f))
115 | m = s + F.interpolate(pm, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
116 | m = self.ResMM(m)
117 | return m
118 |
119 | class Decoder(nn.Module):
120 | def __init__(self, mdim):
121 | super(Decoder, self).__init__()
122 | self.convFM = nn.Conv2d(1024, mdim, kernel_size=(3,3), padding=(1,1), stride=1)
123 | self.ResMM = ResBlock(mdim, mdim)
124 | self.RF3 = Refine(512, mdim) # 1/8 -> 1/4
125 | self.RF2 = Refine(256, mdim) # 1/4 -> 1
126 |
127 | self.pred = nn.Conv2d(mdim, 3, kernel_size=(3,3), padding=(1,1), stride=1)
128 |
129 | def forward(self, r4, r3, r2, VOS_mode=False):
130 | m4 = self.ResMM(self.convFM(r4))
131 | m3 = self.RF3(r3, m4) # out: 1/8, 256
132 | m2 = self.RF2(r2, m3) # out: 1/4, 256
133 |
134 | p2 = self.pred(F.relu(m2))
135 |
136 | p = F.interpolate(p2, scale_factor=4, mode='bilinear', align_corners=False)
137 | return p
138 |
139 |
140 | class Memory(nn.Module):
141 | def __init__(self):
142 | super(Memory, self).__init__()
143 |
144 | def forward(self, m_in, m_out, q_in, q_out): # m_in: o,c,t,h,w
145 | B, D_e, T, H, W = m_in.size()
146 | _, D_o, _, _, _ = m_out.size()
147 |
148 | mi = m_in.view(B, D_e, T*H*W)
149 | mi = torch.transpose(mi, 1, 2) # b, THW, emb
150 |
151 | qi = q_in.view(B, D_e, H*W) # b, emb, HW
152 |
153 | p = torch.bmm(mi, qi) # b, THW, HW
154 | p = p / math.sqrt(D_e)
155 | p = F.softmax(p, dim=1) # b, THW, HW
156 |
157 | mo = m_out.view(B, D_o, T*H*W)
158 | mem = torch.bmm(mo, p) # Weighted-sum B, D_o, HW
159 | mem = mem.view(B, D_o, H, W)
160 |
161 | mem_out = torch.cat([mem, q_out], dim=1)
162 |
163 | return mem_out
164 |
165 |
166 | class KeyValue(nn.Module):
167 | # Not using location
168 | def __init__(self, indim, keydim, valdim):
169 | super(KeyValue, self).__init__()
170 | self.Key = nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)
171 | self.Value = nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)
172 |
173 | def forward(self, x):
174 | return self.Key(x), self.Value(x)
175 |
176 |
177 |
178 |
179 | class STM(nn.Module):
180 | def __init__(self, hdim=-1):
181 | super(STM, self).__init__()
182 | self.hdim = hdim
183 |
184 | self.Encoder_M = Encoder_M(hdim=self.hdim)
185 | self.Encoder_Q = Encoder_Q()
186 |
187 | self.KV_M_r4 = KeyValue(1024, keydim=128, valdim=512)
188 | self.KV_Q_r4 = KeyValue(1024, keydim=128, valdim=512)
189 |
190 | self.Memory = Memory()
191 | self.Decoder = Decoder(256)
192 |
193 | def Pad_memory(self, mems, num_objects):
194 | pad_mems = []
195 | for mem in mems:
196 | batch_and_numobj, C, H, W = mem.shape
197 | batch_size = batch_and_numobj//num_objects
198 | pad_mems.append(mem.view(num_objects, batch_size, C, 1, H, W).transpose(1,0))
199 | return pad_mems
200 |
201 | def memorize(self, frame, masks, num_objects):
202 | # memorize a frame
203 | num_objects = num_objects[0].item()
204 |
205 | (frame, masks), pad = pad_divide_by([frame, masks], 16, (frame.size()[2], frame.size()[3]))
206 |
207 | # make batch arg list
208 | B_list = {'f':[], 'm':[], 'o':[], 'a':[], 'h':[]}
209 | for o in range(1, num_objects+1): # 1 - no
210 | B_list['f'].append(frame)
211 | B_list['m'].append(masks[:,1]) # Unkown region
212 | B_list['o'].append(masks[:,2]) # Foreground region
213 | if self.hdim > 0:
214 | B_list['a'].append(masks[:,3]) # Alpha matte
215 | B_list['h'].append(masks[:,4:]) # hidden layer
216 |
217 | # make Batch
218 | B_ = {}
219 | B_['a'] = None
220 | B_['h'] = None
221 | for arg in B_list.keys():
222 | if len(B_list[arg]) > 0:
223 | B_[arg] = torch.cat(B_list[arg], dim=0)
224 |
225 | r4, _, _, _, _ = self.Encoder_M(B_['f'], B_['m'], B_['o'], B_['a'], B_['h'])
226 | k4, v4 = self.KV_M_r4(r4) # num_objects, 128 and 512, H/16, W/16
227 | k4, v4 = self.Pad_memory([k4, v4], num_objects=num_objects)
228 | return k4, v4
229 |
230 | def Soft_aggregation(self, ps, K):
231 | num_objects, H, W = ps.shape
232 | em = ToCuda(torch.zeros(1, K, H, W))
233 | em[0,0] = torch.prod(1-ps, dim=0) # bg prob
234 | em[0,1:num_objects+1] = ps # obj prob
235 | em = torch.clamp(em, 1e-7, 1-1e-7)
236 | logit = torch.log((em /(1-em)))
237 | return logit
238 |
239 | def segment(self, frame, keys, values, num_objects):
240 | # pad
241 | [frame], pad = pad_divide_by([frame], 16, (frame.size()[2], frame.size()[3]))
242 |
243 | r4, r3, r2, _, _ = self.Encoder_Q(frame)
244 | k4, v4 = self.KV_Q_r4(r4) # 1, dim, H/16, W/16
245 |
246 | # memory select kv:(1, K, C, T, H, W)
247 | m4 = self.Memory(keys.squeeze(1), values.squeeze(1), k4, v4)
248 | logits = self.Decoder(m4, r3, r2)
249 |
250 | logit = logits
251 |
252 | if pad[2]+pad[3] > 0:
253 | logit = logit[:,:,pad[2]:-pad[3],:]
254 | if pad[0]+pad[1] > 0:
255 | logit = logit[:,:,:,pad[0]:-pad[1]]
256 |
257 | return logit
258 |
259 | def forward(self, *args, **kwargs):
260 | if args[1].dim() > 4: # keys
261 | return self.segment(*args, **kwargs)
262 | else:
263 | return self.memorize(*args, **kwargs)
264 |
265 |
--------------------------------------------------------------------------------
/models/trimap/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/models/trimap/__init__.py
--------------------------------------------------------------------------------
/models/trimap/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | # general libs
7 | import sys
8 |
9 | sys.path.insert(0, '../')
10 | from helpers import *
11 |
12 | from .STM import STM
13 |
14 |
15 | class FullModel(nn.Module):
16 | def __init__(self, dilate_kernel=None, eps=0, ignore_label=255,
17 | stage=1,
18 | hdim=-1,):
19 | super(FullModel, self).__init__()
20 | self.DILATION_KERNEL = dilate_kernel
21 | self.EPS = eps
22 | self.IMG_SCALE = 1./255
23 | self.register_buffer('IMG_MEAN', torch.tensor([0.485, 0.456, 0.406]).reshape([1, 1, 3, 1, 1]).float())
24 | self.register_buffer('IMG_STD', torch.tensor([0.229, 0.224, 0.225]).reshape([1, 1, 3, 1, 1]).float())
25 |
26 | self.stage = stage
27 | self.hdim = hdim if self.stage > 2 else -1
28 | self.memory_update = False
29 |
30 | self.model = STM(hdim=self.hdim)
31 |
32 | self.num_object = 1
33 |
34 | self.ignore_label = ignore_label
35 | self.LOSS = nn.CrossEntropyLoss(weight=torch.tensor([1, 1, 1]).float(), ignore_index=ignore_label)
36 |
37 | def make_trimap(self, alpha, ignore_region):
38 | b = alpha.shape[0]
39 | alpha = torch.where(alpha < self.EPS, torch.zeros_like(alpha), alpha)
40 | alpha = torch.where(alpha > 1 - self.EPS, torch.ones_like(alpha), alpha)
41 | trimasks = ((alpha > 0) & (alpha < 1.)).float().split(1)
42 | trimaps = [None] * b
43 | for i in range(b):
44 | # trimap width: 1 - 51
45 | kernel_rad = int(torch.randint(0, 26, size=())) \
46 | if self.DILATION_KERNEL is None else self.DILATION_KERNEL
47 | trimaps[i] = F.max_pool2d(trimasks[i].squeeze(0), kernel_size=kernel_rad*2+1, stride=1, padding=kernel_rad)
48 | trimap = torch.stack(trimaps)
49 | # 0: bg, 1: un, 2: fg
50 | trimap1 = torch.where(trimap > 0.5, torch.ones_like(alpha), 2 * alpha).long()
51 | if ignore_region is not None:
52 | trimap1[ignore_region] = 0
53 | trimap3 = F.one_hot(trimap1.squeeze(2), num_classes=3).permute(0, 1, 4, 2, 3)
54 | return trimap3.float()
55 |
56 | def preprocess(self, a, fg, bg, ignore_region=None, tri=None):
57 | # Data preprocess
58 | with torch.no_grad():
59 | scaled_gts = a
60 | scaled_fgs = fg.flip([2]) * self.IMG_SCALE
61 | if bg is None:
62 | scaled_bgs = scaled_fgs
63 | scaled_imgs = scaled_fgs
64 | else:
65 | scaled_bgs = bg.flip([2]) * self.IMG_SCALE
66 | scaled_imgs = scaled_fgs * scaled_gts + scaled_bgs * (1. - scaled_gts)
67 |
68 | if tri is None:
69 | scaled_tris = self.make_trimap(scaled_gts, ignore_region)
70 | else:
71 | scaled_tris = tri
72 | imgs = scaled_imgs
73 | return scaled_imgs, scaled_fgs, scaled_bgs, scaled_gts, scaled_tris, imgs
74 |
75 | def _forward(self, imgs, tris, alpha, masks=None, og_shape=None):
76 | if self.stage == 1:
77 | batch_size, sample_length = imgs.shape[:2]
78 | num_object = torch.tensor([self.num_object]).to(torch.cuda.current_device())
79 | GT = tris.split(1, dim=0) # [1, S, C, H, W]
80 | FG = imgs.split(1, dim=0) # [1, S, C, H, W]
81 |
82 | if masks is not None:
83 | M = masks.squeeze(2).split(1, dim=1)
84 | E = []
85 | E_logits = []
86 | # we split batch here since the original code only supports b=1
87 | for b in range(batch_size):
88 | Fs = FG[b].split(1, dim=1) # [1, 1, C, H, W]
89 | GTs = GT[b].split(1, dim=1) # [1, 1, C, H, W]
90 | Es = [GTs[0].squeeze(1)] + [None] * (sample_length - 1) # [1, C, H, W]
91 | ELs = []
92 | for t in range(1, sample_length):
93 | input_Es = Es[t-1]
94 | # memorize
95 | prev_key, prev_value = self.model(Fs[t-1].squeeze(1), input_Es, num_object)
96 |
97 | if t-1 == 0: #
98 | this_keys, this_values = prev_key, prev_value # only prev memory
99 | else:
100 | this_keys = torch.cat([keys, prev_key], dim=3)
101 | this_values = torch.cat([values, prev_value], dim=3)
102 |
103 | # segment
104 | logit = self.model(Fs[t].squeeze(1), this_keys, this_values, num_object)
105 | ELs.append(logit)
106 | Es[t] = F.softmax(logit, dim=1)
107 |
108 | # update
109 | keys, values = this_keys, this_values
110 | E.append(torch.cat(Es, dim=0)) # cat t
111 | E_logits.append(torch.cat(ELs, dim=0))
112 |
113 | pred = torch.stack(E, dim=0) # stack b
114 | E_logits = [None] + list(torch.stack(E_logits).split(1, dim=1))
115 | GT = torch.argmax(tris, dim=2)
116 | # Loss & Vis
117 | losses = []
118 | for t in range(1, sample_length):
119 | gt = GT[:,t].squeeze(1)
120 | p = E_logits[t].squeeze(1)
121 | if og_shape is not None:
122 | for b in range(batch_size):
123 | h, w = og_shape[b]
124 | gt[b, h:] = self.ignore_label
125 | gt[b, :, w:] = self.ignore_label
126 | if masks is not None:
127 | mask = M[t].squeeze(1)
128 | gt = torch.where(mask == 0, torch.ones_like(gt) * self.ignore_label, gt)
129 | losses.append(self.LOSS(p, gt))
130 | loss = sum(losses) / float(len(losses))
131 | return pred, loss
132 |
133 | def _forward_single_step(self, img_q, img, tri, alpha, hid, memories=None):
134 | num_object = torch.tensor([self.num_object]).to(torch.cuda.current_device())
135 | # we split batch here since the original code only supports b=1
136 | if self.hdim > 0:
137 | Es = torch.cat([tri, alpha, hid], dim=1)
138 | else:
139 | Es = tri
140 | # memorize
141 | prev_key, prev_value = self.model(img, Es, num_object)
142 |
143 | # update
144 | if memories is None:
145 | memories = dict()
146 | memories['key'] = prev_key
147 | memories['val'] = prev_value
148 | else:
149 | memories['key'] = torch.cat([memories['key'], prev_key], dim=3)
150 | memories['val'] = torch.cat([memories['val'], prev_value], dim=3)
151 |
152 | # segment
153 | logit = self.model(img_q, memories['key'], memories['val'], num_object)
154 | return logit, memories
155 |
156 | def forward(self, a, fg, bg, ignore_region=None, tri=None, og_shape=None,
157 | single_step=False, hid=None, memories=None):
158 | if single_step:
159 | # fg: query frame (normalized between 0~1) [B, 3, H, W]
160 | # bg: prev frame (normalized between 0~1) [B, 3, H, W]
161 | # tri: prev trimap (normalized between 0~1) [B, 3, H, W]
162 | # a: prev alpha (normalized between 0~1) [B, 1, H, W]
163 | logit, memories = self._forward_single_step(fg, bg, tri, a, hid, memories=memories)
164 | return logit, memories
165 | else:
166 | scaled_imgs, _, _, scaled_gts, tris, imgs = self.preprocess(a, fg, bg, ignore_region=ignore_region, tri=tri)
167 |
168 | pred, loss = self._forward(imgs, tris, scaled_gts, og_shape=og_shape)
169 |
170 | return [loss, scaled_imgs, pred, tris, scaled_gts]
171 |
172 |
173 | class FullModel_eval(FullModel):
174 | def _forward(self, imgs, tris, first_frame=False, masks=None, og_shape=None, save_memory=False, max_memory_num=2, memorize_gt=False):
175 | if self.stage == 1:
176 | num_object = torch.tensor([self.num_object]).to(torch.cuda.current_device())
177 |
178 | Fs = imgs
179 |
180 | if first_frame:
181 | Es = tris
182 | pred = Es
183 | else:
184 | logit = self.model(Fs, self.this_keys, self.this_values, num_object, memory_update=self.memory_update,)
185 | Es = F.softmax(logit, dim=1)
186 | pred = Es
187 |
188 | if save_memory and memorize_gt:
189 | Es = tris
190 | pred = tris
191 | prev_key, prev_value = self.model(Fs, Es, num_object)
192 |
193 | if max_memory_num == 0:
194 | if first_frame:
195 | self.this_keys = prev_key
196 | self.this_values = prev_value
197 | elif max_memory_num == 1:
198 | self.this_keys = prev_key
199 | self.this_values = prev_value
200 | else:
201 | if first_frame:
202 | self.this_keys = prev_key
203 | self.this_values = prev_value
204 | elif save_memory:
205 | self.this_keys = torch.cat([self.this_keys, prev_key], dim=3)
206 | self.this_values = torch.cat([self.this_values, prev_value], dim=3)
207 | else:
208 | if self.this_keys.size(3) == 1:
209 | self.this_keys = torch.cat([self.this_keys, prev_key], dim=3)
210 | self.this_values = torch.cat([self.this_values, prev_value], dim=3)
211 | else:
212 | self.this_keys = torch.cat([self.this_keys[:,:,:,:-1], prev_key], dim=3)
213 | self.this_values = torch.cat([self.this_values[:,:,:,:-1], prev_value], dim=3)
214 |
215 | if self.this_keys.size(3) > max_memory_num:
216 | if memorize_gt:
217 | self.this_keys = self.this_keys[:,:,:,1:]
218 | self.this_values = self.this_values[:,:,:,1:]
219 | else:
220 | self.this_keys = torch.cat([self.this_keys[:,:,:,:1], self.this_keys[:,:,:,2:]], dim=3)
221 | self.this_values = torch.cat([self.this_values[:,:,:,:1], self.this_values[:,:,:,2:]], dim=3)
222 |
223 | self.memory_update = save_memory
224 |
225 | return pred.unsqueeze(1), 0
226 |
227 | def _forward_memorize(self, img, tri, alpha, hid):
228 | num_object = torch.tensor([self.num_object]).to(torch.cuda.current_device())
229 | # we split batch here since the original code only supports b=1
230 | if self.hdim > 0:
231 | Es = torch.cat([tri, alpha, hid], dim=1)
232 | else:
233 | Es = tri
234 | # memorize
235 | prev_key, prev_value = self.model(img, Es, num_object)
236 | memories = {'key': prev_key,
237 | 'val': prev_value,
238 | }
239 | return memories
240 |
241 | def _forward_segment(self, img_q, memories=None, memory_update=False):
242 | num_object = torch.tensor([self.num_object]).to(torch.cuda.current_device())
243 | # segment
244 | logit = self.model(img_q, memories['key'], memories['val'], num_object)
245 | return logit
246 |
247 | def forward(self, a, fg, bg, tri=None, first_frame=False, og_shape=None,
248 | memorize=False, segment=False, memories=None, hid=None,
249 | save_memory=False, max_memory_num=2, memory_update=False,
250 | memorize_gt=False,):
251 | if memorize:
252 | # fg: query frame (normalized between 0~1) [B, 3, H, W]
253 | # bg: prev frame (normalized between 0~1) [B, 3, H, W]
254 | # tri: prev trimap (normalized between 0~1) [B, 3, H, W]
255 | # a: prev alpha (normalized between 0~1) [B, 1, H, W]
256 | memories = self._forward_memorize(bg, tri, a, hid)
257 | return memories
258 | elif segment:
259 | # fg: query frame (normalized between 0~1) [B, 3, H, W]
260 | # bg: prev frame (normalized between 0~1) [B, 3, H, W]
261 | # tri: prev trimap (normalized between 0~1) [B, 3, H, W]
262 | # a: prev alpha (normalized between 0~1) [B, 1, H, W]
263 | logit = self._forward_segment(fg, memories=memories)
264 | return logit
265 | else:
266 | scaled_imgs, _, _, scaled_gts, tris, imgs = self.preprocess(a, fg, bg)
267 | if tri is not None:
268 | tris = tri
269 | imgs_fw_HR = imgs.squeeze(0)
270 | tris_fw = tris.squeeze(0)
271 | _, _, H, W = imgs_fw_HR.shape
272 |
273 | imgs_fw = imgs_fw_HR
274 |
275 | pred, loss = self._forward(imgs_fw, tris_fw, first_frame=first_frame, og_shape=og_shape, save_memory=save_memory, max_memory_num=max_memory_num, memorize_gt=memorize_gt,)
276 |
277 | if first_frame:
278 | pred = tris
279 |
280 | return [loss,
281 | scaled_imgs, pred, tris, scaled_gts]
282 |
--------------------------------------------------------------------------------
/scripts/eval_s4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | GPU=$1
4 |
5 | python eval.py --gpu $GPU
--------------------------------------------------------------------------------
/scripts/eval_s4_demo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | GPU=$1
4 |
5 | python eval.py --gpu $GPU --demo
--------------------------------------------------------------------------------
/scripts/train_s1_alpha.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | GPUS=$1
3 | GPUS_ARRAY=($(echo $GPUS | tr ',' "\n"))
4 | NUMBER_OF_CUDA_DEVICES=${#GPUS_ARRAY[@]}
5 | if [ $NUMBER_OF_CUDA_DEVICES -gt 1 ]; then
6 | echo "Training with multiple GPUs: $GPUS"
7 | PY_CMD="-m torch.distributed.launch --nproc_per_node=$NUMBER_OF_CUDA_DEVICES --master_port $((RANDOM + 66000))"
8 | else
9 | echo "Training with a single GPU: $GPUS"
10 | PY_CMD=""
11 | fi
12 |
13 | python $PY_CMD train.py --stage 1 --gpu $GPUS
--------------------------------------------------------------------------------
/scripts/train_s1_trimap.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | GPUS=$1
3 |
4 | python train_s1_trimap.py --gpu $GPUS
5 |
--------------------------------------------------------------------------------
/scripts/train_s2_alpha.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | GPUS=$1
3 | GPUS_ARRAY=($(echo $GPUS | tr ',' "\n"))
4 | NUMBER_OF_CUDA_DEVICES=${#GPUS_ARRAY[@]}
5 | if [ $NUMBER_OF_CUDA_DEVICES -gt 1 ]; then
6 | echo "Training with multiple GPUs: $GPUS"
7 | PY_CMD="-m torch.distributed.launch --nproc_per_node=$NUMBER_OF_CUDA_DEVICES --master_port $((RANDOM + 66000))"
8 | else
9 | echo "Training with a single GPU: $GPUS"
10 | PY_CMD=""
11 | fi
12 |
13 | python $PY_CMD train.py --stage 2 --gpu $GPUS
14 |
--------------------------------------------------------------------------------
/scripts/train_s3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | GPUS=$1
3 | GPUS_ARRAY=($(echo $GPUS | tr ',' "\n"))
4 | NUMBER_OF_CUDA_DEVICES=${#GPUS_ARRAY[@]}
5 | if [ $NUMBER_OF_CUDA_DEVICES -gt 1 ]; then
6 | echo "Training with multiple GPUs: $GPUS"
7 | PY_CMD="-m torch.distributed.launch --nproc_per_node=$NUMBER_OF_CUDA_DEVICES --master_port $((RANDOM + 66000))"
8 | else
9 | echo "Training with a single GPU: $GPUS"
10 | PY_CMD=""
11 | fi
12 |
13 | python $PY_CMD train.py --stage 3 --gpu $GPUS
14 |
--------------------------------------------------------------------------------
/scripts/train_s4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | GPUS=$1
3 | GPUS_ARRAY=($(echo $GPUS | tr ',' "\n"))
4 | NUMBER_OF_CUDA_DEVICES=${#GPUS_ARRAY[@]}
5 | if [ $NUMBER_OF_CUDA_DEVICES -gt 1 ]; then
6 | echo "Training with multiple GPUs: $GPUS"
7 | PY_CMD="-m torch.distributed.launch --nproc_per_node=$NUMBER_OF_CUDA_DEVICES --master_port $((RANDOM + 66000))"
8 | else
9 | echo "Training with a single GPU: $GPUS"
10 | PY_CMD=""
11 | fi
12 |
13 | python $PY_CMD train.py --stage 4 --gpu $GPUS
14 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from io import UnsupportedOperation
3 | import logging
4 | import os
5 | import shutil
6 | import time
7 | import timeit
8 | import shutil
9 |
10 | import numpy as np
11 | import cv2 as cv
12 | import torch
13 | import torch.backends.cudnn as cudnn
14 | import torch.distributed as torch_dist
15 | import torch.nn.functional as F
16 | from torch import nn
17 | from torch.utils import data
18 | from torchvision.utils import save_image
19 |
20 | from config import get_cfg_defaults
21 | from dataset import DIM_Train, VideoMatting108_Train
22 | from helpers import *
23 | from utils.optimizer import RAdam
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser(description='Train network')
27 | parser.add_argument("--stage", type=int, default=1)
28 | parser.add_argument("--gpu", type=str, default='0,1,2,3')
29 | parser.add_argument("--local_rank", type=int, default=-1)
30 |
31 | args = parser.parse_args()
32 |
33 | cfg = get_cfg_defaults()
34 | cfg.TRAIN.STAGE = args.stage
35 | cfg.freeze()
36 |
37 | return args, cfg
38 |
39 | def main(args, cfg):
40 | MODEL = get_model_name(cfg)
41 | random_seed = cfg.SYSTEM.RANDOM_SEED
42 | base_lr = cfg.TRAIN.BASE_LR
43 |
44 | weight_decay = cfg.TRAIN.WEIGHT_DECAY
45 | output_dir = os.path.join(cfg.SYSTEM.OUTDIR, 'checkpoint')
46 | if args.local_rank <= 0:
47 | os.makedirs(output_dir, exist_ok=True)
48 | start = timeit.default_timer()
49 | # cudnn related setting
50 | cudnn.benchmark = cfg.SYSTEM.CUDNN_BENCHMARK
51 | cudnn.deterministic = cfg.SYSTEM.CUDNN_DETERMINISTIC
52 | cudnn.enabled = cfg.SYSTEM.CUDNN_ENABLED
53 | if random_seed > 0:
54 | import random
55 | if args.local_rank <= 0:
56 | print('Seeding with', random_seed)
57 | random.seed(random_seed+args.local_rank)
58 | torch.manual_seed(random_seed+args.local_rank)
59 |
60 | args.world_size = 1
61 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
62 | if args.local_rank >= 0:
63 | device = torch.device('cuda:{}'.format(args.local_rank))
64 | torch.cuda.set_device(device)
65 | torch.distributed.init_process_group(
66 | backend="nccl", init_method="env://",
67 | )
68 | args.world_size = torch.distributed.get_world_size()
69 | else:
70 | if torch.cuda.is_available():
71 | print('using Cuda devices, num:', torch.cuda.device_count())
72 |
73 | if args.local_rank <= 0:
74 | logger, final_output_dir = create_logger(output_dir, MODEL, 'train')
75 | print(cfg)
76 | with open(os.path.join(final_output_dir, 'config.yaml'), 'w') as f:
77 | f.write(str(cfg))
78 | image_outdir = os.path.join(final_output_dir, 'training_images')
79 | os.makedirs(os.path.join(final_output_dir, 'training_images'), exist_ok=True)
80 | else:
81 | image_outdir = None
82 |
83 | if cfg.TRAIN.STAGE == 1:
84 | model_trimap = None
85 | else:
86 | model_trimap = get_model_trimap(cfg, mode='Train')
87 | model = get_model_alpha(cfg, model_trimap, mode='Train')
88 |
89 |
90 | if cfg.TRAIN.STAGE == 1:
91 | load_ckpt = './weights/FBA.pth'
92 | dct = torch.load(load_ckpt, map_location=torch.device('cpu'))
93 | if 'state_dict' in dct.keys():
94 | dct = dct['state_dict']
95 | missing_keys, unexpected_keys = model.NET.load_state_dict(dct, strict=False)
96 | if args.local_rank <= 0:
97 | logger.info('Missing keys: ' + str(sorted(missing_keys)))
98 | logger.info('Unexpected keys: ' + str(sorted(unexpected_keys)))
99 | logger.info("=> loaded checkpoint from Image Matting Weight: {}".format(load_ckpt))
100 | elif cfg.TRAIN.STAGE in [2,3]:
101 | load_ckpt = './weights/s1_OTVM_trimap.pth'
102 | dct = torch.load(load_ckpt, map_location=torch.device('cpu'))
103 | missing_keys, unexpected_keys = model.trimap.model.load_state_dict(dct, strict=False)
104 | if args.local_rank <= 0:
105 | logger.info('Missing keys: ' + str(sorted(missing_keys)))
106 | logger.info('Unexpected keys: ' + str(sorted(unexpected_keys)))
107 | logger.info("=> loaded checkpoint from Pretrained STM Weight: {}".format(load_ckpt))
108 |
109 | if cfg.TRAIN.STAGE == 2:
110 | load_ckpt = './weights/s1_OTVM_alpha.pth'
111 | elif cfg.TRAIN.STAGE == 3:
112 | load_ckpt = './weights/s2_OTVM_alpha.pth'
113 | dct = torch.load(load_ckpt, map_location=torch.device('cpu'))
114 | missing_keys, unexpected_keys = model.NET.load_state_dict(dct, strict=False)
115 | if args.local_rank <= 0:
116 | logger.info('Missing keys: ' + str(sorted(missing_keys)))
117 | logger.info('Unexpected keys: ' + str(sorted(unexpected_keys)))
118 | elif cfg.TRAIN.STAGE == 4:
119 | load_ckpt = './weights/s3_OTVM.pth'
120 | dct = torch.load(load_ckpt, map_location=torch.device('cpu'))
121 | model.load_state_dict(dct)
122 |
123 | torch_barrier()
124 |
125 | ADDITIONAL_INPUTS = dict()
126 |
127 | start_epoch = 0
128 |
129 | if args.local_rank >= 0:
130 | # FBA particularly uses batch_size == 1, thus no syncbn here
131 | if (not cfg.ALPHA.MODEL.endswith('fba')) and (not cfg.TRAIN.FREEZE_BN):
132 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
133 | model = model.to(device)
134 | find_unused_parameters = False
135 | if cfg.TRAIN.STAGE == 2:
136 | find_unused_parameters = True
137 | model = torch.nn.parallel.DistributedDataParallel(
138 | model,
139 | find_unused_parameters=find_unused_parameters,
140 | device_ids=[args.local_rank],
141 | output_device=args.local_rank,
142 | )
143 | else:
144 | model = torch.nn.DataParallel(model).cuda()
145 |
146 | if cfg.TRAIN.STAGE in [2,3]:
147 | params = list()
148 | for k, v in model.named_parameters():
149 | if v.requires_grad:
150 | _k = k[7:] # remove 'module.'
151 | if _k.startswith('NET.'):
152 | if cfg.TRAIN.STAGE == 3:
153 | if args.local_rank <= 0:
154 | logging.info('do NOT train parameter: %s'%(k))
155 | pass
156 | else:
157 | params.append({'params': v, 'lr': base_lr})
158 | elif _k.startswith('trimap.'):
159 | if cfg.TRAIN.STAGE == 2:
160 | if args.local_rank <= 0:
161 | logging.info('do NOT train parameter: %s'%(k))
162 | pass
163 | else:
164 | params.append({'params': v, 'lr': base_lr})
165 | else:
166 | if args.local_rank <= 0:
167 | logging.info('%s: Undefined parameter'%(k))
168 | params.append({'params': v, 'lr': base_lr})
169 | else:
170 | params_dict = {k: v for k, v in model.named_parameters() if v.requires_grad}
171 | params = [{'params': list(params_dict.values()), 'lr': base_lr}]
172 |
173 | params_count = 0
174 | if args.local_rank <= 0:
175 | logging.info('=> Parameters needs to be optimized:')
176 | for param in params:
177 | _param = param['params']
178 | if type(_param) is list:
179 | for _p in _param:
180 | params_count += _p.shape.numel()
181 | else:
182 | params_count += _param.shape.numel()
183 | logging.info('=> Total Parameters: {}'.format(params_count))
184 |
185 |
186 | if cfg.TRAIN.OPTIMIZER == 'adam':
187 | optimizer = torch.optim.Adam(params, lr=base_lr)
188 | elif cfg.TRAIN.OPTIMIZER == 'radam':
189 | optimizer = RAdam(params, lr=base_lr, weight_decay=weight_decay)
190 |
191 | if cfg.TRAIN.LR_STRATEGY == 'stair':
192 | adjust_lr = stair_lr
193 | elif cfg.TRAIN.LR_STRATEGY == 'poly':
194 | adjust_lr = poly_lr
195 | elif cfg.TRAIN.LR_STRATEGY == 'const':
196 | adjust_lr = const_lr
197 | else:
198 | raise NotImplementedError('[%s] is not supported in cfg.TRAIN.LR_STRATEGY'%(cfg.TRAIN.LR_STRATEGY))
199 |
200 | total_epochs = cfg.TRAIN.TOTAL_EPOCHS
201 |
202 | sample_length = cfg.TRAIN.FRAME_NUM
203 | if cfg.TRAIN.STAGE == 1:
204 | sample_length = 1
205 | if cfg.TRAIN.STAGE in [1,2,3]:
206 | train_dataset = DIM_Train(
207 | data_root=cfg.DATASET.PATH,
208 | image_shape=cfg.TRAIN.TRAIN_INPUT_SIZE,
209 | mode='train',
210 | sample_length=sample_length,
211 | )
212 | else:
213 | train_dataset = VideoMatting108_Train(
214 | data_root=cfg.DATASET.PATH,
215 | image_shape=cfg.TRAIN.TRAIN_INPUT_SIZE,
216 | mode='train',
217 | sample_length=sample_length,
218 | max_skip=15,
219 | do_affine=0.5,
220 | do_time_flip=0.5,
221 | )
222 |
223 | if cfg.SYSTEM.TESTMODE:
224 | start_epoch = max(start_epoch, total_epochs - 1)
225 | for epoch in range(start_epoch, total_epochs):
226 | train(epoch, cfg, args, train_dataset, base_lr, start_epoch, total_epochs,
227 | optimizer, model, adjust_lr, image_outdir, MODEL,
228 | ADDITIONAL_INPUTS)
229 | if args.local_rank <= 0:
230 | if (((epoch+1) % cfg.TRAIN.SAVE_EVERY_EPOCH) == 0) or ((epoch+1) == total_epochs):
231 | weight_fn = os.path.join(final_output_dir, 'checkpoint_{}.pth'.format(epoch+1))
232 | logger.info('=> saving checkpoint to {}'.format(weight_fn))
233 | if cfg.TRAIN.STAGE in [1,2]:
234 | torch.save(model.module.NET.state_dict(), weight_fn)
235 | else:
236 | torch.save(model.module.state_dict(), weight_fn)
237 | optim_fn = os.path.join(final_output_dir, 'optim_{}.pth'.format(epoch+1))
238 | torch.save(optimizer.state_dict(), optim_fn)
239 |
240 | if args.local_rank <= 0:
241 | weight_fn = os.path.join('weights', '{:s}.pth'.format(MODEL))
242 | logger.info('=> saving checkpoint to {}'.format(weight_fn))
243 | if cfg.TRAIN.STAGE in [1,2]:
244 | torch.save(model.module.NET.state_dict(), weight_fn)
245 | else:
246 | torch.save(model.module.state_dict(), weight_fn)
247 |
248 | end = timeit.default_timer()
249 | if args.local_rank <= 0:
250 | logger.info('Time: %d sec.' % np.int32((end-start)))
251 | logger.info('Done')
252 |
253 |
254 |
255 | def write_image(outdir, out, step, max_batch=1, trimap=False):
256 | with torch.no_grad():
257 | scaled_imgs, scaled_tris, alphas, comps, gts, fgs, bgs = out[:7]
258 | if trimap:
259 | pred_tris = out[7]
260 | b, s, _, h, w = scaled_imgs.shape
261 | b = max_batch if b > max_batch else b
262 | img_list = list()
263 | img_list.append(scaled_imgs[:max_batch].reshape(b*s, 3, h, w))
264 | img_list.append(scaled_tris[:max_batch].reshape(b*s, 1, h, w).expand(-1, 3, -1, -1))
265 | img_list.append(gts[:max_batch].reshape(b*s, 1, h, w).expand(-1, 3, -1, -1))
266 | img_list.append(alphas[:max_batch].reshape(b*s, 1, h, w).expand(-1, 3, -1, -1))
267 | if trimap:
268 | img_list.append(pred_tris[:max_batch].reshape(b*s, 3, h, w))
269 | img_list.append(comps[:max_batch].reshape(b*s, 3, h, w))
270 | img_list.append(fgs[:max_batch].reshape(b*s, 3, h, w))
271 | img_list.append(bgs[:max_batch].reshape(b*s, 3, h, w))
272 | imgs = torch.cat(img_list, dim=0).reshape(-1, 3, h, w)
273 | if h > 320:
274 | imgs = F.interpolate(imgs, scale_factor=320/h)
275 | save_image(imgs, os.path.join(outdir, '{}.png'.format(step)), nrow=int(s*b))
276 |
277 | def train(epoch, cfg, args, train_dataset, base_lr, start_epoch, total_epochs,
278 | optimizer, model, adjust_learning_rate, image_outdir, MODEL,
279 | ADDITIONAL_INPUTS):
280 | # Training
281 | torch.cuda.empty_cache()
282 | if cfg.TRAIN.STAGE in [1,2,3]:
283 | train_dataset_concat = [train_dataset] * 20
284 | else:
285 | if epoch < 100:
286 | SKIP = min(1+(epoch//5), 25)
287 | else:
288 | SKIP = max(44-(epoch//5), 10)
289 | train_dataset.max_skip = SKIP
290 | train_dataset_concat = [train_dataset] * 20
291 |
292 | train_dataset = data.ConcatDataset(train_dataset_concat)
293 | train_sampler = get_sampler(train_dataset)
294 | trainloader = torch.utils.data.DataLoader(
295 | train_dataset,
296 | batch_size=int(cfg.TRAIN.BATCH_SIZE // args.world_size),
297 | num_workers=cfg.SYSTEM.NUM_WORKERS,
298 | pin_memory=True,
299 | drop_last=True,
300 | shuffle=True if train_sampler is None else False,
301 | sampler=train_sampler)
302 |
303 | if args.local_rank >= 0:
304 | train_sampler.set_epoch(epoch)
305 |
306 | iters_per_epoch = len(trainloader)
307 | image_freq = cfg.TRAIN.IMAGE_FREQ if cfg.TRAIN.IMAGE_FREQ > 0 else 1e+8
308 | image_freq = min(image_freq, iters_per_epoch)
309 |
310 | # STM DISABLES BN DURING TRAINING
311 | model.train()
312 | if cfg.TRAIN.STAGE > 1:
313 | for m in model.module.trimap.modules():
314 | if isinstance(m, nn.BatchNorm2d):
315 | m.eval() # turn-off BN
316 | if cfg.TRAIN.FREEZE_BN:
317 | for m in model.modules():
318 | if isinstance(m, nn.BatchNorm2d):
319 | m.eval() # turn-off BN
320 | if cfg.TRAIN.STAGE == 2:
321 | model.module.trimap.eval()
322 | if args.local_rank <= 0:
323 | logging.info('Set trimap model to eval mode')
324 | if cfg.TRAIN.STAGE == 3:
325 | model.module.NET.eval()
326 | if args.local_rank <= 0:
327 | logging.info('Set alpha model to eval mode')
328 |
329 | sub_losses = ['L_alpha', 'L_comp', 'L_grad'] if not cfg.ALPHA.MODEL.endswith('fba') else \
330 | ['L_alpha_comp', 'L_lap', 'L_grad']
331 |
332 | data_time = AverageMeter()
333 | losses = AverageMeter()
334 | sub_losses_avg = [AverageMeter() for _ in range(len(sub_losses))]
335 | tic = time.time()
336 | cur_iters = epoch*iters_per_epoch
337 |
338 | prefetcher = data_prefetcher(trainloader)
339 | dp = prefetcher.next()
340 | i_iter = 0
341 | while dp[0] is not None:
342 | if cfg.SYSTEM.TESTMODE:
343 | if i_iter > 20:
344 | print()
345 | break
346 | def step(i_iter, dp, tic):
347 | data_time.update(time.time() - tic)
348 |
349 | def handle_batch():
350 | fg, bg, a, ir, tri, _ = dp # [B, 3, 3 or 1, H, W]
351 |
352 | bg = bg if bg.dim() > 1 else None
353 | a = a if a.dim() > 1 else None
354 | ir = ir if ir.dim() > 1 else None
355 |
356 | out = model(a, fg, bg, ignore_region=ir, tri=tri)
357 | L_alpha = out[0].mean()
358 | L_comp = out[1].mean()
359 | L_grad = out[2].mean()
360 | vis_alpha = L_alpha.detach()#.item()
361 | vis_comp = L_comp.detach()#.item()
362 | vis_grad = L_grad.detach()#.item()
363 | if cfg.TRAIN.STAGE == 1:
364 | loss = L_alpha + L_comp + L_grad
365 | batch_out = [loss.detach(), vis_alpha, vis_comp, vis_grad, out[4:-1]]
366 | else:
367 | L_tri = out[3].mean()
368 | loss = L_alpha + L_comp + L_grad + L_tri
369 | batch_out = [loss.detach(), vis_alpha, vis_comp, vis_grad, out[4:]]
370 |
371 | model.zero_grad()
372 | loss.backward()
373 | optimizer.step()
374 |
375 | return batch_out
376 |
377 | loss, vis_alpha, vis_comp, vis_grad, vis_images = handle_batch()
378 |
379 | reduced_loss = reduce_tensor(loss)
380 | reduced_sub_losses = [reduce_tensor(vis_alpha), reduce_tensor(vis_comp), reduce_tensor(vis_grad)]
381 |
382 | # update average loss
383 | losses.update(reduced_loss.item())
384 | sub_losses_avg[0].update(reduced_sub_losses[0].item())
385 | sub_losses_avg[1].update(reduced_sub_losses[1].item())
386 | sub_losses_avg[2].update(reduced_sub_losses[2].item())
387 |
388 | torch_barrier()
389 |
390 | current_lr = adjust_learning_rate(optimizer,
391 | base_lr,
392 | total_epochs * iters_per_epoch,
393 | i_iter+cur_iters)
394 |
395 | if args.local_rank <= 0:
396 | progress_bar(i_iter, iters_per_epoch, epoch, start_epoch, total_epochs, 'finetuning',
397 | 'Data: {data_time} | '
398 | 'Loss: {loss.val:.4f} ({loss.avg:.4f}) | '
399 | '{sub_losses[0]}: {sub_losses_avg[0].val:.4f} ({sub_losses_avg[0].avg:.4f})'.format(
400 | data_time=format_time(data_time.sum),
401 | loss=losses,
402 | sub_losses=sub_losses,
403 | sub_losses_avg=sub_losses_avg))
404 |
405 | if i_iter % image_freq == 0 and args.local_rank <= 0:
406 | write_image(image_outdir, vis_images, i_iter+cur_iters, trimap=(cfg.TRAIN.STAGE > 1))
407 | return current_lr
408 |
409 | current_lr = step(i_iter, dp, tic)
410 | tic = time.time()
411 |
412 | dp = prefetcher.next()
413 | i_iter += 1
414 |
415 | if args.local_rank <= 0:
416 | logger_str = '{:s} | E [{:d}] | I [{:d}] | LR [{:.1e}] | Total Loss:{: 4.6f}'.format(
417 | MODEL, epoch+1, i_iter+1, current_lr, losses.avg)
418 | logger_str += ' | {} [{: 4.6f}] | {} [{: 4.6f}] | {} [{: 4.6f}]'.format(
419 | sub_losses[0], sub_losses_avg[0].avg,
420 | sub_losses[1], sub_losses_avg[1].avg,
421 | sub_losses[2], sub_losses_avg[2].avg)
422 | logging.info(logger_str)
423 |
424 | class data_prefetcher():
425 | def __init__(self, loader):
426 | self.loader = iter(loader)
427 | self.stream = torch.cuda.Stream()
428 | self.preload()
429 |
430 | def preload(self):
431 | try:
432 | self.next_fg, self.next_bg, self.next_a, self.next_ir, self.next_tri, self.next_idx = next(self.loader)
433 | except StopIteration:
434 | self.next_fg = None
435 | self.next_bg = None
436 | self.next_a = None
437 | self.next_ir = None
438 | self.next_tri = None
439 | self.next_idx = None
440 | return
441 | with torch.cuda.stream(self.stream):
442 | self.next_fg = self.next_fg.cuda(non_blocking=True)
443 | self.next_bg = self.next_bg.cuda(non_blocking=True)
444 | self.next_a = self.next_a.cuda(non_blocking=True)
445 | self.next_ir = self.next_ir.cuda(non_blocking=True)
446 | self.next_tri = self.next_tri.cuda(non_blocking=True)
447 | self.next_idx = self.next_idx.cuda(non_blocking=True)
448 |
449 | def next(self):
450 | torch.cuda.current_stream().wait_stream(self.stream)
451 | fg = self.next_fg
452 | bg = self.next_bg
453 | a = self.next_a
454 | ir = self.next_ir
455 | tri = self.next_tri
456 | idx = self.next_idx
457 | if fg is not None:
458 | fg.record_stream(torch.cuda.current_stream())
459 | if bg is not None:
460 | bg.record_stream(torch.cuda.current_stream())
461 | if a is not None:
462 | a.record_stream(torch.cuda.current_stream())
463 | if ir is not None:
464 | ir.record_stream(torch.cuda.current_stream())
465 | if tri is not None:
466 | tri.record_stream(torch.cuda.current_stream())
467 | if idx is not None:
468 | idx.record_stream(torch.cuda.current_stream())
469 | self.preload()
470 | return fg, bg, a, ir, tri, idx
471 |
472 |
473 |
474 |
475 | def get_sampler(dataset, shuffle=True):
476 | if torch_dist.is_available() and torch_dist.is_initialized():
477 | from torch.utils.data.distributed import DistributedSampler
478 | return DistributedSampler(dataset, shuffle=shuffle)
479 | else:
480 | return None
481 |
482 |
483 | def IoU(pred, true):
484 | _, _, n_class, _, _ = pred.shape
485 |
486 | _, xx = torch.max(pred, dim=2)
487 | _, yy = torch.max(true, dim=2)
488 | iou = list()
489 | for n in range(n_class):
490 | x = (xx == n).float()
491 | y = (yy == n).float()
492 |
493 | i = torch.sum(torch.sum(x*y, dim=-1), dim=-1) # sum over spatial dims
494 | u = torch.sum(torch.sum((x+y)-(x*y), dim=-1), dim=-1)
495 |
496 | iou.append(((i + 1e-4) / (u + 1e-4)).mean().item() * 100.) # b
497 |
498 | # mean over mini-batch
499 | return sum(iou)/n_class, iou
500 |
501 |
502 | if __name__ == "__main__":
503 | args, cfg = parse_args()
504 | main(args, cfg)
505 |
--------------------------------------------------------------------------------
/train_s1_trimap.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import time
5 | import timeit
6 | import shutil
7 |
8 | import numpy as np
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | import torch.nn.functional as F
12 | from torch import nn, optim
13 | from torch.utils import data
14 | from torchvision.utils import save_image
15 |
16 | from config import get_cfg_defaults
17 | from dataset import DIM_Train
18 | from helpers import *
19 | from utils.optimizer import RAdam
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(description='Train network')
23 | parser.add_argument("--gpu", type=str, default='0,1,2,3')
24 |
25 | args = parser.parse_args()
26 |
27 | cfg = get_cfg_defaults()
28 | cfg.freeze()
29 |
30 | return args, cfg
31 |
32 |
33 | def main(args, cfg):
34 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
35 | if torch.cuda.is_available():
36 | print('using Cuda devices, num:', torch.cuda.device_count())
37 |
38 | MODEL = 's1_OTVM_trimap'
39 | random_seed = cfg.SYSTEM.RANDOM_SEED
40 | base_lr = cfg.TRAIN.BASE_LR
41 | weight_decay = cfg.TRAIN.WEIGHT_DECAY
42 | output_dir = os.path.join(cfg.SYSTEM.OUTDIR, 'checkpoint')
43 | os.makedirs(output_dir, exist_ok=True)
44 | start = timeit.default_timer()
45 | # cudnn related setting
46 | cudnn.benchmark = cfg.SYSTEM.CUDNN_BENCHMARK
47 | cudnn.deterministic = cfg.SYSTEM.CUDNN_DETERMINISTIC
48 | cudnn.enabled = cfg.SYSTEM.CUDNN_ENABLED
49 | if random_seed > 0:
50 | import random
51 | print('Seeding with', random_seed)
52 | random.seed(random_seed)
53 | torch.manual_seed(random_seed)
54 |
55 | logger, final_output_dir = create_logger(output_dir, MODEL, 'train')
56 | print(cfg)
57 | with open(os.path.join(final_output_dir, 'config.yaml'), 'w') as f:
58 | f.write(str(cfg))
59 | image_outdir = os.path.join(final_output_dir, 'training_images')
60 | os.makedirs(os.path.join(final_output_dir, 'training_images'), exist_ok=True)
61 |
62 | model = get_model_trimap(cfg, mode='Train')
63 | torch_barrier()
64 |
65 | start_epoch = 0
66 |
67 | load_ckpt = './weights/STM_weights.pth'
68 | dct = load_NoPrefix(load_ckpt, 7)
69 | missing_keys, unexpected_keys = model.model.load_state_dict(dct, strict=False)
70 | logger.info('Missing keys: ' + str(sorted(missing_keys)))
71 | logger.info('Unexpected keys: ' + str(sorted(unexpected_keys)))
72 | logger.info("=> loaded checkpoint from {}".format(load_ckpt))
73 |
74 | model = torch.nn.DataParallel(model).cuda()
75 |
76 | # optimizer
77 | params_dict = {k: v for k, v in model.named_parameters() if v.requires_grad}
78 |
79 | params_count = 0
80 | logging.info('=> Parameters needs to be optimized:')
81 | for k in sorted(params_dict):
82 | params_count += params_dict[k].shape.numel()
83 | logging.info('=> Total Parameters: {}'.format(params_count))
84 |
85 | params = [{'params': list(params_dict.values()), 'lr': base_lr}]
86 | if cfg.TRAIN.OPTIMIZER == 'adam':
87 | optimizer = torch.optim.Adam(params, lr=base_lr)
88 | elif cfg.TRAIN.OPTIMIZER == 'radam':
89 | optimizer = RAdam(params, lr=base_lr, weight_decay=weight_decay)
90 |
91 | if cfg.TRAIN.LR_STRATEGY == 'stair':
92 | adjust_lr = stair_lr
93 | elif cfg.TRAIN.LR_STRATEGY == 'poly':
94 | adjust_lr = poly_lr
95 | elif cfg.TRAIN.LR_STRATEGY == 'const':
96 | adjust_lr = const_lr
97 | else:
98 | raise NotImplementedError('[%s] is not supported in cfg.TRAIN.LR_STRATEGY'%(cfg.TRAIN.LR_STRATEGY))
99 |
100 | total_epochs = cfg.TRAIN.TOTAL_EPOCHS
101 |
102 | train_dataset = DIM_Train(
103 | data_root=cfg.DATASET.PATH,
104 | image_shape=cfg.TRAIN.TRAIN_INPUT_SIZE,
105 | mode='train',
106 | sample_length=3,
107 | )
108 | train_dataset = [train_dataset] * 20
109 |
110 | train_dataset = data.ConcatDataset(train_dataset)
111 | trainloader = torch.utils.data.DataLoader(
112 | train_dataset,
113 | batch_size=cfg.TRAIN.BATCH_SIZE,
114 | num_workers=cfg.SYSTEM.NUM_WORKERS,
115 | pin_memory=False,
116 | drop_last=True,
117 | shuffle=True)
118 |
119 | if cfg.SYSTEM.TESTMODE:
120 | start_epoch += 199
121 | for epoch in range(start_epoch, total_epochs):
122 | train(epoch, cfg, trainloader, base_lr, start_epoch, total_epochs,
123 | optimizer, model, adjust_lr, image_outdir, MODEL)
124 |
125 | if (((epoch+1) % cfg.TRAIN.SAVE_EVERY_EPOCH) == 0) or ((epoch+1) == total_epochs):
126 | weight_fn = os.path.join(final_output_dir, 'checkpoint_{}.pth'.format(epoch+1))
127 | logger.info('=> saving checkpoint to {}'.format(weight_fn))
128 | torch.save(model.module.model.state_dict(), weight_fn)
129 | optim_fn = os.path.join(final_output_dir, 'optim_{}.pth'.format(epoch+1))
130 | torch.save(optimizer.state_dict(), optim_fn)
131 |
132 | weight_fn = os.path.join('weights', '{:s}.pth'.format(MODEL))
133 | logger.info('=> saving checkpoint to {}'.format(weight_fn))
134 | torch.save(model.module.model.state_dict(), weight_fn)
135 | end = timeit.default_timer()
136 | logger.info('Time: %d sec.' % np.int32((end-start)))
137 | logger.info('Done')
138 |
139 |
140 |
141 | def write_image(outdir, out, step, max_batch=1):
142 | with torch.no_grad():
143 | scaled_imgs, pred, tris, scaled_gts = out
144 | b, s, _, h, w = scaled_imgs.shape
145 | b = max_batch if b > max_batch else b
146 | img_list = list()
147 | img_list.append(scaled_imgs[:max_batch].reshape(b*s, 3, h, w))
148 | img_list.append(tris[:max_batch].reshape(b*s, 3, h, w))
149 | img_list.append(pred[:max_batch].reshape(b*s, 3, h, w))
150 | imgs = torch.cat(img_list, dim=0).reshape(-1, 3, h, w)
151 | if h > 320:
152 | imgs = F.interpolate(imgs, scale_factor=320/h)
153 | save_image(imgs, os.path.join(outdir, '{}.png'.format(step)), nrow=int(s*b))
154 |
155 | def train(epoch, cfg, trainloader, base_lr, start_epoch, total_epochs,
156 | optimizer, model, adjust_learning_rate, image_outdir, MODEL):
157 | # Training
158 | iters_per_epoch = len(trainloader)
159 | image_freq = cfg.TRAIN.IMAGE_FREQ if cfg.TRAIN.IMAGE_FREQ > 0 else 1e+8
160 | image_freq = min(image_freq, iters_per_epoch)
161 |
162 | # STM DISABLES BN DURING TRAINING
163 | model.train()
164 | for m in model.modules():
165 | if isinstance(m, nn.BatchNorm2d):
166 | m.eval() # turn-off BN
167 |
168 | data_time = AverageMeter()
169 | losses = AverageMeter()
170 | IOU = AverageMeter()
171 | tic = time.time()
172 | cur_iters = epoch*iters_per_epoch
173 |
174 | prefetcher = data_prefetcher(trainloader)
175 | dp = prefetcher.next()
176 | i_iter = 0
177 | while dp[0] is not None:
178 | data_time.update(time.time() - tic)
179 | if cfg.SYSTEM.TESTMODE:
180 | if i_iter > 20:
181 | print()
182 | break
183 |
184 | def handle_batch():
185 | fg, bg, a, ir, tri, _ = dp # [B, 3, 3 or 1, H, W]
186 |
187 | bg = bg if bg.dim() > 1 else None
188 | a = a if a.dim() > 1 else None
189 | ir = ir if ir.dim() > 1 else None
190 |
191 | out = model(a, fg, bg, ignore_region=ir, tri=tri)
192 | loss = out[0].mean()
193 |
194 |
195 | model.zero_grad()
196 | loss.backward()
197 | optimizer.step()
198 | return loss.detach(), out[1:]
199 |
200 | loss, vis_out = handle_batch()
201 |
202 | reduced_loss = reduce_tensor(loss)
203 |
204 | # update average loss
205 | losses.update(reduced_loss.item())
206 |
207 | tri_pred = vis_out[1]
208 | tri_gt = vis_out[2]
209 | mIoU, _ = IoU(tri_pred, tri_gt)
210 | IOU.update(mIoU)
211 | torch_barrier()
212 |
213 | current_lr = adjust_learning_rate(optimizer,
214 | base_lr,
215 | total_epochs * iters_per_epoch,
216 | i_iter+cur_iters)
217 |
218 | tic = time.time()
219 | progress_bar(i_iter, iters_per_epoch, epoch, start_epoch, total_epochs, 'finetuning',
220 | 'Data: {data_time} | '
221 | 'Loss: {loss.val:.4f} ({loss.avg:.4f}) | '
222 | 'IOU: {IOU.val:.4f} ({IOU.avg:.4f})'.format(
223 | data_time=format_time(data_time.sum),
224 | loss=losses,
225 | IOU=IOU))
226 |
227 | if i_iter % image_freq == 0:
228 | write_image(image_outdir, vis_out, i_iter+cur_iters)
229 |
230 | dp = prefetcher.next()
231 | i_iter += 1
232 |
233 | logger_str = '{:s} | E [{:d}] | I [{:d}] | LR [{:.1e}] | CE:{: 4.6f} | mIoU:{: 4.6f}'
234 | logger_format = [MODEL, epoch+1, i_iter+1, current_lr, losses.avg, IOU.avg]
235 | logging.info(logger_str.format(*logger_format))
236 |
237 | class data_prefetcher():
238 | def __init__(self, loader):
239 | self.loader = iter(loader)
240 | self.stream = torch.cuda.Stream()
241 | self.preload()
242 |
243 | def preload(self):
244 | try:
245 | self.next_fg, self.next_bg, self.next_a, self.next_ir, self.next_tri, self.next_idx = next(self.loader)
246 | except StopIteration:
247 | self.next_fg = None
248 | self.next_bg = None
249 | self.next_a = None
250 | self.next_ir = None
251 | self.next_tri = None
252 | self.next_idx = None
253 | return
254 | with torch.cuda.stream(self.stream):
255 | self.next_fg = self.next_fg.cuda(non_blocking=True)
256 | self.next_bg = self.next_bg.cuda(non_blocking=True)
257 | self.next_a = self.next_a.cuda(non_blocking=True)
258 | self.next_ir = self.next_ir.cuda(non_blocking=True)
259 | self.next_tri = self.next_tri.cuda(non_blocking=True)
260 | self.next_idx = self.next_idx.cuda(non_blocking=True)
261 |
262 | def next(self):
263 | torch.cuda.current_stream().wait_stream(self.stream)
264 | fg = self.next_fg
265 | bg = self.next_bg
266 | a = self.next_a
267 | ir = self.next_ir
268 | tri = self.next_tri
269 | idx = self.next_idx
270 | if fg is not None:
271 | fg.record_stream(torch.cuda.current_stream())
272 | if bg is not None:
273 | bg.record_stream(torch.cuda.current_stream())
274 | if a is not None:
275 | a.record_stream(torch.cuda.current_stream())
276 | if ir is not None:
277 | ir.record_stream(torch.cuda.current_stream())
278 | if tri is not None:
279 | tri.record_stream(torch.cuda.current_stream())
280 | if idx is not None:
281 | idx.record_stream(torch.cuda.current_stream())
282 | self.preload()
283 | return fg, bg, a, ir, tri, idx
284 |
285 |
286 |
287 | def IoU(pred, true):
288 | _, _, n_class, _, _ = pred.shape
289 |
290 | _, xx = torch.max(pred, dim=2)
291 | _, yy = torch.max(true, dim=2)
292 | iou = list()
293 | for n in range(n_class):
294 | x = (xx == n).float()
295 | y = (yy == n).float()
296 |
297 | i = torch.sum(torch.sum(x*y, dim=-1), dim=-1) # sum over spatial dims
298 | u = torch.sum(torch.sum((x+y)-(x*y), dim=-1), dim=-1)
299 |
300 | iou.append(((i + 1e-4) / (u + 1e-4)).mean().item() * 100.) # b
301 |
302 | # mean over mini-batch
303 | return sum(iou)/n_class, iou
304 |
305 |
306 | if __name__ == "__main__":
307 | args, cfg = parse_args()
308 | main(args, cfg)
309 |
--------------------------------------------------------------------------------
/utils/loss_func.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def L1_mask(x, y, mask=None, epsilon=1.001e-5, normalize=True):
5 | res = torch.abs(x - y)
6 | b,c,h,w = y.shape
7 | if mask is not None:
8 | res = res * mask
9 | if normalize:
10 | _safe = torch.sum((mask > epsilon).float()).clamp(epsilon, b*c*h*w+1)
11 | return torch.sum(res) / _safe
12 | else:
13 | return torch.sum(res)
14 | if normalize:
15 | return torch.mean(res)
16 | else:
17 | return torch.sum(res)
18 |
19 |
20 | def L1_mask_hard_mining(x, y, mask):
21 | input_size = x.size()
22 | res = torch.sum(torch.abs(x - y), dim=1, keepdim=True)
23 | with torch.no_grad():
24 | idx = mask > 0.5
25 | res_sort = [torch.sort(res[i, idx[i, ...]])[0] for i in range(idx.shape[0])]
26 | res_sort = [i[int(i.shape[0] * 0.5)].item() for i in res_sort]
27 | new_mask = mask.clone()
28 | for i in range(res.shape[0]):
29 | new_mask[i, ...] = ((mask[i, ...] > 0.5) & (res[i, ...] > res_sort[i])).float()
30 |
31 | res = res * new_mask
32 | final_res = torch.sum(res) / torch.sum(new_mask)
33 | return final_res, new_mask
34 |
35 | def get_gradient(image):
36 | b, c, h, w = image.shape
37 | dy = image[:, :, 1:, :] - image[:, :, :-1, :]
38 | dx = image[:, :, :, 1:] - image[:, :, :, :-1]
39 |
40 | dy = F.pad(dy, (0, 0, 0, 1))
41 | dx = F.pad(dx, (0, 1, 0, 0))
42 | return dx, dy
43 |
44 | def L1_grad(pred, gt, mask=None, epsilon=1.001e-5, normalize=True):
45 | fake_grad_x, fake_grad_y = get_gradient(pred)
46 | true_grad_x, true_grad_y = get_gradient(gt)
47 |
48 | mag_fake = torch.sqrt(fake_grad_x ** 2 + fake_grad_y ** 2 + epsilon)
49 | mag_true = torch.sqrt(true_grad_x ** 2 + true_grad_y ** 2 + epsilon)
50 |
51 | return L1_mask(mag_fake, mag_true, mask=mask, normalize=normalize)
52 |
53 | '''
54 | Ported from https://github.com/ceciliavision/perceptual-reflection-removal/blob/master/main.py
55 | '''
56 | def exclusion_loss(img1, img2, level, epsilon=1.001e-5, normalize=True):
57 | gradx_loss=[]
58 | grady_loss=[]
59 | for l in range(level):
60 | gradx1, grady1 = get_gradient(img1)
61 | gradx2, grady2 = get_gradient(img2)
62 |
63 | alphax=2.0*torch.mean(torch.abs(gradx1))/(torch.mean(torch.abs(gradx2)) + epsilon)
64 | alphay=2.0*torch.mean(torch.abs(grady1))/(torch.mean(torch.abs(grady2)) + epsilon)
65 |
66 | gradx1_s=(torch.sigmoid(gradx1)*2)-1
67 | grady1_s=(torch.sigmoid(grady1)*2)-1
68 | gradx2_s=(torch.sigmoid(gradx2*alphax)*2)-1
69 | grady2_s=(torch.sigmoid(grady2*alphay)*2)-1
70 |
71 | safe_x = torch.mean((gradx1_s ** 2) * (gradx2_s ** 2), dim=(1,2,3)) + epsilon
72 | safe_y = torch.mean((grady1_s ** 2) * (grady2_s ** 2), dim=(1,2,3)) + epsilon
73 | gradx_loss.append(safe_x ** 0.25)
74 | grady_loss.append(safe_y ** 0.25)
75 |
76 | img1 = F.avg_pool2d(img1, kernel_size=2, stride=2)
77 | img2 = F.avg_pool2d(img2, kernel_size=2, stride=2)
78 |
79 | if normalize:
80 | return torch.mean(sum(gradx_loss) / float(level)) + torch.mean(sum(grady_loss) / float(level))
81 | else:
82 | return torch.sum(sum(gradx_loss) / float(level)) + torch.sum(sum(grady_loss) / float(level))
83 |
84 | def sparsity_loss(prediction, trimask, eps=1e-5, gamma=0.9):
85 | mask = trimask > 0.5
86 | pred = prediction[mask]
87 | loss = torch.sum(torch.pow(pred+eps, gamma) + torch.pow(1.-pred+eps, gamma) - 1.)
88 | return loss
89 |
90 | '''
91 | Borrowed from https://gist.github.com/alper111/b9c6d80e2dba1ee0bfac15eb7dad09c8
92 | It directly follows OpenCV's image pyramid implementation pyrDown() and pyrUp().
93 | Reference: https://docs.opencv.org/4.4.0/d4/d86/group__imgproc__filter.html#gaf9bba239dfca11654cb7f50f889fc2ff
94 | '''
95 | class LapLoss(torch.nn.Module):
96 | def __init__(self, max_levels=5):
97 | super(LapLoss, self).__init__()
98 | self.max_levels = max_levels
99 | kernel = torch.tensor([[1., 4., 6., 4., 1],
100 | [4., 16., 24., 16., 4.],
101 | [6., 24., 36., 24., 6.],
102 | [4., 16., 24., 16., 4.],
103 | [1., 4., 6., 4., 1.]])
104 | kernel /= 256.
105 | self.register_buffer('KERNEL', kernel.float())
106 |
107 | def downsample(self, x):
108 | # rejecting even rows and columns
109 | return x[:, :, ::2, ::2]
110 |
111 | def upsample(self, x):
112 | # Padding zeros interleaved in x (similar to unpooling where indices are always at top-left corner)
113 | # Original code only works when x.shape[2] == x.shape[3] because it uses the wrong indice order
114 | # after the first permute
115 | cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
116 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3])
117 | cc = cc.permute(0,1,3,2)
118 | cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2, device=x.device)], dim=3)
119 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2)
120 | x_up = cc.permute(0,1,3,2)
121 | return self.conv_gauss(x_up, 4*self.KERNEL.repeat(x.shape[1], 1, 1, 1))
122 |
123 | def conv_gauss(self, img, kernel):
124 | img = F.pad(img, (2, 2, 2, 2), mode='reflect')
125 | out = F.conv2d(img, kernel, groups=img.shape[1])
126 | return out
127 |
128 | def laplacian_pyramid(self, img):
129 | current = img
130 | pyr = []
131 | for level in range(self.max_levels):
132 | filtered = self.conv_gauss(current, \
133 | self.KERNEL.repeat(img.shape[1], 1, 1, 1))
134 | down = self.downsample(filtered)
135 | up = self.upsample(down)
136 | diff = current-up
137 | pyr.append(diff)
138 | current = down
139 | return pyr
140 |
141 | def forward(self, img, tgt, mask=None, normalize=True):
142 | (img, tgt), pad = self.pad_divide_by([img, tgt], 32, (img.size()[2], img.size()[3]))
143 |
144 | pyr_input = self.laplacian_pyramid(img)
145 | pyr_target = self.laplacian_pyramid(tgt)
146 | loss = sum((2 ** level) * L1_mask(ab[0], ab[1], mask=mask, normalize=False) \
147 | for level, ab in enumerate(zip(pyr_input, pyr_target)))
148 | if normalize:
149 | b,c,h,w = tgt.shape
150 | if mask is not None:
151 | _safe = torch.sum((mask > 1e-6).float()).clamp(epsilon, b*c*h*w+1)
152 | else:
153 | _safe = b*c*h*w
154 | return loss / _safe
155 | return loss
156 |
157 | def pad_divide_by(self, in_list, d, in_size):
158 | out_list = []
159 | h, w = in_size
160 | if h % d > 0:
161 | new_h = h + d - h % d
162 | else:
163 | new_h = h
164 | if w % d > 0:
165 | new_w = w + d - w % d
166 | else:
167 | new_w = w
168 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
169 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
170 | pad_array = (int(lw), int(uw), int(lh), int(uh))
171 | for inp in in_list:
172 | out_list.append(F.pad(inp, pad_array))
173 | return out_list, pad_array
--------------------------------------------------------------------------------
/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer, required
4 |
5 | class RAdam(Optimizer):
6 |
7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False):
8 | if not 0.0 <= lr:
9 | raise ValueError("Invalid learning rate: {}".format(lr))
10 | if not 0.0 <= eps:
11 | raise ValueError("Invalid epsilon value: {}".format(eps))
12 | if not 0.0 <= betas[0] < 1.0:
13 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
14 | if not 0.0 <= betas[1] < 1.0:
15 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
16 |
17 | self.degenerated_to_sgd = degenerated_to_sgd
18 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
19 | for param in params:
20 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
21 | param['buffer'] = [[None, None, None] for _ in range(10)]
22 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
23 | super(RAdam, self).__init__(params, defaults)
24 |
25 | def __setstate__(self, state):
26 | super(RAdam, self).__setstate__(state)
27 |
28 | def step(self, closure=None):
29 |
30 | loss = None
31 | if closure is not None:
32 | loss = closure()
33 |
34 | for group in self.param_groups:
35 |
36 | for p in group['params']:
37 | if p.grad is None:
38 | continue
39 | grad = p.grad.data.float()
40 | if grad.is_sparse:
41 | raise RuntimeError('RAdam does not support sparse gradients')
42 |
43 | p_data_fp32 = p.data.float()
44 |
45 | state = self.state[p]
46 |
47 | if len(state) == 0:
48 | state['step'] = 0
49 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
50 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
51 | else:
52 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
53 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
54 |
55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
56 | beta1, beta2 = group['betas']
57 |
58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
59 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
60 |
61 | state['step'] += 1
62 | buffered = group['buffer'][int(state['step'] % 10)]
63 | if state['step'] == buffered[0]:
64 | N_sma, step_size = buffered[1], buffered[2]
65 | else:
66 | buffered[0] = state['step']
67 | beta2_t = beta2 ** state['step']
68 | N_sma_max = 2 / (1 - beta2) - 1
69 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
70 | buffered[1] = N_sma
71 |
72 | # more conservative since it's an approximated value
73 | if N_sma >= 5:
74 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
75 | elif self.degenerated_to_sgd:
76 | step_size = 1.0 / (1 - beta1 ** state['step'])
77 | else:
78 | step_size = -1
79 | buffered[2] = step_size
80 |
81 | # more conservative since it's an approximated value
82 | if N_sma >= 5:
83 | if group['weight_decay'] != 0:
84 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
85 | denom = exp_avg_sq.sqrt().add_(group['eps'])
86 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
87 | p.data.copy_(p_data_fp32)
88 | elif step_size > 0:
89 | if group['weight_decay'] != 0:
90 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
91 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
92 | p.data.copy_(p_data_fp32)
93 |
94 | return loss
95 |
96 | class PlainRAdam(Optimizer):
97 |
98 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False):
99 | if not 0.0 <= lr:
100 | raise ValueError("Invalid learning rate: {}".format(lr))
101 | if not 0.0 <= eps:
102 | raise ValueError("Invalid epsilon value: {}".format(eps))
103 | if not 0.0 <= betas[0] < 1.0:
104 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
105 | if not 0.0 <= betas[1] < 1.0:
106 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
107 |
108 | self.degenerated_to_sgd = degenerated_to_sgd
109 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
110 |
111 | super(PlainRAdam, self).__init__(params, defaults)
112 |
113 | def __setstate__(self, state):
114 | super(PlainRAdam, self).__setstate__(state)
115 |
116 | def step(self, closure=None):
117 |
118 | loss = None
119 | if closure is not None:
120 | loss = closure()
121 |
122 | for group in self.param_groups:
123 |
124 | for p in group['params']:
125 | if p.grad is None:
126 | continue
127 | grad = p.grad.data.float()
128 | if grad.is_sparse:
129 | raise RuntimeError('RAdam does not support sparse gradients')
130 |
131 | p_data_fp32 = p.data.float()
132 |
133 | state = self.state[p]
134 |
135 | if len(state) == 0:
136 | state['step'] = 0
137 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
138 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
139 | else:
140 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
141 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
142 |
143 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
144 | beta1, beta2 = group['betas']
145 |
146 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
147 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
148 |
149 | state['step'] += 1
150 | beta2_t = beta2 ** state['step']
151 | N_sma_max = 2 / (1 - beta2) - 1
152 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
153 |
154 |
155 | # more conservative since it's an approximated value
156 | if N_sma >= 5:
157 | if group['weight_decay'] != 0:
158 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
159 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
160 | denom = exp_avg_sq.sqrt().add_(group['eps'])
161 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
162 | p.data.copy_(p_data_fp32)
163 | elif self.degenerated_to_sgd:
164 | if group['weight_decay'] != 0:
165 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
166 | step_size = group['lr'] / (1 - beta1 ** state['step'])
167 | p_data_fp32.add_(-step_size, exp_avg)
168 | p.data.copy_(p_data_fp32)
169 |
170 | return loss
171 |
172 |
173 | class AdamW(Optimizer):
174 |
175 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0):
176 | if not 0.0 <= lr:
177 | raise ValueError("Invalid learning rate: {}".format(lr))
178 | if not 0.0 <= eps:
179 | raise ValueError("Invalid epsilon value: {}".format(eps))
180 | if not 0.0 <= betas[0] < 1.0:
181 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
182 | if not 0.0 <= betas[1] < 1.0:
183 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
184 |
185 | defaults = dict(lr=lr, betas=betas, eps=eps,
186 | weight_decay=weight_decay, warmup = warmup)
187 | super(AdamW, self).__init__(params, defaults)
188 |
189 | def __setstate__(self, state):
190 | super(AdamW, self).__setstate__(state)
191 |
192 | def step(self, closure=None):
193 | loss = None
194 | if closure is not None:
195 | loss = closure()
196 |
197 | for group in self.param_groups:
198 |
199 | for p in group['params']:
200 | if p.grad is None:
201 | continue
202 | grad = p.grad.data.float()
203 | if grad.is_sparse:
204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
205 |
206 | p_data_fp32 = p.data.float()
207 |
208 | state = self.state[p]
209 |
210 | if len(state) == 0:
211 | state['step'] = 0
212 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
213 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
214 | else:
215 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
216 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
217 |
218 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
219 | beta1, beta2 = group['betas']
220 |
221 | state['step'] += 1
222 |
223 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
224 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
225 |
226 | denom = exp_avg_sq.sqrt().add_(group['eps'])
227 | bias_correction1 = 1 - beta1 ** state['step']
228 | bias_correction2 = 1 - beta2 ** state['step']
229 |
230 | if group['warmup'] > state['step']:
231 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
232 | else:
233 | scheduled_lr = group['lr']
234 |
235 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
236 |
237 | if group['weight_decay'] != 0:
238 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
239 |
240 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
241 |
242 | p.data.copy_(p_data_fp32)
243 |
244 | return loss
245 |
--------------------------------------------------------------------------------
/utils/tmp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/__init__.py
--------------------------------------------------------------------------------
/utils/tmp/augmentation.py:
--------------------------------------------------------------------------------
1 | import easing_functions as ef
2 | import random
3 | import torch
4 | from torchvision import transforms
5 | from torchvision.transforms import functional as F
6 |
7 |
8 | class MotionAugmentation:
9 | def __init__(self,
10 | size,
11 | prob_fgr_affine,
12 | prob_bgr_affine,
13 | prob_noise,
14 | prob_color_jitter,
15 | prob_grayscale,
16 | prob_sharpness,
17 | prob_blur,
18 | prob_hflip,
19 | prob_pause,
20 | static_affine=True,
21 | aspect_ratio_range=(0.9, 1.1)):
22 | self.size = size
23 | self.prob_fgr_affine = prob_fgr_affine
24 | self.prob_bgr_affine = prob_bgr_affine
25 | self.prob_noise = prob_noise
26 | self.prob_color_jitter = prob_color_jitter
27 | self.prob_grayscale = prob_grayscale
28 | self.prob_sharpness = prob_sharpness
29 | self.prob_blur = prob_blur
30 | self.prob_hflip = prob_hflip
31 | self.prob_pause = prob_pause
32 | self.static_affine = static_affine
33 | self.aspect_ratio_range = aspect_ratio_range
34 |
35 | def __call__(self, fgrs, phas, bgrs):
36 | # Foreground affine
37 | if random.random() < self.prob_fgr_affine:
38 | fgrs, phas = self._motion_affine(fgrs, phas)
39 |
40 | # Background affine
41 | if random.random() < self.prob_bgr_affine / 2:
42 | bgrs = self._motion_affine(bgrs)
43 | if random.random() < self.prob_bgr_affine / 2:
44 | fgrs, phas, bgrs = self._motion_affine(fgrs, phas, bgrs)
45 |
46 | # Still Affine
47 | if self.static_affine:
48 | fgrs, phas = self._static_affine(fgrs, phas, scale_ranges=(0.5, 1))
49 | bgrs = self._static_affine(bgrs, scale_ranges=(1, 1.5))
50 |
51 | # To tensor
52 | fgrs = torch.stack([F.to_tensor(fgr) for fgr in fgrs])
53 | phas = torch.stack([F.to_tensor(pha) for pha in phas])
54 | bgrs = torch.stack([F.to_tensor(bgr) for bgr in bgrs])
55 |
56 | # Resize
57 | params = transforms.RandomResizedCrop.get_params(fgrs, scale=(1, 1), ratio=self.aspect_ratio_range)
58 | fgrs = F.resized_crop(fgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
59 | phas = F.resized_crop(phas, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
60 | params = transforms.RandomResizedCrop.get_params(bgrs, scale=(1, 1), ratio=self.aspect_ratio_range)
61 | bgrs = F.resized_crop(bgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
62 |
63 | # Horizontal flip
64 | if random.random() < self.prob_hflip:
65 | fgrs = F.hflip(fgrs)
66 | phas = F.hflip(phas)
67 | if random.random() < self.prob_hflip:
68 | bgrs = F.hflip(bgrs)
69 |
70 | # Noise
71 | if random.random() < self.prob_noise:
72 | fgrs, bgrs = self._motion_noise(fgrs, bgrs)
73 |
74 | # Color jitter
75 | if random.random() < self.prob_color_jitter:
76 | fgrs = self._motion_color_jitter(fgrs)
77 | if random.random() < self.prob_color_jitter:
78 | bgrs = self._motion_color_jitter(bgrs)
79 |
80 | # Grayscale
81 | if random.random() < self.prob_grayscale:
82 | fgrs = F.rgb_to_grayscale(fgrs, num_output_channels=3).contiguous()
83 | bgrs = F.rgb_to_grayscale(bgrs, num_output_channels=3).contiguous()
84 |
85 | # Sharpen
86 | if random.random() < self.prob_sharpness:
87 | sharpness = random.random() * 8
88 | fgrs = F.adjust_sharpness(fgrs, sharpness)
89 | phas = F.adjust_sharpness(phas, sharpness)
90 | bgrs = F.adjust_sharpness(bgrs, sharpness)
91 |
92 | # Blur
93 | if random.random() < self.prob_blur / 3:
94 | fgrs, phas = self._motion_blur(fgrs, phas)
95 | if random.random() < self.prob_blur / 3:
96 | bgrs = self._motion_blur(bgrs)
97 | if random.random() < self.prob_blur / 3:
98 | fgrs, phas, bgrs = self._motion_blur(fgrs, phas, bgrs)
99 |
100 | # Pause
101 | if random.random() < self.prob_pause:
102 | fgrs, phas, bgrs = self._motion_pause(fgrs, phas, bgrs)
103 |
104 | return fgrs, phas, bgrs
105 |
106 | def _static_affine(self, *imgs, scale_ranges):
107 | params = transforms.RandomAffine.get_params(
108 | degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=scale_ranges,
109 | shears=(-5, 5), img_size=imgs[0][0].size)
110 | imgs = [[F.affine(t, *params, F.InterpolationMode.BILINEAR) for t in img] for img in imgs]
111 | return imgs if len(imgs) > 1 else imgs[0]
112 |
113 | def _motion_affine(self, *imgs):
114 | config = dict(degrees=(-10, 10), translate=(0.1, 0.1),
115 | scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size)
116 | angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config)
117 | angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config)
118 |
119 | T = len(imgs[0])
120 | easing = random_easing_fn()
121 | for t in range(T):
122 | percentage = easing(t / (T - 1))
123 | angle = lerp(angleA, angleB, percentage)
124 | transX = lerp(transXA, transXB, percentage)
125 | transY = lerp(transYA, transYB, percentage)
126 | scale = lerp(scaleA, scaleB, percentage)
127 | shearX = lerp(shearXA, shearXB, percentage)
128 | shearY = lerp(shearYA, shearYB, percentage)
129 | for img in imgs:
130 | img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR)
131 | return imgs if len(imgs) > 1 else imgs[0]
132 |
133 | def _motion_noise(self, *imgs):
134 | grain_size = random.random() * 3 + 1 # range 1 ~ 4
135 | monochrome = random.random() < 0.5
136 | for img in imgs:
137 | T, C, H, W = img.shape
138 | noise = torch.randn((T, 1 if monochrome else C, round(H / grain_size), round(W / grain_size)))
139 | noise.mul_(random.random() * 0.2 / grain_size)
140 | if grain_size != 1:
141 | noise = F.resize(noise, (H, W))
142 | img.add_(noise).clamp_(0, 1)
143 | return imgs if len(imgs) > 1 else imgs[0]
144 |
145 | def _motion_color_jitter(self, *imgs):
146 | brightnessA, brightnessB, contrastA, contrastB, saturationA, saturationB, hueA, hueB \
147 | = torch.randn(8).mul(0.1).tolist()
148 | strength = random.random() * 0.2
149 | easing = random_easing_fn()
150 | T = len(imgs[0])
151 | for t in range(T):
152 | percentage = easing(t / (T - 1)) * strength
153 | for img in imgs:
154 | img[t] = F.adjust_brightness(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1))
155 | img[t] = F.adjust_contrast(img[t], max(1 + lerp(contrastA, contrastB, percentage), 0.1))
156 | img[t] = F.adjust_saturation(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1))
157 | img[t] = F.adjust_hue(img[t], min(0.5, max(-0.5, lerp(hueA, hueB, percentage) * 0.1)))
158 | return imgs if len(imgs) > 1 else imgs[0]
159 |
160 | def _motion_blur(self, *imgs):
161 | blurA = random.random() * 10
162 | blurB = random.random() * 10
163 |
164 | T = len(imgs[0])
165 | easing = random_easing_fn()
166 | for t in range(T):
167 | percentage = easing(t / (T - 1))
168 | blur = max(lerp(blurA, blurB, percentage), 0)
169 | if blur != 0:
170 | kernel_size = int(blur * 2)
171 | if kernel_size % 2 == 0:
172 | kernel_size += 1 # Make kernel_size odd
173 | for img in imgs:
174 | img[t] = F.gaussian_blur(img[t], kernel_size, sigma=blur)
175 |
176 | return imgs if len(imgs) > 1 else imgs[0]
177 |
178 | def _motion_pause(self, *imgs):
179 | T = len(imgs[0])
180 | pause_frame = random.choice(range(T - 1))
181 | pause_length = random.choice(range(T - pause_frame))
182 | for img in imgs:
183 | img[pause_frame + 1 : pause_frame + pause_length] = img[pause_frame]
184 | return imgs if len(imgs) > 1 else imgs[0]
185 |
186 |
187 | def lerp(a, b, percentage):
188 | return a * (1 - percentage) + b * percentage
189 |
190 |
191 | def random_easing_fn():
192 | if random.random() < 0.2:
193 | return ef.LinearInOut()
194 | else:
195 | return random.choice([
196 | ef.BackEaseIn,
197 | ef.BackEaseOut,
198 | ef.BackEaseInOut,
199 | ef.BounceEaseIn,
200 | ef.BounceEaseOut,
201 | ef.BounceEaseInOut,
202 | ef.CircularEaseIn,
203 | ef.CircularEaseOut,
204 | ef.CircularEaseInOut,
205 | ef.CubicEaseIn,
206 | ef.CubicEaseOut,
207 | ef.CubicEaseInOut,
208 | ef.ExponentialEaseIn,
209 | ef.ExponentialEaseOut,
210 | ef.ExponentialEaseInOut,
211 | ef.ElasticEaseIn,
212 | ef.ElasticEaseOut,
213 | ef.ElasticEaseInOut,
214 | ef.QuadEaseIn,
215 | ef.QuadEaseOut,
216 | ef.QuadEaseInOut,
217 | ef.QuarticEaseIn,
218 | ef.QuarticEaseOut,
219 | ef.QuarticEaseInOut,
220 | ef.QuinticEaseIn,
221 | ef.QuinticEaseOut,
222 | ef.QuinticEaseInOut,
223 | ef.SineEaseIn,
224 | ef.SineEaseOut,
225 | ef.SineEaseInOut,
226 | Step,
227 | ])()
228 |
229 | class Step: # Custom easing function for sudden change.
230 | def __call__(self, value):
231 | return 0 if value < 0.5 else 1
232 |
233 |
234 | # ---------------------------- Frame Sampler ----------------------------
235 |
236 |
237 | class TrainFrameSampler:
238 | def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]):
239 | self.speed = speed
240 |
241 | def __call__(self, seq_length):
242 | frames = list(range(seq_length))
243 |
244 | # Speed up
245 | speed = random.choice(self.speed)
246 | frames = [int(f * speed) for f in frames]
247 |
248 | # Shift
249 | shift = random.choice(range(seq_length))
250 | frames = [f + shift for f in frames]
251 |
252 | # Reverse
253 | if random.random() < 0.5:
254 | frames = frames[::-1]
255 |
256 | return frames
257 |
258 | class ValidFrameSampler:
259 | def __call__(self, seq_length):
260 | return range(seq_length)
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "3.5"
4 | - "3.6"
5 | install:
6 | - pip install -r requirements.txt
7 | script:
8 | - pytest
9 |
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Marco Forte
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/README.md:
--------------------------------------------------------------------------------
1 | # Closed-Form Matting
2 | [](https://travis-ci.org/MarcoForte/closed-form-matting)
3 |
4 |
5 | Python implementation of image matting method proposed in A. Levin D. Lischinski and Y. Weiss. A Closed Form Solution to Natural Image Matting. IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), June 2006, New York
6 |
7 | The repository also contains implementation of background/foreground reconstruction method proposed in Levin, Anat, Dani Lischinski, and Yair Weiss. "A closed-form solution to natural image matting." IEEE Transactions on Pattern Analysis and Machine Intelligence 30.2 (2008): 228-242.
8 |
9 | ## Requirements
10 | - python 3.5+ (Though it should run on 2.7)
11 | - scipy
12 | - numpy
13 | - opencv-python
14 |
15 | ## Installation
16 |
17 | Clone this repository and install the closed-form-matting package via pip.
18 |
19 | ```bash
20 | git clone https://github.com/MarcoForte/closed-form-matting.git
21 | cd closed-form-matting/
22 | pip install .
23 | ```
24 |
25 | ## Usage
26 |
27 | ### Closed-Form matting
28 | CLI inerface:
29 |
30 | ```bash
31 | # Scribbles input
32 | closed-form-matting ./testdata/source.png -s ./testdata/scribbles.png -o output_alpha.png
33 |
34 | # Trimap input
35 | closed-form-matting ./testdata/source.png -t ./testdata/trimap.png -o output_alpha.png
36 |
37 | # Add flag --solve-fg to compute foreground color and output RGBA image instead
38 | # of alpha.
39 | ```
40 |
41 |
42 | Python interface:
43 |
44 | ```python
45 | import closed_form_matting
46 | ...
47 | # For scribles input
48 | alpha = closed_form_matting.closed_form_matting_with_scribbles(image, scribbles)
49 |
50 | # For trimap input
51 | alpha = closed_form_matting.closed_form_matting_with_trimap(image, trimap)
52 |
53 | # For prior with confidence
54 | alpha = closed_form_matting.closed_form_matting_with_prior(
55 | image, prior, prior_confidence, optional_const_mask)
56 |
57 | # To get Matting Laplacian for image
58 | laplacian = closed_form_matting.compute_laplacian(image, optional_const_mask)
59 | ```
60 |
61 | ### Foreground and Background Reconstruction
62 | CLI interface (requires opencv-python):
63 |
64 | ```bash
65 | solve-foreground-background image.png alpha.png foreground.png background.png
66 | ```
67 |
68 | Python interface:
69 |
70 | ```python
71 | from closed_form_matting import solve_foreground_background
72 | ...
73 | foreground, background = solve_foreground_background(image, alpha)
74 | ```
75 |
76 | ## Results
77 | | Original image | Scribbled image | Output alpha | Output foreground |
78 | |------------------|-----------------|--------------|-------------------|
79 | |  |  |  |  |
80 |
81 |
82 | ## More Information
83 | The computation is generally faster than the matlab version thanks to more vectorization.
84 | Note. The computed laplacian is slightly different due to array ordering in numpy being different than in matlab. To get same laplacian as in matlab change,
85 |
86 | `indsM = np.arange(h*w).reshape((h, w))`
87 | `ravelImg = img.reshape(h*w, d)`
88 | to
89 | `indsM = np.arange(h*w).reshape((h, w), order='F')`
90 | `ravelImg = img.reshape(h*w, d, , order='F')`.
91 | Again note that this will result in incorrect alpha if the `D_s, b_s` orderings are not also changed to `order='F'F`.
92 |
93 | For more information see the original paper http://www.wisdom.weizmann.ac.il/~levina/papers/Matting-Levin-Lischinski-Weiss-CVPR06.pdf
94 | The original matlab code is here http://www.wisdom.weizmann.ac.il/~levina/matting.tar.gz
95 |
96 | ## Disclaimer
97 |
98 | The code is free for academic/research purpose. Use at your own risk and we are not responsible for any loss resulting from this code. Feel free to submit pull request for bug fixes.
99 |
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/closed_form_matting/__init__.py:
--------------------------------------------------------------------------------
1 | # #!/usr/bin/env python
2 | # """Init script when importing closed-form-matting package"""
3 |
4 | # from closed_form_matting.closed_form_matting import (
5 | # compute_laplacian,
6 | # closed_form_matting_with_prior,
7 | # closed_form_matting_with_trimap,
8 | # closed_form_matting_with_scribbles,
9 | # )
10 | # from closed_form_matting.solve_foreground_background import (
11 | # solve_foreground_background
12 | # )
13 |
14 | # __version__ = '1.0.0'
15 | # __all__ = [
16 | # 'compute_laplacian',
17 | # 'closed_form_matting_with_prior',
18 | # 'closed_form_matting_with_trimap',
19 | # 'closed_form_matting_with_scribbles',
20 | # 'solve_foreground_background',
21 | # ]
22 |
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/closed_form_matting/closed_form_matting.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Implementation of Closed-Form Matting.
3 |
4 | This module implements natural image matting method described in:
5 | Levin, Anat, Dani Lischinski, and Yair Weiss. "A closed-form solution to natural image matting."
6 | IEEE Transactions on Pattern Analysis and Machine Intelligence 30.2 (2008): 228-242.
7 |
8 | The code can be used in two ways:
9 | 1. By importing solve_foregound_background in your code:
10 | ```
11 | import closed_form_matting
12 | ...
13 | # For scribles input
14 | alpha = closed_form_matting.closed_form_matting_with_scribbles(image, scribbles)
15 |
16 | # For trimap input
17 | alpha = closed_form_matting.closed_form_matting_with_trimap(image, trimap)
18 |
19 | # For prior with confidence
20 | alpha = closed_form_matting.closed_form_matting_with_prior(
21 | image, prior, prior_confidence, optional_const_mask)
22 |
23 | # To get Matting Laplacian for image
24 | laplacian = compute_laplacian(image, optional_const_mask)
25 | ```
26 | 2. From command line:
27 | ```
28 | # Scribbles input
29 | ./closed_form_matting.py input_image.png -s scribbles_image.png -o output_alpha.png
30 |
31 | # Trimap input
32 | ./closed_form_matting.py input_image.png -t scribbles_image.png -o output_alpha.png
33 |
34 | # Add flag --solve-fg to compute foreground color and output RGBA image instead
35 | # of alpha.
36 | ```
37 | """
38 |
39 | from __future__ import division
40 |
41 | import logging
42 |
43 | import cv2
44 | import numpy as np
45 | from numpy.lib.stride_tricks import as_strided
46 | import scipy.sparse
47 | import scipy.sparse.linalg
48 |
49 |
50 | def _rolling_block(A, block=(3, 3)):
51 | """Applies sliding window to given matrix."""
52 | shape = (A.shape[0] - block[0] + 1, A.shape[1] - block[1] + 1) + block
53 | strides = (A.strides[0], A.strides[1]) + A.strides
54 | return as_strided(A, shape=shape, strides=strides)
55 |
56 |
57 | def compute_laplacian(img, mask=None, eps=10**(-7), win_rad=1):
58 | """Computes Matting Laplacian for a given image.
59 |
60 | Args:
61 | img: 3-dim numpy matrix with input image
62 | mask: mask of pixels for which Laplacian will be computed.
63 | If not set Laplacian will be computed for all pixels.
64 | eps: regularization parameter controlling alpha smoothness
65 | from Eq. 12 of the original paper. Defaults to 1e-7.
66 | win_rad: radius of window used to build Matting Laplacian (i.e.
67 | radius of omega_k in Eq. 12).
68 | Returns: sparse matrix holding Matting Laplacian.
69 | """
70 |
71 | win_size = (win_rad * 2 + 1) ** 2
72 | h, w, d = img.shape
73 | # Number of window centre indices in h, w axes
74 | c_h, c_w = h - 2 * win_rad, w - 2 * win_rad
75 | win_diam = win_rad * 2 + 1
76 |
77 | indsM = np.arange(h * w).reshape((h, w))
78 | ravelImg = img.reshape(h * w, d)
79 | win_inds = _rolling_block(indsM, block=(win_diam, win_diam))
80 |
81 | win_inds = win_inds.reshape(c_h, c_w, win_size)
82 | if mask is not None:
83 | mask = cv2.dilate(
84 | mask.astype(np.uint8),
85 | np.ones((win_diam, win_diam), np.uint8)
86 | ).astype(np.bool)
87 | win_mask = np.sum(mask.ravel()[win_inds], axis=2)
88 | win_inds = win_inds[win_mask > 0, :]
89 | else:
90 | win_inds = win_inds.reshape(-1, win_size)
91 |
92 |
93 | winI = ravelImg[win_inds]
94 |
95 | win_mu = np.mean(winI, axis=1, keepdims=True)
96 | win_var = np.einsum('...ji,...jk ->...ik', winI, winI) / win_size - np.einsum('...ji,...jk ->...ik', win_mu, win_mu)
97 |
98 | inv = np.linalg.inv(win_var + (eps/win_size)*np.eye(3))
99 |
100 | X = np.einsum('...ij,...jk->...ik', winI - win_mu, inv)
101 | vals = np.eye(win_size) - (1.0/win_size)*(1 + np.einsum('...ij,...kj->...ik', X, winI - win_mu))
102 |
103 | nz_indsCol = np.tile(win_inds, win_size).ravel()
104 | nz_indsRow = np.repeat(win_inds, win_size).ravel()
105 | nz_indsVal = vals.ravel()
106 | L = scipy.sparse.coo_matrix((nz_indsVal, (nz_indsRow, nz_indsCol)), shape=(h*w, h*w))
107 | return L
108 |
109 |
110 | def closed_form_matting_with_prior(image, prior, prior_confidence, consts_map=None):
111 | """Applies closed form matting with prior alpha map to image.
112 |
113 | Args:
114 | image: 3-dim numpy matrix with input image.
115 | prior: matrix of same width and height as input image holding apriori alpha map.
116 | prior_confidence: matrix of the same shape as prior hodling confidence of prior alpha.
117 | consts_map: binary mask of pixels that aren't expected to change due to high
118 | prior confidence.
119 |
120 | Returns: 2-dim matrix holding computed alpha map.
121 | """
122 |
123 | assert image.shape[:2] == prior.shape, ('prior must be 2D matrix with height and width equal '
124 | 'to image.')
125 | assert image.shape[:2] == prior_confidence.shape, ('prior_confidence must be 2D matrix with '
126 | 'height and width equal to image.')
127 | assert (consts_map is None) or image.shape[:2] == consts_map.shape, (
128 | 'consts_map must be 2D matrix with height and width equal to image.')
129 |
130 | logging.info('Computing Matting Laplacian.')
131 | laplacian = compute_laplacian(image, ~consts_map if consts_map is not None else None)
132 |
133 | confidence = scipy.sparse.diags(prior_confidence.ravel())
134 | logging.info('Solving for alpha.')
135 | solution = scipy.sparse.linalg.spsolve(
136 | laplacian + confidence,
137 | prior.ravel() * prior_confidence.ravel()
138 | )
139 | alpha = np.minimum(np.maximum(solution.reshape(prior.shape), 0), 1)
140 | return alpha
141 |
142 |
143 | def closed_form_matting_with_trimap(image, trimap, trimap_confidence=100.0):
144 | """Apply Closed-Form matting to given image using trimap."""
145 |
146 | assert image.shape[:2] == trimap.shape, ('trimap must be 2D matrix with height and width equal '
147 | 'to image.')
148 | consts_map = (trimap < 0.1) | (trimap > 0.9)
149 | return closed_form_matting_with_prior(image, trimap, trimap_confidence * consts_map, consts_map)
150 |
151 |
152 | def closed_form_matting_with_scribbles(image, scribbles, scribbles_confidence=100.0):
153 | """Apply Closed-Form matting to given image using scribbles image."""
154 |
155 | assert image.shape == scribbles.shape, 'scribbles must have exactly same shape as image.'
156 | prior = np.sign(np.sum(scribbles - image, axis=2)) / 2 + 0.5
157 | consts_map = prior != 0.5
158 | return closed_form_matting_with_prior(
159 | image,
160 | prior,
161 | scribbles_confidence * consts_map,
162 | consts_map
163 | )
164 |
165 |
166 | closed_form_matting = closed_form_matting_with_trimap
167 |
168 | def main():
169 | import argparse
170 |
171 | logging.basicConfig(level=logging.INFO)
172 | arg_parser = argparse.ArgumentParser(description=__doc__)
173 | arg_parser.add_argument('image', type=str, help='input image')
174 |
175 | arg_parser.add_argument('-t', '--trimap', type=str, help='input trimap')
176 | arg_parser.add_argument('-s', '--scribbles', type=str, help='input scribbles')
177 | arg_parser.add_argument('-o', '--output', type=str, required=True, help='output image')
178 | arg_parser.add_argument(
179 | '--solve-fg', dest='solve_fg', action='store_true',
180 | help='compute foreground color and output RGBA image'
181 | )
182 | args = arg_parser.parse_args()
183 |
184 | image = cv2.imread(args.image, cv2.IMREAD_COLOR) / 255.0
185 |
186 | if args.scribbles:
187 | scribbles = cv2.imread(args.scribbles, cv2.IMREAD_COLOR) / 255.0
188 | alpha = closed_form_matting_with_scribbles(image, scribbles)
189 | elif args.trimap:
190 | trimap = cv2.imread(args.trimap, cv2.IMREAD_GRAYSCALE) / 255.0
191 | alpha = closed_form_matting_with_trimap(image, trimap)
192 | else:
193 | logging.error('Either trimap or scribbles must be specified.')
194 | arg_parser.print_help()
195 | exit(-1)
196 |
197 | if args.solve_fg:
198 | from closed_form_matting.solve_foreground_background import solve_foreground_background
199 | foreground, _ = solve_foreground_background(image, alpha)
200 | output = np.concatenate((foreground, alpha[:, :, np.newaxis]), axis=2)
201 | else:
202 | output = alpha
203 |
204 | cv2.imwrite(args.output, output * 255.0)
205 |
206 |
207 | if __name__ == "__main__":
208 | main()
209 |
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/closed_form_matting/solve_foreground_background.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Computes foreground and background images given source image and transparency map.
3 |
4 | This module implements foreground and background reconstruction method described in Section 7 of:
5 | Levin, Anat, Dani Lischinski, and Yair Weiss. "A closed-form solution to natural image
6 | matting." IEEE Transactions on Pattern Analysis and Machine Intelligence 30.2 (2008): 228-242.
7 |
8 | Please note, that the cost-function optimized by this code doesn't perfectly match Eq. 19 of the
9 | paper, since our implementation mimics `solveFB.m` Matlab function provided by the authors of the
10 | original paper (this implementation is
11 | availale at http://people.csail.mit.edu/alevin/matting.tar.gz).
12 |
13 | The code can be used in two ways:
14 | 1. By importing solve_foregound_background in your code:
15 | ```
16 | from solve_foregound_background import solve_foregound_background
17 | ...
18 | foreground, background = solve_foregound_background(image, alpha)
19 | ```
20 | 2. From command line:
21 | ```
22 | ./solve_foregound_background.py image.png alpha.png foreground.png background.png
23 | ```
24 |
25 | Authors: Mikhail Erofeev, Yury Gitman.
26 | """
27 |
28 | import numpy as np
29 | import scipy.sparse
30 | import scipy.sparse.linalg
31 |
32 | # CONST_ALPHA_MARGIN = 0.02
33 | CONST_ALPHA_MARGIN = 0.
34 |
35 |
36 | def __spdiagonal(diag):
37 | """Produces sparse matrix with given vector on its main diagonal."""
38 | return scipy.sparse.spdiags(diag, (0,), len(diag), len(diag))
39 |
40 |
41 | def get_grad_operator(mask):
42 | """Returns sparse matrix computing horizontal, vertical, and two diagonal gradients."""
43 | horizontal_left = np.ravel_multi_index(np.nonzero(mask[:, :-1] | mask[:, 1:]), mask.shape)
44 | horizontal_right = horizontal_left + 1
45 |
46 | vertical_top = np.ravel_multi_index(np.nonzero(mask[:-1, :] | mask[1:, :]), mask.shape)
47 | vertical_bottom = vertical_top + mask.shape[1]
48 |
49 | diag_main_1 = np.ravel_multi_index(np.nonzero(mask[:-1, :-1] | mask[1:, 1:]), mask.shape)
50 | diag_main_2 = diag_main_1 + mask.shape[1] + 1
51 |
52 | diag_sub_1 = np.ravel_multi_index(np.nonzero(mask[:-1, 1:] | mask[1:, :-1]), mask.shape) + 1
53 | diag_sub_2 = diag_sub_1 + mask.shape[1] - 1
54 |
55 | indices = np.stack((
56 | np.concatenate((horizontal_left, vertical_top, diag_main_1, diag_sub_1)),
57 | np.concatenate((horizontal_right, vertical_bottom, diag_main_2, diag_sub_2))
58 | ), axis=-1)
59 | return scipy.sparse.coo_matrix(
60 | (np.tile([-1, 1], len(indices)), (np.arange(indices.size) // 2, indices.flatten())),
61 | shape=(len(indices), mask.size))
62 |
63 |
64 | def get_const_conditions(image, alpha):
65 | """Returns sparse diagonal matrix and vector encoding color prior conditions."""
66 | falpha = alpha.flatten()
67 | weights = (
68 | (falpha < CONST_ALPHA_MARGIN) * 100.0 +
69 | 0.03 * (1.0 - falpha) * (falpha < 0.3) +
70 | 0.01 * (falpha > 1.0 - CONST_ALPHA_MARGIN)
71 | )
72 | conditions = __spdiagonal(weights)
73 |
74 | mask = falpha < 1.0 - CONST_ALPHA_MARGIN
75 | right_hand = (weights * mask)[:, np.newaxis] * image.reshape((alpha.size, -1))
76 | return conditions, right_hand
77 |
78 |
79 | def solve_foreground_background(image, alpha):
80 | """Compute foreground and background image given source image and transparency map."""
81 |
82 | consts = (alpha < CONST_ALPHA_MARGIN) | (alpha > 1.0 - CONST_ALPHA_MARGIN)
83 | grad = get_grad_operator(~consts)
84 | grad_weights = np.power(np.abs(grad * alpha.flatten()), 0.5)
85 |
86 | grad_only_positive = grad.maximum(0)
87 | grad_weights_f = grad_weights + 0.003 * grad_only_positive * (1.0 - alpha.flatten())
88 | grad_weights_b = grad_weights + 0.003 * grad_only_positive * alpha.flatten()
89 |
90 | grad_pad = scipy.sparse.coo_matrix(grad.shape)
91 |
92 | smoothness_conditions = scipy.sparse.vstack((
93 | scipy.sparse.hstack((__spdiagonal(grad_weights_f) * grad, grad_pad)),
94 | scipy.sparse.hstack((grad_pad, __spdiagonal(grad_weights_b) * grad))
95 | ))
96 |
97 | composite_conditions = scipy.sparse.hstack((
98 | __spdiagonal(alpha.flatten()),
99 | __spdiagonal(1.0 - alpha.flatten())
100 | ))
101 |
102 | const_conditions_f, b_const_f = get_const_conditions(image, 1.0 - alpha)
103 | const_conditions_b, b_const_b = get_const_conditions(image, alpha)
104 |
105 | non_zero_conditions = scipy.sparse.vstack((
106 | composite_conditions,
107 | scipy.sparse.hstack((
108 | const_conditions_f,
109 | scipy.sparse.coo_matrix(const_conditions_f.shape)
110 | )),
111 | scipy.sparse.hstack((
112 | scipy.sparse.coo_matrix(const_conditions_b.shape),
113 | const_conditions_b
114 | ))
115 | ))
116 |
117 | b_composite = image.reshape(alpha.size, -1)
118 |
119 | right_hand = non_zero_conditions.transpose() * np.concatenate((b_composite,
120 | b_const_f,
121 | b_const_b))
122 |
123 | conditons = scipy.sparse.vstack((
124 | non_zero_conditions,
125 | smoothness_conditions
126 | ))
127 | left_hand = conditons.transpose() * conditons
128 |
129 | solution = scipy.sparse.linalg.spsolve(left_hand, right_hand).reshape(2, *image.shape)
130 | foreground = solution[0, :, :, :].reshape(*image.shape)
131 | background = solution[1, :, :, :].reshape(*image.shape)
132 | return foreground, background
133 |
134 |
135 | def main():
136 | """Parse command line arguments and apply solve_foregound_background."""
137 |
138 | import argparse
139 | import cv2
140 | arg_parser = argparse.ArgumentParser(description=__doc__)
141 | arg_parser.add_argument('image', type=str)
142 | arg_parser.add_argument('alpha', type=str)
143 | arg_parser.add_argument('foreground', type=str)
144 | arg_parser.add_argument('background', type=str, default=None, nargs='?')
145 | args = arg_parser.parse_args()
146 |
147 | image = cv2.imread(args.image) / 255.0
148 | alpha = cv2.imread(args.alpha, 0) / 255.0
149 | foreground, background = solve_foreground_background(image, alpha)
150 | cv2.imwrite(args.foreground, foreground * 255.0)
151 | if args.background:
152 | cv2.imwrite(args.background, background * 255.0)
153 |
154 |
155 | if __name__ == "__main__":
156 | main()
157 |
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scipy
3 | opencv_python
4 |
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Setting up closed-form-matting package during pip installation"""
3 |
4 | import os
5 | import re
6 |
7 | import setuptools
8 |
9 | # Project root directory
10 | root_dir = os.path.dirname(__file__)
11 |
12 | # Get version string from __init__.py in the package
13 | with open(os.path.join(root_dir, 'closed_form_matting', '__init__.py')) as f:
14 | version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1)
15 |
16 | # Get dependency list from requirements.txt
17 | with open(os.path.join(root_dir, 'requirements.txt')) as f:
18 | requirements = f.read().split()
19 |
20 | setuptools.setup(
21 | name='closed-form-matting',
22 | version=version,
23 | author='Marco Forte',
24 | author_email='fortemarco.irl@gmail.com',
25 | maintainer='Marco Forte',
26 | maintainer_email='fortemarco.irl@gmail.com',
27 | url='https://github.com/MarcoForte/closed-form-matting',
28 | description='A closed-form solution to natural image matting',
29 | long_description=open(os.path.join(root_dir, 'README.md')).read(),
30 | long_description_content_type='text/markdown',
31 | packages=setuptools.find_packages(),
32 | classifiers=[
33 | 'Development Status :: 4 - Beta',
34 | 'Intended Audience :: Science/Research',
35 | 'Topic :: Scientific/Engineering :: Image Processing',
36 | 'License :: OSI Approved :: MIT License',
37 | 'Programming Language :: Python :: 3',
38 | ],
39 | keywords=['closed-form matting', 'image matting', 'image processing'],
40 | license='MIT',
41 | python_requires='>=3.5',
42 | install_requires=requirements,
43 | entry_points={
44 | 'console_scripts': [
45 | 'closed-form-matting=closed_form_matting.closed_form_matting:main',
46 | 'solve-foreground-background=closed_form_matting.solve_foreground_background:main',
47 | ],
48 | },
49 | )
50 |
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/test_matting.py:
--------------------------------------------------------------------------------
1 | """Tests for Closed-Form matting and foreground/background solver."""
2 | import unittest
3 |
4 | import cv2
5 | import numpy as np
6 |
7 | import closed_form_matting
8 |
9 | class TestMatting(unittest.TestCase):
10 | def test_solution_close_to_original_implementation(self):
11 | image = cv2.imread('testdata/source.png', cv2.IMREAD_COLOR) / 255.0
12 | scribles = cv2.imread('testdata/scribbles.png', cv2.IMREAD_COLOR) / 255.0
13 |
14 | alpha = closed_form_matting.closed_form_matting_with_scribbles(image, scribles)
15 | foreground, background = closed_form_matting.solve_foreground_background(image, alpha)
16 |
17 | matlab_alpha = cv2.imread('testdata/matlab_alpha.png', cv2.IMREAD_GRAYSCALE) / 255.0
18 | matlab_foreground = cv2.imread('testdata/matlab_foreground.png', cv2.IMREAD_COLOR) / 255.0
19 | matlab_background = cv2.imread('testdata/matlab_background.png', cv2.IMREAD_COLOR) / 255.0
20 |
21 | sad_alpha = np.mean(np.abs(alpha - matlab_alpha))
22 | sad_foreground = np.mean(np.abs(foreground - matlab_foreground))
23 | sad_background = np.mean(np.abs(background - matlab_background))
24 |
25 | self.assertLess(sad_alpha, 1e-2)
26 | self.assertLess(sad_foreground, 1e-2)
27 | self.assertLess(sad_background, 1e-2)
28 |
29 | def test_matting_with_trimap(self):
30 | image = cv2.imread('testdata/source.png', cv2.IMREAD_COLOR) / 255.0
31 | trimap = cv2.imread('testdata/trimap.png', cv2.IMREAD_GRAYSCALE) / 255.0
32 |
33 | alpha = closed_form_matting.closed_form_matting_with_trimap(image, trimap)
34 |
35 | reference_alpha = cv2.imread('testdata/output_alpha.png', cv2.IMREAD_GRAYSCALE) / 255.0
36 |
37 | sad_alpha = np.mean(np.abs(alpha - reference_alpha))
38 | self.assertLess(sad_alpha, 1e-3)
39 |
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/testdata/matlab_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/matlab_alpha.png
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/testdata/matlab_background.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/matlab_background.png
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/testdata/matlab_foreground.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/matlab_foreground.png
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/testdata/output_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/output_alpha.png
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/testdata/output_background.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/output_background.png
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/testdata/output_foreground.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/output_foreground.png
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/testdata/scribbles.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/scribbles.png
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/testdata/source.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/source.png
--------------------------------------------------------------------------------
/utils/tmp/closed_form_matting/testdata/trimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hongje/OTVM/accc0fcc5b25c1ea4a343e094e3c33f2a3aade72/utils/tmp/closed_form_matting/testdata/trimap.png
--------------------------------------------------------------------------------
/utils/tmp/group_weight.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.models_TCVOM.FBA import layers_WS
4 |
5 | def group_weight(module, lr_encoder, lr_decoder, WD):
6 | group_decay = { 'encoder': [], 'decoder':[]}
7 | group_bias = { 'encoder': [], 'decoder':[]}
8 | group_GN = { 'encoder': [], 'decoder':[]}
9 |
10 |
11 | for name, m in module.named_modules():
12 | # if hasattr(m, 'requires_grad'):
13 | # if m.requires_grad:
14 | # continue
15 |
16 | part = 'decoder'
17 | if('encoder' in name):
18 | part = 'encoder'
19 |
20 | if isinstance(m, nn.Linear):
21 | group_decay[part].append(m.weight)
22 | if m.bias is not None:
23 | group_bias[part].append(m.bias)
24 |
25 | elif isinstance(m, nn.Conv2d) and m.weight.requires_grad:
26 | group_decay[part].append(m.weight)
27 | if m.bias is not None:
28 | group_bias[part].append(m.bias)
29 | elif isinstance(m, layers_WS.Conv2d) and m.weight.requires_grad:
30 | group_decay[part].append(m.weight)
31 | if m.bias is not None:
32 | group_bias[part].append(m.bias)
33 |
34 | elif isinstance(m, nn.GroupNorm):
35 | if m.weight is not None:
36 | group_GN[part].append(m.weight)
37 | if m.bias is not None:
38 | group_GN[part].append(m.bias)
39 |
40 |
41 | print(len(list(module.parameters())), len(group_decay['encoder']) + len(group_bias['encoder']) + len(group_GN['encoder']) + len(group_decay['decoder']) + len(group_bias['decoder']) + len(group_GN['decoder']) , len(list(module.modules())))
42 | # assert len(list(module.parameters())) == len(group_decay) + len(group_bias) + len(group_GN)
43 | groups = [dict(params=group_decay['decoder'], lr =lr_decoder, weight_decay=WD), dict(params=group_bias['decoder'], lr=2*lr_decoder, weight_decay=0.0), dict(params=group_GN['decoder'], lr=lr_decoder, weight_decay=1e-5),
44 | dict(params=group_decay['encoder'], lr=lr_encoder, weight_decay=WD), dict(params=group_bias['encoder'], lr=2*lr_encoder, weight_decay=0.0), dict(params=group_GN['encoder'], lr=lr_encoder, weight_decay=1e-5)]
45 |
46 | # groups= [dict(params=module.decoder.conv_pred.parameters(), lr=lr, weight_decay=0.0)]
47 | return groups
--------------------------------------------------------------------------------
/utils/tmp/metric.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import math
7 | import time
8 | import skimage.measure
9 |
10 | from PIL import Image
11 | from scipy import ndimage
12 | from scipy.ndimage.morphology import distance_transform_edt
13 | from multiprocessing import Pool
14 |
15 |
16 | def findMaxConnectedRegion(x):
17 | assert len(x.shape) == 2
18 | cc, num = skimage.measure.label(x, connectivity=1, return_num=True)
19 | omega = np.zeros_like(x)
20 | if num > 0:
21 | # find the largest connected region
22 | max_id = np.argmax(np.bincount(cc.flatten())[1:]) + 1
23 | omega[cc == max_id] = 1
24 | return omega
25 |
26 | def genGaussKernel(sigma, q=2):
27 | pi = math.pi
28 | eps = 1e-2
29 |
30 | def gauss(x, sigma):
31 | return np.exp(-np.power(x,2)/(2*np.power(sigma,2))) / (sigma*np.sqrt(2*pi))
32 |
33 | def dgauss(x, sigma):
34 | return -x * gauss(x,sigma) / np.power(sigma, 2)
35 |
36 | hsize = int(np.ceil(sigma*np.sqrt(-2*np.log(np.sqrt(2*pi)*sigma*eps))))
37 | size = 2 * hsize + 1
38 | hx = np.zeros([size, size], dtype=np.float32)
39 | for i in range(size):
40 | for j in range(size):
41 | u, v = i-hsize, j-hsize
42 | hx[i,j] = gauss(u,sigma) * dgauss(v,sigma)
43 |
44 | hx = hx / np.sqrt(np.sum(np.power(np.abs(hx), 2)))
45 | hy = hx.transpose(1, 0)
46 | return hx, hy, size
47 |
48 | def calcOpticalFlow(frames):
49 | prev, curr = frames
50 | flow = cv2.calcOpticalFlowFarneback(prev.astype(np.uint8), curr.astype(np.uint8), None,
51 | 0.5, 5, 10, 2, 7, 1.5,
52 | cv2.OPTFLOW_FARNEBACK_GAUSSIAN)
53 | return flow
54 |
55 |
56 | class ImageFilter(nn.Module):
57 | def __init__(self, chn, kernel_size, weight, device):
58 | super(ImageFilter, self).__init__()
59 | self.kernel_size = kernel_size
60 | assert kernel_size == weight.size(-1)
61 | self.filter = nn.Conv2d(chn, chn, kernel_size, padding=0, bias=False)
62 | self.filter.weight = nn.Parameter(weight)
63 | self.device = device
64 |
65 | def pad(self, x):
66 | assert len(x.shape) == 3
67 | x = x.unsqueeze(-1).permute((0,3,1,2))
68 | b, c, h, w = x.shape
69 | pad = self.kernel_size // 2
70 | y = torch.zeros([b, c, h+pad*2, w+pad*2]).to(self.device)
71 | y[:,:,0:pad,0:pad] = x[:,:,0:1,0:1].repeat(1,1,pad,pad)
72 | y[:,:,0:pad,w+pad:] = x[:,:,0:1,-1:].repeat(1,1,pad,pad)
73 | y[:,:,h+pad:,0:pad] = x[:,:,-1:,0:1].repeat(1,1,pad,pad)
74 | y[:,:,h+pad:,w+pad:] = x[:,:,-1:,-1:].repeat(1,1,pad,pad)
75 |
76 | y[:,:,0:pad,pad:w+pad] = x[:,:,0:1,:].repeat(1,1,pad,1)
77 | y[:,:,pad:h+pad,0:pad] = x[:,:,:,0:1].repeat(1,1,1,pad)
78 | y[:,:,h+pad:,pad:w+pad] = x[:,:,-1:,:].repeat(1,1,pad,1)
79 | y[:,:,pad:h+pad,w+pad:] = x[:,:,:,-1:].repeat(1,1,1,pad)
80 |
81 | y[:,:,pad:h+pad, pad:w+pad] = x
82 | return y
83 |
84 | def forward(self, x):
85 | y = self.filter(self.pad(x))
86 | return y
87 |
88 |
89 | class BatchMetric(object):
90 | def __init__(self, device, grad_sigma=1.4, grad_q=2,
91 | conn_step=0.1, conn_thresh=0.5, conn_theta=0.15, conn_p=1):
92 | # parameters for connectivity
93 | self.conn_step = conn_step
94 | self.conn_thresh = conn_thresh
95 | self.conn_theta = conn_theta
96 | self.conn_p = conn_p
97 | self.device = device
98 |
99 | hx, hy, size = genGaussKernel(grad_sigma, grad_q)
100 | self.hx = hx
101 | self.hy = hy
102 | self.kernel_size = size
103 | kx = self.hx[::-1, ::-1].copy()
104 | ky = self.hy[::-1, ::-1].copy()
105 | kernel_x = torch.from_numpy(kx).unsqueeze(0).unsqueeze(0)
106 | kernel_y = torch.from_numpy(ky).unsqueeze(0).unsqueeze(0)
107 | self.fx = ImageFilter(1, self.kernel_size, kernel_x, self.device).cuda(self.device)
108 | self.fy = ImageFilter(1, self.kernel_size, kernel_y, self.device).cuda(self.device)
109 |
110 | def run(self, input, target, mask=None):
111 | torch.cuda.empty_cache()
112 | input_t = torch.from_numpy(input.astype(np.float32)).to(self.device)
113 | target_t = torch.from_numpy(target.astype(np.float32)).to(self.device)
114 | if mask is None:
115 | mask = torch.zeros_like(target_t).to(self.device)
116 | mask[(target_t>0) * (target_t<255)] = 1
117 | else:
118 | mask = torch.from_numpy(mask.astype(np.float32)).to(self.device)
119 | mask = (mask == 128).float()
120 | sad = self.BatchSAD(input_t, target_t, mask)
121 | mse = self.BatchMSE(input_t, target_t, mask)
122 | grad = self.BatchGradient(input_t, target_t, mask)
123 | conn = self.BatchConnectivity(input_t, target_t, mask)
124 | return sad, mse, grad, conn
125 |
126 | def run_video(self, input, target, mask=None):
127 | torch.cuda.empty_cache()
128 | input_t = torch.from_numpy(input.astype(np.float32)).to(self.device)
129 | target_t = torch.from_numpy(target.astype(np.float32)).to(self.device)
130 | if mask is None:
131 | mask = torch.zeros_like(target_t).to(self.device)
132 | mask[(target_t>0) * (target_t<255)] = 1
133 | else:
134 | mask = torch.from_numpy(mask.astype(np.float32)).to(self.device)
135 | mask = (mask == 128).float()
136 | errs, nums = [], []
137 | err, n = self.SSDA(input_t, target_t, mask)
138 | errs.append(err)
139 | nums.append(n)
140 | err, n = self.dtSSD(input_t, target_t, mask)
141 | errs.append(err)
142 | nums.append(n)
143 | err, n = self.MESSDdt(input_t, target_t, mask)
144 | errs.append(err)
145 | nums.append(n)
146 | return errs, nums
147 |
148 | def run_metric(self, metric, input, target, mask=None):
149 | torch.cuda.empty_cache()
150 | input_t = torch.from_numpy(input.astype(np.float32)).to(self.device)
151 | target_t = torch.from_numpy(target.astype(np.float32)).to(self.device)
152 | if mask is None:
153 | mask = torch.zeros_like(target_t).to(self.device)
154 | mask[(target_t>0) * (target_t<255)] = 1
155 | else:
156 | mask = torch.from_numpy(mask.astype(np.float32)).to(self.device)
157 | mask = (mask == 128).float()
158 |
159 | if metric == 'sad':
160 | ret = self.BatchSAD(input_t, target_t, mask)
161 | elif metric == 'mse':
162 | ret = self.BatchMSE(input_t, target_t, mask)
163 | elif metric == 'grad':
164 | ret = self.BatchGradient(input_t, target_t, mask)
165 | elif metric == 'conn':
166 | ret = self.BatchConnectivity(input_t, target_t, mask)
167 | elif metric == 'ssda':
168 | ret = self.SSDA(input_t, target_t, mask)
169 | elif metric == 'dtssd':
170 | ret = self.dtSSD(input_t, target_t, mask)
171 | elif metric == 'messddt':
172 | ret = self.MESSDdt(input_t, target_t, mask)
173 | else:
174 | raise NotImplementedError
175 | return ret
176 |
177 | def BatchSAD(self, pred, target, mask):
178 | B = target.size(0)
179 | error_map = (pred - target).abs() / 255.
180 | batch_loss = (error_map * mask).view(B, -1).sum(dim=-1)
181 | batch_loss = batch_loss / 1000.
182 | return batch_loss.data.cpu().numpy()
183 |
184 | def BatchMSE(self, pred, target, mask):
185 | B = target.size(0)
186 | error_map = (pred-target) / 255.
187 | batch_loss = (error_map.pow(2) * mask).view(B, -1).sum(dim=-1)
188 | batch_loss = batch_loss / (mask.view(B, -1).sum(dim=-1) + 1.)
189 | return batch_loss.data.cpu().numpy()
190 |
191 | def BatchGradient(self, pred, target, mask):
192 | B = target.size(0)
193 | pred = pred / 255.
194 | target = target / 255.
195 |
196 | pred_x_t = self.fx(pred).squeeze(1)
197 | pred_y_t = self.fy(pred).squeeze(1)
198 | target_x_t = self.fx(target).squeeze(1)
199 | target_y_t = self.fy(target).squeeze(1)
200 | pred_amp = (pred_x_t.pow(2) + pred_y_t.pow(2)).sqrt()
201 | target_amp = (target_x_t.pow(2) + target_y_t.pow(2)).sqrt()
202 | error_map = (pred_amp - target_amp).pow(2)
203 | batch_loss = (error_map * mask).view(B, -1).sum(dim=-1)
204 | return batch_loss.data.cpu().numpy()
205 |
206 | def BatchConnectivity(self, pred, target, mask):
207 | step = self.conn_step
208 | theta = self.conn_theta
209 |
210 | pred = pred / 255.
211 | target = target / 255.
212 | B, dimy, dimx = pred.shape
213 | thresh_steps = torch.arange(0, 1+step, step).to(self.device)
214 | l_map = torch.ones_like(pred).to(self.device)*(-1)
215 | pool = Pool(B)
216 | for i in range(1, len(thresh_steps)):
217 | pred_alpha_thresh = pred>=thresh_steps[i]
218 | target_alpha_thresh = target>=thresh_steps[i]
219 | mask_i = pred_alpha_thresh * target_alpha_thresh
220 | omegas = []
221 | items = [mask_ij.data.cpu().numpy() for mask_ij in mask_i]
222 | for omega in pool.imap(findMaxConnectedRegion, items):
223 | omegas.append(omega)
224 | omegas = torch.from_numpy(np.array(omegas)).to(self.device)
225 | flag = (l_map==-1) * (omegas==0)
226 | l_map[flag==1] = thresh_steps[i-1]
227 | l_map[l_map==-1] = 1
228 | pred_d = pred - l_map
229 | target_d = target - l_map
230 | pred_phi = 1 - pred_d*(pred_d>=theta).float()
231 | target_phi = 1 - target_d*(target_d>=theta).float()
232 | batch_loss = ((pred_phi-target_phi).abs()*mask).view([B, -1]).sum(-1)
233 | pool.close()
234 | return batch_loss.data.cpu().numpy()
235 |
236 | def GaussianGradient(self, mat):
237 | gx = np.zeros_like(mat)
238 | gy = np.zeros_like(mat)
239 | for i in range(mat.shape[0]):
240 | gx[i, ...] = ndimage.filters.convolve(mat[i], self.hx, mode='nearest')
241 | gy[i, ...] = ndimage.filters.convolve(mat[i], self.hy, mode='nearest')
242 | return gx, gy
243 |
244 | def SSDA(self, pred, target, mask=None):
245 | B, h, w = target.shape
246 | pred = pred / 255.
247 | target = target / 255.
248 | error = ((pred-target).pow(2) * mask).view(B, -1).sum(dim=1).sqrt()
249 | num = mask.view(B, -1).sum(dim=1) + 1.
250 | return error.data.cpu().numpy(), num.data.cpu().numpy()
251 |
252 | def dtSSD(self, pred, target, mask=None):
253 | B, h, w = target.shape
254 | pred = pred / 255.
255 | target = target / 255.
256 | pred_0 = pred[:-1, ...]
257 | pred_1 = pred[1:, ...]
258 | target_0 = target[:-1, ...]
259 | target_1 = target[1:, ...]
260 | mask_0 = mask[:-1, ...]
261 | error_map = ((pred_1-pred_0) - (target_1-target_0)).pow(2)
262 | error = (error_map * mask_0).view(mask_0.shape[0], -1).sum(dim=1).sqrt()
263 | num = mask_0.view(mask_0.shape[0], -1).sum(dim=1) + 1.
264 | return error.data.cpu().numpy(), num.data.cpu().numpy()
265 |
266 | def MESSDdt(self, pred, target, mask=None):
267 | B, h, w = target.shape
268 |
269 | pool = Pool(B)
270 | flows = []
271 | items = [t for t in target.data.cpu().numpy()]
272 | for flow in pool.imap(calcOpticalFlow, zip(items[:-1], items[1:])):
273 | flows.append(flow)
274 | flow = torch.from_numpy(np.rint(np.array(flows)).astype(np.int64)).to(self.device)
275 | pool.close()
276 |
277 | pred = pred / 255.
278 | target = target / 255.
279 | pred_0 = pred[:-1, ...]
280 | pred_1 = pred[1:, ...]
281 | target_0 = target[:-1, ...]
282 | target_1 = target[1:, ...]
283 | mask_0 = mask[:-1, ...]
284 | mask_1 = mask[1:, ...]
285 |
286 | B, h, w = target_0.shape
287 | x = torch.arange(0, w).to(self.device)
288 | y = torch.arange(0, h).to(self.device)
289 | xx, yy = torch.meshgrid([y, x])
290 | coords = torch.stack([yy, xx], dim=2).unsqueeze(0).repeat((B, 1, 1, 1))
291 | coords_n = (coords + flow)
292 | coords_y = coords_n[..., 0].clamp(0, h-1)
293 | coords_x = coords_n[..., 1].clamp(0, w-1)
294 | indices = coords_y * w + coords_x
295 | pred_1 = torch.take(pred_1, indices)
296 | target_1 = torch.take(target_1, indices)
297 | mask_1 = torch.take(mask_1, indices)
298 |
299 | error_map = (pred_0-target_0).pow(2) * mask_0 - (pred_1-target_1).pow(2) * mask_1
300 | error = error_map.abs().view(mask_0.shape[0], -1).sum(dim=1)
301 | num = mask_0.view(mask_0.shape[0], -1).sum(dim=1) + 1.
302 | return error.data.cpu().numpy(), num.data.cpu().numpy()
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | from pathlib import Path
4 | import cv2 as cv
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.distributed as torch_dist
11 |
12 | def dt(a):
13 | # a: tensor, [B, S, H, W]
14 | ac = a.cpu().numpy()
15 | b, s = a.shape[:2]
16 | z = []
17 | for i in range(b):
18 | y = []
19 | for j in range(s):
20 | x = ac[i,j]
21 | y.append(cv.distanceTransform((x * 255).astype(np.uint8), cv.DIST_L2, 0))
22 | z.append(np.stack(y))
23 | return torch.from_numpy(np.stack(z)).float().to(a.device)
24 |
25 | def trimap_transform(trimap):
26 | # trimap: tensor, [B, S, 2, H, W]
27 | b, s, _, h, w = trimap.shape
28 |
29 | clicks = torch.zeros((b, s, 6, h, w), device=trimap.device)
30 | for k in range(2):
31 | tk = trimap[:, :, k]
32 | if torch.sum(tk != 0) > 0:
33 | dt_mask = -dt(1. - tk)**2
34 | L = 320
35 | clicks[:, :, 3*k] = torch.exp(dt_mask / (2 * ((0.02 * L)**2)))
36 | clicks[:, :, 3*k+1] = torch.exp(dt_mask / (2 * ((0.08 * L)**2)))
37 | clicks[:, :, 3*k+2] = torch.exp(dt_mask / (2 * ((0.16 * L)**2)))
38 |
39 | return clicks
40 |
41 | def torch_barrier():
42 | if torch_dist.is_initialized():
43 | torch_dist.barrier()
44 |
45 | def reduce_tensor(inp):
46 | """
47 | Reduce the loss from all processes so that
48 | ALL PROCESSES has the averaged results.
49 | """
50 | if torch_dist.is_initialized():
51 | world_size = torch_dist.get_world_size()
52 | if world_size < 2:
53 | return inp
54 | with torch.no_grad():
55 | reduced_inp = inp
56 | torch.distributed.all_reduce(reduced_inp)
57 | torch.distributed.barrier()
58 | return reduced_inp / world_size
59 | return inp
60 |
61 | def print_loss_dict(loss, save=None):
62 | s = ''
63 | for key in sorted(loss.keys()):
64 | s += '{}: {:.6f}\n'.format(key, loss[key])
65 | print (s)
66 | if save is not None:
67 | with open(save, 'w') as f:
68 | f.write(s)
69 |
70 | def coords_grid(batch, ht, wd):
71 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
72 | coords = torch.stack(coords[::-1], axis=0)
73 | return coords.unsqueeze(0).repeat(batch, 1, 1, 1)
74 |
75 | def grid_sampler(img, coords, mode='bilinear'):
76 | """ Wrapper for grid_sample, uses pixel coordinates
77 | img: [B, C, H, W]
78 | coords: [B, 2, H, W]
79 | """
80 | H, W = img.shape[-2:]
81 | xgrid, ygrid = coords.split(1, dim=1)
82 | xgrid = 2*xgrid/(W-1) - 1
83 | ygrid = 2*ygrid/(H-1) - 1
84 |
85 | grid = torch.cat([xgrid, ygrid], dim=1).permute(0, 2, 3, 1)
86 | img = F.grid_sample(img, grid, mode=mode, align_corners=True)
87 |
88 | return img
89 |
90 | def flow_dt(a, ha, gt, hgt, flow, trimask, metric=False, cuda=True):
91 | '''
92 | All tensors in [B, C, H, W]
93 | a: current prediction
94 | gt: current groundtruth
95 | ha: adjacent frame prediction
96 | hgt: adjacent frame groundtruth
97 | flow: optical flow from current frame to adjacent frame
98 | trimask: current frame trimask
99 | '''
100 | # Warp ha back to a and hgt back to gt
101 | with torch.no_grad():
102 | B, _, H, W = a.shape
103 | mask = torch.isnan(flow) # B, 1, H, W
104 | coords = coords_grid(B, H, W) # B, 2, H, W
105 | if cuda:
106 | coords = coords.to(torch.cuda.current_device())
107 | flow[mask] = 0
108 | flow_coords = coords + flow
109 | mask = (~mask[:, :1, :, :]) * trimask.bool()
110 | valid = mask.sum()
111 | if valid == 0:
112 | if metric:
113 | return valid.float(), valid.float(), valid.float()
114 | else:
115 | return valid.float()
116 |
117 | pgt = grid_sampler(hgt, flow_coords)
118 | pa = grid_sampler(ha, flow_coords)
119 | error = torch.abs((a[mask] - gt[mask]) - (pa[mask] - pgt[mask])) # L1 instead of L2
120 | if metric:
121 | error2 = torch.abs((a[mask] - gt[mask]) ** 2 - (pa[mask] - pgt[mask]) ** 2)
122 | return error.sum(), error2.sum(), valid
123 | return error.mean()
124 |
125 | class AverageMeter(object):
126 | """Computes and stores the average and current value"""
127 |
128 | def __init__(self):
129 | self.initialized = False
130 | self.val = None
131 | self.avg = None
132 | self.sum = None
133 | self.count = None
134 |
135 | def initialize(self, val, weight):
136 | self.val = val
137 | self.avg = val
138 | self.sum = val * weight
139 | self.count = weight
140 | self.initialized = True
141 |
142 | def update(self, val, weight=1):
143 | if not self.initialized:
144 | self.initialize(val, weight)
145 | else:
146 | self.add(val, weight)
147 |
148 | def add(self, val, weight):
149 | self.val = val
150 | self.sum += val * weight
151 | self.count += weight
152 | self.avg = self.sum / self.count
153 |
154 | def value(self):
155 | return self.val
156 |
157 | def average(self):
158 | return self.avg
159 |
160 | def create_logger(output_dir, cfg_name, phase='train'):
161 | root_output_dir = Path(output_dir)
162 | # set up logger
163 | if not root_output_dir.exists():
164 | print('=> creating {}'.format(root_output_dir))
165 | root_output_dir.mkdir()
166 |
167 | final_output_dir = root_output_dir / cfg_name
168 |
169 | print('=> creating {}'.format(final_output_dir))
170 | final_output_dir.mkdir(parents=True, exist_ok=True)
171 |
172 | time_str = time.strftime('%Y-%m-%d-%H-%M')
173 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)
174 | final_log_file = final_output_dir / log_file
175 | head = '%(asctime)-15s %(message)s'
176 | logging.basicConfig(filename=str(final_log_file),
177 | format=head)
178 | logger = logging.getLogger()
179 | logger.setLevel(logging.INFO)
180 | console = logging.StreamHandler()
181 | logging.getLogger('').addHandler(console)
182 |
183 | return logger, str(final_output_dir)
184 |
185 | def poly_lr(optimizer, base_lr, max_iters, cur_iters, power=0.9):
186 | lr = base_lr*((1-float(cur_iters)/max_iters)**(power))
187 | optimizer.param_groups[0]['lr'] = lr
188 | return lr
189 |
190 | def const_lr(optimizer, base_lr, max_iters, cur_iters):
191 | return base_lr
192 |
193 | OPT_DICT = {
194 | 'adam': torch.optim.Adam,
195 | 'adamw': torch.optim.AdamW,
196 | 'sgd': torch.optim.SGD,
197 | }
198 |
199 | STR_DICT = {
200 | 'poly': poly_lr,
201 | 'const': const_lr,
202 | }
203 |
--------------------------------------------------------------------------------