├── .gitignore
├── LICENSE
├── datasets
├── CRVD_seq.py
├── __init__.py
└── sRGB_seq.py
├── imgs
└── figure_overview.png
├── models
├── __init__.py
├── basicvsr_plusplus.py
├── birnn.py
├── components.py
├── flornn.py
├── flornn_raw.py
├── forwardrnn.py
├── init.py
└── rvidenet
│ ├── isp.pth
│ └── isp.py
├── pytorch_pwc
├── correlation
│ └── correlation.py
├── extract_flow.py
└── pwc.py
├── readme.md
├── requirements.yaml
├── softmax_splatting
└── softsplat.py
├── test_models
├── CRVD_test.py
└── sRGB_test.py
├── train_models
├── CRVD_train.py
├── base_functions.py
├── sRGB_train.py
└── sRGB_train_distributed.py
└── utils
├── fastdvdnet_utils.py
├── io.py
├── raw.py
├── ssim.py
└── warp.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .xml
2 | .idea
3 | .idea/workspace.xml
4 | .DS_Store
5 | */__pycache__git
6 | .pyc
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/datasets/CRVD_seq.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | import torch
5 | from torch.utils.data.dataset import Dataset
6 |
7 | iso_list = [1600, 3200, 6400, 12800, 25600]
8 | a_list = [3.513262, 6.955588, 13.486051, 26.585953, 52.032536]
9 | g_noise_var_list = [11.917691, 38.117816, 130.818508, 484.539790, 1819.818657]
10 |
11 | def pack_gbrg_raw_torch(raw): # T H W
12 | T, H, W = raw.shape
13 | im = raw.unsqueeze(1)
14 |
15 | out = torch.cat((im[:, :, 1:H:2, 0:W:2],
16 | im[:, :, 1:H:2, 1:W:2],
17 | im[:, :, 0:H:2, 0:W:2],
18 | im[:, :, 0:H:2, 1:W:2]), dim=1)
19 | return out
20 |
21 | def normalize_raw_torch(raw):
22 | black_level = 240
23 | white_level = 2 ** 12 - 1
24 | raw = torch.clamp(raw.type(torch.float32) - black_level, 0) / (white_level - black_level)
25 | return raw
26 |
27 | def open_CRVD_seq_raw(seq_path, file_pattern='frame%d_noisy0.tiff'):
28 | frame_list = []
29 | for i in range(7):
30 | raw = cv2.imread(os.path.join(seq_path, file_pattern % (i+1)), -1)
31 | raw = np.asarray(raw)
32 | raw = np.expand_dims(raw, axis=0)
33 | frame_list.append(raw)
34 | seq = np.concatenate(frame_list, axis=0)
35 | return seq
36 |
37 | def open_CRVD_seq_raw_outdoor(seq_path, file_pattern='frame%d_noisy0.tiff'):
38 | frame_list = []
39 | for i in range(50):
40 | raw = cv2.imread(os.path.join(seq_path, file_pattern % i), -1)
41 | raw = np.asarray(raw)
42 | raw = np.expand_dims(raw, axis=0)
43 | frame_list.append(raw)
44 | seq = np.concatenate(frame_list, axis=0)
45 | return seq
46 |
47 | def crop_position(patch_size, H, W):
48 | position_h = np.random.randint(0, (H - patch_size)//2 - 1) * 2
49 | position_w = np.random.randint(0, (W - patch_size)//2 - 1) * 2
50 | aug = np.random.randint(0, 8)
51 | return position_h, position_w, aug
52 |
53 | def aug_crop(img, patch_size, position_h, position_w, aug):
54 | patch = img[:, position_h:position_h + patch_size + 2, position_w:position_w + patch_size + 2]
55 |
56 | if aug == 0:
57 | patch = patch[:, :-2, :-2]
58 | elif aug == 1:
59 | patch = np.flip(patch, axis=1)
60 | patch = patch[:, 1:-1, :-2]
61 | elif aug == 2:
62 | patch = np.flip(np.flip(patch, axis=1), axis=2)
63 | patch = patch[:, 1:-1, 1:-1]
64 | elif aug == 3:
65 | patch = np.flip(patch, axis=2)
66 | patch = patch[:, :-2, 1:-1]
67 | elif aug == 4:
68 | patch = np.transpose(np.flip(patch, axis=2), (0, 2, 1))
69 | patch = patch[:, :-2, 1:-1]
70 | elif aug == 5:
71 | patch = np.transpose(np.flip(np.flip(patch, axis=1), axis=2), (0, 2, 1))
72 | patch = patch[:, :-2, :-2]
73 | elif aug == 6:
74 | patch = np.transpose(patch, (0, 2, 1))
75 | patch = patch[:, 1:-1, 1:-1]
76 | elif aug == 7:
77 | patch = np.transpose(np.flip(patch, axis=1), (0, 2, 1))
78 | patch = patch[:, 1:-1, :-2]
79 | return patch
80 |
81 |
82 | class CRVDTrainDataset(Dataset):
83 | def __init__(self, CRVD_path, patch_size, patches_per_epoch, mirror_seq=True):
84 | self.CRVD_path = CRVD_path
85 | self.patches_per_epoch = patches_per_epoch
86 | self.patch_size = patch_size * 2
87 | self.mirror_seq = mirror_seq
88 | self.scene_id_list = [1, 2, 3, 4, 5, 6]
89 | self.seqs = {}
90 |
91 | for iso in iso_list:
92 | for scene_id in self.scene_id_list:
93 | self.seqs['%d_%d_clean' % (iso, scene_id)] = open_CRVD_seq_raw(os.path.join(self.CRVD_path, 'indoor_raw_gt/scene%d/ISO%d' % (scene_id, iso)),
94 | 'frame%d_clean_and_slightly_denoised.tiff')
95 | for i in range(10):
96 | self.seqs['%d_%d_noisy_%d' % (iso, scene_id, i)] = open_CRVD_seq_raw(os.path.join(self.CRVD_path, 'indoor_raw_noisy/scene%d/ISO%d' % (scene_id, iso)),
97 | 'frame%d_noisy{}.tiff'.format(i))
98 |
99 | def __getitem__(self, index):
100 | index = index % (len(iso_list) * len(self.scene_id_list) * 10)
101 | iso_index = index // (len(self.scene_id_list) * 10)
102 | scene_index = (index - iso_index * len(self.scene_id_list) * 10) // 10
103 | noisy_index = index % 10
104 | iso = iso_list[iso_index]
105 | scene_id = self.scene_id_list[scene_index]
106 |
107 | seq = self.seqs['%d_%d_clean' % (iso, scene_id)]
108 | seqn = self.seqs['%d_%d_noisy_%d' % (iso, scene_id, noisy_index)]
109 | T, H, W = seq.shape
110 | position_h, position_w, aug = crop_position(self.patch_size, H, W)
111 | seq = aug_crop(seq, self.patch_size, position_h, position_w, aug)
112 | seqn = aug_crop(seqn, self.patch_size, position_h, position_w, aug)
113 | clean_list, noisy_list = [], []
114 | for i in range(T):
115 | clean_list.append(np.expand_dims(seq[i], axis=0))
116 | noisy_list.append(np.expand_dims(seqn[i], axis=0))
117 | seq = torch.from_numpy(np.concatenate(clean_list, axis=0).astype(np.int32))
118 | seqn = torch.from_numpy(np.concatenate(noisy_list, axis=0).astype(np.int32))
119 | seq = normalize_raw_torch(pack_gbrg_raw_torch(seq))
120 | seqn = normalize_raw_torch(pack_gbrg_raw_torch(seqn))
121 |
122 | if self.mirror_seq:
123 | seq = torch.cat((seq, torch.flip(seq, dims=[0])), dim=0)
124 | seqn = torch.cat((seqn, torch.flip(seqn, dims=[0])), dim=0)
125 |
126 | a = torch.tensor(a_list[iso_index], dtype=torch.float32).view((1, 1, 1, 1)) / (2 ** 12 - 1 - 240)
127 | b = torch.tensor(g_noise_var_list[iso_index], dtype=torch.float32).view((1, 1, 1, 1)) / ((2 ** 12 - 1 - 240) ** 2)
128 |
129 | return {'seq': seq,
130 | 'seqn': seqn,
131 | 'a': a, 'b': b}
132 |
133 | def __len__(self):
134 | return self.patches_per_epoch
135 |
136 | class CRVDTestDataset(Dataset):
137 | def __init__(self, CRVD_path):
138 | self.CRVD_path = CRVD_path
139 | self.scene_id_list = [7, 8, 9, 10, 11]
140 | self.seqs = {}
141 |
142 | for iso in iso_list:
143 | for scene_id in self.scene_id_list:
144 | self.seqs['%d_%d_clean' % (iso, scene_id)] = open_CRVD_seq_raw(os.path.join(self.CRVD_path, 'indoor_raw_gt/scene%d/ISO%d' % (scene_id, iso)),
145 | 'frame%d_clean_and_slightly_denoised.tiff')
146 | self.seqs['%d_%d_noisy' % (iso, scene_id)] = open_CRVD_seq_raw(os.path.join(self.CRVD_path, 'indoor_raw_noisy/scene%d/ISO%d' % (scene_id, iso)),
147 | 'frame%d_noisy0.tiff')
148 |
149 | def __getitem__(self, index):
150 | iso = iso_list[index // len(self.scene_id_list)]
151 | scene_id = self.scene_id_list[index % len(self.scene_id_list)]
152 |
153 | seq = torch.from_numpy(self.seqs['%d_%d_clean' % (iso, scene_id)].astype(np.float32))
154 | seqn = torch.from_numpy(self.seqs['%d_%d_noisy' % (iso, scene_id)].astype(np.float32))
155 | seq = normalize_raw_torch(pack_gbrg_raw_torch(seq))
156 | seqn = normalize_raw_torch(pack_gbrg_raw_torch(seqn))
157 | a = torch.tensor(a_list[index // len(self.scene_id_list)], dtype=torch.float32).view((1, 1, 1, 1)) / (2 ** 12 - 1 - 240)
158 | b = torch.tensor(g_noise_var_list[index // len(self.scene_id_list)], dtype=torch.float32).view((1, 1, 1, 1)) / ((2 ** 12 - 1 - 240) ** 2)
159 |
160 | return {'seq': seq,
161 | 'seqn': seqn,
162 | 'iso': iso, 'a': a, 'b': b, 'scene_id': scene_id}
163 |
164 | def __len__(self):
165 | return len(iso_list) * len(self.scene_id_list)
166 |
167 | class CRVDOurdoorDataset(Dataset):
168 | def __init__(self, CRVD_path):
169 | self.CRVD_path = CRVD_path
170 | self.scene_id_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
171 | self.seqs = {}
172 |
173 | self.iso = 25600
174 | for scene_id in self.scene_id_list:
175 | self.seqs['%d_%d_noisy' % (self.iso, scene_id)] = open_CRVD_seq_raw_outdoor(os.path.join(self.CRVD_path, 'outdoor_raw_noisy/scene%d/iso%d' % (scene_id, self.iso)),
176 | 'frame%d.tiff')
177 |
178 | def __getitem__(self, index):
179 | scene_id = self.scene_id_list[index]
180 |
181 | seqn = torch.from_numpy(self.seqs['%d_%d_noisy' % (self.iso, scene_id)].astype(np.float32))
182 | seqn = normalize_raw_torch(pack_gbrg_raw_torch(seqn))
183 | a = torch.tensor(a_list[4], dtype=torch.float32).view((1, 1, 1, 1)) / (2 ** 12 - 1 - 240)
184 | b = torch.tensor(g_noise_var_list[4], dtype=torch.float32).view((1, 1, 1, 1)) / ((2 ** 12 - 1 - 240) ** 2)
185 |
186 | return {'seqn': seqn,
187 | 'iso': self.iso, 'a': a, 'b': b, 'scene_id': scene_id}
188 |
189 | def __len__(self):
190 | return len(self.scene_id_list)
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from datasets.CRVD_seq import CRVDTrainDataset, CRVDTestDataset, CRVDOurdoorDataset
2 | from datasets.sRGB_seq import SrgbTrainDataset, SrgbValDataset
--------------------------------------------------------------------------------
/datasets/sRGB_seq.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import numpy as np
3 | import os
4 | import torch
5 | from torch.utils.data.dataset import Dataset
6 | from utils.fastdvdnet_utils import open_sequence
7 | from utils.io import list_dir, open_images_uint8
8 |
9 |
10 |
11 | class SrgbTrainDataset(Dataset):
12 | def __init__(self, seq_dir, train_length, patch_size, patches_per_epoch, temp_stride=3, image_postfix='png', pin_memory=False):
13 | self.seq_dir = seq_dir
14 | self.train_length = train_length
15 | self.patch_size = patch_size
16 | self.patches_per_epoch = patches_per_epoch
17 | self.temp_stride = temp_stride
18 | self.pin_memory = pin_memory
19 |
20 | self.seq_names = list_dir(seq_dir)
21 | self.seqs = {}
22 | for seq_name in self.seq_names:
23 | self.seqs[seq_name] = {}
24 | self.seqs[seq_name]['clean_image_files'] = list_dir(os.path.join(self.seq_dir, seq_name),
25 | postfix=image_postfix, full_path=True)
26 | if self.pin_memory:
27 | self.seqs[seq_name]['clean_images'] = open_images_uint8(self.seqs[seq_name]['clean_image_files'])
28 |
29 | self.seq_count = []
30 | for i in range(len(self.seq_names)):
31 | count = (len(self.seqs[self.seq_names[i]]['clean_image_files']) - self.train_length + self.temp_stride) // self.temp_stride
32 | self.seq_count.append(count)
33 | self.seq_count_cum = np.cumsum(self.seq_count)
34 |
35 | def __getitem__(self, index):
36 | if self.patches_per_epoch is not None:
37 | index = index % self.seq_count_cum[-1]
38 | for i in range(len(self.seq_count_cum)):
39 | if index < self.seq_count_cum[i]:
40 | seq_name = self.seq_names[i]
41 | seq_index = index if i == 0 else index - self.seq_count_cum[i - 1]
42 | break
43 | center_frame_index = seq_index * self.temp_stride + (self.train_length//2)
44 | if self.pin_memory:
45 | clean_images = self.seqs[seq_name]['clean_images']
46 | else:
47 | clean_images = open_images_uint8(self.seqs[seq_name]['clean_image_files'])
48 | data = clean_images[center_frame_index - (self.train_length // 2):center_frame_index +
49 | (self.train_length // 2) + (self.train_length % 2)]
50 |
51 | # crop patches
52 | num_frames, C, H, W = data.shape
53 | position_H = np.random.randint(0, H - self.patch_size + 1)
54 | position_W = np.random.randint(0, W - self.patch_size + 1)
55 | data = data[:, :, position_H:position_H+self.patch_size, position_W:position_W+self.patch_size]
56 |
57 | return_dict = {'data':data}
58 | return return_dict
59 |
60 | def __len__(self):
61 | if self.patches_per_epoch is None:
62 | return self.seq_count_cum[-1]
63 | else:
64 | return self.patches_per_epoch
65 |
66 | """
67 | Dataset related functions
68 | Copyright (C) 2018, Matias Tassano
69 | This program is free software: you can use, modify and/or
70 | redistribute it under the terms of the GNU General Public
71 | License as published by the Free Software Foundation, either
72 | version 3 of the License, or (at your option) any later
73 | version. You should have received a copy of this license along
74 | this program. If not, see .
75 | """
76 |
77 | NUMFRXSEQ_VAL = 85 # number of frames of each sequence to include in validation dataset
78 | VALSEQPATT = '*' # pattern for name of validation sequence
79 |
80 | class SrgbValDataset(Dataset):
81 | """Validation dataset. Loads all the images in the dataset folder on memory.
82 | """
83 | def __init__(self, valsetdir, gray_mode=False, num_input_frames=NUMFRXSEQ_VAL):
84 | self.gray_mode = gray_mode
85 |
86 | # Look for subdirs with individual sequences
87 | seqs_dirs = sorted(glob.glob(os.path.join(valsetdir, VALSEQPATT)))
88 |
89 | # open individual sequences and append them to the sequence list
90 | sequences = []
91 | for seq_dir in seqs_dirs:
92 | seq, _, _ = open_sequence(seq_dir, gray_mode, expand_if_needed=False, \
93 | max_num_fr=num_input_frames)
94 | # seq is [num_frames, C, H, W]
95 | sequences.append(seq)
96 |
97 | self.seqs_dirs = seqs_dirs
98 | self.sequences = sequences
99 |
100 | def __getitem__(self, index):
101 | return {'seq':torch.from_numpy(self.sequences[index]), 'name':self.seqs_dirs[index]}
102 |
103 | def __len__(self):
104 | return len(self.sequences)
--------------------------------------------------------------------------------
/imgs/figure_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nagejacob/FloRNN/5419715af261bf1d619818baaf26708b81781f4a/imgs/figure_overview.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.rvidenet.isp import ISP
2 | from models.basicvsr_plusplus import BasicVSRPlusPlus
3 | from models.birnn import BiRNN
4 | from models.flornn import FloRNN
5 | from models.flornn_raw import FloRNNRaw
6 | from models.forwardrnn import ForwardRNN
--------------------------------------------------------------------------------
/models/basicvsr_plusplus.py:
--------------------------------------------------------------------------------
1 | from mmcv.cnn import constant_init
2 | from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d
3 | from models.components import ResBlocks, D
4 | from pytorch_pwc.extract_flow import extract_flow_torch
5 | from pytorch_pwc.pwc import PWCNet
6 | import torch
7 | import torch.nn as nn
8 | from utils.warp import warp
9 |
10 | class BasicVSRPlusPlus(nn.Module):
11 | def __init__(self, img_channels=3, spatial_blocks=-1, temporal_blocks=-1, num_channels=64):
12 | super(BasicVSRPlusPlus, self).__init__()
13 | self.num_channels = num_channels
14 | self.pwcnet = PWCNet()
15 |
16 | self.feat_extract = ResBlocks(input_channels=img_channels * 2, num_resblocks=spatial_blocks, num_channels=num_channels)
17 |
18 | self.backbone = nn.ModuleDict()
19 | self.deform_align = nn.ModuleDict()
20 | self.module_names = ['forward_1', 'backward_1', 'forward_2', 'backward_2']
21 | for i, module_name in enumerate(self.module_names):
22 | self.backbone[module_name] = ResBlocks(input_channels=num_channels * (i+2), num_resblocks=temporal_blocks, num_channels=num_channels)
23 | self.deform_align[module_name] = SecondOrderDeformableAlignment(
24 | 2 * num_channels,
25 | num_channels,
26 | 3,
27 | padding=1,
28 | deform_groups=16,
29 | max_residue_magnitude=10)
30 |
31 | self.d = D(in_channels=num_channels * 4, mid_channels=num_channels * 2, out_channels=img_channels)
32 | self.device = torch.device('cuda')
33 |
34 | def trainable_parameters(self):
35 | return [{'params':self.feat_extract.parameters()}, {'params':self.backbone.parameters()},
36 | {'params':self.deform_align.parameters()}, {'params':self.d.parameters()}]
37 |
38 | def spatial_feature(self, seqn, noise_level_map):
39 | spatial_hs = []
40 | for i in range(seqn.shape[1]):
41 | spatial_h = self.feat_extract(torch.cat((seqn[:, i].cuda(), noise_level_map[:, i].cuda()), dim=1))
42 | if not self.training:
43 | spatial_h = spatial_h.cpu()
44 | spatial_hs.append(spatial_h)
45 | return spatial_hs
46 |
47 | def extract_flows(self, seqn):
48 | N, T, C, H, W = seqn.shape
49 | forward_flows, backward_flows = [], []
50 | for i in range(T-1):
51 | forward_flow = extract_flow_torch(self.pwcnet, seqn[:, i+1].cuda(), seqn[:, i].cuda())
52 | backward_flow = extract_flow_torch(self.pwcnet, seqn[:, i].cuda(), seqn[:, i+1].cuda())
53 | if not self.training:
54 | forward_flow = forward_flow.cpu()
55 | backward_flow = backward_flow.cpu()
56 | forward_flows.append(forward_flow)
57 | backward_flows.append(backward_flow)
58 | return forward_flows, backward_flows
59 |
60 | def forward(self, seqn, noise_level_map):
61 | if self.training:
62 | self.device = torch.device('cuda')
63 | return self.forward_train(seqn, noise_level_map)
64 | else:
65 | self.device = torch.device('cpu')
66 | return self.forward_test(seqn, noise_level_map)
67 |
68 | def forward_train(self, seqn, noise_level_map):
69 | N, T, C, H, W = seqn.shape
70 | hs = {}
71 | for module_name in self.module_names:
72 | hs[module_name] = [None] * T
73 | zeros_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device)
74 | zeros_flow = torch.zeros((N, 2, H, W), device=seqn.device)
75 | seqdn = torch.empty_like(seqn)
76 |
77 | # extract flows
78 | forward_flows, backward_flows = self.extract_flows(seqn)
79 |
80 | # extract spatial features
81 | hs['spatial'] = self.spatial_feature(seqn, noise_level_map)
82 |
83 | # extract forward features
84 | spatial_h = hs['spatial'][0]
85 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, zeros_h), dim=1))
86 | hs['forward_1'][0] = forward_h
87 |
88 | spatial_h = hs['spatial'][1]
89 | flow_n1 = forward_flows[0]
90 | forward_h_n1 = forward_h
91 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1)
92 | feat_prop = self.deform_align['forward_1'](torch.cat((forward_h_n1, zeros_h), dim=1),
93 | torch.cat((aligned_forward_h_n1, spatial_h, zeros_h), dim=1),
94 | flow_n1, zeros_flow)
95 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, feat_prop), dim=1))
96 | hs['forward_1'][1] = forward_h
97 |
98 | for i in range(2, T):
99 | spatial_h = hs['spatial'][i]
100 | flow_n1 = forward_flows[i - 1]
101 | forward_h_n1 = forward_h
102 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1)
103 | flow_n2 = flow_n1 + warp(forward_flows[i - 2], flow_n1)[0]
104 | forward_h_n2 = hs['forward_1'][i - 2]
105 | aligned_forward_h_n2, _ = warp(forward_h_n2, flow_n2)
106 | feat_prop = self.deform_align['forward_1'](torch.cat((forward_h_n1, forward_h_n2), dim=1), torch.cat(
107 | (aligned_forward_h_n1, spatial_h, aligned_forward_h_n2), dim=1), flow_n1, flow_n2)
108 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, feat_prop), dim=1))
109 | hs['forward_1'][i] = forward_h
110 |
111 | # extract backward features
112 | spatial_h = hs['spatial'][-1]
113 | backward_h = self.backbone['backward_1'](torch.cat((spatial_h, zeros_h, hs['forward_1'][-1]), dim=1))
114 | hs['backward_1'][-1] = backward_h
115 |
116 | spatial_h = hs['spatial'][-2]
117 | flow_p1 = backward_flows[-1]
118 | backward_h_p1 = backward_h
119 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1)
120 | feat_prop = self.deform_align['backward_1'](torch.cat((backward_h_p1, zeros_h), dim=1),
121 | torch.cat((aligned_backward_h_p1, spatial_h, zeros_h), dim=1),
122 | flow_p1, zeros_flow)
123 | backward_h = self.backbone['backward_1'](torch.cat((spatial_h, feat_prop, hs['forward_1'][-2]), dim=1))
124 | hs['backward_1'][-2] = backward_h
125 |
126 | for i in range(3, T + 1):
127 | spatial_h = hs['spatial'][T - i]
128 | flow_p1 = backward_flows[T - i]
129 | backward_h_p1 = backward_h
130 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1)
131 | flow_p2 = flow_p1 + warp(backward_flows[T - i + 1], flow_p1)[0]
132 | backward_h_p2 = hs['backward_1'][T - i + 1]
133 | aligned_backward_h_p2, _ = warp(backward_h_p2, flow_p2)
134 | feat_prop = self.deform_align['backward_1'](torch.cat((backward_h_p1, backward_h_p2), dim=1),
135 | torch.cat((aligned_backward_h_p1, spatial_h, backward_h_p2),
136 | dim=1), flow_p1, flow_p2)
137 | backward_h = self.backbone['backward_1'](
138 | torch.cat((spatial_h, feat_prop, hs['forward_1'][T - i]), dim=1))
139 | hs['backward_1'][T - i] = backward_h
140 |
141 | # extract forward features
142 | spatial_h = hs['spatial'][0]
143 | forward_h = self.backbone['forward_2'](torch.cat((spatial_h, zeros_h,
144 | hs['forward_1'][0],
145 | hs['backward_1'][0]), dim=1))
146 | hs['forward_2'][0] = forward_h
147 |
148 | spatial_h = hs['spatial'][1]
149 | flow_n1 = forward_flows[0]
150 | forward_h_n1 = forward_h
151 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1)
152 | feat_prop = self.deform_align['forward_2'](torch.cat((forward_h_n1, zeros_h), dim=1),
153 | torch.cat((aligned_forward_h_n1, spatial_h, zeros_h), dim=1),
154 | flow_n1, zeros_flow)
155 | forward_h = self.backbone['forward_2'](
156 | torch.cat((spatial_h, feat_prop, hs['forward_1'][1], hs['backward_1'][1]), dim=1))
157 | hs['forward_2'][1] = forward_h
158 |
159 | for i in range(2, T):
160 | spatial_h = hs['spatial'][i]
161 | flow_n1 = forward_flows[i - 1]
162 | forward_h_n1 = forward_h
163 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1)
164 | flow_n2 = flow_n1 + warp(forward_flows[i - 2], flow_n1)[0]
165 | forward_h_n2 = hs['forward_2'][i - 2]
166 | aligned_forward_h_n2, _ = warp(forward_h_n2, flow_n2)
167 | feat_prop = self.deform_align['forward_2'](torch.cat((forward_h_n1, forward_h_n2), dim=1), torch.cat(
168 | (aligned_forward_h_n1, spatial_h, aligned_forward_h_n2), dim=1), flow_n1, flow_n2)
169 | forward_h = self.backbone['forward_2'](
170 | torch.cat((spatial_h, feat_prop, hs['forward_1'][i], hs['backward_1'][i]), dim=1))
171 | hs['forward_2'][i] = forward_h
172 |
173 | # extract backward features
174 | spatial_h = hs['spatial'][-1]
175 | backward_h = self.backbone['backward_2'](
176 | torch.cat((spatial_h, zeros_h, hs['forward_1'][-1], hs['backward_1'][-1], hs['forward_2'][-1]),
177 | dim=1))
178 | hs['backward_2'][-1] = backward_h
179 |
180 | spatial_h = hs['spatial'][-2]
181 | flow_p1 = backward_flows[-1]
182 | backward_h_p1 = backward_h
183 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1)
184 | feat_prop = self.deform_align['backward_2'](torch.cat((backward_h_p1, zeros_h), dim=1),
185 | torch.cat((aligned_backward_h_p1, spatial_h, zeros_h), dim=1),
186 | flow_p1, zeros_flow)
187 | backward_h = self.backbone['backward_2'](
188 | torch.cat((spatial_h, feat_prop, hs['forward_1'][-2], hs['backward_1'][-2], hs['forward_2'][-2]),
189 | dim=1))
190 | hs['backward_2'][-2] = backward_h
191 |
192 | for i in range(3, T + 1):
193 | spatial_h = hs['spatial'][T - i]
194 | flow_p1 = backward_flows[T - i]
195 | backward_h_p1 = backward_h
196 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1)
197 | flow_p2 = flow_p1 + warp(backward_flows[T - i + 1], flow_p1)[0]
198 | backward_h_p2 = hs['backward_2'][T - i + 1]
199 | aligned_backward_h_p2, _ = warp(backward_h_p2, flow_p2)
200 | feat_prop = self.deform_align['backward_2'](torch.cat((backward_h_p1, backward_h_p2), dim=1),
201 | torch.cat((aligned_backward_h_p1, spatial_h, backward_h_p2),
202 | dim=1), flow_p1, flow_p2)
203 | backward_h = self.backbone['backward_2'](torch.cat((spatial_h, feat_prop, hs['forward_1'][T - i],
204 | hs['backward_1'][T - i], hs['forward_2'][T - i]),
205 | dim=1))
206 | hs['backward_2'][T - i] = backward_h
207 |
208 | # generate results
209 | for i in range(T):
210 | seqdn[:, i] = self.d(torch.cat((hs['forward_1'][i], hs['backward_1'][i], hs['forward_2'][i], hs['backward_2'][i]), dim=1))
211 |
212 | return seqdn
213 |
214 | def forward_test(self, seqn, noise_level_map):
215 | N, T, C, H, W = seqn.shape
216 | hs = {}
217 | for module_name in self.module_names:
218 | hs[module_name] = [None] * T
219 | zeros_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device)
220 | zeros_flow = torch.zeros((1, 2, H, W), device=seqn.device)
221 | seqdn = torch.empty_like(seqn)
222 |
223 | # extract flows
224 | forward_flows, backward_flows = self.extract_flows(seqn)
225 |
226 | # extract spatial features
227 | hs['spatial'] = self.spatial_feature(seqn, noise_level_map)
228 |
229 | # extract forward features
230 | spatial_h = hs['spatial'][0].cuda()
231 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, zeros_h.cuda()), dim=1))
232 | hs['forward_1'][0] = forward_h.cpu()
233 |
234 | spatial_h = hs['spatial'][1].cuda()
235 | flow_n1 = forward_flows[0].cuda()
236 | forward_h_n1 = forward_h
237 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1)
238 | feat_prop = self.deform_align['forward_1'](torch.cat((forward_h_n1, zeros_h.cuda()), dim=1),
239 | torch.cat((aligned_forward_h_n1, spatial_h, zeros_h.cuda()), dim=1),
240 | flow_n1, zeros_flow.cuda())
241 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, feat_prop), dim=1))
242 | hs['forward_1'][1] = forward_h.cpu()
243 |
244 | for i in range(2, T):
245 | spatial_h = hs['spatial'][i].cuda()
246 | flow_n1 = forward_flows[i - 1].cuda()
247 | forward_h_n1 = forward_h
248 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1)
249 | flow_n2 = flow_n1 + warp(forward_flows[i - 2].cuda(), flow_n1)[0]
250 | forward_h_n2 = hs['forward_1'][i - 2].cuda()
251 | aligned_forward_h_n2, _ = warp(forward_h_n2, flow_n2)
252 | feat_prop = self.deform_align['forward_1'](torch.cat((forward_h_n1, forward_h_n2), dim=1), torch.cat(
253 | (aligned_forward_h_n1, spatial_h, aligned_forward_h_n2), dim=1), flow_n1, flow_n2)
254 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, feat_prop), dim=1))
255 | hs['forward_1'][i] = forward_h.cpu()
256 |
257 | # extract backward features
258 | spatial_h = hs['spatial'][-1].cuda()
259 | backward_h = self.backbone['backward_1'](torch.cat((spatial_h, zeros_h.cuda(), hs['forward_1'][-1].cuda()), dim=1))
260 | hs['backward_1'][-1] = backward_h.cpu()
261 |
262 | spatial_h = hs['spatial'][-2].cuda()
263 | flow_p1 = backward_flows[-1].cuda()
264 | backward_h_p1 = backward_h
265 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1)
266 | feat_prop = self.deform_align['backward_1'](torch.cat((backward_h_p1, zeros_h.cuda()), dim=1),
267 | torch.cat((aligned_backward_h_p1, spatial_h, zeros_h.cuda()), dim=1),
268 | flow_p1, zeros_flow.cuda())
269 | backward_h = self.backbone['backward_1'](torch.cat((spatial_h, feat_prop, hs['forward_1'][-2].cuda()), dim=1))
270 | hs['backward_1'][-2] = backward_h.cpu()
271 |
272 | for i in range(3, T + 1):
273 | spatial_h = hs['spatial'][T - i].cuda()
274 | flow_p1 = backward_flows[T - i].cuda()
275 | backward_h_p1 = backward_h
276 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1)
277 | flow_p2 = flow_p1 + warp(backward_flows[T - i + 1].cuda(), flow_p1)[0]
278 | backward_h_p2 = hs['backward_1'][T - i + 1].cuda()
279 | aligned_backward_h_p2, _ = warp(backward_h_p2, flow_p2)
280 | feat_prop = self.deform_align['backward_1'](torch.cat((backward_h_p1, backward_h_p2), dim=1),
281 | torch.cat((aligned_backward_h_p1, spatial_h, backward_h_p2),
282 | dim=1), flow_p1, flow_p2)
283 | backward_h = self.backbone['backward_1'](
284 | torch.cat((spatial_h, feat_prop, hs['forward_1'][T - i].cuda()), dim=1))
285 | hs['backward_1'][T - i] = backward_h.cpu()
286 |
287 | # extract forward features
288 | spatial_h = hs['spatial'][0].cuda()
289 | forward_h = self.backbone['forward_2'](torch.cat((spatial_h, zeros_h.cuda(),
290 | hs['forward_1'][0].cuda(),
291 | hs['backward_1'][0].cuda()), dim=1))
292 | hs['forward_2'][0] = forward_h.cpu()
293 |
294 | spatial_h = hs['spatial'][1].cuda()
295 | flow_n1 = forward_flows[0].cuda()
296 | forward_h_n1 = forward_h
297 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1)
298 | feat_prop = self.deform_align['forward_2'](torch.cat((forward_h_n1, zeros_h.cuda()), dim=1),
299 | torch.cat((aligned_forward_h_n1, spatial_h, zeros_h.cuda()), dim=1),
300 | flow_n1, zeros_flow.cuda())
301 | forward_h = self.backbone['forward_2'](
302 | torch.cat((spatial_h, feat_prop, hs['forward_1'][1].cuda(), hs['backward_1'][1].cuda()), dim=1))
303 | hs['forward_2'][1] = forward_h.cpu()
304 |
305 | for i in range(2, T):
306 | spatial_h = hs['spatial'][i].cuda()
307 | flow_n1 = forward_flows[i - 1].cuda()
308 | forward_h_n1 = forward_h
309 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1)
310 | flow_n2 = flow_n1 + warp(forward_flows[i - 2].cuda(), flow_n1)[0]
311 | forward_h_n2 = hs['forward_2'][i - 2].cuda()
312 | aligned_forward_h_n2, _ = warp(forward_h_n2, flow_n2)
313 | feat_prop = self.deform_align['forward_2'](torch.cat((forward_h_n1, forward_h_n2), dim=1), torch.cat(
314 | (aligned_forward_h_n1, spatial_h, aligned_forward_h_n2), dim=1), flow_n1, flow_n2)
315 | forward_h = self.backbone['forward_2'](
316 | torch.cat((spatial_h, feat_prop, hs['forward_1'][i].cuda(), hs['backward_1'][i].cuda()), dim=1))
317 | hs['forward_2'][i] = forward_h.cpu()
318 |
319 | # extract backward features
320 | spatial_h = hs['spatial'][-1].cuda()
321 | backward_h = self.backbone['backward_2'](
322 | torch.cat((spatial_h, zeros_h.cuda(), hs['forward_1'][-1].cuda(), hs['backward_1'][-1].cuda(), hs['forward_2'][-1].cuda()),
323 | dim=1))
324 | hs['backward_2'][-1] = backward_h.cpu()
325 |
326 | spatial_h = hs['spatial'][-2].cuda()
327 | flow_p1 = backward_flows[-1].cuda()
328 | backward_h_p1 = backward_h
329 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1)
330 | feat_prop = self.deform_align['backward_2'](torch.cat((backward_h_p1, zeros_h.cuda()), dim=1),
331 | torch.cat((aligned_backward_h_p1, spatial_h, zeros_h.cuda()), dim=1),
332 | flow_p1, zeros_flow.cuda())
333 | backward_h = self.backbone['backward_2'](
334 | torch.cat((spatial_h, feat_prop, hs['forward_1'][-2].cuda(), hs['backward_1'][-2].cuda(), hs['forward_2'][-2].cuda()),
335 | dim=1))
336 | hs['backward_2'][-2] = backward_h.cpu()
337 |
338 | for i in range(3, T + 1):
339 | spatial_h = hs['spatial'][T - i].cuda()
340 | flow_p1 = backward_flows[T - i].cuda()
341 | backward_h_p1 = backward_h
342 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1)
343 | flow_p2 = flow_p1 + warp(backward_flows[T - i + 1].cuda(), flow_p1)[0]
344 | backward_h_p2 = hs['backward_2'][T - i + 1].cuda()
345 | aligned_backward_h_p2, _ = warp(backward_h_p2, flow_p2)
346 | feat_prop = self.deform_align['backward_2'](torch.cat((backward_h_p1, backward_h_p2), dim=1),
347 | torch.cat((aligned_backward_h_p1, spatial_h, backward_h_p2),
348 | dim=1), flow_p1, flow_p2)
349 | backward_h = self.backbone['backward_2'](torch.cat((spatial_h, feat_prop, hs['forward_1'][T - i].cuda(),
350 | hs['backward_1'][T - i].cuda(), hs['forward_2'][T - i].cuda()),
351 | dim=1))
352 | hs['backward_2'][T - i] = backward_h.cpu()
353 |
354 | # generate results
355 | for i in range(T):
356 | seqdn[:, i] = self.d(
357 | torch.cat((hs['forward_1'][i].cuda(), hs['backward_1'][i].cuda(), hs['forward_2'][i].cuda(), hs['backward_2'][i].cuda()), dim=1)).cpu()
358 |
359 | return seqdn
360 |
361 |
362 | class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
363 | """Second-order deformable alignment module.
364 | Args:
365 | in_channels (int): Same as nn.Conv2d.
366 | out_channels (int): Same as nn.Conv2d.
367 | kernel_size (int or tuple[int]): Same as nn.Conv2d.
368 | stride (int or tuple[int]): Same as nn.Conv2d.
369 | padding (int or tuple[int]): Same as nn.Conv2d.
370 | dilation (int or tuple[int]): Same as nn.Conv2d.
371 | groups (int): Same as nn.Conv2d.
372 | bias (bool or str): If specified as `auto`, it will be decided by the
373 | norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
374 | False.
375 | max_residue_magnitude (int): The maximum magnitude of the offset
376 | residue (Eq. 6 in paper). Default: 10.
377 | """
378 |
379 | def __init__(self, *args, **kwargs):
380 | self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
381 |
382 | super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
383 |
384 | self.conv_offset = nn.Sequential(
385 | nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
386 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
387 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
388 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
389 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
390 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
391 | nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
392 | )
393 |
394 | self.init_offset()
395 |
396 | def init_offset(self):
397 | constant_init(self.conv_offset[-1], val=0, bias=0)
398 |
399 | def forward(self, x, extra_feat, flow_1, flow_2):
400 | extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
401 | out = self.conv_offset(extra_feat)
402 | o1, o2, mask = torch.chunk(out, 3, dim=1)
403 |
404 | # offset
405 | offset = self.max_residue_magnitude * torch.tanh(
406 | torch.cat((o1, o2), dim=1))
407 | offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
408 | offset_1 = offset_1 + flow_1.flip(1).repeat(1,
409 | offset_1.size(1) // 2, 1,
410 | 1)
411 | offset_2 = offset_2 + flow_2.flip(1).repeat(1,
412 | offset_2.size(1) // 2, 1,
413 | 1)
414 | offset = torch.cat([offset_1, offset_2], dim=1)
415 |
416 | # mask
417 | mask = torch.sigmoid(mask)
418 |
419 | return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
420 | self.stride, self.padding,
421 | self.dilation, self.groups,
422 | self.deform_groups)
--------------------------------------------------------------------------------
/models/birnn.py:
--------------------------------------------------------------------------------
1 | from models.components import ResBlocks, D
2 | from pytorch_pwc.extract_flow import extract_flow_torch
3 | from pytorch_pwc.pwc import PWCNet
4 | import torch
5 | import torch.nn as nn
6 | from utils.warp import warp
7 |
8 | class BiRNN(nn.Module):
9 | def __init__(self, img_channels=3, num_resblocks=6, num_channels=64):
10 | super(BiRNN, self).__init__()
11 | self.num_channels = num_channels
12 | self.pwcnet = PWCNet()
13 | self.forward_rnn = ResBlocks(input_channels=img_channels + img_channels + num_channels, num_resblocks=num_resblocks, num_channels=num_channels)
14 | self.backward_rnn = ResBlocks(input_channels=img_channels + img_channels + num_channels, num_resblocks=num_resblocks, num_channels=num_channels)
15 | self.d = D(in_channels=num_channels * 2, mid_channels=num_channels * 2, out_channels=img_channels)
16 |
17 | def trainable_parameters(self):
18 | return [{'params':self.forward_rnn.parameters()}, {'params':self.backward_rnn.parameters()}, {'params':self.d.parameters()}]
19 |
20 | def forward(self, seqn, noise_level_map):
21 | if self.training:
22 | feature_device = torch.device('cuda')
23 | else:
24 | feature_device = torch.device('cpu')
25 | N, T, C, H, W = seqn.shape
26 | forward_hs = torch.empty((N, T, self.num_channels, H, W), device=feature_device)
27 | backward_hs = torch.empty((N, T, self.num_channels, H, W), device=feature_device)
28 | seqdn = torch.empty_like(seqn)
29 |
30 | # extract forward features
31 | init_forward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device)
32 | forward_h = self.forward_rnn(torch.cat((seqn[:, 0], noise_level_map[:, 0], init_forward_h), dim=1))
33 | forward_hs[:, 0] = forward_h.to(feature_device)
34 | for i in range(1, T):
35 | flow = extract_flow_torch(self.pwcnet, seqn[:, i], seqn[:, i-1])
36 | aligned_forward_h, _ = warp(forward_h, flow)
37 | forward_h = self.forward_rnn(torch.cat((seqn[:, i], noise_level_map[:, i], aligned_forward_h), dim=1))
38 | forward_hs[:, i] = forward_h.to(feature_device)
39 |
40 | # extract backward features
41 | init_backward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device)
42 | backward_h = self.backward_rnn(torch.cat((seqn[:, -1], noise_level_map[:, -1], init_backward_h), dim=1))
43 | backward_hs[:, -1] = backward_h.to(feature_device)
44 | for i in range(2, T+1):
45 | flow = extract_flow_torch(self.pwcnet, seqn[:, T-i], seqn[:, T-i+1])
46 | aligned_backward_h, _ = warp(backward_h, flow)
47 | backward_h = self.backward_rnn(torch.cat((seqn[:, T-i], noise_level_map[:, T-i], aligned_backward_h), dim=1))
48 | backward_hs[:, T-i] = backward_h.to(feature_device)
49 |
50 | # generate results
51 | for i in range(T):
52 | seqdn[:, i] = self.d(torch.cat((forward_hs[:, i].to(seqn.device), backward_hs[:, i].to(seqn.device)), dim=1))
53 |
54 | return seqdn
55 |
--------------------------------------------------------------------------------
/models/components.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from models.init import init_fn
3 | import torch
4 | import torch.nn as nn
5 |
6 | class ResBlock(nn.Module):
7 | def __init__(self, in_channels, mid_channels, out_channels):
8 | super(ResBlock, self).__init__()
9 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, stride=1, padding=1, bias=False)
10 | self.relu = nn.ReLU(inplace=True)
11 | self.conv2 = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
12 |
13 | def forward(self, x):
14 | output = self.conv2(self.relu(self.conv1(x)))
15 | output = torch.add(output, x)
16 | return output
17 |
18 | class ResBlocks(nn.Module):
19 | def __init__(self, input_channels, num_resblocks, num_channels):
20 | super(ResBlocks, self).__init__()
21 | self.input_channels = input_channels
22 | self.first_conv = nn.Conv2d(in_channels=self.input_channels, out_channels=num_channels, kernel_size=3, stride=1, padding=1, bias=False)
23 |
24 | modules = []
25 | for _ in range(num_resblocks):
26 | modules.append(ResBlock(in_channels=num_channels, mid_channels=num_channels, out_channels=num_channels))
27 | self.resblocks = nn.Sequential(*modules)
28 |
29 | fn = functools.partial(init_fn, init_type='kaiming_normal', init_bn_type='uniform', gain=0.2)
30 | self.apply(fn)
31 |
32 | def forward(self, h):
33 | shallow_feature = self.first_conv(h)
34 | new_h = self.resblocks(shallow_feature)
35 | return new_h
36 |
37 | class D(nn.Module):
38 | def __init__(self, in_channels, mid_channels, out_channels):
39 | super(D, self).__init__()
40 | layers = []
41 | layers.append(nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, stride=1, padding=1, bias=False))
42 | layers.append(nn.ReLU())
43 | layers.append(nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False))
44 | self.convs = nn.Sequential(*layers)
45 |
46 | fn = functools.partial(init_fn, init_type='kaiming_normal', init_bn_type='uniform', gain=0.2)
47 | self.apply(fn)
48 |
49 | def forward(self, x):
50 | x = self.convs(x)
51 | return x
--------------------------------------------------------------------------------
/models/flornn.py:
--------------------------------------------------------------------------------
1 | from models.components import ResBlocks, D
2 | from pytorch_pwc.extract_flow import extract_flow_torch
3 | from pytorch_pwc.pwc import PWCNet
4 | from softmax_splatting.softsplat import FunctionSoftsplat
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from utils.warp import warp
9 |
10 | def expand(ten, size_h, size_w, value=0):
11 | return F.pad(ten, pad=[size_w, size_w, size_h, size_h], mode='constant', value=value)
12 |
13 | def split_border(ten, size_h, size_w):
14 | img = ten[:, :, size_h:-size_h, size_w:-size_w]
15 | return img, ten
16 |
17 | def merge_border(img, border, size_h, size_w):
18 | expanded_img = F.pad(img, pad=[size_w, size_w, size_h, size_h], mode='constant')
19 | expanded_img[:, :, :size_h, :] = border[:, :, :size_h, :]
20 | expanded_img[:, :, -size_h:, :] = border[:, :, -size_h:, :]
21 | expanded_img[:, :, :, :size_w] = border[:, :, :, :size_w]
22 | expanded_img[:, :, :, -size_w:] = border[:, :, :, -size_w:]
23 | return expanded_img
24 |
25 | class FloRNN(nn.Module):
26 | def __init__(self, img_channels, num_resblocks=6, num_channels=64, forward_count=2, border_ratio=0.3):
27 | super(FloRNN, self).__init__()
28 | self.num_channels = num_channels
29 | self.forward_count = forward_count
30 | self.pwcnet = PWCNet()
31 | self.forward_rnn = ResBlocks(input_channels=img_channels + img_channels + num_channels, num_resblocks=num_resblocks, num_channels=num_channels)
32 | self.backward_rnn = ResBlocks(input_channels=img_channels + img_channels + num_channels, num_resblocks=num_resblocks, num_channels=num_channels)
33 | self.d = D(in_channels=num_channels * 2, mid_channels=num_channels * 2, out_channels=img_channels)
34 | self.border_ratio = border_ratio
35 |
36 | def trainable_parameters(self):
37 | return [{'params':self.forward_rnn.parameters()}, {'params':self.backward_rnn.parameters()}, {'params':self.d.parameters()}]
38 |
39 | def forward(self, seqn_not_pad, noise_level_map_not_pad):
40 | N, T, C, H, W = seqn_not_pad.shape
41 | seqdn = torch.empty_like(seqn_not_pad)
42 | expanded_forward_flow_queue = []
43 | border_queue = []
44 | size_h, size_w = int(H * self.border_ratio), int(W * self.border_ratio)
45 |
46 | # reflect pad seqn and noise_level_map
47 | seqn = torch.empty((N, T+self.forward_count, C, H, W), device=seqn_not_pad.device)
48 | noise_level_map = torch.empty((N, T+self.forward_count, C, H, W), device=noise_level_map_not_pad.device)
49 | seqn[:, :T] = seqn_not_pad
50 | noise_level_map[:, :T] = noise_level_map_not_pad
51 | for i in range(self.forward_count):
52 | seqn[:, T+i] = seqn_not_pad[:, T-2-i]
53 | noise_level_map[:, T+i] = noise_level_map_not_pad[:, T-2-i]
54 |
55 | init_backward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device)
56 | backward_h = self.backward_rnn(torch.cat((seqn[:, 0], noise_level_map[:, 0], init_backward_h), dim=1))
57 | init_forward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device)
58 | forward_h = self.forward_rnn(torch.cat((seqn[:, 0], noise_level_map[:, 0], init_forward_h), dim=1))
59 |
60 | for i in range(1, T+self.forward_count):
61 | forward_flow = extract_flow_torch(self.pwcnet, seqn[:, i-1], seqn[:, i])
62 |
63 | expanded_backward_h, expanded_forward_flow = expand(backward_h, size_h, size_w), expand(forward_flow, size_h, size_w)
64 | expanded_forward_flow_queue.append(expanded_forward_flow)
65 | aligned_expanded_backward_h = FunctionSoftsplat(expanded_backward_h, expanded_forward_flow, None, 'average')
66 | aligned_backward_h, border = split_border(aligned_expanded_backward_h, size_h, size_w)
67 | border_queue.append(border)
68 |
69 | backward_h = self.backward_rnn(torch.cat((seqn[:, i], noise_level_map[:, i], aligned_backward_h), dim=1))
70 |
71 | if i >= self.forward_count:
72 | aligned_backward_h = backward_h
73 | for j in reversed(range(self.forward_count)):
74 | aligned_backward_h = merge_border(aligned_backward_h, border_queue[j], size_h, size_w)
75 | aligned_backward_h, _ = warp(aligned_backward_h, expanded_forward_flow_queue[j])
76 | aligned_backward_h, _ = split_border(aligned_backward_h, size_h, size_w)
77 |
78 | seqdn[:, i - self.forward_count] = self.d(torch.cat((forward_h, aligned_backward_h), dim=1))
79 |
80 | backward_flow = extract_flow_torch(self.pwcnet, seqn[:, i-self.forward_count+1], seqn[:, i-self.forward_count])
81 | aligned_forward_h, _ = warp(forward_h, backward_flow)
82 | forward_h = self.forward_rnn(torch.cat((seqn[:, i-self.forward_count+1], noise_level_map[:, i-self.forward_count+1], aligned_forward_h), dim=1))
83 | expanded_forward_flow_queue.pop(0)
84 | border_queue.pop(0)
85 |
86 | return seqdn
87 |
88 |
--------------------------------------------------------------------------------
/models/flornn_raw.py:
--------------------------------------------------------------------------------
1 | from models.components import ResBlocks, D
2 | from pytorch_pwc.extract_flow import extract_flow_torch
3 | from pytorch_pwc.pwc import PWCNet
4 | from softmax_splatting.softsplat import FunctionSoftsplat
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from utils.raw import demosaic
9 | from utils.warp import warp
10 |
11 | def expand(ten, size_h, size_w, value=0):
12 | return F.pad(ten, pad=[size_w, size_w, size_h, size_h], mode='constant', value=value)
13 |
14 | def split_border(ten, size_h, size_w):
15 | img = ten[:, :, size_h:-size_h, size_w:-size_w]
16 | return img, ten
17 |
18 | def merge_border(img, border, size_h, size_w):
19 | expanded_img = F.pad(img, pad=[size_w, size_w, size_h, size_h], mode='constant')
20 | expanded_img[:, :, :size_h, :] = border[:, :, :size_h, :]
21 | expanded_img[:, :, -size_h:, :] = border[:, :, -size_h:, :]
22 | expanded_img[:, :, :, :size_w] = border[:, :, :, :size_w]
23 | expanded_img[:, :, :, -size_w:] = border[:, :, :, -size_w:]
24 | return expanded_img
25 |
26 | class FloRNNRaw(nn.Module):
27 | def __init__(self, img_channels, num_resblocks=6, num_channels=64, forward_count=2, border_ratio=0.1):
28 | super(FloRNNRaw, self).__init__()
29 | self.num_channels = num_channels
30 | self.forward_count = forward_count
31 | self.pwcnet = PWCNet()
32 | self.forward_rnn = ResBlocks(input_channels=img_channels + 2 + num_channels, num_resblocks=num_resblocks, num_channels=num_channels)
33 | self.backward_rnn = ResBlocks(input_channels=img_channels + 2 + num_channels, num_resblocks=num_resblocks, num_channels=num_channels)
34 | self.d = D(in_channels=num_channels * 2, mid_channels=num_channels * 2, out_channels=img_channels)
35 | self.border_ratio = border_ratio
36 |
37 | def trainable_parameters(self):
38 | return [{'params':self.forward_rnn.parameters()}, {'params':self.backward_rnn.parameters()}, {'params':self.d.parameters()}]
39 |
40 | def forward(self, seqn_not_pad, a_not_pad, b_not_pad):
41 | N, T, C, H, W = seqn_not_pad.shape
42 | seqdn = torch.empty_like(seqn_not_pad)
43 | expanded_forward_flow_queue = []
44 | border_queue = []
45 | size_h, size_w = int(H * self.border_ratio), int(W * self.border_ratio)
46 |
47 | # reflect pad seqn and noise_level_map
48 | seqn = torch.empty((N, T+self.forward_count, C, H, W), device=seqn_not_pad.device)
49 | a = torch.empty((N, T + self.forward_count, 1, H, W), device=a_not_pad.device)
50 | b = torch.empty((N, T + self.forward_count, 1, H, W), device=b_not_pad.device)
51 | seqn[:, :T] = seqn_not_pad
52 | a[:, :T] = a_not_pad
53 | b[:, :T] = b_not_pad
54 | for i in range(self.forward_count):
55 | seqn[:, T+i] = seqn_not_pad[:, T-2-i]
56 | a[:, T + i] = a_not_pad[:, T - 2 - i]
57 | b[:, T + i] = b_not_pad[:, T - 2 - i]
58 | srgb_seqn = demosaic(seqn)
59 |
60 | init_backward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device)
61 | backward_h = self.backward_rnn(torch.cat((seqn[:, 0], a[:, 0], b[:, 0], init_backward_h), dim=1))
62 | init_forward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device)
63 | forward_h = self.forward_rnn(torch.cat((seqn[:, 0], a[:, 0], b[:, 0], init_forward_h), dim=1))
64 |
65 | for i in range(1, T+self.forward_count):
66 | forward_flow = extract_flow_torch(self.pwcnet, srgb_seqn[:, i-1], srgb_seqn[:, i])
67 |
68 | expanded_backward_h, expanded_forward_flow = expand(backward_h, size_h, size_w), expand(forward_flow, size_h, size_w)
69 | expanded_forward_flow_queue.append(expanded_forward_flow)
70 | aligned_expanded_backward_h = FunctionSoftsplat(expanded_backward_h, expanded_forward_flow, None, 'average')
71 | aligned_backward_h, border = split_border(aligned_expanded_backward_h, size_h, size_w)
72 | border_queue.append(border)
73 |
74 | backward_h = self.backward_rnn(torch.cat((seqn[:, i], a[:, i], b[:, i], aligned_backward_h), dim=1))
75 |
76 | if i >= self.forward_count:
77 | aligned_backward_h = backward_h
78 | for j in reversed(range(self.forward_count)):
79 | aligned_backward_h = merge_border(aligned_backward_h, border_queue[j], size_h, size_w)
80 | aligned_backward_h, _ = warp(aligned_backward_h, expanded_forward_flow_queue[j])
81 | aligned_backward_h, _ = split_border(aligned_backward_h, size_h, size_w)
82 |
83 | seqdn[:, i - self.forward_count] = self.d(torch.cat((forward_h, aligned_backward_h), dim=1))
84 |
85 | backward_flow = extract_flow_torch(self.pwcnet, srgb_seqn[:, i-self.forward_count+1], srgb_seqn[:, i-self.forward_count])
86 | aligned_forward_h, _ = warp(forward_h, backward_flow)
87 | forward_h = self.forward_rnn(torch.cat((seqn[:, i-self.forward_count+1], a[:, i-self.forward_count+1], b[:, i-self.forward_count+1], aligned_forward_h), dim=1))
88 | expanded_forward_flow_queue.pop(0)
89 | border_queue.pop(0)
90 |
91 | return seqdn
92 |
93 |
--------------------------------------------------------------------------------
/models/forwardrnn.py:
--------------------------------------------------------------------------------
1 | from models.components import ResBlocks, D
2 | from pytorch_pwc.extract_flow import extract_flow_torch
3 | from pytorch_pwc.pwc import PWCNet
4 | import torch
5 | import torch.nn as nn
6 | from utils.warp import warp
7 |
8 | class ForwardRNN(nn.Module):
9 | def __init__(self, img_channels=3, num_resblocks=6, num_channels=64):
10 | super(ForwardRNN, self).__init__()
11 | self.num_channels = num_channels
12 | self.pwcnet = PWCNet()
13 | self.forward_rnn = ResBlocks(input_channels=img_channels + img_channels + num_channels, num_resblocks=num_resblocks, num_channels=num_channels)
14 | self.d = D(in_channels=num_channels, mid_channels=num_channels, out_channels=img_channels)
15 |
16 | def trainable_parameters(self):
17 | return [{'params':self.forward_rnn.parameters()}, {'params':self.d.parameters()}]
18 |
19 | def forward(self, seqn, noise_level_map):
20 | N, T, C, H, W = seqn.shape
21 | seqdn = torch.empty_like(seqn)
22 |
23 | init_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device)
24 | h = self.forward_rnn(torch.cat((seqn[:, 0], noise_level_map[:, 0], init_h), dim=1))
25 | seqdn[:, 0] = self.d(h)
26 |
27 | for i in range(1, T):
28 | flow = extract_flow_torch(self.pwcnet, seqn[:, i], seqn[:, i-1])
29 | aligned_h, _ = warp(h, flow)
30 | h = self.forward_rnn(torch.cat((seqn[:, i], noise_level_map[:, i], aligned_h), dim=1))
31 | seqdn[:, i] = self.d(h)
32 |
33 | return seqdn
34 |
--------------------------------------------------------------------------------
/models/init.py:
--------------------------------------------------------------------------------
1 | """
2 | # --------------------------------------------
3 | # weights initialization
4 | # --------------------------------------------
5 | """
6 | from torch.nn import init
7 |
8 | """
9 | # Kai Zhang, https://github.com/cszn/KAIR
10 | #
11 | # Args:
12 | # init_type:
13 | # normal; normal; xavier_normal; xavier_uniform;
14 | # kaiming_normal; kaiming_uniform; orthogonal
15 | # init_bn_type:
16 | # uniform; constant
17 | # gain:
18 | # 0.2
19 | """
20 |
21 | def init_fn(m, init_type='kaiming_normal', init_bn_type='uniform', gain=0.2):
22 | classname = m.__class__.__name__
23 |
24 | if classname.find('Conv') != -1 or classname.find('Linear') != -1:
25 |
26 | if init_type == 'normal':
27 | init.normal_(m.weight.data, 0, 0.1)
28 | m.weight.data.clamp_(-1, 1).mul_(gain)
29 |
30 | elif init_type == 'uniform':
31 | init.uniform_(m.weight.data, -0.2, 0.2)
32 | m.weight.data.mul_(gain)
33 |
34 | elif init_type == 'xavier_normal':
35 | init.xavier_normal_(m.weight.data, gain=gain)
36 | m.weight.data.clamp_(-1, 1)
37 |
38 | elif init_type == 'xavier_uniform':
39 | init.xavier_uniform_(m.weight.data, gain=gain)
40 |
41 | elif init_type == 'kaiming_normal':
42 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
43 | m.weight.data.clamp_(-1, 1).mul_(gain)
44 |
45 | elif init_type == 'kaiming_uniform':
46 | init.kaiming_uniform_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
47 | m.weight.data.mul_(gain)
48 |
49 | elif init_type == 'orthogonal':
50 | init.orthogonal_(m.weight.data, gain=gain)
51 |
52 | else:
53 | raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_type))
54 |
55 | if m.bias is not None:
56 | m.bias.data.zero_()
57 |
58 | elif classname.find('BatchNorm2d') != -1:
59 |
60 | if init_bn_type == 'uniform': # preferred
61 | if m.affine:
62 | init.uniform_(m.weight.data, 0.1, 1.0)
63 | init.constant_(m.bias.data, 0.0)
64 | elif init_bn_type == 'constant':
65 | if m.affine:
66 | init.constant_(m.weight.data, 1.0)
67 | init.constant_(m.bias.data, 0.0)
68 | else:
69 | raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_bn_type))
--------------------------------------------------------------------------------
/models/rvidenet/isp.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nagejacob/FloRNN/5419715af261bf1d619818baaf26708b81781f4a/models/rvidenet/isp.pth
--------------------------------------------------------------------------------
/models/rvidenet/isp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 |
5 | class ISP(nn.Module):
6 |
7 | def __init__(self):
8 | super(ISP, self).__init__()
9 |
10 | self.conv1_1 = nn.Conv2d(4, 32, kernel_size=3, stride=1, padding=1)
11 | self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
12 | self.pool1 = nn.MaxPool2d(kernel_size=2)
13 |
14 | self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
15 | self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
16 | self.pool2 = nn.MaxPool2d(kernel_size=2)
17 |
18 | self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
19 | self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
20 |
21 | self.upv4 = nn.ConvTranspose2d(128, 64, 2, stride=2)
22 | self.conv4_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
23 | self.conv4_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
24 |
25 | self.upv5 = nn.ConvTranspose2d(64, 32, 2, stride=2)
26 | self.conv5_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
27 | self.conv5_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
28 |
29 | self.conv6_1 = nn.Conv2d(32, 12, kernel_size=1, stride=1)
30 |
31 | def forward(self, x):
32 | conv1 = self.lrelu(self.conv1_1(x))
33 | conv1 = self.lrelu(self.conv1_2(conv1))
34 | pool1 = self.pool1(conv1)
35 |
36 | conv2 = self.lrelu(self.conv2_1(pool1))
37 | conv2 = self.lrelu(self.conv2_2(conv2))
38 | pool2 = self.pool1(conv2)
39 |
40 | conv3 = self.lrelu(self.conv3_1(pool2))
41 | conv3 = self.lrelu(self.conv3_2(conv3))
42 |
43 | up4 = self.upv4(conv3)
44 | up4 = torch.cat([up4, conv2], 1)
45 | conv4 = self.lrelu(self.conv4_1(up4))
46 | conv4 = self.lrelu(self.conv4_2(conv4))
47 |
48 | up5 = self.upv5(conv4)
49 | up5 = torch.cat([up5, conv1], 1)
50 | conv5 = self.lrelu(self.conv5_1(up5))
51 | conv5 = self.lrelu(self.conv5_2(conv5))
52 |
53 | conv6 = self.conv6_1(conv5)
54 | out = nn.functional.pixel_shuffle(conv6, 2)
55 | return out
56 |
57 | def _initialize_weights(self):
58 | for m in self.modules():
59 | if isinstance(m, nn.Conv2d):
60 | m.weight.data.normal_(0.0, 0.02)
61 | if m.bias is not None:
62 | m.bias.data.normal_(0.0, 0.02)
63 | if isinstance(m, nn.ConvTranspose2d):
64 | m.weight.data.normal_(0.0, 0.02)
65 |
66 | def lrelu(self, x):
67 | outt = torch.max(0.2 * x, x)
68 | return outt
69 |
70 |
71 | def initialize_weights(net_l, scale=1):
72 | if not isinstance(net_l, list):
73 | net_l = [net_l]
74 | for net in net_l:
75 | for m in net.modules():
76 | if isinstance(m, nn.Conv2d):
77 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
78 | m.weight.data *= scale
79 | if m.bias is not None:
80 | m.bias.data.zero_()
81 | elif isinstance(m, nn.Linear):
82 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
83 | m.weight.data *= scale
84 | if m.bias is not None:
85 | m.bias.data.zero_()
86 | elif isinstance(m, nn.BatchNorm2d):
87 | init.constant_(m.weight, 1)
88 | init.constant_(m.bias.data, 0.0)
89 |
90 |
91 | def make_layer(block, n_layers):
92 | layers = []
93 | for _ in range(n_layers):
94 | layers.append(block())
95 | return nn.Sequential(*layers)
--------------------------------------------------------------------------------
/pytorch_pwc/correlation/correlation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import torch
4 |
5 | import cupy
6 | import re
7 |
8 | kernel_Correlation_rearrange = '''
9 | extern "C" __global__ void kernel_Correlation_rearrange(
10 | const int n,
11 | const float* input,
12 | float* output
13 | ) {
14 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
15 |
16 | if (intIndex >= n) {
17 | return;
18 | }
19 |
20 | int intSample = blockIdx.z;
21 | int intChannel = blockIdx.y;
22 |
23 | float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
24 |
25 | __syncthreads();
26 |
27 | int intPaddedY = (intIndex / SIZE_3(input)) + 4;
28 | int intPaddedX = (intIndex % SIZE_3(input)) + 4;
29 | int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
30 |
31 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
32 | }
33 | '''
34 |
35 | kernel_Correlation_updateOutput = '''
36 | extern "C" __global__ void kernel_Correlation_updateOutput(
37 | const int n,
38 | const float* rbot0,
39 | const float* rbot1,
40 | float* top
41 | ) {
42 | extern __shared__ char patch_data_char[];
43 |
44 | float *patch_data = (float *)patch_data_char;
45 |
46 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
47 | int x1 = blockIdx.x + 4;
48 | int y1 = blockIdx.y + 4;
49 | int item = blockIdx.z;
50 | int ch_off = threadIdx.x;
51 |
52 | // Load 3D patch into shared shared memory
53 | for (int j = 0; j < 1; j++) { // HEIGHT
54 | for (int i = 0; i < 1; i++) { // WIDTH
55 | int ji_off = (j + i) * SIZE_3(rbot0);
56 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
57 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
58 | int idxPatchData = ji_off + ch;
59 | patch_data[idxPatchData] = rbot0[idx1];
60 | }
61 | }
62 | }
63 |
64 | __syncthreads();
65 |
66 | __shared__ float sum[32];
67 |
68 | // Compute correlation
69 | for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
70 | sum[ch_off] = 0;
71 |
72 | int s2o = top_channel % 9 - 4;
73 | int s2p = top_channel / 9 - 4;
74 |
75 | for (int j = 0; j < 1; j++) { // HEIGHT
76 | for (int i = 0; i < 1; i++) { // WIDTH
77 | int ji_off = (j + i) * SIZE_3(rbot0);
78 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
79 | int x2 = x1 + s2o;
80 | int y2 = y1 + s2p;
81 |
82 | int idxPatchData = ji_off + ch;
83 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
84 |
85 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
86 | }
87 | }
88 | }
89 |
90 | __syncthreads();
91 |
92 | if (ch_off == 0) {
93 | float total_sum = 0;
94 | for (int idx = 0; idx < 32; idx++) {
95 | total_sum += sum[idx];
96 | }
97 | const int sumelems = SIZE_3(rbot0);
98 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
99 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
100 | }
101 | }
102 | }
103 | '''
104 |
105 | kernel_Correlation_updateGradFirst = '''
106 | #define ROUND_OFF 50000
107 |
108 | extern "C" __global__ void kernel_Correlation_updateGradFirst(
109 | const int n,
110 | const int intSample,
111 | const float* rbot0,
112 | const float* rbot1,
113 | const float* gradOutput,
114 | float* gradFirst,
115 | float* gradSecond
116 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
117 | int n = intIndex % SIZE_1(gradFirst); // channels
118 | int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos
119 | int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos
120 |
121 | // round_off is a trick to enable integer division with ceil, even for negative numbers
122 | // We use a large offset, for the inner part not to become negative.
123 | const int round_off = ROUND_OFF;
124 | const int round_off_s1 = round_off;
125 |
126 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
127 | int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
128 | int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
129 |
130 | // Same here:
131 | int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
132 | int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
133 |
134 | float sum = 0;
135 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
136 | xmin = max(0,xmin);
137 | xmax = min(SIZE_3(gradOutput)-1,xmax);
138 |
139 | ymin = max(0,ymin);
140 | ymax = min(SIZE_2(gradOutput)-1,ymax);
141 |
142 | for (int p = -4; p <= 4; p++) {
143 | for (int o = -4; o <= 4; o++) {
144 | // Get rbot1 data:
145 | int s2o = o;
146 | int s2p = p;
147 | int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
148 | float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
149 |
150 | // Index offset for gradOutput in following loops:
151 | int op = (p+4) * 9 + (o+4); // index[o,p]
152 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
153 |
154 | for (int y = ymin; y <= ymax; y++) {
155 | for (int x = xmin; x <= xmax; x++) {
156 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
157 | sum += gradOutput[idxgradOutput] * bot1tmp;
158 | }
159 | }
160 | }
161 | }
162 | }
163 | const int sumelems = SIZE_1(gradFirst);
164 | const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4);
165 | gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
166 | } }
167 | '''
168 |
169 | kernel_Correlation_updateGradSecond = '''
170 | #define ROUND_OFF 50000
171 |
172 | extern "C" __global__ void kernel_Correlation_updateGradSecond(
173 | const int n,
174 | const int intSample,
175 | const float* rbot0,
176 | const float* rbot1,
177 | const float* gradOutput,
178 | float* gradFirst,
179 | float* gradSecond
180 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
181 | int n = intIndex % SIZE_1(gradSecond); // channels
182 | int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos
183 | int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos
184 |
185 | // round_off is a trick to enable integer division with ceil, even for negative numbers
186 | // We use a large offset, for the inner part not to become negative.
187 | const int round_off = ROUND_OFF;
188 | const int round_off_s1 = round_off;
189 |
190 | float sum = 0;
191 | for (int p = -4; p <= 4; p++) {
192 | for (int o = -4; o <= 4; o++) {
193 | int s2o = o;
194 | int s2p = p;
195 |
196 | //Get X,Y ranges and clamp
197 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
198 | int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
199 | int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
200 |
201 | // Same here:
202 | int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
203 | int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
204 |
205 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
206 | xmin = max(0,xmin);
207 | xmax = min(SIZE_3(gradOutput)-1,xmax);
208 |
209 | ymin = max(0,ymin);
210 | ymax = min(SIZE_2(gradOutput)-1,ymax);
211 |
212 | // Get rbot0 data:
213 | int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
214 | float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
215 |
216 | // Index offset for gradOutput in following loops:
217 | int op = (p+4) * 9 + (o+4); // index[o,p]
218 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
219 |
220 | for (int y = ymin; y <= ymax; y++) {
221 | for (int x = xmin; x <= xmax; x++) {
222 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
223 | sum += gradOutput[idxgradOutput] * bot0tmp;
224 | }
225 | }
226 | }
227 | }
228 | }
229 | const int sumelems = SIZE_1(gradSecond);
230 | const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4);
231 | gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
232 | } }
233 | '''
234 |
235 |
236 | def cupy_kernel(strFunction, objVariables):
237 | strKernel = globals()[strFunction]
238 |
239 | while True:
240 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
241 |
242 | if objMatch is None:
243 | break
244 | # end
245 |
246 | intArg = int(objMatch.group(2))
247 |
248 | strTensor = objMatch.group(4)
249 | intSizes = objVariables[strTensor].size()
250 |
251 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
252 | # end
253 |
254 | while True:
255 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
256 |
257 | if objMatch is None:
258 | break
259 | # end
260 |
261 | intArgs = int(objMatch.group(2))
262 | strArgs = objMatch.group(4).split(',')
263 |
264 | strTensor = strArgs[0]
265 | intStrides = objVariables[strTensor].stride()
266 | strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(
267 | intStrides[intArg]) + ')' for intArg in range(intArgs)]
268 |
269 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
270 | # end
271 |
272 | return strKernel
273 |
274 |
275 | # end
276 |
277 | @cupy.memoize(for_each_device=True)
278 | def cupy_launch(strFunction, strKernel):
279 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
280 |
281 |
282 | # end
283 |
284 | class _FunctionCorrelation(torch.autograd.Function):
285 | @staticmethod
286 | def forward(self, first, second):
287 | rbot0 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]])
288 | rbot1 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]])
289 |
290 | self.save_for_backward(first, second, rbot0, rbot1)
291 |
292 | assert (first.is_contiguous() == True)
293 | assert (second.is_contiguous() == True)
294 |
295 | output = first.new_zeros([first.shape[0], 81, first.shape[2], first.shape[3]])
296 |
297 | if first.is_cuda == True:
298 | n = first.shape[2] * first.shape[3]
299 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
300 | 'input': first,
301 | 'output': rbot0
302 | }))(
303 | grid=tuple([int((n + 16 - 1) / 16), first.shape[1], first.shape[0]]),
304 | block=tuple([16, 1, 1]),
305 | args=[n, first.data_ptr(), rbot0.data_ptr()]
306 | )
307 |
308 | n = second.shape[2] * second.shape[3]
309 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
310 | 'input': second,
311 | 'output': rbot1
312 | }))(
313 | grid=tuple([int((n + 16 - 1) / 16), second.shape[1], second.shape[0]]),
314 | block=tuple([16, 1, 1]),
315 | args=[n, second.data_ptr(), rbot1.data_ptr()]
316 | )
317 |
318 | n = output.shape[1] * output.shape[2] * output.shape[3]
319 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
320 | 'rbot0': rbot0,
321 | 'rbot1': rbot1,
322 | 'top': output
323 | }))(
324 | grid=tuple([output.shape[3], output.shape[2], output.shape[0]]),
325 | block=tuple([32, 1, 1]),
326 | shared_mem=first.shape[1] * 4,
327 | args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()]
328 | )
329 |
330 | elif first.is_cuda == False:
331 | raise NotImplementedError()
332 |
333 | # end
334 |
335 | return output
336 |
337 | # end
338 |
339 | @staticmethod
340 | def backward(self, gradOutput):
341 | first, second, rbot0, rbot1 = self.saved_tensors
342 |
343 | assert (gradOutput.is_contiguous() == True)
344 |
345 | gradFirst = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \
346 | self.needs_input_grad[0] == True else None
347 | gradSecond = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \
348 | self.needs_input_grad[1] == True else None
349 |
350 | if first.is_cuda == True:
351 | if gradFirst is not None:
352 | for intSample in range(first.shape[0]):
353 | n = first.shape[1] * first.shape[2] * first.shape[3]
354 | cupy_launch('kernel_Correlation_updateGradFirst',
355 | cupy_kernel('kernel_Correlation_updateGradFirst', {
356 | 'rbot0': rbot0,
357 | 'rbot1': rbot1,
358 | 'gradOutput': gradOutput,
359 | 'gradFirst': gradFirst,
360 | 'gradSecond': None
361 | }))(
362 | grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
363 | block=tuple([512, 1, 1]),
364 | args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(),
365 | gradFirst.data_ptr(), None]
366 | )
367 | # end
368 | # end
369 |
370 | if gradSecond is not None:
371 | for intSample in range(first.shape[0]):
372 | n = first.shape[1] * first.shape[2] * first.shape[3]
373 | cupy_launch('kernel_Correlation_updateGradSecond',
374 | cupy_kernel('kernel_Correlation_updateGradSecond', {
375 | 'rbot0': rbot0,
376 | 'rbot1': rbot1,
377 | 'gradOutput': gradOutput,
378 | 'gradFirst': None,
379 | 'gradSecond': gradSecond
380 | }))(
381 | grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
382 | block=tuple([512, 1, 1]),
383 | args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None,
384 | gradSecond.data_ptr()]
385 | )
386 | # end
387 | # end
388 |
389 | elif first.is_cuda == False:
390 | raise NotImplementedError()
391 |
392 | # end
393 |
394 | return gradFirst, gradSecond
395 |
396 |
397 | # end
398 | # end
399 |
400 | def FunctionCorrelation(tenFirst, tenSecond):
401 | return _FunctionCorrelation.apply(tenFirst, tenSecond)
402 |
403 |
404 | # end
405 |
406 | class ModuleCorrelation(torch.nn.Module):
407 | def __init__(self):
408 | super(ModuleCorrelation, self).__init__()
409 |
410 | # end
411 |
412 | def forward(self, tenFirst, tenSecond):
413 | return _FunctionCorrelation.apply(tenFirst, tenSecond)
414 | # end
415 | # end
--------------------------------------------------------------------------------
/pytorch_pwc/extract_flow.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | # im1_torch, im2_torch in shape (N, C, H, W)
5 | def extract_flow_torch(model, im1_torch, im2_torch):
6 | # interpolate image, make new_H, mew_W divide by 64
7 | assert im1_torch.shape == im2_torch.shape
8 | N, C, H, W = im1_torch.shape
9 | device = im1_torch.device
10 | new_H = int(math.floor(math.ceil(H / 64.0) * 64.0))
11 | new_W = int(math.floor(math.ceil(W / 64.0) * 64.0))
12 | im1_torch = torch.nn.functional.interpolate(input=im1_torch, size=(new_H, new_W), mode='bilinear',
13 | align_corners=False)
14 | im2_torch = torch.nn.functional.interpolate(input=im2_torch, size=(new_H, new_W), mode='bilinear',
15 | align_corners=False)
16 | model.eval()
17 | with torch.no_grad():
18 | flo12 = model(im1_torch, im2_torch)
19 | flo12 = 20.0 * torch.nn.functional.interpolate(input=flo12, size=(H, W), mode='bilinear',
20 | align_corners=False)
21 | flo12[:, 0, :, :] *= float(W) / float(new_W)
22 | flo12[:, 1, :, :] *= float(H) / float(new_H)
23 | return flo12
24 |
25 | # im1_np, im2_np in shape (C, H, W)
26 | def extract_flow_np(model, im1_np, im2_np):
27 | im1_torch = torch.from_numpy(im1_np).unsqueeze(0).to(torch.device('cuda'))
28 | im2_torch = torch.from_numpy(im2_np).unsqueeze(0).to(torch.device('cuda'))
29 | flo12_torch = extract_flow_torch(model, im1_torch, im2_torch)
30 | flo12_np = flo12_torch.detach().cpu().squeeze(0).numpy()
31 | return flo12_np
32 |
33 |
--------------------------------------------------------------------------------
/pytorch_pwc/pwc.py:
--------------------------------------------------------------------------------
1 | from .correlation import correlation # the custom cost volume layer
2 | import torch
3 |
4 | ##########################################################
5 |
6 | assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0
7 |
8 | # torch.set_grad_enabled(False)
9 |
10 | torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
11 |
12 | arguments_strModel = 'default'
13 |
14 | backwarp_tenGrid = {}
15 | backwarp_tenPartial = {}
16 |
17 | def backwarp(tenInput, tenFlow):
18 | if (str(tenFlow.shape)+str(tenFlow.device)) not in backwarp_tenGrid:
19 | tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1)
20 | tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3])
21 |
22 | backwarp_tenGrid[str(tenFlow.shape) + str(tenFlow.device)] = torch.cat([ tenHor, tenVer ], 1).to(tenFlow.device)
23 | # end
24 |
25 | if (str(tenFlow.shape)+str(tenFlow.device)) not in backwarp_tenPartial:
26 | backwarp_tenPartial[str(tenFlow.shape)+str(tenFlow.device)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ])
27 | # end
28 |
29 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
30 | tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)+str(tenFlow.device)] ], 1)
31 |
32 | tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape) + str(tenFlow.device)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False)
33 |
34 | tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0
35 |
36 | return tenOutput[:, :-1, :, :] * tenMask
37 | # end
38 |
39 | ##########################################################
40 |
41 | class PWCNet(torch.nn.Module):
42 | def __init__(self):
43 | super(PWCNet, self).__init__()
44 |
45 | class Extractor(torch.nn.Module):
46 | def __init__(self):
47 | super(Extractor, self).__init__()
48 |
49 | self.netOne = torch.nn.Sequential(
50 | torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
51 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
52 | torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
53 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
54 | torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
55 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
56 | )
57 |
58 | self.netTwo = torch.nn.Sequential(
59 | torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
60 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
61 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
62 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
63 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
64 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
65 | )
66 |
67 | self.netThr = torch.nn.Sequential(
68 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
69 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
70 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
71 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
72 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
73 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
74 | )
75 |
76 | self.netFou = torch.nn.Sequential(
77 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1),
78 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
79 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
80 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
81 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
82 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
83 | )
84 |
85 | self.netFiv = torch.nn.Sequential(
86 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1),
87 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
88 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
89 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
90 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
91 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
92 | )
93 |
94 | self.netSix = torch.nn.Sequential(
95 | torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1),
96 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
97 | torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
98 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
99 | torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
100 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
101 | )
102 | # end
103 |
104 | def forward(self, tenInput):
105 | tenOne = self.netOne(tenInput)
106 | tenTwo = self.netTwo(tenOne)
107 | tenThr = self.netThr(tenTwo)
108 | tenFou = self.netFou(tenThr)
109 | tenFiv = self.netFiv(tenFou)
110 | tenSix = self.netSix(tenFiv)
111 |
112 | return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ]
113 | # end
114 | # end
115 |
116 | class Decoder(torch.nn.Module):
117 | def __init__(self, intLevel):
118 | super(Decoder, self).__init__()
119 |
120 | intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1]
121 | intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0]
122 |
123 | if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1)
124 | if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1)
125 | if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1]
126 |
127 | self.netOne = torch.nn.Sequential(
128 | torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1),
129 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
130 | )
131 |
132 | self.netTwo = torch.nn.Sequential(
133 | torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1),
134 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
135 | )
136 |
137 | self.netThr = torch.nn.Sequential(
138 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1),
139 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
140 | )
141 |
142 | self.netFou = torch.nn.Sequential(
143 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1),
144 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
145 | )
146 |
147 | self.netFiv = torch.nn.Sequential(
148 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1),
149 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
150 | )
151 |
152 | self.netSix = torch.nn.Sequential(
153 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1)
154 | )
155 | # end
156 |
157 | def forward(self, tenFirst, tenSecond, objPrevious):
158 | tenFlow = None
159 | tenFeat = None
160 |
161 | if objPrevious is None:
162 | tenFlow = None
163 | tenFeat = None
164 | tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False)
165 |
166 | tenFeat = torch.cat([ tenVolume ], 1)
167 |
168 | elif objPrevious is not None:
169 | tenFlow = self.netUpflow(objPrevious['tenFlow'])
170 | tenFeat = self.netUpfeat(objPrevious['tenFeat'])
171 |
172 | tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False)
173 |
174 | tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1)
175 |
176 | # end
177 |
178 | tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1)
179 | tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1)
180 | tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1)
181 | tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1)
182 | tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1)
183 |
184 | tenFlow = self.netSix(tenFeat)
185 |
186 | return {
187 | 'tenFlow': tenFlow,
188 | 'tenFeat': tenFeat
189 | }
190 | # end
191 | # end
192 |
193 | class Refiner(torch.nn.Module):
194 | def __init__(self):
195 | super(Refiner, self).__init__()
196 |
197 | self.netMain = torch.nn.Sequential(
198 | torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1),
199 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
200 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2),
201 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
202 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4),
203 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
204 | torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8),
205 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
206 | torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16),
207 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
208 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1),
209 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
210 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1)
211 | )
212 | # end
213 |
214 | def forward(self, tenInput):
215 | return self.netMain(tenInput)
216 | # end
217 | # end
218 |
219 | self.netExtractor = Extractor()
220 |
221 | self.netTwo = Decoder(2)
222 | self.netThr = Decoder(3)
223 | self.netFou = Decoder(4)
224 | self.netFiv = Decoder(5)
225 | self.netSix = Decoder(6)
226 |
227 | self.netRefiner = Refiner()
228 |
229 | self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-pwc/network-' + arguments_strModel + '.pytorch', file_name='pwc-' + arguments_strModel).items() })
230 | # end
231 |
232 | def forward(self, tenFirst, tenSecond):
233 | tenFirst = self.netExtractor(tenFirst)
234 | tenSecond = self.netExtractor(tenSecond)
235 |
236 | objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None)
237 | objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate)
238 | objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate)
239 | objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate)
240 | objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate)
241 |
242 | return objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat'])
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Unidirectional Video Denoising by Mimicking Backward Recurrent Modules with Look-ahead Forward Ones
2 | This source code for our paper "Unidirectional Video Denoising by Mimicking Backward Recurrent Modules with Look-ahead Forward Ones" (ECCV 2022)
3 | 
4 |
5 | ## Usage
6 | ### Dependencies
7 | You can create a conda environment with all the dependencies by running
8 |
9 | ```conda env create -f requirements.yaml -n ```
10 |
11 | ### Datasets
12 | For synthetic gaussian noise, [DAVIS-2017-trainval-480p](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip) dataset is used for training,
13 | [DAVIS-2017-test-dev-480p](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-test-dev-480p.zip) and [Set8](https://www.dropbox.com/sh/20n4cscqkqsfgoj/AABGftyJuJDwuCLGczL-fKvBa/test_sequences?dl=0&subfolder_nav_tracking=1) are used for testing.
14 | For real world raw noise, [CRVD](https://github.com/cao-cong/RViDeNet#captured-raw-video-denoising-dataset-crvd-dataset) dataset is used for training and testing.
15 |
16 | ### Testing
17 | Download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1A854tOA6_qB14ax3JZ7bb7tLo0UovkyI?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1YomcegvdtoxVPr96odCo8w?pwd=ogua).
18 | We also provide denoised results (tractor from DAVIS-2017-test-dev-480p) for visual comparison.
19 | 1. For synthetic gaussian noise,
20 | ```
21 | cd test_models
22 | python sRGB_test.py \
23 | --model_file \
24 | --test_path
25 | ```
26 | 2. For real world raw noise,
27 | ```
28 | cd test_models
29 | python CRVD_test.py \
30 | --model_file \
31 | --test_path
32 | ```
33 |
34 | ### Training
35 | 1. For synthetic gaussian noise,
36 | ```
37 | cd train_models
38 | python sRGB_train.py \
39 | --trainset_dir \
40 | --valset_dir \
41 | --log_dir
42 | ```
43 | 2. For real world raw noise,
44 | ```
45 | cd train_models
46 | python CRVD_train.py \
47 | --CRVD_dir \
48 | --log_dir
49 | ```
50 | 3. For distributed training of synthetic gaussian noise,
51 | ```
52 | cd train_models
53 | python -m torch.distributed.launch --nproc_per_node=4 sRGB_train_distributed.py \
54 | --trainset_dir \
55 | --valset_dir \
56 | --log_dir
57 | ```
58 |
59 | ## Citation
60 |
61 | If you find our work useful in your research or publication, please cite:
62 | ```
63 | @article{li2022unidirectional,
64 | title={Unidirectional Video Denoising by Mimicking Backward Recurrent Modules with Look-ahead Forward Ones},
65 | author={Li, Junyi and Wu, Xiaohe and Niu, Zhenxing and Zuo, Wangmeng},
66 | booktitle={ECCV},
67 | year={2022}
68 | }
69 | ```
70 |
--------------------------------------------------------------------------------
/requirements.yaml:
--------------------------------------------------------------------------------
1 | name: video_denoising
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - cudatoolkit=10.2.89
8 | - cupy=8.2.0
9 | - imageio=2.9.0
10 | - mkl=2019.4
11 | - more-itertools=8.6.0
12 | - opencv=3.4.2
13 | - pip=20.0.2
14 | - pypng=0.0.20
15 | - python=3.7
16 | - pytorch=1.7.0
17 | - scikit-image=0.16.2
18 | - scipy=1.5.2
19 | - torchvision=0.8.1
20 | - pip:
21 | - future==0.18.2
22 | - tensorboardx==2.0
--------------------------------------------------------------------------------
/softmax_splatting/softsplat.py:
--------------------------------------------------------------------------------
1 | # borrowed from https://github.com/sniklaus/softmax-splatting
2 |
3 | import torch
4 | import cupy
5 | import re
6 |
7 | kernel_Softsplat_updateOutput = '''
8 | extern "C" __global__ void kernel_Softsplat_updateOutput(
9 | const int n,
10 | const float* input,
11 | const float* flow,
12 | float* output
13 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
14 | const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output);
15 | const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output);
16 | const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output);
17 | const int intX = ( intIndex ) % SIZE_3(output);
18 |
19 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
20 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);
21 |
22 | int intNorthwestX = (int) (floor(fltOutputX));
23 | int intNorthwestY = (int) (floor(fltOutputY));
24 | int intNortheastX = intNorthwestX + 1;
25 | int intNortheastY = intNorthwestY;
26 | int intSouthwestX = intNorthwestX;
27 | int intSouthwestY = intNorthwestY + 1;
28 | int intSoutheastX = intNorthwestX + 1;
29 | int intSoutheastY = intNorthwestY + 1;
30 |
31 | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY);
32 | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY);
33 | float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY));
34 | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY));
35 |
36 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) {
37 | atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest);
38 | }
39 |
40 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) {
41 | atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast);
42 | }
43 |
44 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) {
45 | atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest);
46 | }
47 |
48 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) {
49 | atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast);
50 | }
51 | } }
52 | '''
53 |
54 | kernel_Softsplat_updateGradInput = '''
55 | extern "C" __global__ void kernel_Softsplat_updateGradInput(
56 | const int n,
57 | const float* input,
58 | const float* flow,
59 | const float* gradOutput,
60 | float* gradInput,
61 | float* gradFlow
62 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
63 | const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput);
64 | const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput);
65 | const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput);
66 | const int intX = ( intIndex ) % SIZE_3(gradInput);
67 |
68 | float fltGradInput = 0.0;
69 |
70 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
71 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);
72 |
73 | int intNorthwestX = (int) (floor(fltOutputX));
74 | int intNorthwestY = (int) (floor(fltOutputY));
75 | int intNortheastX = intNorthwestX + 1;
76 | int intNortheastY = intNorthwestY;
77 | int intSouthwestX = intNorthwestX;
78 | int intSouthwestY = intNorthwestY + 1;
79 | int intSoutheastX = intNorthwestX + 1;
80 | int intSoutheastY = intNorthwestY + 1;
81 |
82 | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY);
83 | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY);
84 | float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY));
85 | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY));
86 |
87 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) {
88 | fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
89 | }
90 |
91 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) {
92 | fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
93 | }
94 |
95 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) {
96 | fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
97 | }
98 |
99 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) {
100 | fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
101 | }
102 |
103 | gradInput[intIndex] = fltGradInput;
104 | } }
105 | '''
106 |
107 | kernel_Softsplat_updateGradFlow = '''
108 | extern "C" __global__ void kernel_Softsplat_updateGradFlow(
109 | const int n,
110 | const float* input,
111 | const float* flow,
112 | const float* gradOutput,
113 | float* gradInput,
114 | float* gradFlow
115 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
116 | float fltGradFlow = 0.0;
117 |
118 | const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow);
119 | const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow);
120 | const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow);
121 | const int intX = ( intIndex ) % SIZE_3(gradFlow);
122 |
123 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
124 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);
125 |
126 | int intNorthwestX = (int) (floor(fltOutputX));
127 | int intNorthwestY = (int) (floor(fltOutputY));
128 | int intNortheastX = intNorthwestX + 1;
129 | int intNortheastY = intNorthwestY;
130 | int intSouthwestX = intNorthwestX;
131 | int intSouthwestY = intNorthwestY + 1;
132 | int intSoutheastX = intNorthwestX + 1;
133 | int intSoutheastY = intNorthwestY + 1;
134 |
135 | float fltNorthwest = 0.0;
136 | float fltNortheast = 0.0;
137 | float fltSouthwest = 0.0;
138 | float fltSoutheast = 0.0;
139 |
140 | if (intC == 0) {
141 | fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY);
142 | fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY);
143 | fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY));
144 | fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY));
145 |
146 | } else if (intC == 1) {
147 | fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (-1.0));
148 | fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0));
149 | fltSouthwest = ((float) (intNortheastX) - fltOutputX) * ((float) (+1.0));
150 | fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0));
151 |
152 | }
153 |
154 | for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) {
155 | float fltInput = VALUE_4(input, intN, intChannel, intY, intX);
156 |
157 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) {
158 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest;
159 | }
160 |
161 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) {
162 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast;
163 | }
164 |
165 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) {
166 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest;
167 | }
168 |
169 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) {
170 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast;
171 | }
172 | }
173 |
174 | gradFlow[intIndex] = fltGradFlow;
175 | } }
176 | '''
177 |
178 | def cupy_kernel(strFunction, objVariables):
179 | strKernel = globals()[strFunction]
180 |
181 | while True:
182 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
183 |
184 | if objMatch is None:
185 | break
186 | # end
187 |
188 | intArg = int(objMatch.group(2))
189 |
190 | strTensor = objMatch.group(4)
191 | intSizes = objVariables[strTensor].size()
192 |
193 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
194 | # end
195 |
196 | while True:
197 | objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)
198 |
199 | if objMatch is None:
200 | break
201 | # end
202 |
203 | intArgs = int(objMatch.group(2))
204 | strArgs = objMatch.group(4).split(',')
205 |
206 | strTensor = strArgs[0]
207 | intStrides = objVariables[strTensor].stride()
208 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
209 |
210 | strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
211 | # end
212 |
213 | while True:
214 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
215 |
216 | if objMatch is None:
217 | break
218 | # end
219 |
220 | intArgs = int(objMatch.group(2))
221 | strArgs = objMatch.group(4).split(',')
222 |
223 | strTensor = strArgs[0]
224 | intStrides = objVariables[strTensor].stride()
225 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
226 |
227 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
228 | # end
229 |
230 | return strKernel
231 | # end
232 |
233 | @cupy.memoize(for_each_device=True)
234 | def cupy_launch(strFunction, strKernel):
235 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
236 | # end
237 |
238 | class _FunctionSoftsplat(torch.autograd.Function):
239 | @staticmethod
240 | def forward(self, input, flow):
241 | self.save_for_backward(input, flow)
242 |
243 | intSamples = input.shape[0]
244 | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3]
245 | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3]
246 |
247 | assert(intFlowDepth == 2)
248 | assert(intInputHeight == intFlowHeight)
249 | assert(intInputWidth == intFlowWidth)
250 |
251 | input = input.contiguous(); assert(input.is_cuda == True)
252 | flow = flow.contiguous(); assert(flow.is_cuda == True)
253 |
254 | output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ])
255 |
256 | if input.is_cuda == True:
257 | n = output.nelement()
258 | cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', {
259 | 'input': input,
260 | 'flow': flow,
261 | 'output': output
262 | }))(
263 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
264 | block=tuple([ 512, 1, 1 ]),
265 | args=[ n, input.data_ptr(), flow.data_ptr(), output.data_ptr() ]
266 | )
267 |
268 | elif input.is_cuda == False:
269 | raise NotImplementedError()
270 |
271 | # end
272 |
273 | return output
274 | # end
275 |
276 | @staticmethod
277 | def backward(self, gradOutput):
278 | input, flow = self.saved_tensors
279 |
280 | intSamples = input.shape[0]
281 | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3]
282 | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3]
283 |
284 | assert(intFlowDepth == 2)
285 | assert(intInputHeight == intFlowHeight)
286 | assert(intInputWidth == intFlowWidth)
287 |
288 | gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True)
289 |
290 | gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if self.needs_input_grad[0] == True else None
291 | gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ]) if self.needs_input_grad[1] == True else None
292 |
293 | if input.is_cuda == True:
294 | if gradInput is not None:
295 | n = gradInput.nelement()
296 | cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', {
297 | 'input': input,
298 | 'flow': flow,
299 | 'gradOutput': gradOutput,
300 | 'gradInput': gradInput,
301 | 'gradFlow': gradFlow
302 | }))(
303 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
304 | block=tuple([ 512, 1, 1 ]),
305 | args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ]
306 | )
307 | # end
308 |
309 | if gradFlow is not None:
310 | n = gradFlow.nelement()
311 | cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', {
312 | 'input': input,
313 | 'flow': flow,
314 | 'gradOutput': gradOutput,
315 | 'gradInput': gradInput,
316 | 'gradFlow': gradFlow
317 | }))(
318 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
319 | block=tuple([ 512, 1, 1 ]),
320 | args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ]
321 | )
322 | # end
323 |
324 | elif input.is_cuda == False:
325 | raise NotImplementedError()
326 |
327 | # end
328 |
329 | return gradInput, gradFlow
330 | # end
331 | # end
332 |
333 | def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType):
334 | assert(tenMetric is None or tenMetric.shape[1] == 1)
335 | assert(strType in ['summation', 'average', 'linear', 'softmax'])
336 |
337 | if strType == 'average':
338 | tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1)
339 |
340 | elif strType == 'linear':
341 | tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1)
342 |
343 | elif strType == 'softmax':
344 | tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1)
345 |
346 | # end
347 |
348 | tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow)
349 |
350 | if strType != 'summation':
351 | tenNormalize = tenOutput[:, -1:, :, :]
352 |
353 | tenNormalize[tenNormalize == 0.0] = 1.0
354 |
355 | tenOutput = tenOutput[:, :-1, :, :] / tenNormalize
356 | # end
357 |
358 | return tenOutput
359 | # end
360 |
361 | class ModuleSoftsplat(torch.nn.Module):
362 | def __init__(self, strType):
363 | super(ModuleSoftsplat, self).__init__()
364 |
365 | self.strType = strType
366 | # end
367 |
368 | def forward(self, tenInput, tenFlow, tenMetric=None):
369 | return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType)
370 | # end
371 | # end
372 |
--------------------------------------------------------------------------------
/test_models/CRVD_test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 | import argparse
4 | from datasets import CRVDTestDataset
5 | import numpy as np
6 | import os
7 | from models import ISP, FloRNNRaw
8 | from skimage.measure.simple_metrics import compare_psnr
9 | from skimage.metrics import structural_similarity
10 | import torch
11 | import torch.nn as nn
12 | from utils.io import np2image_bgr
13 |
14 | def raw_ssim(pack1, pack2):
15 | test_raw_ssim = 0
16 | for i in range(4):
17 | test_raw_ssim += structural_similarity(pack1[i], pack2[i], data_range=1.0)
18 | return test_raw_ssim / 4
19 |
20 | def denoise_seq(seqn, a, b, model):
21 | T, C, H, W = seqn.shape
22 | a = a.expand((1, T, 1, H, W)).cuda()
23 | b = b.expand((1, T, 1, H, W)).cuda()
24 | seqdn = model(seqn.unsqueeze(0), a, b)[0]
25 | seqdn = torch.clamp(seqdn, 0, 1)
26 | return seqdn
27 |
28 | def main(**args):
29 | dataset_val = CRVDTestDataset(CRVD_path=args['crvd_dir'])
30 | isp = ISP().cuda()
31 | isp.load_state_dict(torch.load(args['isp_path'])['state_dict'])
32 |
33 | if args['model'] == 'FloRNNRaw':
34 | model = FloRNNRaw(img_channels=4, num_resblocks=args['num_resblocks'], forward_count=args['forward_count'], border_ratio=args['border_ratio'])
35 |
36 | state_temp_dict = torch.load(args['model_file'])['state_dict']
37 | model = nn.DataParallel(model).cuda()
38 | model.load_state_dict(state_temp_dict)
39 | model.eval()
40 |
41 | iso_psnr, iso_ssim = {}, {}
42 | for data in dataset_val:
43 |
44 | # our channels: RGGB, RViDeNet channels: RGBG. we must pass RGBG pack to ISP as it's pretrained by RViDeNet
45 | seq = data['seq'].cuda()
46 | seqn = data['seqn'].cuda()
47 |
48 | with torch.no_grad():
49 | seqdn = denoise_seq(seqn, data['a'], data['b'], model)
50 | seqn[:, 2:] = torch.flip(seqn[:, 2:], dims=[1])
51 | seqdn[:, 2:] = torch.flip(seqdn[:, 2:], dims=[1])
52 | seq[:, 2:] = torch.flip(seq[:, 2:], dims=[1])
53 |
54 | seq_raw_psnr, seq_srgb_psnr, seq_raw_ssim, seq_srgb_ssim = 0, 0, 0, 0
55 | for i in range(seq.shape[0]):
56 | gt_raw_frame = seq[i].cpu().numpy()
57 | denoised_raw_frame = (np.uint16(seqdn[i].cpu().numpy() * (2 ** 12 - 1 - 240) + 240).astype(np.float32) - 240) / (2 ** 12 - 1 - 240)
58 | with torch.no_grad():
59 | gt_srgb_frame = np.uint8(np.clip(isp(seq[i:i+1]).cpu().numpy()[0], 0, 1) * 255).astype(np.float32) / 255.
60 | denoised_srgb_frame = np.uint8(np.clip(isp(seqdn[i:i+1]).cpu().numpy()[0], 0, 1) * 255).astype(np.float32) / 255.
61 |
62 | seq_raw_psnr += compare_psnr(gt_raw_frame, denoised_raw_frame, data_range=1.0)
63 | seq_srgb_psnr += compare_psnr(gt_srgb_frame, denoised_srgb_frame, data_range=1.0)
64 | seq_raw_ssim += raw_ssim(gt_raw_frame, denoised_raw_frame)
65 | seq_srgb_ssim += structural_similarity(np.transpose(gt_srgb_frame, (1, 2, 0)), np.transpose(denoised_srgb_frame, (1, 2, 0)),
66 | data_range=1.0, multichannel=True)
67 |
68 | seq_raw_psnr /= seq.shape[0]
69 | seq_srgb_psnr /= seq.shape[0]
70 | seq_raw_ssim /= seq.shape[0]
71 | seq_srgb_ssim /= seq.shape[0]
72 |
73 | if (str(data['iso'])+'raw') not in iso_psnr.keys():
74 | iso_psnr[str(data['iso'])+'raw'] = seq_raw_psnr / 5
75 | iso_psnr[str(data['iso'])+'srgb'] = seq_srgb_psnr / 5
76 | iso_ssim[str(data['iso'])+'raw'] = seq_raw_ssim / 5
77 | iso_ssim[str(data['iso'])+'srgb'] = seq_srgb_ssim / 5
78 | else:
79 | iso_psnr[str(data['iso'])+'raw'] += seq_raw_psnr / 5
80 | iso_psnr[str(data['iso']) + 'srgb'] += seq_srgb_psnr / 5
81 | iso_ssim[str(data['iso']) + 'raw'] += seq_raw_ssim / 5
82 | iso_ssim[str(data['iso']) + 'srgb'] += seq_srgb_ssim / 5
83 |
84 | dataset_raw_psnr, dataset_srgb_psnr, dataset_raw_ssim, dataset_srgb_ssim = 0, 0, 0, 0
85 | for iso in [1600, 3200, 6400, 12800, 25600]:
86 | print('iso %d, raw: %6.4f/%6.4f, srgb: %6.4f/%6.4f' % (iso, iso_psnr[str(iso)+'raw'], iso_ssim[str(iso)+'raw'],
87 | iso_psnr[str(iso)+'srgb'], iso_ssim[str(iso)+'srgb']))
88 | dataset_raw_psnr += iso_psnr[str(iso)+'raw']
89 | dataset_srgb_psnr += iso_psnr[str(iso)+'srgb']
90 | dataset_raw_ssim += iso_ssim[str(iso)+'raw']
91 | dataset_srgb_ssim += iso_ssim[str(iso)+'srgb']
92 |
93 | print('CRVD, raw: %6.4f/%6.4f, srgb: %6.4f/%6.4f' % (dataset_raw_psnr / 5, dataset_raw_ssim / 5, dataset_srgb_psnr / 5, dataset_srgb_ssim / 5))
94 |
95 | if __name__ == '__main__':
96 | parser = argparse.ArgumentParser(description="test raw model")
97 | parser.add_argument("--model", type=str, default='FloRNNRaw') # model in ['FloRNNRaw']
98 | parser.add_argument("--num_resblocks", type=int, default=15)
99 | parser.add_argument("--forward_count", type=int, default=3)
100 | parser.add_argument("--border_ratio", type=float, default=0.1)
101 | parser.add_argument("--model_file", type=str, default='/home/nagejacob/Documents/codes/VDN/logs/ours_raw/ckpt_e12.pth')
102 | parser.add_argument("--crvd_dir", type=str, default="/hdd/Documents/datasets/CRVD")
103 | parser.add_argument("--isp_path", type=str, default="../models/rvidenet/isp.pth")
104 | argspar = parser.parse_args()
105 |
106 | print("\n### Testing model ###")
107 | print("> Parameters:")
108 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()):
109 | print('\t{}: {}'.format(p, v))
110 | print('\n')
111 |
112 | main(**vars(argspar))
--------------------------------------------------------------------------------
/test_models/sRGB_test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 | import argparse
4 | from datasets import SrgbValDataset
5 | from models import ForwardRNN, BiRNN, FloRNN, BasicVSRPlusPlus
6 | import time
7 | import torch
8 | import torch.nn as nn
9 | from utils.fastdvdnet_utils import fastdvdnet_batch_psnr
10 | from utils.ssim import batch_ssim
11 |
12 | def count_params(model):
13 | params = sum(p.numel() for p in model.parameters())
14 | print(params / 1000 / 1000)
15 | return params
16 |
17 | def denoise_seq(seqn, noise_std, model):
18 |
19 | # init arrays to handle contiguous frames and related patches
20 | numframes, C, H, W = seqn.shape
21 |
22 | # build noise map from noise std---assuming Gaussian noise
23 | noise_level_map = noise_std.expand((numframes, C, H, W)).cuda()
24 |
25 | with torch.no_grad():
26 | denframes = model(seqn.unsqueeze(0), noise_level_map.unsqueeze(0))
27 |
28 | denframes = torch.clamp(denframes.squeeze(0), 0., 1.)
29 |
30 | # free memory up
31 | del noise_level_map
32 | torch.cuda.empty_cache()
33 |
34 | # convert to appropiate type and return
35 | return denframes
36 |
37 | def test(**args):
38 | test_set = SrgbValDataset(args['test_path'], num_input_frames=args['max_num_fr_per_seq'])
39 |
40 | if args['model'] == 'ForwardRNN':
41 | model = ForwardRNN(img_channels=3, num_resblocks=args['num_resblocks'])
42 | elif args['model'] == 'BiRNN':
43 | model = BiRNN(img_channels=3, num_resblocks=args['num_resblocks'])
44 | elif args['model'] == 'FloRNN':
45 | model = FloRNN(img_channels=3, num_resblocks=args['num_resblocks'], forward_count=args['forward_count'], border_ratio=args['border_ratio'])
46 | elif args['model'] == 'BasicVSRPlusPlus':
47 | model = BasicVSRPlusPlus(img_channels=3, spatial_blocks=6, temporal_blocks=6, num_channels=64)
48 |
49 | state_temp_dict = torch.load(args['model_file'])['state_dict']
50 | model = nn.DataParallel(model).cuda()
51 | model.load_state_dict(state_temp_dict)
52 | model = model.module
53 | model.eval()
54 |
55 | dataset_psnr, dataset_ssim, seq_count = 0, 0, 0
56 | total_time, total_frames = 0, 0
57 | for data in test_set:
58 | seq = data['seq']
59 |
60 | # Add noise
61 | torch.manual_seed(0)
62 | noise = torch.empty_like(seq).normal_(mean=0, std=args['noise_sigma'])
63 | seqn = seq + noise
64 | noise_std = torch.FloatTensor([args['noise_sigma']])
65 | seqn = seqn.contiguous()
66 |
67 | torch.cuda.synchronize()
68 | start_time = time.time()
69 | with torch.no_grad():
70 | denframes = denoise_seq(seqn, noise_std=noise_std, model=model)
71 |
72 | torch.cuda.synchronize()
73 | total_time += time.time() - start_time
74 | total_frames += seqn.shape[0]
75 |
76 | psnr = fastdvdnet_batch_psnr(denframes, seq, 1.)
77 | ssim = batch_ssim(denframes, seq, 1.)
78 | dataset_psnr += psnr
79 | dataset_ssim += ssim
80 | seq_count += 1
81 | name = data['name'].split('/')[-2] + '/' + data['name'].split('/')[-1]
82 | print('{0:50}:, PSNR: {1:.4f}dB, SSIM: {2:.4f}'.format(name, psnr, ssim))
83 | if args['display_time']:
84 | print('frames: %d, time/frame: %6.4f s' % (total_frames, total_time / total_frames))
85 |
86 | print('sigma %d, PSNR: %6.4f, SSIM: %6.4f' % (int(round(args['noise_sigma'] * 255.)),
87 | dataset_psnr/seq_count, dataset_ssim/seq_count))
88 |
89 | if __name__ == "__main__":
90 | # Parse arguments
91 | parser = argparse.ArgumentParser(description="test sRGB model")
92 | parser.add_argument("--model", type=str, default='BasicVSRPlusPlus') # model in ['ForwardRNN', 'BiRNN', 'FloRNN', 'BasciVSRPlusPlus']
93 | parser.add_argument("--num_resblocks", type=int, default=15)
94 | parser.add_argument("--forward_count", type=int, default=3)
95 | parser.add_argument("--border_ratio", type=float, default=0.1)
96 | parser.add_argument("--model_file", type=str, default='/home/nagejacob/Documents/codes/VDN/logs/basicvsr_plusplus/ckpt_e12.pth')
97 | parser.add_argument("--test_path", type=str, default="/hdd/Documents/datasets/Set8")
98 | parser.add_argument("--max_num_fr_per_seq", type=int, default=85)
99 | parser.add_argument("--noise_sigma", type=float, default=20, help='noise level used on test_models set')
100 | parser.add_argument("--display_time", type=bool, default=False)
101 | argspar = parser.parse_args()
102 | # Normalize noises ot [0, 1]
103 | argspar.noise_sigma /= 255.
104 |
105 |
106 | print("\n### Testing model ###")
107 | print("> Parameters:")
108 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()):
109 | print('\t{}: {}'.format(p, v))
110 | print('\n')
111 |
112 | for sigma in [10, 20, 30, 40, 50]:
113 | argspar.noise_sigma = sigma / 255.
114 | print('sigma=%d' % sigma)
115 | dataset_psnr = test(**vars(argspar))
116 |
--------------------------------------------------------------------------------
/train_models/CRVD_train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 | import argparse
4 | from datasets import CRVDTrainDataset, CRVDTestDataset
5 | from models import FloRNNRaw
6 | import numpy as np
7 | import os
8 | from skimage.measure.simple_metrics import compare_psnr
9 | import time
10 | import torch
11 | from torch.utils.data import DataLoader
12 | from train_models.base_functions import batch_psnr, resume_training, save_model
13 | from utils.io import log
14 |
15 | torch.backends.cudnn.benchmark = True
16 |
17 | def main(**args):
18 | dataset_train = CRVDTrainDataset(CRVD_path=args['CRVD_dir'],
19 | patch_size=args['patch_size'],
20 | patches_per_epoch=args['patches_per_epoch'],
21 | mirror_seq=args['mirror_seq'])
22 | loader_train = DataLoader(dataset=dataset_train, batch_size=args['batch_size'], num_workers=4, shuffle=True, drop_last=True)
23 | dataset_val = CRVDTestDataset(CRVD_path=args['CRVD_dir'])
24 |
25 | if args['model'] == 'FloRNNRaw':
26 | model = FloRNNRaw(img_channels=4, num_resblocks=args['num_resblocks'], forward_count=args['forward_count'],
27 | border_ratio=args['border_ratio'])
28 | model = torch.nn.DataParallel(model).cuda()
29 |
30 | criterion = torch.nn.MSELoss(reduction='sum').cuda()
31 | optimizer = torch.optim.Adam(model.module.trainable_parameters(), lr=args['lr'])
32 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args['milestones'], gamma=0.1)
33 |
34 | start_epoch = resume_training(args, model, optimizer, scheduler)
35 | for epoch in range(start_epoch, args['epochs']):
36 | start_time = time.time()
37 |
38 | # training
39 | model.train()
40 | for i, data in enumerate(loader_train):
41 | seq = data['seq'].cuda()
42 | N, T, C, H, W = seq.shape
43 |
44 | seqn = data['seqn'].cuda()
45 | a = data['a'].expand((N, T, 1, H, W)).cuda()
46 | b = data['b'].expand((N, T, 1, H, W)).cuda()
47 |
48 | seqdn = model(seqn, a, b)
49 |
50 | if args['model'] in ['FloRNNRaw']:
51 | end_index = -1 if (args['forward_count'] == -1) else (-args['forward_count'])
52 | loss = criterion(seq[:, 1:end_index], seqdn[:, 1:end_index]) / (N * 2)
53 | else:
54 | loss = criterion(seq, seqdn) / (N * 2)
55 |
56 | optimizer.zero_grad()
57 | loss.backward()
58 | optimizer.step()
59 |
60 | if (i+1) % args['print_every'] == 0:
61 | train_psnr = torch.mean(batch_psnr(seq, seqdn)).item()
62 | log(args["log_file"], "[epoch {}][{}/{}] loss: {:1.4f} PSNR_train: {:1.4f}\n". \
63 | format(epoch + 1, i + 1, int(args['patches_per_epoch'] // args['batch_size']), loss.item(), train_psnr))
64 |
65 | scheduler.step()
66 |
67 | # evaluating
68 | model.eval()
69 | iso_psnr = {}
70 | for data in dataset_val:
71 | seq = data['seq']
72 | T, C, H, W = seq.shape
73 |
74 | seqn = data['seqn'].cuda()
75 | a = data['a'].expand((T, 1, H, W)).cuda()
76 | b = data['b'].expand((T, 1, H, W)).cuda()
77 |
78 | with torch.no_grad():
79 | seqdn = torch.clamp(model(seqn.unsqueeze(0), a.unsqueeze(0), b.unsqueeze(0)).squeeze(0), 0., 1.)
80 |
81 | # calculate psnr the same as RViDeNet
82 | seq_psnr = 0
83 | for i in range(T):
84 | seq_psnr += compare_psnr(seq[i].numpy(),
85 | (np.uint16(seqdn[i].cpu().numpy() * (2 ** 12 - 1 - 240) + 240).astype(np.float32) - 240) / (2 ** 12 - 1 - 240),
86 | data_range=1.0)
87 | seq_psnr /= T
88 |
89 | if str(data['iso']) not in iso_psnr.keys():
90 | iso_psnr[str(data['iso'])] = seq_psnr
91 | else:
92 | iso_psnr[str(data['iso'])] += seq_psnr
93 | dataset_psnr = 0
94 | for iso in [1600, 3200, 6400, 12800, 25600]:
95 | log(args['log_file'], 'iso %d, %6.4f\n' % (iso, iso_psnr[str(iso)] / 5))
96 | dataset_psnr += iso_psnr[str(iso)] / 5
97 | dataset_psnr = dataset_psnr / 5
98 |
99 | log(args["log_file"], "\n[epoch %d] PSNR_val: %.4f, %0.2f hour/epoch\n\n" % (epoch + 1, dataset_psnr, (time.time()-start_time)/3600))
100 |
101 | # save model
102 | save_model(args, model, optimizer, scheduler, epoch + 1)
103 |
104 |
105 | if __name__ == '__main__':
106 | parser = argparse.ArgumentParser(description="Train the denoiser")
107 |
108 | # Model parameters
109 | parser.add_argument("--model", type=str, default='FloRNNRaw')
110 | parser.add_argument("--num_resblocks", type=int, default=15)
111 | parser.add_argument("--forward_count", type=int, default=3)
112 | parser.add_argument("--border_ratio", type=float, default=0.1)
113 |
114 | # Training parameters
115 | parser.add_argument("--batch_size", type=int, default=16)
116 | parser.add_argument("--epochs", "--e", type=int, default=12)
117 | parser.add_argument("--milestones", nargs=1, type=int, default=[11])
118 | parser.add_argument("--lr", type=float, default=1e-4)
119 | parser.add_argument("--print_every", type=int, default=100)
120 | parser.add_argument("--patch_size", "--p", type=int, default=96, help="Patch size")
121 | parser.add_argument("--patches_per_epoch", "--n", type=int, default=256000, help="Number of patches")
122 | parser.add_argument("--mirror_seq", type=bool, default=True)
123 |
124 | # Paths
125 | parser.add_argument("--CRVD_dir", type=str, default='/hdd/Documents/datasets/CRVD')
126 | parser.add_argument("--log_dir", type=str, default="../logs/FloRNNRaw")
127 | argspar = parser.parse_args()
128 |
129 | argspar.log_file = os.path.join(argspar.log_dir, 'log.out')
130 |
131 | if not os.path.exists(argspar.log_dir):
132 | os.makedirs(argspar.log_dir)
133 | log(argspar.log_file, "\n### Training the denoiser ###\n")
134 | log(argspar.log_file, "> Parameters:\n")
135 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()):
136 | log(argspar.log_file, '\t{}: {}\n'.format(p, v))
137 |
138 | main(**vars(argspar))
--------------------------------------------------------------------------------
/train_models/base_functions.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import math
3 | import os
4 | import re
5 | import torch
6 | from utils.io import log
7 |
8 | def resume_training(args, model, optimizer, scheduler):
9 | """ Resumes previous training or starts anew
10 | """
11 | model_files = glob.glob(os.path.join(args['log_dir'], '*.pth'))
12 |
13 | if len(model_files) == 0:
14 | start_epoch = 0
15 | else:
16 | log(args.log_file, "> Resuming previous training\n")
17 | epochs_exist = []
18 | for model_file in model_files:
19 | result = re.findall('ckpt_e(.*).pth', model_file)
20 | epochs_exist.append(int(result[0]))
21 | max_epoch = max(epochs_exist)
22 | max_epoch_model_file = os.path.join(args['log_dir'], 'ckpt_e%d.pth' % max_epoch)
23 | checkpoint = torch.load(max_epoch_model_file)
24 | model.load_state_dict(checkpoint['state_dict'])
25 | optimizer.load_state_dict(checkpoint['optimizer'])
26 | scheduler.load_state_dict(checkpoint['scheduler'])
27 |
28 | start_epoch = max_epoch
29 |
30 | return start_epoch
31 |
32 | def save_model(args, model, optimizer, scheduler, epoch):
33 | save_dict = {
34 | 'args': args,
35 | 'state_dict': model.state_dict(),
36 | 'optimizer' : optimizer.state_dict(),
37 | 'scheduler': scheduler.state_dict()}
38 |
39 | torch.save(save_dict, os.path.join(args['log_dir'], 'ckpt_e{}.pth'.format(epoch)))
40 |
41 | # the same as skimage.metrics.peak_signal_noise_ratio
42 | def batch_psnr(a, b):
43 | a = torch.clamp(a, 0, 1)
44 | b = torch.clamp(b, 0, 1)
45 | x = torch.mean((a - b) ** 2, dim=[-3, -2, -1])
46 | return 20 * torch.log(1 / torch.sqrt(x)) / math.log(10)
47 |
--------------------------------------------------------------------------------
/train_models/sRGB_train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 | import argparse
4 | from datasets import SrgbTrainDataset, SrgbValDataset
5 | from models import ForwardRNN, BiRNN, FloRNN
6 | import os
7 | import time
8 | import torch
9 | from torch.utils.data import DataLoader
10 | from train_models.base_functions import resume_training, save_model
11 | from utils.fastdvdnet_utils import fastdvdnet_batch_psnr, normalize_augment
12 | from utils.io import log
13 |
14 | torch.backends.cudnn.benchmark = True
15 |
16 | def main(**args):
17 | dataset_train = SrgbTrainDataset(seq_dir=args['trainset_dir'],
18 | train_length=args['train_length'],
19 | patch_size=args['patch_size'],
20 | patches_per_epoch=args['patches_per_epoch'],
21 | image_postfix='jpg',
22 | pin_memory=True)
23 | loader_train = DataLoader(dataset=dataset_train, batch_size=args['batch_size'], num_workers=4, shuffle=True, drop_last=True)
24 | dataset_val = SrgbValDataset(valsetdir=args['valset_dir'])
25 | loader_val = DataLoader(dataset=dataset_val, batch_size=1)
26 |
27 | if args['model'] == 'ForwardRNN':
28 | model = ForwardRNN(img_channels=3, num_resblocks=args['num_resblocks'])
29 | elif args['model'] == 'BiRNN':
30 | model = BiRNN(img_channels=3, num_resblocks=args['num_resblocks'])
31 | elif args['model'] == 'FloRNN':
32 | model = FloRNN(img_channels=3, num_resblocks=args['num_resblocks'], forward_count=args['forward_count'],
33 | border_ratio=args['border_ratio'])
34 | model = torch.nn.DataParallel(model).cuda()
35 |
36 | criterion = torch.nn.MSELoss(reduction='sum').cuda()
37 | optimizer = torch.optim.Adam(model.module.trainable_parameters(), lr=args['lr'])
38 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args['milestones'], gamma=0.1)
39 |
40 | start_epoch = resume_training(args, model, optimizer, scheduler)
41 | for epoch in range(start_epoch, args['epochs']):
42 | start_time = time.time()
43 |
44 | # training
45 | model.train()
46 | for i, data in enumerate(loader_train):
47 | seq = data['data'].cuda()
48 | seq = normalize_augment(seq)
49 |
50 | N, T, C, H, W = seq.shape
51 | stdn = torch.empty((N, 1, 1, 1, 1)).cuda().uniform_(args['noise_ival'][0], to=args['noise_ival'][1])
52 | noise_level_map = stdn.expand_as(seq)
53 |
54 | noise = torch.normal(mean=torch.zeros_like(seq), std=noise_level_map)
55 | seqn = seq + noise
56 | seqdn = model(seqn, noise_level_map)
57 |
58 | if args['model'] in ['FloRNN']:
59 | end_index = -1 if (args['forward_count'] == -1) else (-args['forward_count'])
60 | loss = criterion(seq[:, 1:end_index], seqdn[:, 1:end_index]) / (N * 2)
61 | else:
62 | loss = criterion(seq, seqdn) / (N * 2)
63 |
64 | loss.backward()
65 | optimizer.step()
66 | optimizer.zero_grad()
67 |
68 | if (i+1) % args['print_every'] == 0:
69 | train_psnr = fastdvdnet_batch_psnr(seq, seqdn)
70 | log(args["log_file"], "[epoch {}][{}/{}] loss: {:1.4f} PSNR_train: {:1.4f}\n". \
71 | format(epoch + 1, i + 1, int(args['patches_per_epoch'] // args['batch_size']), loss.item(), train_psnr))
72 |
73 | scheduler.step()
74 |
75 | # evaluating
76 | model.eval()
77 | psnr_val = 0
78 | for i, data in enumerate(loader_val):
79 | seq = data['seq'].cuda()
80 |
81 | torch.manual_seed(0)
82 | stdn = torch.FloatTensor([args['val_noiseL']])
83 | noise_level_map = stdn.expand_as(seq)
84 | noise = torch.empty_like(seq).normal_(mean=0, std=args['val_noiseL'])
85 | seqn = seq + noise
86 |
87 | with torch.no_grad():
88 | seqdn = model(seqn, noise_level_map)
89 | psnr_val += fastdvdnet_batch_psnr(seq, seqdn)
90 |
91 | psnr_val = psnr_val / len(dataset_val)
92 | log(args["log_file"], "\n[epoch %d] PSNR_val: %.4f, %0.2f hour/epoch\n\n" % (epoch + 1, psnr_val, (time.time()-start_time)/3600))
93 |
94 | # save model
95 | save_model(args, model, optimizer, scheduler, epoch + 1)
96 |
97 |
98 | if __name__ == '__main__':
99 | parser = argparse.ArgumentParser(description="Train the denoiser")
100 |
101 | # Model parameters
102 | parser.add_argument("--model", type=str, default='FloRNN')
103 | parser.add_argument("--num_resblocks", type=int, default=15)
104 | parser.add_argument("--forward_count", type=int, default=3)
105 | parser.add_argument("--border_ratio", type=float, default=0.1)
106 |
107 | # Training parameters
108 | parser.add_argument("--batch_size", type=int, default=8)
109 | parser.add_argument("--epochs", "--e", type=int, default=12)
110 | parser.add_argument("--milestones", nargs=1, type=int, default=[11])
111 | parser.add_argument("--lr", type=float, default=1e-4)
112 | parser.add_argument("--print_every", type=int, default=100)
113 | parser.add_argument("--noise_ival", nargs=2, type=int, default=[0, 55])
114 | parser.add_argument("--val_noiseL", type=float, default=20)
115 | parser.add_argument("--patch_size", "--p", type=int, default=96, help="Patch size")
116 | parser.add_argument("--patches_per_epoch", "--n", type=int, default=128000, help="Number of patches")
117 |
118 | # Paths
119 | parser.add_argument("--trainset_dir", type=str, default='/hdd/Documents/datasets/DAVIS-2017-trainval-480p')
120 | parser.add_argument("--valset_dir", type=str, default='/hdd/Documents/datasets/Set8')
121 | parser.add_argument("--log_dir", type=str, default="../logs/FloRNN")
122 | argspar = parser.parse_args()
123 |
124 | argspar.log_file = os.path.join(argspar.log_dir, 'log.out')
125 | argspar.train_length = 10 if (argspar.forward_count == -1) else (8 + argspar.forward_count)
126 |
127 | # Normalize noise between [0, 1]
128 | argspar.val_noiseL /= 255.
129 | argspar.noise_ival[0] /= 255.
130 | argspar.noise_ival[1] /= 255.
131 |
132 | if not os.path.exists(argspar.log_dir):
133 | os.makedirs(argspar.log_dir)
134 | log(argspar.log_file, "\n### Training the denoiser ###\n")
135 | log(argspar.log_file, "> Parameters:\n")
136 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()):
137 | log(argspar.log_file, '\t{}: {}\n'.format(p, v))
138 |
139 | main(**vars(argspar))
--------------------------------------------------------------------------------
/train_models/sRGB_train_distributed.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 | import argparse
4 | from datasets import SrgbTrainDataset, SrgbValDataset
5 | from models import BasicVSRPlusPlus
6 | import os
7 | import time
8 | import torch
9 | from torch.utils.data import DataLoader
10 | from train_models.base_functions import resume_training, save_model
11 | from utils.fastdvdnet_utils import fastdvdnet_batch_psnr, normalize_augment
12 | from utils.io import log
13 |
14 | torch.backends.cudnn.benchmark = True
15 |
16 | def main(**args):
17 | torch.cuda.set_device(args['local_rank'])
18 | torch.distributed.init_process_group(backend='nccl', init_method=args['init_method'], rank=args['local_rank'], world_size=args['world_size'])
19 |
20 | dataset_train = SrgbTrainDataset(seq_dir=args['trainset_dir'],
21 | train_length=args['train_length'],
22 | patch_size=args['patch_size'],
23 | patches_per_epoch=args['patches_per_epoch'],
24 | image_postfix='jpg',
25 | pin_memory=True)
26 | sampler_train = torch.utils.data.distributed.DistributedSampler(dataset=dataset_train, shuffle=True)
27 | loader_train = DataLoader(dataset=dataset_train, batch_size=args['batch_size'], sampler=sampler_train, num_workers=4, drop_last=True)
28 | dataset_val = SrgbValDataset(valsetdir=args['valset_dir'])
29 | loader_val = DataLoader(dataset=dataset_val, batch_size=1)
30 |
31 | if args['model'] == 'BasicVSRPlusPlus':
32 | model = BasicVSRPlusPlus(img_channels=3, spatial_blocks=6, temporal_blocks=6, num_channels=64)
33 | model = model.to(torch.device('cuda', args['local_rank']))
34 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args['local_rank']], output_device=args['local_rank'], find_unused_parameters=True)
35 |
36 | criterion = torch.nn.MSELoss(reduction='sum').to(torch.device('cuda', args['local_rank']))
37 | optimizer = torch.optim.Adam(model.module.trainable_parameters(), lr=args['lr'])
38 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args['milestones'], gamma=0.1)
39 |
40 | start_epoch = resume_training(args, model, optimizer, scheduler)
41 | for epoch in range(start_epoch, args['epochs']):
42 | sampler_train.set_epoch(epoch)
43 | start_time = time.time()
44 |
45 | # training
46 | model.train()
47 | for i, data in enumerate(loader_train):
48 | seq = data['data'].to(torch.device('cuda', args['local_rank']))
49 | seq = normalize_augment(seq)
50 |
51 | N, T, C, H, W = seq.shape
52 | stdn = torch.empty((N, 1, 1, 1, 1)).to(torch.device('cuda', args['local_rank'])).uniform_(args['noise_ival'][0], to=args['noise_ival'][1])
53 | noise_level_map = stdn.expand_as(seq)
54 |
55 | noise = torch.normal(mean=torch.zeros_like(seq), std=noise_level_map)
56 | seqn = seq + noise
57 | seqdn = model(seqn, noise_level_map)
58 |
59 | if args['model'] in ['FloRNN']:
60 | end_index = -1 if (args['forward_count'] == -1) else (-args['forward_count'])
61 | loss = criterion(seq[:, 1:end_index], seqdn[:, 1:end_index]) / (N * 2)
62 | else:
63 | loss = criterion(seq, seqdn) / (N * 2)
64 |
65 | loss.backward()
66 | optimizer.step()
67 | optimizer.zero_grad()
68 |
69 | if (i+1) % args['print_every'] == 0 and args['local_rank'] == 0:
70 | train_psnr = fastdvdnet_batch_psnr(seq, seqdn)
71 | log(args["log_file"], "[epoch {}][{}/{}] loss: {:1.4f} PSNR_train: {:1.4f}\n". \
72 | format(epoch + 1, i + 1, int(args['patches_per_epoch'] // args['batch_size'] // args['world_size']), loss.item(), train_psnr))
73 |
74 | scheduler.step()
75 |
76 | # evaluating
77 | if args['local_rank'] == 0:
78 | model.eval()
79 | psnr_val = 0
80 | for i, data in enumerate(loader_val):
81 | seq = data['seq']
82 |
83 | torch.manual_seed(0)
84 | stdn = torch.FloatTensor([args['val_noiseL']])
85 | noise_level_map = stdn.expand_as(seq)
86 | noise = torch.empty_like(seq).normal_(mean=0, std=args['val_noiseL'])
87 | seqn = seq + noise
88 |
89 | with torch.no_grad():
90 | seqdn = model(seqn, noise_level_map)
91 | psnr_val += fastdvdnet_batch_psnr(seq, seqdn)
92 |
93 | psnr_val = psnr_val / len(dataset_val)
94 | log(args["log_file"], "\n[epoch %d] PSNR_val: %.4f, %0.2f hour/epoch\n\n" % (epoch + 1, psnr_val, (time.time()-start_time)/3600))
95 |
96 | # save model
97 | save_model(args, model, optimizer, scheduler, epoch + 1)
98 |
99 |
100 | if __name__ == '__main__':
101 | parser = argparse.ArgumentParser(description="Train the denoiser")
102 | parser.add_argument("--local_rank", type=int, default=0)
103 |
104 | # Model parameters
105 | parser.add_argument("--model", type=str, default='BasicVSRPlusPlus')
106 |
107 | # Training parameters
108 | parser.add_argument("--batch_size", type=int, default=8)
109 | parser.add_argument("--world_size", type=int, default=4)
110 | parser.add_argument("--init_method", default='tcp://127.0.0.1:25000')
111 | parser.add_argument("--epochs", "--e", type=int, default=12)
112 | parser.add_argument("--milestones", nargs=1, type=int, default=[11])
113 | parser.add_argument("--lr", type=float, default=1e-4)
114 | parser.add_argument("--print_every", type=int, default=100)
115 | parser.add_argument("--noise_ival", nargs=2, type=int, default=[0, 55])
116 | parser.add_argument("--val_noiseL", type=float, default=20)
117 | parser.add_argument("--patch_size", "--p", type=int, default=96, help="Patch size")
118 | parser.add_argument("--patches_per_epoch", "--n", type=int, default=128000, help="Number of patches")
119 |
120 | # Paths
121 | parser.add_argument("--trainset_dir", type=str, default='/mnt/disk10T/Documents/datasets/DAVIS-2017-trainval-480p')
122 | parser.add_argument("--valset_dir", type=str, default='/mnt/disk10T/Documents/datasets/Set8')
123 | parser.add_argument("--log_dir", type=str, default="../logs/BiRNN_plusplus")
124 | argspar = parser.parse_args()
125 |
126 | argspar.log_file = os.path.join(argspar.log_dir, 'log.out')
127 | argspar.train_length = 10
128 | argspar.batch_size = argspar.batch_size // argspar.world_size
129 |
130 | # Normalize noise between [0, 1]
131 | argspar.val_noiseL /= 255.
132 | argspar.noise_ival[0] /= 255.
133 | argspar.noise_ival[1] /= 255.
134 |
135 | if argspar.local_rank == 0:
136 | if not os.path.exists(argspar.log_dir):
137 | os.makedirs(argspar.log_dir)
138 | log(argspar.log_file, "\n### Training the denoiser ###\n")
139 | log(argspar.log_file, "> Parameters:\n")
140 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()):
141 | log(argspar.log_file, '\t{}: {}\n'.format(p, v))
142 |
143 | main(**vars(argspar))
--------------------------------------------------------------------------------
/utils/fastdvdnet_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import glob
3 | import numpy as np
4 | import os
5 | from random import choices
6 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr
7 | import torch
8 |
9 | IMAGETYPES = ('*.bmp', '*.png', '*.jpg', '*.jpeg', '*.tif') # Supported image types
10 |
11 | def fastdvdnet_batch_psnr(img, imclean, data_range=1.):
12 | r"""
13 | Computes the PSNR along the batch dimension (not pixel-wise)
14 |
15 | Args:
16 | img: a `torch.Tensor` containing the restored image
17 | imclean: a `torch.Tensor` containing the reference image
18 | data_range: The data range of the input image (distance between
19 | minimum and maximum possible values). By default, this is estimated
20 | from the image data-type.
21 | """
22 | img_cpu = img.data.cpu().numpy().astype(np.float32)
23 | imgclean = imclean.data.cpu().numpy().astype(np.float32)
24 | psnr = 0
25 | for i in range(img_cpu.shape[0]):
26 | psnr += compare_psnr(imgclean[i, :, :, :], img_cpu[i, :, :, :],
27 | data_range=data_range)
28 | return psnr/img_cpu.shape[0]
29 |
30 | def get_imagenames(seq_dir, pattern=None):
31 | """ Get ordered list of filenames
32 | """
33 | files = []
34 | for typ in IMAGETYPES:
35 | files.extend(glob.glob(os.path.join(seq_dir, typ)))
36 |
37 | # filter filenames
38 | if not pattern is None:
39 | ffiltered = [f for f in files if pattern in os.path.split(f)[-1]]
40 | files = ffiltered
41 | del ffiltered
42 |
43 | # sort filenames alphabetically
44 | files.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
45 | return files
46 |
47 | def open_sequence(seq_dir, gray_mode, expand_if_needed=False, max_num_fr=85):
48 | r""" Opens a sequence of images and expands it to even sizes if necesary
49 | Args:
50 | fpath: string, path to image sequence
51 | gray_mode: boolean, True indicating if images is to be open are in grayscale mode
52 | expand_if_needed: if True, the spatial dimensions will be expanded if
53 | size is odd
54 | expand_axis0: if True, output will have a fourth dimension
55 | max_num_fr: maximum number of frames to load
56 | Returns:
57 | seq: array of dims [num_frames, C, H, W], C=1 grayscale or C=3 RGB, H and W are even.
58 | The image gets normalized gets normalized to the range [0, 1].
59 | expanded_h: True if original dim H was odd and image got expanded in this dimension.
60 | expanded_w: True if original dim W was odd and image got expanded in this dimension.
61 | """
62 | # Get ordered list of filenames
63 | files = get_imagenames(seq_dir)
64 |
65 | seq_list = []
66 | # print("\tOpen sequence in folder: ", seq_dir)
67 | for fpath in files[0:max_num_fr]:
68 |
69 | img, expanded_h, expanded_w = open_image(fpath,\
70 | gray_mode=gray_mode,\
71 | expand_if_needed=expand_if_needed,\
72 | expand_axis0=False)
73 | seq_list.append(img)
74 | seq = np.stack(seq_list, axis=0)
75 | return seq, expanded_h, expanded_w
76 |
77 | def open_image(fpath, gray_mode, expand_if_needed=False, expand_axis0=True, normalize_data=True):
78 | r""" Opens an image and expands it if necesary
79 | Args:
80 | fpath: string, path of image file
81 | gray_mode: boolean, True indicating if image is to be open
82 | in grayscale mode
83 | expand_if_needed: if True, the spatial dimensions will be expanded if
84 | size is odd
85 | expand_axis0: if True, output will have a fourth dimension
86 | Returns:
87 | img: image of dims NxCxHxW, N=1, C=1 grayscale or C=3 RGB, H and W are even.
88 | if expand_axis0=False, the output will have a shape CxHxW.
89 | The image gets normalized gets normalized to the range [0, 1].
90 | expanded_h: True if original dim H was odd and image got expanded in this dimension.
91 | expanded_w: True if original dim W was odd and image got expanded in this dimension.
92 | """
93 | if not gray_mode:
94 | # Open image as a CxHxW torch.Tensor
95 | img = cv2.imread(fpath)
96 | # from HxWxC to CxHxW, RGB image
97 | img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
98 | else:
99 | # from HxWxC to CxHxW grayscale image (C=1)
100 | img = cv2.imread(fpath, cv2.IMREAD_GRAYSCALE)
101 | img = np.expand_dims(img, 0)
102 |
103 | if expand_axis0:
104 | img = np.expand_dims(img, 0)
105 |
106 | # Handle odd sizes
107 | expanded_h = False
108 | expanded_w = False
109 | sh_im = img.shape
110 | if expand_if_needed:
111 | if sh_im[-2]%2 == 1:
112 | expanded_h = True
113 | if expand_axis0:
114 | img = np.concatenate((img, \
115 | img[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
116 | else:
117 | img = np.concatenate((img, \
118 | img[:, -1, :][:, np.newaxis, :]), axis=1)
119 |
120 |
121 | if sh_im[-1]%2 == 1:
122 | expanded_w = True
123 | if expand_axis0:
124 | img = np.concatenate((img, \
125 | img[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
126 | else:
127 | img = np.concatenate((img, \
128 | img[:, :, -1][:, :, np.newaxis]), axis=2)
129 |
130 | if normalize_data:
131 | img = normalize(img)
132 | return img, expanded_h, expanded_w
133 |
134 | def normalize(data):
135 | r"""Normalizes a unit8 image to a float32 image in the range [0, 1]
136 |
137 | Args:
138 | data: a unint8 numpy array to normalize from [0, 255] to [0, 1]
139 | """
140 | return np.float32(data/255.)
141 |
142 | def normalize_augment(img_train):
143 | '''Normalizes and augments an input patch of dim [N, num_frames, C. H, W] in [0., 255.] to \
144 | [N, num_frames*C. H, W] in [0., 1.]. It also returns the central frame of the temporal \
145 | patch as a ground truth.
146 | '''
147 | def transform(sample):
148 | # define transformations
149 | do_nothing = lambda x: x
150 | do_nothing.__name__ = 'do_nothing'
151 | flipud = lambda x: torch.flip(x, dims=[2])
152 | flipud.__name__ = 'flipup'
153 | rot90 = lambda x: torch.rot90(x, k=1, dims=[2, 3])
154 | rot90.__name__ = 'rot90'
155 | rot90_flipud = lambda x: torch.flip(torch.rot90(x, k=1, dims=[2, 3]), dims=[2])
156 | rot90_flipud.__name__ = 'rot90_flipud'
157 | rot180 = lambda x: torch.rot90(x, k=2, dims=[2, 3])
158 | rot180.__name__ = 'rot180'
159 | rot180_flipud = lambda x: torch.flip(torch.rot90(x, k=2, dims=[2, 3]), dims=[2])
160 | rot180_flipud.__name__ = 'rot180_flipud'
161 | rot270 = lambda x: torch.rot90(x, k=3, dims=[2, 3])
162 | rot270.__name__ = 'rot270'
163 | rot270_flipud = lambda x: torch.flip(torch.rot90(x, k=3, dims=[2, 3]), dims=[2])
164 | rot270_flipud.__name__ = 'rot270_flipud'
165 | add_csnt = lambda x: x + torch.normal(mean=torch.zeros(x.size()[0], 1, 1, 1), \
166 | std=(5/255.)).expand_as(x).to(x.device)
167 | add_csnt.__name__ = 'add_csnt'
168 |
169 | # define transformations and their frequency, then pick one.
170 | aug_list = [do_nothing, flipud, rot90, rot90_flipud, \
171 | rot180, rot180_flipud, rot270, rot270_flipud, add_csnt]
172 | w_aug = [32, 12, 12, 12, 12, 12, 12, 12, 12] # one fourth chances to do_nothing
173 | transf = choices(aug_list, w_aug)
174 |
175 | # transform all images in array
176 | return transf[0](sample)
177 |
178 | N, T, C, H, W = img_train.shape
179 | # convert to [N, num_frames*C. H, W] in [0., 1.] from [N, num_frames, C. H, W] in [0., 255.]
180 | img_train = img_train.type(torch.float32).view(N, -1, H, W) / 255.
181 |
182 | # augment
183 | img_train = transform(img_train)
184 |
185 | # view back
186 | img_train = img_train.view(N, T, C, H, W)
187 |
188 | return img_train
189 |
190 | def remove_dataparallel_wrapper(state_dict):
191 | r"""Converts a DataParallel models to a normal one by removing the "module."
192 | wrapper in the module dictionary
193 |
194 |
195 | Args:
196 | state_dict: a torch.nn.DataParallel state dictionary
197 | """
198 | from collections import OrderedDict
199 |
200 | new_state_dict = OrderedDict()
201 | for k, v in state_dict.items():
202 | name = k[7:] # remove 'module.' of DataParallel
203 | new_state_dict[name] = v
204 |
205 | return new_state_dict
206 |
--------------------------------------------------------------------------------
/utils/io.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import imageio
3 | import numpy as np
4 | import os
5 | import torch
6 |
7 | def list_dir(dir, postfix=None, full_path=False):
8 | if full_path:
9 | if postfix is None:
10 | names = sorted([name for name in os.listdir(dir) if not name.startswith('.')])
11 | return sorted([os.path.join(dir, name) for name in names])
12 | else:
13 | names = sorted([name for name in os.listdir(dir) if (not name.startswith('.') and name.endswith(postfix))])
14 | return sorted([os.path.join(dir, name) for name in names])
15 | else:
16 | if postfix is None:
17 | return sorted([name for name in os.listdir(dir) if not name.startswith('.')])
18 | else:
19 | return sorted([name for name in os.listdir(dir) if (not name.startswith('.') and name.endswith(postfix))])
20 |
21 | def open_images_uint8(image_files):
22 | image_list = []
23 | for image_file in image_files:
24 | image = imageio.imread(image_file).astype(np.uint8)
25 | if len(image.shape) == 3:
26 | image = np.transpose(image, (2, 0, 1))
27 | image_list.append(image)
28 | seq = np.stack(image_list, axis=0)
29 | return seq
30 |
31 | def log(log_file, str, also_print=True):
32 | with open(log_file, 'a+') as F:
33 | F.write(str)
34 | if also_print:
35 | print(str, end='')
36 |
37 | # return pytorch image in shape 1x3xHxW
38 | def image2tensor(image_file):
39 | image = imageio.imread(image_file).astype(np.float32) / np.float32(255.0)
40 | if len(image.shape) == 3:
41 | image = np.transpose(image, (2, 0, 1))
42 | elif len(image.shape) == 2:
43 | image = np.expand_dims(image, 0)
44 | image = np.asarray(image, dtype=np.float32)
45 | image = torch.from_numpy(image).unsqueeze(0)
46 | return image
47 |
48 | # save numpy image in shape 3xHxW
49 | def np2image(image, image_file):
50 | image = np.transpose(image, (1, 2, 0))
51 | image = np.clip(image, 0., 1.)
52 | image = image * 255.
53 | image = image.astype(np.uint8)
54 | imageio.imwrite(image_file, image)
55 |
56 | def np2image_bgr(image, image_file):
57 | image = np.transpose(image, (1, 2, 0))
58 | image = np.clip(image, 0., 1.)
59 | image = image * 255.
60 | image = image.astype(np.uint8)
61 | cv2.imwrite(image_file, image)
62 |
63 | # save tensor image in shape 1x3xHxW
64 | def tensor2image(image, image_file):
65 | image = image.detach().cpu().squeeze(0).numpy()
66 | np2image(image, image_file)
67 |
68 |
--------------------------------------------------------------------------------
/utils/raw.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # simply convert raw seq to rgb seq for computing optical flow
4 | def demosaic(raw_seq):
5 | N, T, C, H, W = raw_seq.shape
6 | rgb_seq = torch.empty((N, T, 3, H, W), dtype=raw_seq.dtype, device=raw_seq.device)
7 | rgb_seq[:, :, 0] = raw_seq[:, :, 0]
8 | rgb_seq[:, :, 1] = (raw_seq[:, :, 1] + raw_seq[:, :, 2]) / 2
9 | rgb_seq[:, :, 2] = (raw_seq[:, :, 3])
10 | return rgb_seq
--------------------------------------------------------------------------------
/utils/ssim.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.metrics import structural_similarity as compare_ssim
3 |
4 | # img: T, C, H, W, imclean: T, C, H, W
5 | def batch_ssim(img, imclean, data_range):
6 |
7 | img = img.data.cpu().numpy().astype(np.float32)
8 | img = np.transpose(img, (0, 2, 3, 1))
9 | img_clean = imclean.data.cpu().numpy().astype(np.float32)
10 | img_clean = np.transpose(img_clean, (0, 2, 3, 1))
11 |
12 | ssim = 0
13 | for i in range(img.shape[0]):
14 | origin_i = img_clean[i, :, :, :]
15 | denoised_i = img[i, :, :, :]
16 | ssim += compare_ssim(origin_i.astype(float), denoised_i.astype(float), multichannel=True, win_size=11, K1=0.01,
17 | K2=0.03, sigma=1.5, gaussian_weights=True, data_range=1)
18 | return ssim/img.shape[0]
--------------------------------------------------------------------------------
/utils/warp.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def warp(x, flo):
4 | '''
5 | warp an image/tensor (im2) back to im1, according to the optical flow
6 | x: [B, C, H, W] (im2)
7 | flo: [B, 2, H, W] (flow)
8 | '''
9 | B, C, H, W = x.size()
10 | # mesh grid
11 | xx = torch.arange(0, W, device=x.device).view(1, -1).repeat(H, 1)
12 | yy = torch.arange(0, H, device=x.device).view(-1, 1).repeat(1, W)
13 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
14 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
15 | grid = torch.cat((xx, yy), 1).float()
16 |
17 | if x.is_cuda:
18 | grid = grid.to(x.device)
19 | vgrid = torch.autograd.Variable(grid) + flo
20 |
21 | # scale grid to [-1, 1]
22 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
23 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
24 |
25 | vgrid = vgrid.permute(0, 2, 3, 1)
26 | output = torch.nn.functional.grid_sample(x, vgrid, align_corners=True)
27 | mask = torch.autograd.Variable(torch.ones((B, C, H, W), device=x.device))
28 | mask = torch.nn.functional.grid_sample(mask, vgrid, align_corners=True)
29 |
30 | mask[mask < 0.9999] = 0
31 | mask[mask > 0] = 1
32 |
33 | return output * mask, mask
--------------------------------------------------------------------------------