├── .gitignore ├── LICENSE ├── datasets ├── CRVD_seq.py ├── __init__.py └── sRGB_seq.py ├── imgs └── figure_overview.png ├── models ├── __init__.py ├── basicvsr_plusplus.py ├── birnn.py ├── components.py ├── flornn.py ├── flornn_raw.py ├── forwardrnn.py ├── init.py └── rvidenet │ ├── isp.pth │ └── isp.py ├── pytorch_pwc ├── correlation │ └── correlation.py ├── extract_flow.py └── pwc.py ├── readme.md ├── requirements.yaml ├── softmax_splatting └── softsplat.py ├── test_models ├── CRVD_test.py └── sRGB_test.py ├── train_models ├── CRVD_train.py ├── base_functions.py ├── sRGB_train.py └── sRGB_train_distributed.py └── utils ├── fastdvdnet_utils.py ├── io.py ├── raw.py ├── ssim.py └── warp.py /.gitignore: -------------------------------------------------------------------------------- 1 | .xml 2 | .idea 3 | .idea/workspace.xml 4 | .DS_Store 5 | */__pycache__git 6 | .pyc 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /datasets/CRVD_seq.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import torch 5 | from torch.utils.data.dataset import Dataset 6 | 7 | iso_list = [1600, 3200, 6400, 12800, 25600] 8 | a_list = [3.513262, 6.955588, 13.486051, 26.585953, 52.032536] 9 | g_noise_var_list = [11.917691, 38.117816, 130.818508, 484.539790, 1819.818657] 10 | 11 | def pack_gbrg_raw_torch(raw): # T H W 12 | T, H, W = raw.shape 13 | im = raw.unsqueeze(1) 14 | 15 | out = torch.cat((im[:, :, 1:H:2, 0:W:2], 16 | im[:, :, 1:H:2, 1:W:2], 17 | im[:, :, 0:H:2, 0:W:2], 18 | im[:, :, 0:H:2, 1:W:2]), dim=1) 19 | return out 20 | 21 | def normalize_raw_torch(raw): 22 | black_level = 240 23 | white_level = 2 ** 12 - 1 24 | raw = torch.clamp(raw.type(torch.float32) - black_level, 0) / (white_level - black_level) 25 | return raw 26 | 27 | def open_CRVD_seq_raw(seq_path, file_pattern='frame%d_noisy0.tiff'): 28 | frame_list = [] 29 | for i in range(7): 30 | raw = cv2.imread(os.path.join(seq_path, file_pattern % (i+1)), -1) 31 | raw = np.asarray(raw) 32 | raw = np.expand_dims(raw, axis=0) 33 | frame_list.append(raw) 34 | seq = np.concatenate(frame_list, axis=0) 35 | return seq 36 | 37 | def open_CRVD_seq_raw_outdoor(seq_path, file_pattern='frame%d_noisy0.tiff'): 38 | frame_list = [] 39 | for i in range(50): 40 | raw = cv2.imread(os.path.join(seq_path, file_pattern % i), -1) 41 | raw = np.asarray(raw) 42 | raw = np.expand_dims(raw, axis=0) 43 | frame_list.append(raw) 44 | seq = np.concatenate(frame_list, axis=0) 45 | return seq 46 | 47 | def crop_position(patch_size, H, W): 48 | position_h = np.random.randint(0, (H - patch_size)//2 - 1) * 2 49 | position_w = np.random.randint(0, (W - patch_size)//2 - 1) * 2 50 | aug = np.random.randint(0, 8) 51 | return position_h, position_w, aug 52 | 53 | def aug_crop(img, patch_size, position_h, position_w, aug): 54 | patch = img[:, position_h:position_h + patch_size + 2, position_w:position_w + patch_size + 2] 55 | 56 | if aug == 0: 57 | patch = patch[:, :-2, :-2] 58 | elif aug == 1: 59 | patch = np.flip(patch, axis=1) 60 | patch = patch[:, 1:-1, :-2] 61 | elif aug == 2: 62 | patch = np.flip(np.flip(patch, axis=1), axis=2) 63 | patch = patch[:, 1:-1, 1:-1] 64 | elif aug == 3: 65 | patch = np.flip(patch, axis=2) 66 | patch = patch[:, :-2, 1:-1] 67 | elif aug == 4: 68 | patch = np.transpose(np.flip(patch, axis=2), (0, 2, 1)) 69 | patch = patch[:, :-2, 1:-1] 70 | elif aug == 5: 71 | patch = np.transpose(np.flip(np.flip(patch, axis=1), axis=2), (0, 2, 1)) 72 | patch = patch[:, :-2, :-2] 73 | elif aug == 6: 74 | patch = np.transpose(patch, (0, 2, 1)) 75 | patch = patch[:, 1:-1, 1:-1] 76 | elif aug == 7: 77 | patch = np.transpose(np.flip(patch, axis=1), (0, 2, 1)) 78 | patch = patch[:, 1:-1, :-2] 79 | return patch 80 | 81 | 82 | class CRVDTrainDataset(Dataset): 83 | def __init__(self, CRVD_path, patch_size, patches_per_epoch, mirror_seq=True): 84 | self.CRVD_path = CRVD_path 85 | self.patches_per_epoch = patches_per_epoch 86 | self.patch_size = patch_size * 2 87 | self.mirror_seq = mirror_seq 88 | self.scene_id_list = [1, 2, 3, 4, 5, 6] 89 | self.seqs = {} 90 | 91 | for iso in iso_list: 92 | for scene_id in self.scene_id_list: 93 | self.seqs['%d_%d_clean' % (iso, scene_id)] = open_CRVD_seq_raw(os.path.join(self.CRVD_path, 'indoor_raw_gt/scene%d/ISO%d' % (scene_id, iso)), 94 | 'frame%d_clean_and_slightly_denoised.tiff') 95 | for i in range(10): 96 | self.seqs['%d_%d_noisy_%d' % (iso, scene_id, i)] = open_CRVD_seq_raw(os.path.join(self.CRVD_path, 'indoor_raw_noisy/scene%d/ISO%d' % (scene_id, iso)), 97 | 'frame%d_noisy{}.tiff'.format(i)) 98 | 99 | def __getitem__(self, index): 100 | index = index % (len(iso_list) * len(self.scene_id_list) * 10) 101 | iso_index = index // (len(self.scene_id_list) * 10) 102 | scene_index = (index - iso_index * len(self.scene_id_list) * 10) // 10 103 | noisy_index = index % 10 104 | iso = iso_list[iso_index] 105 | scene_id = self.scene_id_list[scene_index] 106 | 107 | seq = self.seqs['%d_%d_clean' % (iso, scene_id)] 108 | seqn = self.seqs['%d_%d_noisy_%d' % (iso, scene_id, noisy_index)] 109 | T, H, W = seq.shape 110 | position_h, position_w, aug = crop_position(self.patch_size, H, W) 111 | seq = aug_crop(seq, self.patch_size, position_h, position_w, aug) 112 | seqn = aug_crop(seqn, self.patch_size, position_h, position_w, aug) 113 | clean_list, noisy_list = [], [] 114 | for i in range(T): 115 | clean_list.append(np.expand_dims(seq[i], axis=0)) 116 | noisy_list.append(np.expand_dims(seqn[i], axis=0)) 117 | seq = torch.from_numpy(np.concatenate(clean_list, axis=0).astype(np.int32)) 118 | seqn = torch.from_numpy(np.concatenate(noisy_list, axis=0).astype(np.int32)) 119 | seq = normalize_raw_torch(pack_gbrg_raw_torch(seq)) 120 | seqn = normalize_raw_torch(pack_gbrg_raw_torch(seqn)) 121 | 122 | if self.mirror_seq: 123 | seq = torch.cat((seq, torch.flip(seq, dims=[0])), dim=0) 124 | seqn = torch.cat((seqn, torch.flip(seqn, dims=[0])), dim=0) 125 | 126 | a = torch.tensor(a_list[iso_index], dtype=torch.float32).view((1, 1, 1, 1)) / (2 ** 12 - 1 - 240) 127 | b = torch.tensor(g_noise_var_list[iso_index], dtype=torch.float32).view((1, 1, 1, 1)) / ((2 ** 12 - 1 - 240) ** 2) 128 | 129 | return {'seq': seq, 130 | 'seqn': seqn, 131 | 'a': a, 'b': b} 132 | 133 | def __len__(self): 134 | return self.patches_per_epoch 135 | 136 | class CRVDTestDataset(Dataset): 137 | def __init__(self, CRVD_path): 138 | self.CRVD_path = CRVD_path 139 | self.scene_id_list = [7, 8, 9, 10, 11] 140 | self.seqs = {} 141 | 142 | for iso in iso_list: 143 | for scene_id in self.scene_id_list: 144 | self.seqs['%d_%d_clean' % (iso, scene_id)] = open_CRVD_seq_raw(os.path.join(self.CRVD_path, 'indoor_raw_gt/scene%d/ISO%d' % (scene_id, iso)), 145 | 'frame%d_clean_and_slightly_denoised.tiff') 146 | self.seqs['%d_%d_noisy' % (iso, scene_id)] = open_CRVD_seq_raw(os.path.join(self.CRVD_path, 'indoor_raw_noisy/scene%d/ISO%d' % (scene_id, iso)), 147 | 'frame%d_noisy0.tiff') 148 | 149 | def __getitem__(self, index): 150 | iso = iso_list[index // len(self.scene_id_list)] 151 | scene_id = self.scene_id_list[index % len(self.scene_id_list)] 152 | 153 | seq = torch.from_numpy(self.seqs['%d_%d_clean' % (iso, scene_id)].astype(np.float32)) 154 | seqn = torch.from_numpy(self.seqs['%d_%d_noisy' % (iso, scene_id)].astype(np.float32)) 155 | seq = normalize_raw_torch(pack_gbrg_raw_torch(seq)) 156 | seqn = normalize_raw_torch(pack_gbrg_raw_torch(seqn)) 157 | a = torch.tensor(a_list[index // len(self.scene_id_list)], dtype=torch.float32).view((1, 1, 1, 1)) / (2 ** 12 - 1 - 240) 158 | b = torch.tensor(g_noise_var_list[index // len(self.scene_id_list)], dtype=torch.float32).view((1, 1, 1, 1)) / ((2 ** 12 - 1 - 240) ** 2) 159 | 160 | return {'seq': seq, 161 | 'seqn': seqn, 162 | 'iso': iso, 'a': a, 'b': b, 'scene_id': scene_id} 163 | 164 | def __len__(self): 165 | return len(iso_list) * len(self.scene_id_list) 166 | 167 | class CRVDOurdoorDataset(Dataset): 168 | def __init__(self, CRVD_path): 169 | self.CRVD_path = CRVD_path 170 | self.scene_id_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 171 | self.seqs = {} 172 | 173 | self.iso = 25600 174 | for scene_id in self.scene_id_list: 175 | self.seqs['%d_%d_noisy' % (self.iso, scene_id)] = open_CRVD_seq_raw_outdoor(os.path.join(self.CRVD_path, 'outdoor_raw_noisy/scene%d/iso%d' % (scene_id, self.iso)), 176 | 'frame%d.tiff') 177 | 178 | def __getitem__(self, index): 179 | scene_id = self.scene_id_list[index] 180 | 181 | seqn = torch.from_numpy(self.seqs['%d_%d_noisy' % (self.iso, scene_id)].astype(np.float32)) 182 | seqn = normalize_raw_torch(pack_gbrg_raw_torch(seqn)) 183 | a = torch.tensor(a_list[4], dtype=torch.float32).view((1, 1, 1, 1)) / (2 ** 12 - 1 - 240) 184 | b = torch.tensor(g_noise_var_list[4], dtype=torch.float32).view((1, 1, 1, 1)) / ((2 ** 12 - 1 - 240) ** 2) 185 | 186 | return {'seqn': seqn, 187 | 'iso': self.iso, 'a': a, 'b': b, 'scene_id': scene_id} 188 | 189 | def __len__(self): 190 | return len(self.scene_id_list) -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.CRVD_seq import CRVDTrainDataset, CRVDTestDataset, CRVDOurdoorDataset 2 | from datasets.sRGB_seq import SrgbTrainDataset, SrgbValDataset -------------------------------------------------------------------------------- /datasets/sRGB_seq.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import os 4 | import torch 5 | from torch.utils.data.dataset import Dataset 6 | from utils.fastdvdnet_utils import open_sequence 7 | from utils.io import list_dir, open_images_uint8 8 | 9 | 10 | 11 | class SrgbTrainDataset(Dataset): 12 | def __init__(self, seq_dir, train_length, patch_size, patches_per_epoch, temp_stride=3, image_postfix='png', pin_memory=False): 13 | self.seq_dir = seq_dir 14 | self.train_length = train_length 15 | self.patch_size = patch_size 16 | self.patches_per_epoch = patches_per_epoch 17 | self.temp_stride = temp_stride 18 | self.pin_memory = pin_memory 19 | 20 | self.seq_names = list_dir(seq_dir) 21 | self.seqs = {} 22 | for seq_name in self.seq_names: 23 | self.seqs[seq_name] = {} 24 | self.seqs[seq_name]['clean_image_files'] = list_dir(os.path.join(self.seq_dir, seq_name), 25 | postfix=image_postfix, full_path=True) 26 | if self.pin_memory: 27 | self.seqs[seq_name]['clean_images'] = open_images_uint8(self.seqs[seq_name]['clean_image_files']) 28 | 29 | self.seq_count = [] 30 | for i in range(len(self.seq_names)): 31 | count = (len(self.seqs[self.seq_names[i]]['clean_image_files']) - self.train_length + self.temp_stride) // self.temp_stride 32 | self.seq_count.append(count) 33 | self.seq_count_cum = np.cumsum(self.seq_count) 34 | 35 | def __getitem__(self, index): 36 | if self.patches_per_epoch is not None: 37 | index = index % self.seq_count_cum[-1] 38 | for i in range(len(self.seq_count_cum)): 39 | if index < self.seq_count_cum[i]: 40 | seq_name = self.seq_names[i] 41 | seq_index = index if i == 0 else index - self.seq_count_cum[i - 1] 42 | break 43 | center_frame_index = seq_index * self.temp_stride + (self.train_length//2) 44 | if self.pin_memory: 45 | clean_images = self.seqs[seq_name]['clean_images'] 46 | else: 47 | clean_images = open_images_uint8(self.seqs[seq_name]['clean_image_files']) 48 | data = clean_images[center_frame_index - (self.train_length // 2):center_frame_index + 49 | (self.train_length // 2) + (self.train_length % 2)] 50 | 51 | # crop patches 52 | num_frames, C, H, W = data.shape 53 | position_H = np.random.randint(0, H - self.patch_size + 1) 54 | position_W = np.random.randint(0, W - self.patch_size + 1) 55 | data = data[:, :, position_H:position_H+self.patch_size, position_W:position_W+self.patch_size] 56 | 57 | return_dict = {'data':data} 58 | return return_dict 59 | 60 | def __len__(self): 61 | if self.patches_per_epoch is None: 62 | return self.seq_count_cum[-1] 63 | else: 64 | return self.patches_per_epoch 65 | 66 | """ 67 | Dataset related functions 68 | Copyright (C) 2018, Matias Tassano 69 | This program is free software: you can use, modify and/or 70 | redistribute it under the terms of the GNU General Public 71 | License as published by the Free Software Foundation, either 72 | version 3 of the License, or (at your option) any later 73 | version. You should have received a copy of this license along 74 | this program. If not, see . 75 | """ 76 | 77 | NUMFRXSEQ_VAL = 85 # number of frames of each sequence to include in validation dataset 78 | VALSEQPATT = '*' # pattern for name of validation sequence 79 | 80 | class SrgbValDataset(Dataset): 81 | """Validation dataset. Loads all the images in the dataset folder on memory. 82 | """ 83 | def __init__(self, valsetdir, gray_mode=False, num_input_frames=NUMFRXSEQ_VAL): 84 | self.gray_mode = gray_mode 85 | 86 | # Look for subdirs with individual sequences 87 | seqs_dirs = sorted(glob.glob(os.path.join(valsetdir, VALSEQPATT))) 88 | 89 | # open individual sequences and append them to the sequence list 90 | sequences = [] 91 | for seq_dir in seqs_dirs: 92 | seq, _, _ = open_sequence(seq_dir, gray_mode, expand_if_needed=False, \ 93 | max_num_fr=num_input_frames) 94 | # seq is [num_frames, C, H, W] 95 | sequences.append(seq) 96 | 97 | self.seqs_dirs = seqs_dirs 98 | self.sequences = sequences 99 | 100 | def __getitem__(self, index): 101 | return {'seq':torch.from_numpy(self.sequences[index]), 'name':self.seqs_dirs[index]} 102 | 103 | def __len__(self): 104 | return len(self.sequences) -------------------------------------------------------------------------------- /imgs/figure_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nagejacob/FloRNN/5419715af261bf1d619818baaf26708b81781f4a/imgs/figure_overview.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.rvidenet.isp import ISP 2 | from models.basicvsr_plusplus import BasicVSRPlusPlus 3 | from models.birnn import BiRNN 4 | from models.flornn import FloRNN 5 | from models.flornn_raw import FloRNNRaw 6 | from models.forwardrnn import ForwardRNN -------------------------------------------------------------------------------- /models/basicvsr_plusplus.py: -------------------------------------------------------------------------------- 1 | from mmcv.cnn import constant_init 2 | from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d 3 | from models.components import ResBlocks, D 4 | from pytorch_pwc.extract_flow import extract_flow_torch 5 | from pytorch_pwc.pwc import PWCNet 6 | import torch 7 | import torch.nn as nn 8 | from utils.warp import warp 9 | 10 | class BasicVSRPlusPlus(nn.Module): 11 | def __init__(self, img_channels=3, spatial_blocks=-1, temporal_blocks=-1, num_channels=64): 12 | super(BasicVSRPlusPlus, self).__init__() 13 | self.num_channels = num_channels 14 | self.pwcnet = PWCNet() 15 | 16 | self.feat_extract = ResBlocks(input_channels=img_channels * 2, num_resblocks=spatial_blocks, num_channels=num_channels) 17 | 18 | self.backbone = nn.ModuleDict() 19 | self.deform_align = nn.ModuleDict() 20 | self.module_names = ['forward_1', 'backward_1', 'forward_2', 'backward_2'] 21 | for i, module_name in enumerate(self.module_names): 22 | self.backbone[module_name] = ResBlocks(input_channels=num_channels * (i+2), num_resblocks=temporal_blocks, num_channels=num_channels) 23 | self.deform_align[module_name] = SecondOrderDeformableAlignment( 24 | 2 * num_channels, 25 | num_channels, 26 | 3, 27 | padding=1, 28 | deform_groups=16, 29 | max_residue_magnitude=10) 30 | 31 | self.d = D(in_channels=num_channels * 4, mid_channels=num_channels * 2, out_channels=img_channels) 32 | self.device = torch.device('cuda') 33 | 34 | def trainable_parameters(self): 35 | return [{'params':self.feat_extract.parameters()}, {'params':self.backbone.parameters()}, 36 | {'params':self.deform_align.parameters()}, {'params':self.d.parameters()}] 37 | 38 | def spatial_feature(self, seqn, noise_level_map): 39 | spatial_hs = [] 40 | for i in range(seqn.shape[1]): 41 | spatial_h = self.feat_extract(torch.cat((seqn[:, i].cuda(), noise_level_map[:, i].cuda()), dim=1)) 42 | if not self.training: 43 | spatial_h = spatial_h.cpu() 44 | spatial_hs.append(spatial_h) 45 | return spatial_hs 46 | 47 | def extract_flows(self, seqn): 48 | N, T, C, H, W = seqn.shape 49 | forward_flows, backward_flows = [], [] 50 | for i in range(T-1): 51 | forward_flow = extract_flow_torch(self.pwcnet, seqn[:, i+1].cuda(), seqn[:, i].cuda()) 52 | backward_flow = extract_flow_torch(self.pwcnet, seqn[:, i].cuda(), seqn[:, i+1].cuda()) 53 | if not self.training: 54 | forward_flow = forward_flow.cpu() 55 | backward_flow = backward_flow.cpu() 56 | forward_flows.append(forward_flow) 57 | backward_flows.append(backward_flow) 58 | return forward_flows, backward_flows 59 | 60 | def forward(self, seqn, noise_level_map): 61 | if self.training: 62 | self.device = torch.device('cuda') 63 | return self.forward_train(seqn, noise_level_map) 64 | else: 65 | self.device = torch.device('cpu') 66 | return self.forward_test(seqn, noise_level_map) 67 | 68 | def forward_train(self, seqn, noise_level_map): 69 | N, T, C, H, W = seqn.shape 70 | hs = {} 71 | for module_name in self.module_names: 72 | hs[module_name] = [None] * T 73 | zeros_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device) 74 | zeros_flow = torch.zeros((N, 2, H, W), device=seqn.device) 75 | seqdn = torch.empty_like(seqn) 76 | 77 | # extract flows 78 | forward_flows, backward_flows = self.extract_flows(seqn) 79 | 80 | # extract spatial features 81 | hs['spatial'] = self.spatial_feature(seqn, noise_level_map) 82 | 83 | # extract forward features 84 | spatial_h = hs['spatial'][0] 85 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, zeros_h), dim=1)) 86 | hs['forward_1'][0] = forward_h 87 | 88 | spatial_h = hs['spatial'][1] 89 | flow_n1 = forward_flows[0] 90 | forward_h_n1 = forward_h 91 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1) 92 | feat_prop = self.deform_align['forward_1'](torch.cat((forward_h_n1, zeros_h), dim=1), 93 | torch.cat((aligned_forward_h_n1, spatial_h, zeros_h), dim=1), 94 | flow_n1, zeros_flow) 95 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, feat_prop), dim=1)) 96 | hs['forward_1'][1] = forward_h 97 | 98 | for i in range(2, T): 99 | spatial_h = hs['spatial'][i] 100 | flow_n1 = forward_flows[i - 1] 101 | forward_h_n1 = forward_h 102 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1) 103 | flow_n2 = flow_n1 + warp(forward_flows[i - 2], flow_n1)[0] 104 | forward_h_n2 = hs['forward_1'][i - 2] 105 | aligned_forward_h_n2, _ = warp(forward_h_n2, flow_n2) 106 | feat_prop = self.deform_align['forward_1'](torch.cat((forward_h_n1, forward_h_n2), dim=1), torch.cat( 107 | (aligned_forward_h_n1, spatial_h, aligned_forward_h_n2), dim=1), flow_n1, flow_n2) 108 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, feat_prop), dim=1)) 109 | hs['forward_1'][i] = forward_h 110 | 111 | # extract backward features 112 | spatial_h = hs['spatial'][-1] 113 | backward_h = self.backbone['backward_1'](torch.cat((spatial_h, zeros_h, hs['forward_1'][-1]), dim=1)) 114 | hs['backward_1'][-1] = backward_h 115 | 116 | spatial_h = hs['spatial'][-2] 117 | flow_p1 = backward_flows[-1] 118 | backward_h_p1 = backward_h 119 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1) 120 | feat_prop = self.deform_align['backward_1'](torch.cat((backward_h_p1, zeros_h), dim=1), 121 | torch.cat((aligned_backward_h_p1, spatial_h, zeros_h), dim=1), 122 | flow_p1, zeros_flow) 123 | backward_h = self.backbone['backward_1'](torch.cat((spatial_h, feat_prop, hs['forward_1'][-2]), dim=1)) 124 | hs['backward_1'][-2] = backward_h 125 | 126 | for i in range(3, T + 1): 127 | spatial_h = hs['spatial'][T - i] 128 | flow_p1 = backward_flows[T - i] 129 | backward_h_p1 = backward_h 130 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1) 131 | flow_p2 = flow_p1 + warp(backward_flows[T - i + 1], flow_p1)[0] 132 | backward_h_p2 = hs['backward_1'][T - i + 1] 133 | aligned_backward_h_p2, _ = warp(backward_h_p2, flow_p2) 134 | feat_prop = self.deform_align['backward_1'](torch.cat((backward_h_p1, backward_h_p2), dim=1), 135 | torch.cat((aligned_backward_h_p1, spatial_h, backward_h_p2), 136 | dim=1), flow_p1, flow_p2) 137 | backward_h = self.backbone['backward_1']( 138 | torch.cat((spatial_h, feat_prop, hs['forward_1'][T - i]), dim=1)) 139 | hs['backward_1'][T - i] = backward_h 140 | 141 | # extract forward features 142 | spatial_h = hs['spatial'][0] 143 | forward_h = self.backbone['forward_2'](torch.cat((spatial_h, zeros_h, 144 | hs['forward_1'][0], 145 | hs['backward_1'][0]), dim=1)) 146 | hs['forward_2'][0] = forward_h 147 | 148 | spatial_h = hs['spatial'][1] 149 | flow_n1 = forward_flows[0] 150 | forward_h_n1 = forward_h 151 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1) 152 | feat_prop = self.deform_align['forward_2'](torch.cat((forward_h_n1, zeros_h), dim=1), 153 | torch.cat((aligned_forward_h_n1, spatial_h, zeros_h), dim=1), 154 | flow_n1, zeros_flow) 155 | forward_h = self.backbone['forward_2']( 156 | torch.cat((spatial_h, feat_prop, hs['forward_1'][1], hs['backward_1'][1]), dim=1)) 157 | hs['forward_2'][1] = forward_h 158 | 159 | for i in range(2, T): 160 | spatial_h = hs['spatial'][i] 161 | flow_n1 = forward_flows[i - 1] 162 | forward_h_n1 = forward_h 163 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1) 164 | flow_n2 = flow_n1 + warp(forward_flows[i - 2], flow_n1)[0] 165 | forward_h_n2 = hs['forward_2'][i - 2] 166 | aligned_forward_h_n2, _ = warp(forward_h_n2, flow_n2) 167 | feat_prop = self.deform_align['forward_2'](torch.cat((forward_h_n1, forward_h_n2), dim=1), torch.cat( 168 | (aligned_forward_h_n1, spatial_h, aligned_forward_h_n2), dim=1), flow_n1, flow_n2) 169 | forward_h = self.backbone['forward_2']( 170 | torch.cat((spatial_h, feat_prop, hs['forward_1'][i], hs['backward_1'][i]), dim=1)) 171 | hs['forward_2'][i] = forward_h 172 | 173 | # extract backward features 174 | spatial_h = hs['spatial'][-1] 175 | backward_h = self.backbone['backward_2']( 176 | torch.cat((spatial_h, zeros_h, hs['forward_1'][-1], hs['backward_1'][-1], hs['forward_2'][-1]), 177 | dim=1)) 178 | hs['backward_2'][-1] = backward_h 179 | 180 | spatial_h = hs['spatial'][-2] 181 | flow_p1 = backward_flows[-1] 182 | backward_h_p1 = backward_h 183 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1) 184 | feat_prop = self.deform_align['backward_2'](torch.cat((backward_h_p1, zeros_h), dim=1), 185 | torch.cat((aligned_backward_h_p1, spatial_h, zeros_h), dim=1), 186 | flow_p1, zeros_flow) 187 | backward_h = self.backbone['backward_2']( 188 | torch.cat((spatial_h, feat_prop, hs['forward_1'][-2], hs['backward_1'][-2], hs['forward_2'][-2]), 189 | dim=1)) 190 | hs['backward_2'][-2] = backward_h 191 | 192 | for i in range(3, T + 1): 193 | spatial_h = hs['spatial'][T - i] 194 | flow_p1 = backward_flows[T - i] 195 | backward_h_p1 = backward_h 196 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1) 197 | flow_p2 = flow_p1 + warp(backward_flows[T - i + 1], flow_p1)[0] 198 | backward_h_p2 = hs['backward_2'][T - i + 1] 199 | aligned_backward_h_p2, _ = warp(backward_h_p2, flow_p2) 200 | feat_prop = self.deform_align['backward_2'](torch.cat((backward_h_p1, backward_h_p2), dim=1), 201 | torch.cat((aligned_backward_h_p1, spatial_h, backward_h_p2), 202 | dim=1), flow_p1, flow_p2) 203 | backward_h = self.backbone['backward_2'](torch.cat((spatial_h, feat_prop, hs['forward_1'][T - i], 204 | hs['backward_1'][T - i], hs['forward_2'][T - i]), 205 | dim=1)) 206 | hs['backward_2'][T - i] = backward_h 207 | 208 | # generate results 209 | for i in range(T): 210 | seqdn[:, i] = self.d(torch.cat((hs['forward_1'][i], hs['backward_1'][i], hs['forward_2'][i], hs['backward_2'][i]), dim=1)) 211 | 212 | return seqdn 213 | 214 | def forward_test(self, seqn, noise_level_map): 215 | N, T, C, H, W = seqn.shape 216 | hs = {} 217 | for module_name in self.module_names: 218 | hs[module_name] = [None] * T 219 | zeros_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device) 220 | zeros_flow = torch.zeros((1, 2, H, W), device=seqn.device) 221 | seqdn = torch.empty_like(seqn) 222 | 223 | # extract flows 224 | forward_flows, backward_flows = self.extract_flows(seqn) 225 | 226 | # extract spatial features 227 | hs['spatial'] = self.spatial_feature(seqn, noise_level_map) 228 | 229 | # extract forward features 230 | spatial_h = hs['spatial'][0].cuda() 231 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, zeros_h.cuda()), dim=1)) 232 | hs['forward_1'][0] = forward_h.cpu() 233 | 234 | spatial_h = hs['spatial'][1].cuda() 235 | flow_n1 = forward_flows[0].cuda() 236 | forward_h_n1 = forward_h 237 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1) 238 | feat_prop = self.deform_align['forward_1'](torch.cat((forward_h_n1, zeros_h.cuda()), dim=1), 239 | torch.cat((aligned_forward_h_n1, spatial_h, zeros_h.cuda()), dim=1), 240 | flow_n1, zeros_flow.cuda()) 241 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, feat_prop), dim=1)) 242 | hs['forward_1'][1] = forward_h.cpu() 243 | 244 | for i in range(2, T): 245 | spatial_h = hs['spatial'][i].cuda() 246 | flow_n1 = forward_flows[i - 1].cuda() 247 | forward_h_n1 = forward_h 248 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1) 249 | flow_n2 = flow_n1 + warp(forward_flows[i - 2].cuda(), flow_n1)[0] 250 | forward_h_n2 = hs['forward_1'][i - 2].cuda() 251 | aligned_forward_h_n2, _ = warp(forward_h_n2, flow_n2) 252 | feat_prop = self.deform_align['forward_1'](torch.cat((forward_h_n1, forward_h_n2), dim=1), torch.cat( 253 | (aligned_forward_h_n1, spatial_h, aligned_forward_h_n2), dim=1), flow_n1, flow_n2) 254 | forward_h = self.backbone['forward_1'](torch.cat((spatial_h, feat_prop), dim=1)) 255 | hs['forward_1'][i] = forward_h.cpu() 256 | 257 | # extract backward features 258 | spatial_h = hs['spatial'][-1].cuda() 259 | backward_h = self.backbone['backward_1'](torch.cat((spatial_h, zeros_h.cuda(), hs['forward_1'][-1].cuda()), dim=1)) 260 | hs['backward_1'][-1] = backward_h.cpu() 261 | 262 | spatial_h = hs['spatial'][-2].cuda() 263 | flow_p1 = backward_flows[-1].cuda() 264 | backward_h_p1 = backward_h 265 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1) 266 | feat_prop = self.deform_align['backward_1'](torch.cat((backward_h_p1, zeros_h.cuda()), dim=1), 267 | torch.cat((aligned_backward_h_p1, spatial_h, zeros_h.cuda()), dim=1), 268 | flow_p1, zeros_flow.cuda()) 269 | backward_h = self.backbone['backward_1'](torch.cat((spatial_h, feat_prop, hs['forward_1'][-2].cuda()), dim=1)) 270 | hs['backward_1'][-2] = backward_h.cpu() 271 | 272 | for i in range(3, T + 1): 273 | spatial_h = hs['spatial'][T - i].cuda() 274 | flow_p1 = backward_flows[T - i].cuda() 275 | backward_h_p1 = backward_h 276 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1) 277 | flow_p2 = flow_p1 + warp(backward_flows[T - i + 1].cuda(), flow_p1)[0] 278 | backward_h_p2 = hs['backward_1'][T - i + 1].cuda() 279 | aligned_backward_h_p2, _ = warp(backward_h_p2, flow_p2) 280 | feat_prop = self.deform_align['backward_1'](torch.cat((backward_h_p1, backward_h_p2), dim=1), 281 | torch.cat((aligned_backward_h_p1, spatial_h, backward_h_p2), 282 | dim=1), flow_p1, flow_p2) 283 | backward_h = self.backbone['backward_1']( 284 | torch.cat((spatial_h, feat_prop, hs['forward_1'][T - i].cuda()), dim=1)) 285 | hs['backward_1'][T - i] = backward_h.cpu() 286 | 287 | # extract forward features 288 | spatial_h = hs['spatial'][0].cuda() 289 | forward_h = self.backbone['forward_2'](torch.cat((spatial_h, zeros_h.cuda(), 290 | hs['forward_1'][0].cuda(), 291 | hs['backward_1'][0].cuda()), dim=1)) 292 | hs['forward_2'][0] = forward_h.cpu() 293 | 294 | spatial_h = hs['spatial'][1].cuda() 295 | flow_n1 = forward_flows[0].cuda() 296 | forward_h_n1 = forward_h 297 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1) 298 | feat_prop = self.deform_align['forward_2'](torch.cat((forward_h_n1, zeros_h.cuda()), dim=1), 299 | torch.cat((aligned_forward_h_n1, spatial_h, zeros_h.cuda()), dim=1), 300 | flow_n1, zeros_flow.cuda()) 301 | forward_h = self.backbone['forward_2']( 302 | torch.cat((spatial_h, feat_prop, hs['forward_1'][1].cuda(), hs['backward_1'][1].cuda()), dim=1)) 303 | hs['forward_2'][1] = forward_h.cpu() 304 | 305 | for i in range(2, T): 306 | spatial_h = hs['spatial'][i].cuda() 307 | flow_n1 = forward_flows[i - 1].cuda() 308 | forward_h_n1 = forward_h 309 | aligned_forward_h_n1, _ = warp(forward_h_n1, flow_n1) 310 | flow_n2 = flow_n1 + warp(forward_flows[i - 2].cuda(), flow_n1)[0] 311 | forward_h_n2 = hs['forward_2'][i - 2].cuda() 312 | aligned_forward_h_n2, _ = warp(forward_h_n2, flow_n2) 313 | feat_prop = self.deform_align['forward_2'](torch.cat((forward_h_n1, forward_h_n2), dim=1), torch.cat( 314 | (aligned_forward_h_n1, spatial_h, aligned_forward_h_n2), dim=1), flow_n1, flow_n2) 315 | forward_h = self.backbone['forward_2']( 316 | torch.cat((spatial_h, feat_prop, hs['forward_1'][i].cuda(), hs['backward_1'][i].cuda()), dim=1)) 317 | hs['forward_2'][i] = forward_h.cpu() 318 | 319 | # extract backward features 320 | spatial_h = hs['spatial'][-1].cuda() 321 | backward_h = self.backbone['backward_2']( 322 | torch.cat((spatial_h, zeros_h.cuda(), hs['forward_1'][-1].cuda(), hs['backward_1'][-1].cuda(), hs['forward_2'][-1].cuda()), 323 | dim=1)) 324 | hs['backward_2'][-1] = backward_h.cpu() 325 | 326 | spatial_h = hs['spatial'][-2].cuda() 327 | flow_p1 = backward_flows[-1].cuda() 328 | backward_h_p1 = backward_h 329 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1) 330 | feat_prop = self.deform_align['backward_2'](torch.cat((backward_h_p1, zeros_h.cuda()), dim=1), 331 | torch.cat((aligned_backward_h_p1, spatial_h, zeros_h.cuda()), dim=1), 332 | flow_p1, zeros_flow.cuda()) 333 | backward_h = self.backbone['backward_2']( 334 | torch.cat((spatial_h, feat_prop, hs['forward_1'][-2].cuda(), hs['backward_1'][-2].cuda(), hs['forward_2'][-2].cuda()), 335 | dim=1)) 336 | hs['backward_2'][-2] = backward_h.cpu() 337 | 338 | for i in range(3, T + 1): 339 | spatial_h = hs['spatial'][T - i].cuda() 340 | flow_p1 = backward_flows[T - i].cuda() 341 | backward_h_p1 = backward_h 342 | aligned_backward_h_p1, _ = warp(backward_h_p1, flow_p1) 343 | flow_p2 = flow_p1 + warp(backward_flows[T - i + 1].cuda(), flow_p1)[0] 344 | backward_h_p2 = hs['backward_2'][T - i + 1].cuda() 345 | aligned_backward_h_p2, _ = warp(backward_h_p2, flow_p2) 346 | feat_prop = self.deform_align['backward_2'](torch.cat((backward_h_p1, backward_h_p2), dim=1), 347 | torch.cat((aligned_backward_h_p1, spatial_h, backward_h_p2), 348 | dim=1), flow_p1, flow_p2) 349 | backward_h = self.backbone['backward_2'](torch.cat((spatial_h, feat_prop, hs['forward_1'][T - i].cuda(), 350 | hs['backward_1'][T - i].cuda(), hs['forward_2'][T - i].cuda()), 351 | dim=1)) 352 | hs['backward_2'][T - i] = backward_h.cpu() 353 | 354 | # generate results 355 | for i in range(T): 356 | seqdn[:, i] = self.d( 357 | torch.cat((hs['forward_1'][i].cuda(), hs['backward_1'][i].cuda(), hs['forward_2'][i].cuda(), hs['backward_2'][i].cuda()), dim=1)).cpu() 358 | 359 | return seqdn 360 | 361 | 362 | class SecondOrderDeformableAlignment(ModulatedDeformConv2d): 363 | """Second-order deformable alignment module. 364 | Args: 365 | in_channels (int): Same as nn.Conv2d. 366 | out_channels (int): Same as nn.Conv2d. 367 | kernel_size (int or tuple[int]): Same as nn.Conv2d. 368 | stride (int or tuple[int]): Same as nn.Conv2d. 369 | padding (int or tuple[int]): Same as nn.Conv2d. 370 | dilation (int or tuple[int]): Same as nn.Conv2d. 371 | groups (int): Same as nn.Conv2d. 372 | bias (bool or str): If specified as `auto`, it will be decided by the 373 | norm_cfg. Bias will be set as True if norm_cfg is None, otherwise 374 | False. 375 | max_residue_magnitude (int): The maximum magnitude of the offset 376 | residue (Eq. 6 in paper). Default: 10. 377 | """ 378 | 379 | def __init__(self, *args, **kwargs): 380 | self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) 381 | 382 | super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) 383 | 384 | self.conv_offset = nn.Sequential( 385 | nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1), 386 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 387 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 388 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 389 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 390 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 391 | nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), 392 | ) 393 | 394 | self.init_offset() 395 | 396 | def init_offset(self): 397 | constant_init(self.conv_offset[-1], val=0, bias=0) 398 | 399 | def forward(self, x, extra_feat, flow_1, flow_2): 400 | extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1) 401 | out = self.conv_offset(extra_feat) 402 | o1, o2, mask = torch.chunk(out, 3, dim=1) 403 | 404 | # offset 405 | offset = self.max_residue_magnitude * torch.tanh( 406 | torch.cat((o1, o2), dim=1)) 407 | offset_1, offset_2 = torch.chunk(offset, 2, dim=1) 408 | offset_1 = offset_1 + flow_1.flip(1).repeat(1, 409 | offset_1.size(1) // 2, 1, 410 | 1) 411 | offset_2 = offset_2 + flow_2.flip(1).repeat(1, 412 | offset_2.size(1) // 2, 1, 413 | 1) 414 | offset = torch.cat([offset_1, offset_2], dim=1) 415 | 416 | # mask 417 | mask = torch.sigmoid(mask) 418 | 419 | return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, 420 | self.stride, self.padding, 421 | self.dilation, self.groups, 422 | self.deform_groups) -------------------------------------------------------------------------------- /models/birnn.py: -------------------------------------------------------------------------------- 1 | from models.components import ResBlocks, D 2 | from pytorch_pwc.extract_flow import extract_flow_torch 3 | from pytorch_pwc.pwc import PWCNet 4 | import torch 5 | import torch.nn as nn 6 | from utils.warp import warp 7 | 8 | class BiRNN(nn.Module): 9 | def __init__(self, img_channels=3, num_resblocks=6, num_channels=64): 10 | super(BiRNN, self).__init__() 11 | self.num_channels = num_channels 12 | self.pwcnet = PWCNet() 13 | self.forward_rnn = ResBlocks(input_channels=img_channels + img_channels + num_channels, num_resblocks=num_resblocks, num_channels=num_channels) 14 | self.backward_rnn = ResBlocks(input_channels=img_channels + img_channels + num_channels, num_resblocks=num_resblocks, num_channels=num_channels) 15 | self.d = D(in_channels=num_channels * 2, mid_channels=num_channels * 2, out_channels=img_channels) 16 | 17 | def trainable_parameters(self): 18 | return [{'params':self.forward_rnn.parameters()}, {'params':self.backward_rnn.parameters()}, {'params':self.d.parameters()}] 19 | 20 | def forward(self, seqn, noise_level_map): 21 | if self.training: 22 | feature_device = torch.device('cuda') 23 | else: 24 | feature_device = torch.device('cpu') 25 | N, T, C, H, W = seqn.shape 26 | forward_hs = torch.empty((N, T, self.num_channels, H, W), device=feature_device) 27 | backward_hs = torch.empty((N, T, self.num_channels, H, W), device=feature_device) 28 | seqdn = torch.empty_like(seqn) 29 | 30 | # extract forward features 31 | init_forward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device) 32 | forward_h = self.forward_rnn(torch.cat((seqn[:, 0], noise_level_map[:, 0], init_forward_h), dim=1)) 33 | forward_hs[:, 0] = forward_h.to(feature_device) 34 | for i in range(1, T): 35 | flow = extract_flow_torch(self.pwcnet, seqn[:, i], seqn[:, i-1]) 36 | aligned_forward_h, _ = warp(forward_h, flow) 37 | forward_h = self.forward_rnn(torch.cat((seqn[:, i], noise_level_map[:, i], aligned_forward_h), dim=1)) 38 | forward_hs[:, i] = forward_h.to(feature_device) 39 | 40 | # extract backward features 41 | init_backward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device) 42 | backward_h = self.backward_rnn(torch.cat((seqn[:, -1], noise_level_map[:, -1], init_backward_h), dim=1)) 43 | backward_hs[:, -1] = backward_h.to(feature_device) 44 | for i in range(2, T+1): 45 | flow = extract_flow_torch(self.pwcnet, seqn[:, T-i], seqn[:, T-i+1]) 46 | aligned_backward_h, _ = warp(backward_h, flow) 47 | backward_h = self.backward_rnn(torch.cat((seqn[:, T-i], noise_level_map[:, T-i], aligned_backward_h), dim=1)) 48 | backward_hs[:, T-i] = backward_h.to(feature_device) 49 | 50 | # generate results 51 | for i in range(T): 52 | seqdn[:, i] = self.d(torch.cat((forward_hs[:, i].to(seqn.device), backward_hs[:, i].to(seqn.device)), dim=1)) 53 | 54 | return seqdn 55 | -------------------------------------------------------------------------------- /models/components.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from models.init import init_fn 3 | import torch 4 | import torch.nn as nn 5 | 6 | class ResBlock(nn.Module): 7 | def __init__(self, in_channels, mid_channels, out_channels): 8 | super(ResBlock, self).__init__() 9 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, stride=1, padding=1, bias=False) 10 | self.relu = nn.ReLU(inplace=True) 11 | self.conv2 = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False) 12 | 13 | def forward(self, x): 14 | output = self.conv2(self.relu(self.conv1(x))) 15 | output = torch.add(output, x) 16 | return output 17 | 18 | class ResBlocks(nn.Module): 19 | def __init__(self, input_channels, num_resblocks, num_channels): 20 | super(ResBlocks, self).__init__() 21 | self.input_channels = input_channels 22 | self.first_conv = nn.Conv2d(in_channels=self.input_channels, out_channels=num_channels, kernel_size=3, stride=1, padding=1, bias=False) 23 | 24 | modules = [] 25 | for _ in range(num_resblocks): 26 | modules.append(ResBlock(in_channels=num_channels, mid_channels=num_channels, out_channels=num_channels)) 27 | self.resblocks = nn.Sequential(*modules) 28 | 29 | fn = functools.partial(init_fn, init_type='kaiming_normal', init_bn_type='uniform', gain=0.2) 30 | self.apply(fn) 31 | 32 | def forward(self, h): 33 | shallow_feature = self.first_conv(h) 34 | new_h = self.resblocks(shallow_feature) 35 | return new_h 36 | 37 | class D(nn.Module): 38 | def __init__(self, in_channels, mid_channels, out_channels): 39 | super(D, self).__init__() 40 | layers = [] 41 | layers.append(nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, stride=1, padding=1, bias=False)) 42 | layers.append(nn.ReLU()) 43 | layers.append(nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)) 44 | self.convs = nn.Sequential(*layers) 45 | 46 | fn = functools.partial(init_fn, init_type='kaiming_normal', init_bn_type='uniform', gain=0.2) 47 | self.apply(fn) 48 | 49 | def forward(self, x): 50 | x = self.convs(x) 51 | return x -------------------------------------------------------------------------------- /models/flornn.py: -------------------------------------------------------------------------------- 1 | from models.components import ResBlocks, D 2 | from pytorch_pwc.extract_flow import extract_flow_torch 3 | from pytorch_pwc.pwc import PWCNet 4 | from softmax_splatting.softsplat import FunctionSoftsplat 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from utils.warp import warp 9 | 10 | def expand(ten, size_h, size_w, value=0): 11 | return F.pad(ten, pad=[size_w, size_w, size_h, size_h], mode='constant', value=value) 12 | 13 | def split_border(ten, size_h, size_w): 14 | img = ten[:, :, size_h:-size_h, size_w:-size_w] 15 | return img, ten 16 | 17 | def merge_border(img, border, size_h, size_w): 18 | expanded_img = F.pad(img, pad=[size_w, size_w, size_h, size_h], mode='constant') 19 | expanded_img[:, :, :size_h, :] = border[:, :, :size_h, :] 20 | expanded_img[:, :, -size_h:, :] = border[:, :, -size_h:, :] 21 | expanded_img[:, :, :, :size_w] = border[:, :, :, :size_w] 22 | expanded_img[:, :, :, -size_w:] = border[:, :, :, -size_w:] 23 | return expanded_img 24 | 25 | class FloRNN(nn.Module): 26 | def __init__(self, img_channels, num_resblocks=6, num_channels=64, forward_count=2, border_ratio=0.3): 27 | super(FloRNN, self).__init__() 28 | self.num_channels = num_channels 29 | self.forward_count = forward_count 30 | self.pwcnet = PWCNet() 31 | self.forward_rnn = ResBlocks(input_channels=img_channels + img_channels + num_channels, num_resblocks=num_resblocks, num_channels=num_channels) 32 | self.backward_rnn = ResBlocks(input_channels=img_channels + img_channels + num_channels, num_resblocks=num_resblocks, num_channels=num_channels) 33 | self.d = D(in_channels=num_channels * 2, mid_channels=num_channels * 2, out_channels=img_channels) 34 | self.border_ratio = border_ratio 35 | 36 | def trainable_parameters(self): 37 | return [{'params':self.forward_rnn.parameters()}, {'params':self.backward_rnn.parameters()}, {'params':self.d.parameters()}] 38 | 39 | def forward(self, seqn_not_pad, noise_level_map_not_pad): 40 | N, T, C, H, W = seqn_not_pad.shape 41 | seqdn = torch.empty_like(seqn_not_pad) 42 | expanded_forward_flow_queue = [] 43 | border_queue = [] 44 | size_h, size_w = int(H * self.border_ratio), int(W * self.border_ratio) 45 | 46 | # reflect pad seqn and noise_level_map 47 | seqn = torch.empty((N, T+self.forward_count, C, H, W), device=seqn_not_pad.device) 48 | noise_level_map = torch.empty((N, T+self.forward_count, C, H, W), device=noise_level_map_not_pad.device) 49 | seqn[:, :T] = seqn_not_pad 50 | noise_level_map[:, :T] = noise_level_map_not_pad 51 | for i in range(self.forward_count): 52 | seqn[:, T+i] = seqn_not_pad[:, T-2-i] 53 | noise_level_map[:, T+i] = noise_level_map_not_pad[:, T-2-i] 54 | 55 | init_backward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device) 56 | backward_h = self.backward_rnn(torch.cat((seqn[:, 0], noise_level_map[:, 0], init_backward_h), dim=1)) 57 | init_forward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device) 58 | forward_h = self.forward_rnn(torch.cat((seqn[:, 0], noise_level_map[:, 0], init_forward_h), dim=1)) 59 | 60 | for i in range(1, T+self.forward_count): 61 | forward_flow = extract_flow_torch(self.pwcnet, seqn[:, i-1], seqn[:, i]) 62 | 63 | expanded_backward_h, expanded_forward_flow = expand(backward_h, size_h, size_w), expand(forward_flow, size_h, size_w) 64 | expanded_forward_flow_queue.append(expanded_forward_flow) 65 | aligned_expanded_backward_h = FunctionSoftsplat(expanded_backward_h, expanded_forward_flow, None, 'average') 66 | aligned_backward_h, border = split_border(aligned_expanded_backward_h, size_h, size_w) 67 | border_queue.append(border) 68 | 69 | backward_h = self.backward_rnn(torch.cat((seqn[:, i], noise_level_map[:, i], aligned_backward_h), dim=1)) 70 | 71 | if i >= self.forward_count: 72 | aligned_backward_h = backward_h 73 | for j in reversed(range(self.forward_count)): 74 | aligned_backward_h = merge_border(aligned_backward_h, border_queue[j], size_h, size_w) 75 | aligned_backward_h, _ = warp(aligned_backward_h, expanded_forward_flow_queue[j]) 76 | aligned_backward_h, _ = split_border(aligned_backward_h, size_h, size_w) 77 | 78 | seqdn[:, i - self.forward_count] = self.d(torch.cat((forward_h, aligned_backward_h), dim=1)) 79 | 80 | backward_flow = extract_flow_torch(self.pwcnet, seqn[:, i-self.forward_count+1], seqn[:, i-self.forward_count]) 81 | aligned_forward_h, _ = warp(forward_h, backward_flow) 82 | forward_h = self.forward_rnn(torch.cat((seqn[:, i-self.forward_count+1], noise_level_map[:, i-self.forward_count+1], aligned_forward_h), dim=1)) 83 | expanded_forward_flow_queue.pop(0) 84 | border_queue.pop(0) 85 | 86 | return seqdn 87 | 88 | -------------------------------------------------------------------------------- /models/flornn_raw.py: -------------------------------------------------------------------------------- 1 | from models.components import ResBlocks, D 2 | from pytorch_pwc.extract_flow import extract_flow_torch 3 | from pytorch_pwc.pwc import PWCNet 4 | from softmax_splatting.softsplat import FunctionSoftsplat 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from utils.raw import demosaic 9 | from utils.warp import warp 10 | 11 | def expand(ten, size_h, size_w, value=0): 12 | return F.pad(ten, pad=[size_w, size_w, size_h, size_h], mode='constant', value=value) 13 | 14 | def split_border(ten, size_h, size_w): 15 | img = ten[:, :, size_h:-size_h, size_w:-size_w] 16 | return img, ten 17 | 18 | def merge_border(img, border, size_h, size_w): 19 | expanded_img = F.pad(img, pad=[size_w, size_w, size_h, size_h], mode='constant') 20 | expanded_img[:, :, :size_h, :] = border[:, :, :size_h, :] 21 | expanded_img[:, :, -size_h:, :] = border[:, :, -size_h:, :] 22 | expanded_img[:, :, :, :size_w] = border[:, :, :, :size_w] 23 | expanded_img[:, :, :, -size_w:] = border[:, :, :, -size_w:] 24 | return expanded_img 25 | 26 | class FloRNNRaw(nn.Module): 27 | def __init__(self, img_channels, num_resblocks=6, num_channels=64, forward_count=2, border_ratio=0.1): 28 | super(FloRNNRaw, self).__init__() 29 | self.num_channels = num_channels 30 | self.forward_count = forward_count 31 | self.pwcnet = PWCNet() 32 | self.forward_rnn = ResBlocks(input_channels=img_channels + 2 + num_channels, num_resblocks=num_resblocks, num_channels=num_channels) 33 | self.backward_rnn = ResBlocks(input_channels=img_channels + 2 + num_channels, num_resblocks=num_resblocks, num_channels=num_channels) 34 | self.d = D(in_channels=num_channels * 2, mid_channels=num_channels * 2, out_channels=img_channels) 35 | self.border_ratio = border_ratio 36 | 37 | def trainable_parameters(self): 38 | return [{'params':self.forward_rnn.parameters()}, {'params':self.backward_rnn.parameters()}, {'params':self.d.parameters()}] 39 | 40 | def forward(self, seqn_not_pad, a_not_pad, b_not_pad): 41 | N, T, C, H, W = seqn_not_pad.shape 42 | seqdn = torch.empty_like(seqn_not_pad) 43 | expanded_forward_flow_queue = [] 44 | border_queue = [] 45 | size_h, size_w = int(H * self.border_ratio), int(W * self.border_ratio) 46 | 47 | # reflect pad seqn and noise_level_map 48 | seqn = torch.empty((N, T+self.forward_count, C, H, W), device=seqn_not_pad.device) 49 | a = torch.empty((N, T + self.forward_count, 1, H, W), device=a_not_pad.device) 50 | b = torch.empty((N, T + self.forward_count, 1, H, W), device=b_not_pad.device) 51 | seqn[:, :T] = seqn_not_pad 52 | a[:, :T] = a_not_pad 53 | b[:, :T] = b_not_pad 54 | for i in range(self.forward_count): 55 | seqn[:, T+i] = seqn_not_pad[:, T-2-i] 56 | a[:, T + i] = a_not_pad[:, T - 2 - i] 57 | b[:, T + i] = b_not_pad[:, T - 2 - i] 58 | srgb_seqn = demosaic(seqn) 59 | 60 | init_backward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device) 61 | backward_h = self.backward_rnn(torch.cat((seqn[:, 0], a[:, 0], b[:, 0], init_backward_h), dim=1)) 62 | init_forward_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device) 63 | forward_h = self.forward_rnn(torch.cat((seqn[:, 0], a[:, 0], b[:, 0], init_forward_h), dim=1)) 64 | 65 | for i in range(1, T+self.forward_count): 66 | forward_flow = extract_flow_torch(self.pwcnet, srgb_seqn[:, i-1], srgb_seqn[:, i]) 67 | 68 | expanded_backward_h, expanded_forward_flow = expand(backward_h, size_h, size_w), expand(forward_flow, size_h, size_w) 69 | expanded_forward_flow_queue.append(expanded_forward_flow) 70 | aligned_expanded_backward_h = FunctionSoftsplat(expanded_backward_h, expanded_forward_flow, None, 'average') 71 | aligned_backward_h, border = split_border(aligned_expanded_backward_h, size_h, size_w) 72 | border_queue.append(border) 73 | 74 | backward_h = self.backward_rnn(torch.cat((seqn[:, i], a[:, i], b[:, i], aligned_backward_h), dim=1)) 75 | 76 | if i >= self.forward_count: 77 | aligned_backward_h = backward_h 78 | for j in reversed(range(self.forward_count)): 79 | aligned_backward_h = merge_border(aligned_backward_h, border_queue[j], size_h, size_w) 80 | aligned_backward_h, _ = warp(aligned_backward_h, expanded_forward_flow_queue[j]) 81 | aligned_backward_h, _ = split_border(aligned_backward_h, size_h, size_w) 82 | 83 | seqdn[:, i - self.forward_count] = self.d(torch.cat((forward_h, aligned_backward_h), dim=1)) 84 | 85 | backward_flow = extract_flow_torch(self.pwcnet, srgb_seqn[:, i-self.forward_count+1], srgb_seqn[:, i-self.forward_count]) 86 | aligned_forward_h, _ = warp(forward_h, backward_flow) 87 | forward_h = self.forward_rnn(torch.cat((seqn[:, i-self.forward_count+1], a[:, i-self.forward_count+1], b[:, i-self.forward_count+1], aligned_forward_h), dim=1)) 88 | expanded_forward_flow_queue.pop(0) 89 | border_queue.pop(0) 90 | 91 | return seqdn 92 | 93 | -------------------------------------------------------------------------------- /models/forwardrnn.py: -------------------------------------------------------------------------------- 1 | from models.components import ResBlocks, D 2 | from pytorch_pwc.extract_flow import extract_flow_torch 3 | from pytorch_pwc.pwc import PWCNet 4 | import torch 5 | import torch.nn as nn 6 | from utils.warp import warp 7 | 8 | class ForwardRNN(nn.Module): 9 | def __init__(self, img_channels=3, num_resblocks=6, num_channels=64): 10 | super(ForwardRNN, self).__init__() 11 | self.num_channels = num_channels 12 | self.pwcnet = PWCNet() 13 | self.forward_rnn = ResBlocks(input_channels=img_channels + img_channels + num_channels, num_resblocks=num_resblocks, num_channels=num_channels) 14 | self.d = D(in_channels=num_channels, mid_channels=num_channels, out_channels=img_channels) 15 | 16 | def trainable_parameters(self): 17 | return [{'params':self.forward_rnn.parameters()}, {'params':self.d.parameters()}] 18 | 19 | def forward(self, seqn, noise_level_map): 20 | N, T, C, H, W = seqn.shape 21 | seqdn = torch.empty_like(seqn) 22 | 23 | init_h = torch.zeros((N, self.num_channels, H, W), device=seqn.device) 24 | h = self.forward_rnn(torch.cat((seqn[:, 0], noise_level_map[:, 0], init_h), dim=1)) 25 | seqdn[:, 0] = self.d(h) 26 | 27 | for i in range(1, T): 28 | flow = extract_flow_torch(self.pwcnet, seqn[:, i], seqn[:, i-1]) 29 | aligned_h, _ = warp(h, flow) 30 | h = self.forward_rnn(torch.cat((seqn[:, i], noise_level_map[:, i], aligned_h), dim=1)) 31 | seqdn[:, i] = self.d(h) 32 | 33 | return seqdn 34 | -------------------------------------------------------------------------------- /models/init.py: -------------------------------------------------------------------------------- 1 | """ 2 | # -------------------------------------------- 3 | # weights initialization 4 | # -------------------------------------------- 5 | """ 6 | from torch.nn import init 7 | 8 | """ 9 | # Kai Zhang, https://github.com/cszn/KAIR 10 | # 11 | # Args: 12 | # init_type: 13 | # normal; normal; xavier_normal; xavier_uniform; 14 | # kaiming_normal; kaiming_uniform; orthogonal 15 | # init_bn_type: 16 | # uniform; constant 17 | # gain: 18 | # 0.2 19 | """ 20 | 21 | def init_fn(m, init_type='kaiming_normal', init_bn_type='uniform', gain=0.2): 22 | classname = m.__class__.__name__ 23 | 24 | if classname.find('Conv') != -1 or classname.find('Linear') != -1: 25 | 26 | if init_type == 'normal': 27 | init.normal_(m.weight.data, 0, 0.1) 28 | m.weight.data.clamp_(-1, 1).mul_(gain) 29 | 30 | elif init_type == 'uniform': 31 | init.uniform_(m.weight.data, -0.2, 0.2) 32 | m.weight.data.mul_(gain) 33 | 34 | elif init_type == 'xavier_normal': 35 | init.xavier_normal_(m.weight.data, gain=gain) 36 | m.weight.data.clamp_(-1, 1) 37 | 38 | elif init_type == 'xavier_uniform': 39 | init.xavier_uniform_(m.weight.data, gain=gain) 40 | 41 | elif init_type == 'kaiming_normal': 42 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu') 43 | m.weight.data.clamp_(-1, 1).mul_(gain) 44 | 45 | elif init_type == 'kaiming_uniform': 46 | init.kaiming_uniform_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu') 47 | m.weight.data.mul_(gain) 48 | 49 | elif init_type == 'orthogonal': 50 | init.orthogonal_(m.weight.data, gain=gain) 51 | 52 | else: 53 | raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_type)) 54 | 55 | if m.bias is not None: 56 | m.bias.data.zero_() 57 | 58 | elif classname.find('BatchNorm2d') != -1: 59 | 60 | if init_bn_type == 'uniform': # preferred 61 | if m.affine: 62 | init.uniform_(m.weight.data, 0.1, 1.0) 63 | init.constant_(m.bias.data, 0.0) 64 | elif init_bn_type == 'constant': 65 | if m.affine: 66 | init.constant_(m.weight.data, 1.0) 67 | init.constant_(m.bias.data, 0.0) 68 | else: 69 | raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_bn_type)) -------------------------------------------------------------------------------- /models/rvidenet/isp.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nagejacob/FloRNN/5419715af261bf1d619818baaf26708b81781f4a/models/rvidenet/isp.pth -------------------------------------------------------------------------------- /models/rvidenet/isp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | class ISP(nn.Module): 6 | 7 | def __init__(self): 8 | super(ISP, self).__init__() 9 | 10 | self.conv1_1 = nn.Conv2d(4, 32, kernel_size=3, stride=1, padding=1) 11 | self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 12 | self.pool1 = nn.MaxPool2d(kernel_size=2) 13 | 14 | self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) 15 | self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 16 | self.pool2 = nn.MaxPool2d(kernel_size=2) 17 | 18 | self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 19 | self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 20 | 21 | self.upv4 = nn.ConvTranspose2d(128, 64, 2, stride=2) 22 | self.conv4_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1) 23 | self.conv4_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 24 | 25 | self.upv5 = nn.ConvTranspose2d(64, 32, 2, stride=2) 26 | self.conv5_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1) 27 | self.conv5_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 28 | 29 | self.conv6_1 = nn.Conv2d(32, 12, kernel_size=1, stride=1) 30 | 31 | def forward(self, x): 32 | conv1 = self.lrelu(self.conv1_1(x)) 33 | conv1 = self.lrelu(self.conv1_2(conv1)) 34 | pool1 = self.pool1(conv1) 35 | 36 | conv2 = self.lrelu(self.conv2_1(pool1)) 37 | conv2 = self.lrelu(self.conv2_2(conv2)) 38 | pool2 = self.pool1(conv2) 39 | 40 | conv3 = self.lrelu(self.conv3_1(pool2)) 41 | conv3 = self.lrelu(self.conv3_2(conv3)) 42 | 43 | up4 = self.upv4(conv3) 44 | up4 = torch.cat([up4, conv2], 1) 45 | conv4 = self.lrelu(self.conv4_1(up4)) 46 | conv4 = self.lrelu(self.conv4_2(conv4)) 47 | 48 | up5 = self.upv5(conv4) 49 | up5 = torch.cat([up5, conv1], 1) 50 | conv5 = self.lrelu(self.conv5_1(up5)) 51 | conv5 = self.lrelu(self.conv5_2(conv5)) 52 | 53 | conv6 = self.conv6_1(conv5) 54 | out = nn.functional.pixel_shuffle(conv6, 2) 55 | return out 56 | 57 | def _initialize_weights(self): 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | m.weight.data.normal_(0.0, 0.02) 61 | if m.bias is not None: 62 | m.bias.data.normal_(0.0, 0.02) 63 | if isinstance(m, nn.ConvTranspose2d): 64 | m.weight.data.normal_(0.0, 0.02) 65 | 66 | def lrelu(self, x): 67 | outt = torch.max(0.2 * x, x) 68 | return outt 69 | 70 | 71 | def initialize_weights(net_l, scale=1): 72 | if not isinstance(net_l, list): 73 | net_l = [net_l] 74 | for net in net_l: 75 | for m in net.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 78 | m.weight.data *= scale 79 | if m.bias is not None: 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.Linear): 82 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 83 | m.weight.data *= scale 84 | if m.bias is not None: 85 | m.bias.data.zero_() 86 | elif isinstance(m, nn.BatchNorm2d): 87 | init.constant_(m.weight, 1) 88 | init.constant_(m.bias.data, 0.0) 89 | 90 | 91 | def make_layer(block, n_layers): 92 | layers = [] 93 | for _ in range(n_layers): 94 | layers.append(block()) 95 | return nn.Sequential(*layers) -------------------------------------------------------------------------------- /pytorch_pwc/correlation/correlation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | 5 | import cupy 6 | import re 7 | 8 | kernel_Correlation_rearrange = ''' 9 | extern "C" __global__ void kernel_Correlation_rearrange( 10 | const int n, 11 | const float* input, 12 | float* output 13 | ) { 14 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; 15 | 16 | if (intIndex >= n) { 17 | return; 18 | } 19 | 20 | int intSample = blockIdx.z; 21 | int intChannel = blockIdx.y; 22 | 23 | float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; 24 | 25 | __syncthreads(); 26 | 27 | int intPaddedY = (intIndex / SIZE_3(input)) + 4; 28 | int intPaddedX = (intIndex % SIZE_3(input)) + 4; 29 | int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; 30 | 31 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; 32 | } 33 | ''' 34 | 35 | kernel_Correlation_updateOutput = ''' 36 | extern "C" __global__ void kernel_Correlation_updateOutput( 37 | const int n, 38 | const float* rbot0, 39 | const float* rbot1, 40 | float* top 41 | ) { 42 | extern __shared__ char patch_data_char[]; 43 | 44 | float *patch_data = (float *)patch_data_char; 45 | 46 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 47 | int x1 = blockIdx.x + 4; 48 | int y1 = blockIdx.y + 4; 49 | int item = blockIdx.z; 50 | int ch_off = threadIdx.x; 51 | 52 | // Load 3D patch into shared shared memory 53 | for (int j = 0; j < 1; j++) { // HEIGHT 54 | for (int i = 0; i < 1; i++) { // WIDTH 55 | int ji_off = (j + i) * SIZE_3(rbot0); 56 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 57 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; 58 | int idxPatchData = ji_off + ch; 59 | patch_data[idxPatchData] = rbot0[idx1]; 60 | } 61 | } 62 | } 63 | 64 | __syncthreads(); 65 | 66 | __shared__ float sum[32]; 67 | 68 | // Compute correlation 69 | for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { 70 | sum[ch_off] = 0; 71 | 72 | int s2o = top_channel % 9 - 4; 73 | int s2p = top_channel / 9 - 4; 74 | 75 | for (int j = 0; j < 1; j++) { // HEIGHT 76 | for (int i = 0; i < 1; i++) { // WIDTH 77 | int ji_off = (j + i) * SIZE_3(rbot0); 78 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 79 | int x2 = x1 + s2o; 80 | int y2 = y1 + s2p; 81 | 82 | int idxPatchData = ji_off + ch; 83 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; 84 | 85 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; 86 | } 87 | } 88 | } 89 | 90 | __syncthreads(); 91 | 92 | if (ch_off == 0) { 93 | float total_sum = 0; 94 | for (int idx = 0; idx < 32; idx++) { 95 | total_sum += sum[idx]; 96 | } 97 | const int sumelems = SIZE_3(rbot0); 98 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; 99 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; 100 | } 101 | } 102 | } 103 | ''' 104 | 105 | kernel_Correlation_updateGradFirst = ''' 106 | #define ROUND_OFF 50000 107 | 108 | extern "C" __global__ void kernel_Correlation_updateGradFirst( 109 | const int n, 110 | const int intSample, 111 | const float* rbot0, 112 | const float* rbot1, 113 | const float* gradOutput, 114 | float* gradFirst, 115 | float* gradSecond 116 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 117 | int n = intIndex % SIZE_1(gradFirst); // channels 118 | int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos 119 | int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos 120 | 121 | // round_off is a trick to enable integer division with ceil, even for negative numbers 122 | // We use a large offset, for the inner part not to become negative. 123 | const int round_off = ROUND_OFF; 124 | const int round_off_s1 = round_off; 125 | 126 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 127 | int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) 128 | int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) 129 | 130 | // Same here: 131 | int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) 132 | int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) 133 | 134 | float sum = 0; 135 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 136 | xmin = max(0,xmin); 137 | xmax = min(SIZE_3(gradOutput)-1,xmax); 138 | 139 | ymin = max(0,ymin); 140 | ymax = min(SIZE_2(gradOutput)-1,ymax); 141 | 142 | for (int p = -4; p <= 4; p++) { 143 | for (int o = -4; o <= 4; o++) { 144 | // Get rbot1 data: 145 | int s2o = o; 146 | int s2p = p; 147 | int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; 148 | float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] 149 | 150 | // Index offset for gradOutput in following loops: 151 | int op = (p+4) * 9 + (o+4); // index[o,p] 152 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 153 | 154 | for (int y = ymin; y <= ymax; y++) { 155 | for (int x = xmin; x <= xmax; x++) { 156 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 157 | sum += gradOutput[idxgradOutput] * bot1tmp; 158 | } 159 | } 160 | } 161 | } 162 | } 163 | const int sumelems = SIZE_1(gradFirst); 164 | const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); 165 | gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; 166 | } } 167 | ''' 168 | 169 | kernel_Correlation_updateGradSecond = ''' 170 | #define ROUND_OFF 50000 171 | 172 | extern "C" __global__ void kernel_Correlation_updateGradSecond( 173 | const int n, 174 | const int intSample, 175 | const float* rbot0, 176 | const float* rbot1, 177 | const float* gradOutput, 178 | float* gradFirst, 179 | float* gradSecond 180 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 181 | int n = intIndex % SIZE_1(gradSecond); // channels 182 | int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos 183 | int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos 184 | 185 | // round_off is a trick to enable integer division with ceil, even for negative numbers 186 | // We use a large offset, for the inner part not to become negative. 187 | const int round_off = ROUND_OFF; 188 | const int round_off_s1 = round_off; 189 | 190 | float sum = 0; 191 | for (int p = -4; p <= 4; p++) { 192 | for (int o = -4; o <= 4; o++) { 193 | int s2o = o; 194 | int s2p = p; 195 | 196 | //Get X,Y ranges and clamp 197 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 198 | int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) 199 | int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) 200 | 201 | // Same here: 202 | int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) 203 | int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) 204 | 205 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 206 | xmin = max(0,xmin); 207 | xmax = min(SIZE_3(gradOutput)-1,xmax); 208 | 209 | ymin = max(0,ymin); 210 | ymax = min(SIZE_2(gradOutput)-1,ymax); 211 | 212 | // Get rbot0 data: 213 | int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; 214 | float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] 215 | 216 | // Index offset for gradOutput in following loops: 217 | int op = (p+4) * 9 + (o+4); // index[o,p] 218 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 219 | 220 | for (int y = ymin; y <= ymax; y++) { 221 | for (int x = xmin; x <= xmax; x++) { 222 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 223 | sum += gradOutput[idxgradOutput] * bot0tmp; 224 | } 225 | } 226 | } 227 | } 228 | } 229 | const int sumelems = SIZE_1(gradSecond); 230 | const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); 231 | gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; 232 | } } 233 | ''' 234 | 235 | 236 | def cupy_kernel(strFunction, objVariables): 237 | strKernel = globals()[strFunction] 238 | 239 | while True: 240 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) 241 | 242 | if objMatch is None: 243 | break 244 | # end 245 | 246 | intArg = int(objMatch.group(2)) 247 | 248 | strTensor = objMatch.group(4) 249 | intSizes = objVariables[strTensor].size() 250 | 251 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) 252 | # end 253 | 254 | while True: 255 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) 256 | 257 | if objMatch is None: 258 | break 259 | # end 260 | 261 | intArgs = int(objMatch.group(2)) 262 | strArgs = objMatch.group(4).split(',') 263 | 264 | strTensor = strArgs[0] 265 | intStrides = objVariables[strTensor].stride() 266 | strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( 267 | intStrides[intArg]) + ')' for intArg in range(intArgs)] 268 | 269 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') 270 | # end 271 | 272 | return strKernel 273 | 274 | 275 | # end 276 | 277 | @cupy.memoize(for_each_device=True) 278 | def cupy_launch(strFunction, strKernel): 279 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) 280 | 281 | 282 | # end 283 | 284 | class _FunctionCorrelation(torch.autograd.Function): 285 | @staticmethod 286 | def forward(self, first, second): 287 | rbot0 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) 288 | rbot1 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) 289 | 290 | self.save_for_backward(first, second, rbot0, rbot1) 291 | 292 | assert (first.is_contiguous() == True) 293 | assert (second.is_contiguous() == True) 294 | 295 | output = first.new_zeros([first.shape[0], 81, first.shape[2], first.shape[3]]) 296 | 297 | if first.is_cuda == True: 298 | n = first.shape[2] * first.shape[3] 299 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 300 | 'input': first, 301 | 'output': rbot0 302 | }))( 303 | grid=tuple([int((n + 16 - 1) / 16), first.shape[1], first.shape[0]]), 304 | block=tuple([16, 1, 1]), 305 | args=[n, first.data_ptr(), rbot0.data_ptr()] 306 | ) 307 | 308 | n = second.shape[2] * second.shape[3] 309 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 310 | 'input': second, 311 | 'output': rbot1 312 | }))( 313 | grid=tuple([int((n + 16 - 1) / 16), second.shape[1], second.shape[0]]), 314 | block=tuple([16, 1, 1]), 315 | args=[n, second.data_ptr(), rbot1.data_ptr()] 316 | ) 317 | 318 | n = output.shape[1] * output.shape[2] * output.shape[3] 319 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { 320 | 'rbot0': rbot0, 321 | 'rbot1': rbot1, 322 | 'top': output 323 | }))( 324 | grid=tuple([output.shape[3], output.shape[2], output.shape[0]]), 325 | block=tuple([32, 1, 1]), 326 | shared_mem=first.shape[1] * 4, 327 | args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()] 328 | ) 329 | 330 | elif first.is_cuda == False: 331 | raise NotImplementedError() 332 | 333 | # end 334 | 335 | return output 336 | 337 | # end 338 | 339 | @staticmethod 340 | def backward(self, gradOutput): 341 | first, second, rbot0, rbot1 = self.saved_tensors 342 | 343 | assert (gradOutput.is_contiguous() == True) 344 | 345 | gradFirst = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ 346 | self.needs_input_grad[0] == True else None 347 | gradSecond = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ 348 | self.needs_input_grad[1] == True else None 349 | 350 | if first.is_cuda == True: 351 | if gradFirst is not None: 352 | for intSample in range(first.shape[0]): 353 | n = first.shape[1] * first.shape[2] * first.shape[3] 354 | cupy_launch('kernel_Correlation_updateGradFirst', 355 | cupy_kernel('kernel_Correlation_updateGradFirst', { 356 | 'rbot0': rbot0, 357 | 'rbot1': rbot1, 358 | 'gradOutput': gradOutput, 359 | 'gradFirst': gradFirst, 360 | 'gradSecond': None 361 | }))( 362 | grid=tuple([int((n + 512 - 1) / 512), 1, 1]), 363 | block=tuple([512, 1, 1]), 364 | args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), 365 | gradFirst.data_ptr(), None] 366 | ) 367 | # end 368 | # end 369 | 370 | if gradSecond is not None: 371 | for intSample in range(first.shape[0]): 372 | n = first.shape[1] * first.shape[2] * first.shape[3] 373 | cupy_launch('kernel_Correlation_updateGradSecond', 374 | cupy_kernel('kernel_Correlation_updateGradSecond', { 375 | 'rbot0': rbot0, 376 | 'rbot1': rbot1, 377 | 'gradOutput': gradOutput, 378 | 'gradFirst': None, 379 | 'gradSecond': gradSecond 380 | }))( 381 | grid=tuple([int((n + 512 - 1) / 512), 1, 1]), 382 | block=tuple([512, 1, 1]), 383 | args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, 384 | gradSecond.data_ptr()] 385 | ) 386 | # end 387 | # end 388 | 389 | elif first.is_cuda == False: 390 | raise NotImplementedError() 391 | 392 | # end 393 | 394 | return gradFirst, gradSecond 395 | 396 | 397 | # end 398 | # end 399 | 400 | def FunctionCorrelation(tenFirst, tenSecond): 401 | return _FunctionCorrelation.apply(tenFirst, tenSecond) 402 | 403 | 404 | # end 405 | 406 | class ModuleCorrelation(torch.nn.Module): 407 | def __init__(self): 408 | super(ModuleCorrelation, self).__init__() 409 | 410 | # end 411 | 412 | def forward(self, tenFirst, tenSecond): 413 | return _FunctionCorrelation.apply(tenFirst, tenSecond) 414 | # end 415 | # end -------------------------------------------------------------------------------- /pytorch_pwc/extract_flow.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | # im1_torch, im2_torch in shape (N, C, H, W) 5 | def extract_flow_torch(model, im1_torch, im2_torch): 6 | # interpolate image, make new_H, mew_W divide by 64 7 | assert im1_torch.shape == im2_torch.shape 8 | N, C, H, W = im1_torch.shape 9 | device = im1_torch.device 10 | new_H = int(math.floor(math.ceil(H / 64.0) * 64.0)) 11 | new_W = int(math.floor(math.ceil(W / 64.0) * 64.0)) 12 | im1_torch = torch.nn.functional.interpolate(input=im1_torch, size=(new_H, new_W), mode='bilinear', 13 | align_corners=False) 14 | im2_torch = torch.nn.functional.interpolate(input=im2_torch, size=(new_H, new_W), mode='bilinear', 15 | align_corners=False) 16 | model.eval() 17 | with torch.no_grad(): 18 | flo12 = model(im1_torch, im2_torch) 19 | flo12 = 20.0 * torch.nn.functional.interpolate(input=flo12, size=(H, W), mode='bilinear', 20 | align_corners=False) 21 | flo12[:, 0, :, :] *= float(W) / float(new_W) 22 | flo12[:, 1, :, :] *= float(H) / float(new_H) 23 | return flo12 24 | 25 | # im1_np, im2_np in shape (C, H, W) 26 | def extract_flow_np(model, im1_np, im2_np): 27 | im1_torch = torch.from_numpy(im1_np).unsqueeze(0).to(torch.device('cuda')) 28 | im2_torch = torch.from_numpy(im2_np).unsqueeze(0).to(torch.device('cuda')) 29 | flo12_torch = extract_flow_torch(model, im1_torch, im2_torch) 30 | flo12_np = flo12_torch.detach().cpu().squeeze(0).numpy() 31 | return flo12_np 32 | 33 | -------------------------------------------------------------------------------- /pytorch_pwc/pwc.py: -------------------------------------------------------------------------------- 1 | from .correlation import correlation # the custom cost volume layer 2 | import torch 3 | 4 | ########################################################## 5 | 6 | assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0 7 | 8 | # torch.set_grad_enabled(False) 9 | 10 | torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance 11 | 12 | arguments_strModel = 'default' 13 | 14 | backwarp_tenGrid = {} 15 | backwarp_tenPartial = {} 16 | 17 | def backwarp(tenInput, tenFlow): 18 | if (str(tenFlow.shape)+str(tenFlow.device)) not in backwarp_tenGrid: 19 | tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) 20 | tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) 21 | 22 | backwarp_tenGrid[str(tenFlow.shape) + str(tenFlow.device)] = torch.cat([ tenHor, tenVer ], 1).to(tenFlow.device) 23 | # end 24 | 25 | if (str(tenFlow.shape)+str(tenFlow.device)) not in backwarp_tenPartial: 26 | backwarp_tenPartial[str(tenFlow.shape)+str(tenFlow.device)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ]) 27 | # end 28 | 29 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 30 | tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)+str(tenFlow.device)] ], 1) 31 | 32 | tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape) + str(tenFlow.device)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False) 33 | 34 | tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0 35 | 36 | return tenOutput[:, :-1, :, :] * tenMask 37 | # end 38 | 39 | ########################################################## 40 | 41 | class PWCNet(torch.nn.Module): 42 | def __init__(self): 43 | super(PWCNet, self).__init__() 44 | 45 | class Extractor(torch.nn.Module): 46 | def __init__(self): 47 | super(Extractor, self).__init__() 48 | 49 | self.netOne = torch.nn.Sequential( 50 | torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), 51 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 52 | torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), 53 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 54 | torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), 55 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 56 | ) 57 | 58 | self.netTwo = torch.nn.Sequential( 59 | torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), 60 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 61 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 62 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 63 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 64 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 65 | ) 66 | 67 | self.netThr = torch.nn.Sequential( 68 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 69 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 70 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 71 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 72 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 73 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 74 | ) 75 | 76 | self.netFou = torch.nn.Sequential( 77 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), 78 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 79 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), 80 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 81 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), 82 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 83 | ) 84 | 85 | self.netFiv = torch.nn.Sequential( 86 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), 87 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 88 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 89 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 90 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 91 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 92 | ) 93 | 94 | self.netSix = torch.nn.Sequential( 95 | torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1), 96 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 97 | torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), 98 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 99 | torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), 100 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 101 | ) 102 | # end 103 | 104 | def forward(self, tenInput): 105 | tenOne = self.netOne(tenInput) 106 | tenTwo = self.netTwo(tenOne) 107 | tenThr = self.netThr(tenTwo) 108 | tenFou = self.netFou(tenThr) 109 | tenFiv = self.netFiv(tenFou) 110 | tenSix = self.netSix(tenFiv) 111 | 112 | return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ] 113 | # end 114 | # end 115 | 116 | class Decoder(torch.nn.Module): 117 | def __init__(self, intLevel): 118 | super(Decoder, self).__init__() 119 | 120 | intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1] 121 | intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0] 122 | 123 | if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1) 124 | if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1) 125 | if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1] 126 | 127 | self.netOne = torch.nn.Sequential( 128 | torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1), 129 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 130 | ) 131 | 132 | self.netTwo = torch.nn.Sequential( 133 | torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1), 134 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 135 | ) 136 | 137 | self.netThr = torch.nn.Sequential( 138 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1), 139 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 140 | ) 141 | 142 | self.netFou = torch.nn.Sequential( 143 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1), 144 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 145 | ) 146 | 147 | self.netFiv = torch.nn.Sequential( 148 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1), 149 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 150 | ) 151 | 152 | self.netSix = torch.nn.Sequential( 153 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1) 154 | ) 155 | # end 156 | 157 | def forward(self, tenFirst, tenSecond, objPrevious): 158 | tenFlow = None 159 | tenFeat = None 160 | 161 | if objPrevious is None: 162 | tenFlow = None 163 | tenFeat = None 164 | tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False) 165 | 166 | tenFeat = torch.cat([ tenVolume ], 1) 167 | 168 | elif objPrevious is not None: 169 | tenFlow = self.netUpflow(objPrevious['tenFlow']) 170 | tenFeat = self.netUpfeat(objPrevious['tenFeat']) 171 | 172 | tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False) 173 | 174 | tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1) 175 | 176 | # end 177 | 178 | tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1) 179 | tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1) 180 | tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1) 181 | tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1) 182 | tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1) 183 | 184 | tenFlow = self.netSix(tenFeat) 185 | 186 | return { 187 | 'tenFlow': tenFlow, 188 | 'tenFeat': tenFeat 189 | } 190 | # end 191 | # end 192 | 193 | class Refiner(torch.nn.Module): 194 | def __init__(self): 195 | super(Refiner, self).__init__() 196 | 197 | self.netMain = torch.nn.Sequential( 198 | torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1), 199 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 200 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2), 201 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 202 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4), 203 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 204 | torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8), 205 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 206 | torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16), 207 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 208 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1), 209 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 210 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1) 211 | ) 212 | # end 213 | 214 | def forward(self, tenInput): 215 | return self.netMain(tenInput) 216 | # end 217 | # end 218 | 219 | self.netExtractor = Extractor() 220 | 221 | self.netTwo = Decoder(2) 222 | self.netThr = Decoder(3) 223 | self.netFou = Decoder(4) 224 | self.netFiv = Decoder(5) 225 | self.netSix = Decoder(6) 226 | 227 | self.netRefiner = Refiner() 228 | 229 | self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-pwc/network-' + arguments_strModel + '.pytorch', file_name='pwc-' + arguments_strModel).items() }) 230 | # end 231 | 232 | def forward(self, tenFirst, tenSecond): 233 | tenFirst = self.netExtractor(tenFirst) 234 | tenSecond = self.netExtractor(tenSecond) 235 | 236 | objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) 237 | objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) 238 | objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) 239 | objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) 240 | objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) 241 | 242 | return objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat']) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Unidirectional Video Denoising by Mimicking Backward Recurrent Modules with Look-ahead Forward Ones 2 | This source code for our paper "Unidirectional Video Denoising by Mimicking Backward Recurrent Modules with Look-ahead Forward Ones" (ECCV 2022) 3 | ![overview](./imgs/figure_overview.png) 4 | 5 | ## Usage 6 | ### Dependencies 7 | You can create a conda environment with all the dependencies by running 8 | 9 | ```conda env create -f requirements.yaml -n ``` 10 | 11 | ### Datasets 12 | For synthetic gaussian noise, [DAVIS-2017-trainval-480p](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip) dataset is used for training, 13 | [DAVIS-2017-test-dev-480p](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-test-dev-480p.zip) and [Set8](https://www.dropbox.com/sh/20n4cscqkqsfgoj/AABGftyJuJDwuCLGczL-fKvBa/test_sequences?dl=0&subfolder_nav_tracking=1) are used for testing. 14 | For real world raw noise, [CRVD](https://github.com/cao-cong/RViDeNet#captured-raw-video-denoising-dataset-crvd-dataset) dataset is used for training and testing. 15 | 16 | ### Testing 17 | Download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1A854tOA6_qB14ax3JZ7bb7tLo0UovkyI?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1YomcegvdtoxVPr96odCo8w?pwd=ogua). 18 | We also provide denoised results (tractor from DAVIS-2017-test-dev-480p) for visual comparison. 19 | 1. For synthetic gaussian noise, 20 | ``` 21 | cd test_models 22 | python sRGB_test.py \ 23 | --model_file \ 24 | --test_path 25 | ``` 26 | 2. For real world raw noise, 27 | ``` 28 | cd test_models 29 | python CRVD_test.py \ 30 | --model_file \ 31 | --test_path 32 | ``` 33 | 34 | ### Training 35 | 1. For synthetic gaussian noise, 36 | ``` 37 | cd train_models 38 | python sRGB_train.py \ 39 | --trainset_dir \ 40 | --valset_dir \ 41 | --log_dir 42 | ``` 43 | 2. For real world raw noise, 44 | ``` 45 | cd train_models 46 | python CRVD_train.py \ 47 | --CRVD_dir \ 48 | --log_dir 49 | ``` 50 | 3. For distributed training of synthetic gaussian noise, 51 | ``` 52 | cd train_models 53 | python -m torch.distributed.launch --nproc_per_node=4 sRGB_train_distributed.py \ 54 | --trainset_dir \ 55 | --valset_dir \ 56 | --log_dir 57 | ``` 58 | 59 | ## Citation 60 | 61 | If you find our work useful in your research or publication, please cite: 62 | ``` 63 | @article{li2022unidirectional, 64 | title={Unidirectional Video Denoising by Mimicking Backward Recurrent Modules with Look-ahead Forward Ones}, 65 | author={Li, Junyi and Wu, Xiaohe and Niu, Zhenxing and Zuo, Wangmeng}, 66 | booktitle={ECCV}, 67 | year={2022} 68 | } 69 | ``` 70 | -------------------------------------------------------------------------------- /requirements.yaml: -------------------------------------------------------------------------------- 1 | name: video_denoising 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=10.2.89 8 | - cupy=8.2.0 9 | - imageio=2.9.0 10 | - mkl=2019.4 11 | - more-itertools=8.6.0 12 | - opencv=3.4.2 13 | - pip=20.0.2 14 | - pypng=0.0.20 15 | - python=3.7 16 | - pytorch=1.7.0 17 | - scikit-image=0.16.2 18 | - scipy=1.5.2 19 | - torchvision=0.8.1 20 | - pip: 21 | - future==0.18.2 22 | - tensorboardx==2.0 -------------------------------------------------------------------------------- /softmax_splatting/softsplat.py: -------------------------------------------------------------------------------- 1 | # borrowed from https://github.com/sniklaus/softmax-splatting 2 | 3 | import torch 4 | import cupy 5 | import re 6 | 7 | kernel_Softsplat_updateOutput = ''' 8 | extern "C" __global__ void kernel_Softsplat_updateOutput( 9 | const int n, 10 | const float* input, 11 | const float* flow, 12 | float* output 13 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 14 | const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); 15 | const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); 16 | const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); 17 | const int intX = ( intIndex ) % SIZE_3(output); 18 | 19 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 20 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 21 | 22 | int intNorthwestX = (int) (floor(fltOutputX)); 23 | int intNorthwestY = (int) (floor(fltOutputY)); 24 | int intNortheastX = intNorthwestX + 1; 25 | int intNortheastY = intNorthwestY; 26 | int intSouthwestX = intNorthwestX; 27 | int intSouthwestY = intNorthwestY + 1; 28 | int intSoutheastX = intNorthwestX + 1; 29 | int intSoutheastY = intNorthwestY + 1; 30 | 31 | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); 32 | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); 33 | float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); 34 | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); 35 | 36 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { 37 | atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); 38 | } 39 | 40 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { 41 | atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); 42 | } 43 | 44 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { 45 | atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); 46 | } 47 | 48 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { 49 | atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); 50 | } 51 | } } 52 | ''' 53 | 54 | kernel_Softsplat_updateGradInput = ''' 55 | extern "C" __global__ void kernel_Softsplat_updateGradInput( 56 | const int n, 57 | const float* input, 58 | const float* flow, 59 | const float* gradOutput, 60 | float* gradInput, 61 | float* gradFlow 62 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 63 | const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); 64 | const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); 65 | const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); 66 | const int intX = ( intIndex ) % SIZE_3(gradInput); 67 | 68 | float fltGradInput = 0.0; 69 | 70 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 71 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 72 | 73 | int intNorthwestX = (int) (floor(fltOutputX)); 74 | int intNorthwestY = (int) (floor(fltOutputY)); 75 | int intNortheastX = intNorthwestX + 1; 76 | int intNortheastY = intNorthwestY; 77 | int intSouthwestX = intNorthwestX; 78 | int intSouthwestY = intNorthwestY + 1; 79 | int intSoutheastX = intNorthwestX + 1; 80 | int intSoutheastY = intNorthwestY + 1; 81 | 82 | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); 83 | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); 84 | float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); 85 | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); 86 | 87 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { 88 | fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; 89 | } 90 | 91 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { 92 | fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; 93 | } 94 | 95 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { 96 | fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; 97 | } 98 | 99 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { 100 | fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; 101 | } 102 | 103 | gradInput[intIndex] = fltGradInput; 104 | } } 105 | ''' 106 | 107 | kernel_Softsplat_updateGradFlow = ''' 108 | extern "C" __global__ void kernel_Softsplat_updateGradFlow( 109 | const int n, 110 | const float* input, 111 | const float* flow, 112 | const float* gradOutput, 113 | float* gradInput, 114 | float* gradFlow 115 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 116 | float fltGradFlow = 0.0; 117 | 118 | const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); 119 | const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); 120 | const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); 121 | const int intX = ( intIndex ) % SIZE_3(gradFlow); 122 | 123 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 124 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 125 | 126 | int intNorthwestX = (int) (floor(fltOutputX)); 127 | int intNorthwestY = (int) (floor(fltOutputY)); 128 | int intNortheastX = intNorthwestX + 1; 129 | int intNortheastY = intNorthwestY; 130 | int intSouthwestX = intNorthwestX; 131 | int intSouthwestY = intNorthwestY + 1; 132 | int intSoutheastX = intNorthwestX + 1; 133 | int intSoutheastY = intNorthwestY + 1; 134 | 135 | float fltNorthwest = 0.0; 136 | float fltNortheast = 0.0; 137 | float fltSouthwest = 0.0; 138 | float fltSoutheast = 0.0; 139 | 140 | if (intC == 0) { 141 | fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY); 142 | fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY); 143 | fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); 144 | fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); 145 | 146 | } else if (intC == 1) { 147 | fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (-1.0)); 148 | fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); 149 | fltSouthwest = ((float) (intNortheastX) - fltOutputX) * ((float) (+1.0)); 150 | fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); 151 | 152 | } 153 | 154 | for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { 155 | float fltInput = VALUE_4(input, intN, intChannel, intY, intX); 156 | 157 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { 158 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; 159 | } 160 | 161 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { 162 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; 163 | } 164 | 165 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { 166 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; 167 | } 168 | 169 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { 170 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; 171 | } 172 | } 173 | 174 | gradFlow[intIndex] = fltGradFlow; 175 | } } 176 | ''' 177 | 178 | def cupy_kernel(strFunction, objVariables): 179 | strKernel = globals()[strFunction] 180 | 181 | while True: 182 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) 183 | 184 | if objMatch is None: 185 | break 186 | # end 187 | 188 | intArg = int(objMatch.group(2)) 189 | 190 | strTensor = objMatch.group(4) 191 | intSizes = objVariables[strTensor].size() 192 | 193 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) 194 | # end 195 | 196 | while True: 197 | objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) 198 | 199 | if objMatch is None: 200 | break 201 | # end 202 | 203 | intArgs = int(objMatch.group(2)) 204 | strArgs = objMatch.group(4).split(',') 205 | 206 | strTensor = strArgs[0] 207 | intStrides = objVariables[strTensor].stride() 208 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] 209 | 210 | strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') 211 | # end 212 | 213 | while True: 214 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) 215 | 216 | if objMatch is None: 217 | break 218 | # end 219 | 220 | intArgs = int(objMatch.group(2)) 221 | strArgs = objMatch.group(4).split(',') 222 | 223 | strTensor = strArgs[0] 224 | intStrides = objVariables[strTensor].stride() 225 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] 226 | 227 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') 228 | # end 229 | 230 | return strKernel 231 | # end 232 | 233 | @cupy.memoize(for_each_device=True) 234 | def cupy_launch(strFunction, strKernel): 235 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) 236 | # end 237 | 238 | class _FunctionSoftsplat(torch.autograd.Function): 239 | @staticmethod 240 | def forward(self, input, flow): 241 | self.save_for_backward(input, flow) 242 | 243 | intSamples = input.shape[0] 244 | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] 245 | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] 246 | 247 | assert(intFlowDepth == 2) 248 | assert(intInputHeight == intFlowHeight) 249 | assert(intInputWidth == intFlowWidth) 250 | 251 | input = input.contiguous(); assert(input.is_cuda == True) 252 | flow = flow.contiguous(); assert(flow.is_cuda == True) 253 | 254 | output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) 255 | 256 | if input.is_cuda == True: 257 | n = output.nelement() 258 | cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { 259 | 'input': input, 260 | 'flow': flow, 261 | 'output': output 262 | }))( 263 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 264 | block=tuple([ 512, 1, 1 ]), 265 | args=[ n, input.data_ptr(), flow.data_ptr(), output.data_ptr() ] 266 | ) 267 | 268 | elif input.is_cuda == False: 269 | raise NotImplementedError() 270 | 271 | # end 272 | 273 | return output 274 | # end 275 | 276 | @staticmethod 277 | def backward(self, gradOutput): 278 | input, flow = self.saved_tensors 279 | 280 | intSamples = input.shape[0] 281 | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] 282 | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] 283 | 284 | assert(intFlowDepth == 2) 285 | assert(intInputHeight == intFlowHeight) 286 | assert(intInputWidth == intFlowWidth) 287 | 288 | gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True) 289 | 290 | gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if self.needs_input_grad[0] == True else None 291 | gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ]) if self.needs_input_grad[1] == True else None 292 | 293 | if input.is_cuda == True: 294 | if gradInput is not None: 295 | n = gradInput.nelement() 296 | cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { 297 | 'input': input, 298 | 'flow': flow, 299 | 'gradOutput': gradOutput, 300 | 'gradInput': gradInput, 301 | 'gradFlow': gradFlow 302 | }))( 303 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 304 | block=tuple([ 512, 1, 1 ]), 305 | args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] 306 | ) 307 | # end 308 | 309 | if gradFlow is not None: 310 | n = gradFlow.nelement() 311 | cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { 312 | 'input': input, 313 | 'flow': flow, 314 | 'gradOutput': gradOutput, 315 | 'gradInput': gradInput, 316 | 'gradFlow': gradFlow 317 | }))( 318 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 319 | block=tuple([ 512, 1, 1 ]), 320 | args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] 321 | ) 322 | # end 323 | 324 | elif input.is_cuda == False: 325 | raise NotImplementedError() 326 | 327 | # end 328 | 329 | return gradInput, gradFlow 330 | # end 331 | # end 332 | 333 | def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): 334 | assert(tenMetric is None or tenMetric.shape[1] == 1) 335 | assert(strType in ['summation', 'average', 'linear', 'softmax']) 336 | 337 | if strType == 'average': 338 | tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) 339 | 340 | elif strType == 'linear': 341 | tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) 342 | 343 | elif strType == 'softmax': 344 | tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1) 345 | 346 | # end 347 | 348 | tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) 349 | 350 | if strType != 'summation': 351 | tenNormalize = tenOutput[:, -1:, :, :] 352 | 353 | tenNormalize[tenNormalize == 0.0] = 1.0 354 | 355 | tenOutput = tenOutput[:, :-1, :, :] / tenNormalize 356 | # end 357 | 358 | return tenOutput 359 | # end 360 | 361 | class ModuleSoftsplat(torch.nn.Module): 362 | def __init__(self, strType): 363 | super(ModuleSoftsplat, self).__init__() 364 | 365 | self.strType = strType 366 | # end 367 | 368 | def forward(self, tenInput, tenFlow, tenMetric=None): 369 | return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) 370 | # end 371 | # end 372 | -------------------------------------------------------------------------------- /test_models/CRVD_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import argparse 4 | from datasets import CRVDTestDataset 5 | import numpy as np 6 | import os 7 | from models import ISP, FloRNNRaw 8 | from skimage.measure.simple_metrics import compare_psnr 9 | from skimage.metrics import structural_similarity 10 | import torch 11 | import torch.nn as nn 12 | from utils.io import np2image_bgr 13 | 14 | def raw_ssim(pack1, pack2): 15 | test_raw_ssim = 0 16 | for i in range(4): 17 | test_raw_ssim += structural_similarity(pack1[i], pack2[i], data_range=1.0) 18 | return test_raw_ssim / 4 19 | 20 | def denoise_seq(seqn, a, b, model): 21 | T, C, H, W = seqn.shape 22 | a = a.expand((1, T, 1, H, W)).cuda() 23 | b = b.expand((1, T, 1, H, W)).cuda() 24 | seqdn = model(seqn.unsqueeze(0), a, b)[0] 25 | seqdn = torch.clamp(seqdn, 0, 1) 26 | return seqdn 27 | 28 | def main(**args): 29 | dataset_val = CRVDTestDataset(CRVD_path=args['crvd_dir']) 30 | isp = ISP().cuda() 31 | isp.load_state_dict(torch.load(args['isp_path'])['state_dict']) 32 | 33 | if args['model'] == 'FloRNNRaw': 34 | model = FloRNNRaw(img_channels=4, num_resblocks=args['num_resblocks'], forward_count=args['forward_count'], border_ratio=args['border_ratio']) 35 | 36 | state_temp_dict = torch.load(args['model_file'])['state_dict'] 37 | model = nn.DataParallel(model).cuda() 38 | model.load_state_dict(state_temp_dict) 39 | model.eval() 40 | 41 | iso_psnr, iso_ssim = {}, {} 42 | for data in dataset_val: 43 | 44 | # our channels: RGGB, RViDeNet channels: RGBG. we must pass RGBG pack to ISP as it's pretrained by RViDeNet 45 | seq = data['seq'].cuda() 46 | seqn = data['seqn'].cuda() 47 | 48 | with torch.no_grad(): 49 | seqdn = denoise_seq(seqn, data['a'], data['b'], model) 50 | seqn[:, 2:] = torch.flip(seqn[:, 2:], dims=[1]) 51 | seqdn[:, 2:] = torch.flip(seqdn[:, 2:], dims=[1]) 52 | seq[:, 2:] = torch.flip(seq[:, 2:], dims=[1]) 53 | 54 | seq_raw_psnr, seq_srgb_psnr, seq_raw_ssim, seq_srgb_ssim = 0, 0, 0, 0 55 | for i in range(seq.shape[0]): 56 | gt_raw_frame = seq[i].cpu().numpy() 57 | denoised_raw_frame = (np.uint16(seqdn[i].cpu().numpy() * (2 ** 12 - 1 - 240) + 240).astype(np.float32) - 240) / (2 ** 12 - 1 - 240) 58 | with torch.no_grad(): 59 | gt_srgb_frame = np.uint8(np.clip(isp(seq[i:i+1]).cpu().numpy()[0], 0, 1) * 255).astype(np.float32) / 255. 60 | denoised_srgb_frame = np.uint8(np.clip(isp(seqdn[i:i+1]).cpu().numpy()[0], 0, 1) * 255).astype(np.float32) / 255. 61 | 62 | seq_raw_psnr += compare_psnr(gt_raw_frame, denoised_raw_frame, data_range=1.0) 63 | seq_srgb_psnr += compare_psnr(gt_srgb_frame, denoised_srgb_frame, data_range=1.0) 64 | seq_raw_ssim += raw_ssim(gt_raw_frame, denoised_raw_frame) 65 | seq_srgb_ssim += structural_similarity(np.transpose(gt_srgb_frame, (1, 2, 0)), np.transpose(denoised_srgb_frame, (1, 2, 0)), 66 | data_range=1.0, multichannel=True) 67 | 68 | seq_raw_psnr /= seq.shape[0] 69 | seq_srgb_psnr /= seq.shape[0] 70 | seq_raw_ssim /= seq.shape[0] 71 | seq_srgb_ssim /= seq.shape[0] 72 | 73 | if (str(data['iso'])+'raw') not in iso_psnr.keys(): 74 | iso_psnr[str(data['iso'])+'raw'] = seq_raw_psnr / 5 75 | iso_psnr[str(data['iso'])+'srgb'] = seq_srgb_psnr / 5 76 | iso_ssim[str(data['iso'])+'raw'] = seq_raw_ssim / 5 77 | iso_ssim[str(data['iso'])+'srgb'] = seq_srgb_ssim / 5 78 | else: 79 | iso_psnr[str(data['iso'])+'raw'] += seq_raw_psnr / 5 80 | iso_psnr[str(data['iso']) + 'srgb'] += seq_srgb_psnr / 5 81 | iso_ssim[str(data['iso']) + 'raw'] += seq_raw_ssim / 5 82 | iso_ssim[str(data['iso']) + 'srgb'] += seq_srgb_ssim / 5 83 | 84 | dataset_raw_psnr, dataset_srgb_psnr, dataset_raw_ssim, dataset_srgb_ssim = 0, 0, 0, 0 85 | for iso in [1600, 3200, 6400, 12800, 25600]: 86 | print('iso %d, raw: %6.4f/%6.4f, srgb: %6.4f/%6.4f' % (iso, iso_psnr[str(iso)+'raw'], iso_ssim[str(iso)+'raw'], 87 | iso_psnr[str(iso)+'srgb'], iso_ssim[str(iso)+'srgb'])) 88 | dataset_raw_psnr += iso_psnr[str(iso)+'raw'] 89 | dataset_srgb_psnr += iso_psnr[str(iso)+'srgb'] 90 | dataset_raw_ssim += iso_ssim[str(iso)+'raw'] 91 | dataset_srgb_ssim += iso_ssim[str(iso)+'srgb'] 92 | 93 | print('CRVD, raw: %6.4f/%6.4f, srgb: %6.4f/%6.4f' % (dataset_raw_psnr / 5, dataset_raw_ssim / 5, dataset_srgb_psnr / 5, dataset_srgb_ssim / 5)) 94 | 95 | if __name__ == '__main__': 96 | parser = argparse.ArgumentParser(description="test raw model") 97 | parser.add_argument("--model", type=str, default='FloRNNRaw') # model in ['FloRNNRaw'] 98 | parser.add_argument("--num_resblocks", type=int, default=15) 99 | parser.add_argument("--forward_count", type=int, default=3) 100 | parser.add_argument("--border_ratio", type=float, default=0.1) 101 | parser.add_argument("--model_file", type=str, default='/home/nagejacob/Documents/codes/VDN/logs/ours_raw/ckpt_e12.pth') 102 | parser.add_argument("--crvd_dir", type=str, default="/hdd/Documents/datasets/CRVD") 103 | parser.add_argument("--isp_path", type=str, default="../models/rvidenet/isp.pth") 104 | argspar = parser.parse_args() 105 | 106 | print("\n### Testing model ###") 107 | print("> Parameters:") 108 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()): 109 | print('\t{}: {}'.format(p, v)) 110 | print('\n') 111 | 112 | main(**vars(argspar)) -------------------------------------------------------------------------------- /test_models/sRGB_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import argparse 4 | from datasets import SrgbValDataset 5 | from models import ForwardRNN, BiRNN, FloRNN, BasicVSRPlusPlus 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | from utils.fastdvdnet_utils import fastdvdnet_batch_psnr 10 | from utils.ssim import batch_ssim 11 | 12 | def count_params(model): 13 | params = sum(p.numel() for p in model.parameters()) 14 | print(params / 1000 / 1000) 15 | return params 16 | 17 | def denoise_seq(seqn, noise_std, model): 18 | 19 | # init arrays to handle contiguous frames and related patches 20 | numframes, C, H, W = seqn.shape 21 | 22 | # build noise map from noise std---assuming Gaussian noise 23 | noise_level_map = noise_std.expand((numframes, C, H, W)).cuda() 24 | 25 | with torch.no_grad(): 26 | denframes = model(seqn.unsqueeze(0), noise_level_map.unsqueeze(0)) 27 | 28 | denframes = torch.clamp(denframes.squeeze(0), 0., 1.) 29 | 30 | # free memory up 31 | del noise_level_map 32 | torch.cuda.empty_cache() 33 | 34 | # convert to appropiate type and return 35 | return denframes 36 | 37 | def test(**args): 38 | test_set = SrgbValDataset(args['test_path'], num_input_frames=args['max_num_fr_per_seq']) 39 | 40 | if args['model'] == 'ForwardRNN': 41 | model = ForwardRNN(img_channels=3, num_resblocks=args['num_resblocks']) 42 | elif args['model'] == 'BiRNN': 43 | model = BiRNN(img_channels=3, num_resblocks=args['num_resblocks']) 44 | elif args['model'] == 'FloRNN': 45 | model = FloRNN(img_channels=3, num_resblocks=args['num_resblocks'], forward_count=args['forward_count'], border_ratio=args['border_ratio']) 46 | elif args['model'] == 'BasicVSRPlusPlus': 47 | model = BasicVSRPlusPlus(img_channels=3, spatial_blocks=6, temporal_blocks=6, num_channels=64) 48 | 49 | state_temp_dict = torch.load(args['model_file'])['state_dict'] 50 | model = nn.DataParallel(model).cuda() 51 | model.load_state_dict(state_temp_dict) 52 | model = model.module 53 | model.eval() 54 | 55 | dataset_psnr, dataset_ssim, seq_count = 0, 0, 0 56 | total_time, total_frames = 0, 0 57 | for data in test_set: 58 | seq = data['seq'] 59 | 60 | # Add noise 61 | torch.manual_seed(0) 62 | noise = torch.empty_like(seq).normal_(mean=0, std=args['noise_sigma']) 63 | seqn = seq + noise 64 | noise_std = torch.FloatTensor([args['noise_sigma']]) 65 | seqn = seqn.contiguous() 66 | 67 | torch.cuda.synchronize() 68 | start_time = time.time() 69 | with torch.no_grad(): 70 | denframes = denoise_seq(seqn, noise_std=noise_std, model=model) 71 | 72 | torch.cuda.synchronize() 73 | total_time += time.time() - start_time 74 | total_frames += seqn.shape[0] 75 | 76 | psnr = fastdvdnet_batch_psnr(denframes, seq, 1.) 77 | ssim = batch_ssim(denframes, seq, 1.) 78 | dataset_psnr += psnr 79 | dataset_ssim += ssim 80 | seq_count += 1 81 | name = data['name'].split('/')[-2] + '/' + data['name'].split('/')[-1] 82 | print('{0:50}:, PSNR: {1:.4f}dB, SSIM: {2:.4f}'.format(name, psnr, ssim)) 83 | if args['display_time']: 84 | print('frames: %d, time/frame: %6.4f s' % (total_frames, total_time / total_frames)) 85 | 86 | print('sigma %d, PSNR: %6.4f, SSIM: %6.4f' % (int(round(args['noise_sigma'] * 255.)), 87 | dataset_psnr/seq_count, dataset_ssim/seq_count)) 88 | 89 | if __name__ == "__main__": 90 | # Parse arguments 91 | parser = argparse.ArgumentParser(description="test sRGB model") 92 | parser.add_argument("--model", type=str, default='BasicVSRPlusPlus') # model in ['ForwardRNN', 'BiRNN', 'FloRNN', 'BasciVSRPlusPlus'] 93 | parser.add_argument("--num_resblocks", type=int, default=15) 94 | parser.add_argument("--forward_count", type=int, default=3) 95 | parser.add_argument("--border_ratio", type=float, default=0.1) 96 | parser.add_argument("--model_file", type=str, default='/home/nagejacob/Documents/codes/VDN/logs/basicvsr_plusplus/ckpt_e12.pth') 97 | parser.add_argument("--test_path", type=str, default="/hdd/Documents/datasets/Set8") 98 | parser.add_argument("--max_num_fr_per_seq", type=int, default=85) 99 | parser.add_argument("--noise_sigma", type=float, default=20, help='noise level used on test_models set') 100 | parser.add_argument("--display_time", type=bool, default=False) 101 | argspar = parser.parse_args() 102 | # Normalize noises ot [0, 1] 103 | argspar.noise_sigma /= 255. 104 | 105 | 106 | print("\n### Testing model ###") 107 | print("> Parameters:") 108 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()): 109 | print('\t{}: {}'.format(p, v)) 110 | print('\n') 111 | 112 | for sigma in [10, 20, 30, 40, 50]: 113 | argspar.noise_sigma = sigma / 255. 114 | print('sigma=%d' % sigma) 115 | dataset_psnr = test(**vars(argspar)) 116 | -------------------------------------------------------------------------------- /train_models/CRVD_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import argparse 4 | from datasets import CRVDTrainDataset, CRVDTestDataset 5 | from models import FloRNNRaw 6 | import numpy as np 7 | import os 8 | from skimage.measure.simple_metrics import compare_psnr 9 | import time 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from train_models.base_functions import batch_psnr, resume_training, save_model 13 | from utils.io import log 14 | 15 | torch.backends.cudnn.benchmark = True 16 | 17 | def main(**args): 18 | dataset_train = CRVDTrainDataset(CRVD_path=args['CRVD_dir'], 19 | patch_size=args['patch_size'], 20 | patches_per_epoch=args['patches_per_epoch'], 21 | mirror_seq=args['mirror_seq']) 22 | loader_train = DataLoader(dataset=dataset_train, batch_size=args['batch_size'], num_workers=4, shuffle=True, drop_last=True) 23 | dataset_val = CRVDTestDataset(CRVD_path=args['CRVD_dir']) 24 | 25 | if args['model'] == 'FloRNNRaw': 26 | model = FloRNNRaw(img_channels=4, num_resblocks=args['num_resblocks'], forward_count=args['forward_count'], 27 | border_ratio=args['border_ratio']) 28 | model = torch.nn.DataParallel(model).cuda() 29 | 30 | criterion = torch.nn.MSELoss(reduction='sum').cuda() 31 | optimizer = torch.optim.Adam(model.module.trainable_parameters(), lr=args['lr']) 32 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args['milestones'], gamma=0.1) 33 | 34 | start_epoch = resume_training(args, model, optimizer, scheduler) 35 | for epoch in range(start_epoch, args['epochs']): 36 | start_time = time.time() 37 | 38 | # training 39 | model.train() 40 | for i, data in enumerate(loader_train): 41 | seq = data['seq'].cuda() 42 | N, T, C, H, W = seq.shape 43 | 44 | seqn = data['seqn'].cuda() 45 | a = data['a'].expand((N, T, 1, H, W)).cuda() 46 | b = data['b'].expand((N, T, 1, H, W)).cuda() 47 | 48 | seqdn = model(seqn, a, b) 49 | 50 | if args['model'] in ['FloRNNRaw']: 51 | end_index = -1 if (args['forward_count'] == -1) else (-args['forward_count']) 52 | loss = criterion(seq[:, 1:end_index], seqdn[:, 1:end_index]) / (N * 2) 53 | else: 54 | loss = criterion(seq, seqdn) / (N * 2) 55 | 56 | optimizer.zero_grad() 57 | loss.backward() 58 | optimizer.step() 59 | 60 | if (i+1) % args['print_every'] == 0: 61 | train_psnr = torch.mean(batch_psnr(seq, seqdn)).item() 62 | log(args["log_file"], "[epoch {}][{}/{}] loss: {:1.4f} PSNR_train: {:1.4f}\n". \ 63 | format(epoch + 1, i + 1, int(args['patches_per_epoch'] // args['batch_size']), loss.item(), train_psnr)) 64 | 65 | scheduler.step() 66 | 67 | # evaluating 68 | model.eval() 69 | iso_psnr = {} 70 | for data in dataset_val: 71 | seq = data['seq'] 72 | T, C, H, W = seq.shape 73 | 74 | seqn = data['seqn'].cuda() 75 | a = data['a'].expand((T, 1, H, W)).cuda() 76 | b = data['b'].expand((T, 1, H, W)).cuda() 77 | 78 | with torch.no_grad(): 79 | seqdn = torch.clamp(model(seqn.unsqueeze(0), a.unsqueeze(0), b.unsqueeze(0)).squeeze(0), 0., 1.) 80 | 81 | # calculate psnr the same as RViDeNet 82 | seq_psnr = 0 83 | for i in range(T): 84 | seq_psnr += compare_psnr(seq[i].numpy(), 85 | (np.uint16(seqdn[i].cpu().numpy() * (2 ** 12 - 1 - 240) + 240).astype(np.float32) - 240) / (2 ** 12 - 1 - 240), 86 | data_range=1.0) 87 | seq_psnr /= T 88 | 89 | if str(data['iso']) not in iso_psnr.keys(): 90 | iso_psnr[str(data['iso'])] = seq_psnr 91 | else: 92 | iso_psnr[str(data['iso'])] += seq_psnr 93 | dataset_psnr = 0 94 | for iso in [1600, 3200, 6400, 12800, 25600]: 95 | log(args['log_file'], 'iso %d, %6.4f\n' % (iso, iso_psnr[str(iso)] / 5)) 96 | dataset_psnr += iso_psnr[str(iso)] / 5 97 | dataset_psnr = dataset_psnr / 5 98 | 99 | log(args["log_file"], "\n[epoch %d] PSNR_val: %.4f, %0.2f hour/epoch\n\n" % (epoch + 1, dataset_psnr, (time.time()-start_time)/3600)) 100 | 101 | # save model 102 | save_model(args, model, optimizer, scheduler, epoch + 1) 103 | 104 | 105 | if __name__ == '__main__': 106 | parser = argparse.ArgumentParser(description="Train the denoiser") 107 | 108 | # Model parameters 109 | parser.add_argument("--model", type=str, default='FloRNNRaw') 110 | parser.add_argument("--num_resblocks", type=int, default=15) 111 | parser.add_argument("--forward_count", type=int, default=3) 112 | parser.add_argument("--border_ratio", type=float, default=0.1) 113 | 114 | # Training parameters 115 | parser.add_argument("--batch_size", type=int, default=16) 116 | parser.add_argument("--epochs", "--e", type=int, default=12) 117 | parser.add_argument("--milestones", nargs=1, type=int, default=[11]) 118 | parser.add_argument("--lr", type=float, default=1e-4) 119 | parser.add_argument("--print_every", type=int, default=100) 120 | parser.add_argument("--patch_size", "--p", type=int, default=96, help="Patch size") 121 | parser.add_argument("--patches_per_epoch", "--n", type=int, default=256000, help="Number of patches") 122 | parser.add_argument("--mirror_seq", type=bool, default=True) 123 | 124 | # Paths 125 | parser.add_argument("--CRVD_dir", type=str, default='/hdd/Documents/datasets/CRVD') 126 | parser.add_argument("--log_dir", type=str, default="../logs/FloRNNRaw") 127 | argspar = parser.parse_args() 128 | 129 | argspar.log_file = os.path.join(argspar.log_dir, 'log.out') 130 | 131 | if not os.path.exists(argspar.log_dir): 132 | os.makedirs(argspar.log_dir) 133 | log(argspar.log_file, "\n### Training the denoiser ###\n") 134 | log(argspar.log_file, "> Parameters:\n") 135 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()): 136 | log(argspar.log_file, '\t{}: {}\n'.format(p, v)) 137 | 138 | main(**vars(argspar)) -------------------------------------------------------------------------------- /train_models/base_functions.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import os 4 | import re 5 | import torch 6 | from utils.io import log 7 | 8 | def resume_training(args, model, optimizer, scheduler): 9 | """ Resumes previous training or starts anew 10 | """ 11 | model_files = glob.glob(os.path.join(args['log_dir'], '*.pth')) 12 | 13 | if len(model_files) == 0: 14 | start_epoch = 0 15 | else: 16 | log(args.log_file, "> Resuming previous training\n") 17 | epochs_exist = [] 18 | for model_file in model_files: 19 | result = re.findall('ckpt_e(.*).pth', model_file) 20 | epochs_exist.append(int(result[0])) 21 | max_epoch = max(epochs_exist) 22 | max_epoch_model_file = os.path.join(args['log_dir'], 'ckpt_e%d.pth' % max_epoch) 23 | checkpoint = torch.load(max_epoch_model_file) 24 | model.load_state_dict(checkpoint['state_dict']) 25 | optimizer.load_state_dict(checkpoint['optimizer']) 26 | scheduler.load_state_dict(checkpoint['scheduler']) 27 | 28 | start_epoch = max_epoch 29 | 30 | return start_epoch 31 | 32 | def save_model(args, model, optimizer, scheduler, epoch): 33 | save_dict = { 34 | 'args': args, 35 | 'state_dict': model.state_dict(), 36 | 'optimizer' : optimizer.state_dict(), 37 | 'scheduler': scheduler.state_dict()} 38 | 39 | torch.save(save_dict, os.path.join(args['log_dir'], 'ckpt_e{}.pth'.format(epoch))) 40 | 41 | # the same as skimage.metrics.peak_signal_noise_ratio 42 | def batch_psnr(a, b): 43 | a = torch.clamp(a, 0, 1) 44 | b = torch.clamp(b, 0, 1) 45 | x = torch.mean((a - b) ** 2, dim=[-3, -2, -1]) 46 | return 20 * torch.log(1 / torch.sqrt(x)) / math.log(10) 47 | -------------------------------------------------------------------------------- /train_models/sRGB_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import argparse 4 | from datasets import SrgbTrainDataset, SrgbValDataset 5 | from models import ForwardRNN, BiRNN, FloRNN 6 | import os 7 | import time 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from train_models.base_functions import resume_training, save_model 11 | from utils.fastdvdnet_utils import fastdvdnet_batch_psnr, normalize_augment 12 | from utils.io import log 13 | 14 | torch.backends.cudnn.benchmark = True 15 | 16 | def main(**args): 17 | dataset_train = SrgbTrainDataset(seq_dir=args['trainset_dir'], 18 | train_length=args['train_length'], 19 | patch_size=args['patch_size'], 20 | patches_per_epoch=args['patches_per_epoch'], 21 | image_postfix='jpg', 22 | pin_memory=True) 23 | loader_train = DataLoader(dataset=dataset_train, batch_size=args['batch_size'], num_workers=4, shuffle=True, drop_last=True) 24 | dataset_val = SrgbValDataset(valsetdir=args['valset_dir']) 25 | loader_val = DataLoader(dataset=dataset_val, batch_size=1) 26 | 27 | if args['model'] == 'ForwardRNN': 28 | model = ForwardRNN(img_channels=3, num_resblocks=args['num_resblocks']) 29 | elif args['model'] == 'BiRNN': 30 | model = BiRNN(img_channels=3, num_resblocks=args['num_resblocks']) 31 | elif args['model'] == 'FloRNN': 32 | model = FloRNN(img_channels=3, num_resblocks=args['num_resblocks'], forward_count=args['forward_count'], 33 | border_ratio=args['border_ratio']) 34 | model = torch.nn.DataParallel(model).cuda() 35 | 36 | criterion = torch.nn.MSELoss(reduction='sum').cuda() 37 | optimizer = torch.optim.Adam(model.module.trainable_parameters(), lr=args['lr']) 38 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args['milestones'], gamma=0.1) 39 | 40 | start_epoch = resume_training(args, model, optimizer, scheduler) 41 | for epoch in range(start_epoch, args['epochs']): 42 | start_time = time.time() 43 | 44 | # training 45 | model.train() 46 | for i, data in enumerate(loader_train): 47 | seq = data['data'].cuda() 48 | seq = normalize_augment(seq) 49 | 50 | N, T, C, H, W = seq.shape 51 | stdn = torch.empty((N, 1, 1, 1, 1)).cuda().uniform_(args['noise_ival'][0], to=args['noise_ival'][1]) 52 | noise_level_map = stdn.expand_as(seq) 53 | 54 | noise = torch.normal(mean=torch.zeros_like(seq), std=noise_level_map) 55 | seqn = seq + noise 56 | seqdn = model(seqn, noise_level_map) 57 | 58 | if args['model'] in ['FloRNN']: 59 | end_index = -1 if (args['forward_count'] == -1) else (-args['forward_count']) 60 | loss = criterion(seq[:, 1:end_index], seqdn[:, 1:end_index]) / (N * 2) 61 | else: 62 | loss = criterion(seq, seqdn) / (N * 2) 63 | 64 | loss.backward() 65 | optimizer.step() 66 | optimizer.zero_grad() 67 | 68 | if (i+1) % args['print_every'] == 0: 69 | train_psnr = fastdvdnet_batch_psnr(seq, seqdn) 70 | log(args["log_file"], "[epoch {}][{}/{}] loss: {:1.4f} PSNR_train: {:1.4f}\n". \ 71 | format(epoch + 1, i + 1, int(args['patches_per_epoch'] // args['batch_size']), loss.item(), train_psnr)) 72 | 73 | scheduler.step() 74 | 75 | # evaluating 76 | model.eval() 77 | psnr_val = 0 78 | for i, data in enumerate(loader_val): 79 | seq = data['seq'].cuda() 80 | 81 | torch.manual_seed(0) 82 | stdn = torch.FloatTensor([args['val_noiseL']]) 83 | noise_level_map = stdn.expand_as(seq) 84 | noise = torch.empty_like(seq).normal_(mean=0, std=args['val_noiseL']) 85 | seqn = seq + noise 86 | 87 | with torch.no_grad(): 88 | seqdn = model(seqn, noise_level_map) 89 | psnr_val += fastdvdnet_batch_psnr(seq, seqdn) 90 | 91 | psnr_val = psnr_val / len(dataset_val) 92 | log(args["log_file"], "\n[epoch %d] PSNR_val: %.4f, %0.2f hour/epoch\n\n" % (epoch + 1, psnr_val, (time.time()-start_time)/3600)) 93 | 94 | # save model 95 | save_model(args, model, optimizer, scheduler, epoch + 1) 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser(description="Train the denoiser") 100 | 101 | # Model parameters 102 | parser.add_argument("--model", type=str, default='FloRNN') 103 | parser.add_argument("--num_resblocks", type=int, default=15) 104 | parser.add_argument("--forward_count", type=int, default=3) 105 | parser.add_argument("--border_ratio", type=float, default=0.1) 106 | 107 | # Training parameters 108 | parser.add_argument("--batch_size", type=int, default=8) 109 | parser.add_argument("--epochs", "--e", type=int, default=12) 110 | parser.add_argument("--milestones", nargs=1, type=int, default=[11]) 111 | parser.add_argument("--lr", type=float, default=1e-4) 112 | parser.add_argument("--print_every", type=int, default=100) 113 | parser.add_argument("--noise_ival", nargs=2, type=int, default=[0, 55]) 114 | parser.add_argument("--val_noiseL", type=float, default=20) 115 | parser.add_argument("--patch_size", "--p", type=int, default=96, help="Patch size") 116 | parser.add_argument("--patches_per_epoch", "--n", type=int, default=128000, help="Number of patches") 117 | 118 | # Paths 119 | parser.add_argument("--trainset_dir", type=str, default='/hdd/Documents/datasets/DAVIS-2017-trainval-480p') 120 | parser.add_argument("--valset_dir", type=str, default='/hdd/Documents/datasets/Set8') 121 | parser.add_argument("--log_dir", type=str, default="../logs/FloRNN") 122 | argspar = parser.parse_args() 123 | 124 | argspar.log_file = os.path.join(argspar.log_dir, 'log.out') 125 | argspar.train_length = 10 if (argspar.forward_count == -1) else (8 + argspar.forward_count) 126 | 127 | # Normalize noise between [0, 1] 128 | argspar.val_noiseL /= 255. 129 | argspar.noise_ival[0] /= 255. 130 | argspar.noise_ival[1] /= 255. 131 | 132 | if not os.path.exists(argspar.log_dir): 133 | os.makedirs(argspar.log_dir) 134 | log(argspar.log_file, "\n### Training the denoiser ###\n") 135 | log(argspar.log_file, "> Parameters:\n") 136 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()): 137 | log(argspar.log_file, '\t{}: {}\n'.format(p, v)) 138 | 139 | main(**vars(argspar)) -------------------------------------------------------------------------------- /train_models/sRGB_train_distributed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import argparse 4 | from datasets import SrgbTrainDataset, SrgbValDataset 5 | from models import BasicVSRPlusPlus 6 | import os 7 | import time 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from train_models.base_functions import resume_training, save_model 11 | from utils.fastdvdnet_utils import fastdvdnet_batch_psnr, normalize_augment 12 | from utils.io import log 13 | 14 | torch.backends.cudnn.benchmark = True 15 | 16 | def main(**args): 17 | torch.cuda.set_device(args['local_rank']) 18 | torch.distributed.init_process_group(backend='nccl', init_method=args['init_method'], rank=args['local_rank'], world_size=args['world_size']) 19 | 20 | dataset_train = SrgbTrainDataset(seq_dir=args['trainset_dir'], 21 | train_length=args['train_length'], 22 | patch_size=args['patch_size'], 23 | patches_per_epoch=args['patches_per_epoch'], 24 | image_postfix='jpg', 25 | pin_memory=True) 26 | sampler_train = torch.utils.data.distributed.DistributedSampler(dataset=dataset_train, shuffle=True) 27 | loader_train = DataLoader(dataset=dataset_train, batch_size=args['batch_size'], sampler=sampler_train, num_workers=4, drop_last=True) 28 | dataset_val = SrgbValDataset(valsetdir=args['valset_dir']) 29 | loader_val = DataLoader(dataset=dataset_val, batch_size=1) 30 | 31 | if args['model'] == 'BasicVSRPlusPlus': 32 | model = BasicVSRPlusPlus(img_channels=3, spatial_blocks=6, temporal_blocks=6, num_channels=64) 33 | model = model.to(torch.device('cuda', args['local_rank'])) 34 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args['local_rank']], output_device=args['local_rank'], find_unused_parameters=True) 35 | 36 | criterion = torch.nn.MSELoss(reduction='sum').to(torch.device('cuda', args['local_rank'])) 37 | optimizer = torch.optim.Adam(model.module.trainable_parameters(), lr=args['lr']) 38 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args['milestones'], gamma=0.1) 39 | 40 | start_epoch = resume_training(args, model, optimizer, scheduler) 41 | for epoch in range(start_epoch, args['epochs']): 42 | sampler_train.set_epoch(epoch) 43 | start_time = time.time() 44 | 45 | # training 46 | model.train() 47 | for i, data in enumerate(loader_train): 48 | seq = data['data'].to(torch.device('cuda', args['local_rank'])) 49 | seq = normalize_augment(seq) 50 | 51 | N, T, C, H, W = seq.shape 52 | stdn = torch.empty((N, 1, 1, 1, 1)).to(torch.device('cuda', args['local_rank'])).uniform_(args['noise_ival'][0], to=args['noise_ival'][1]) 53 | noise_level_map = stdn.expand_as(seq) 54 | 55 | noise = torch.normal(mean=torch.zeros_like(seq), std=noise_level_map) 56 | seqn = seq + noise 57 | seqdn = model(seqn, noise_level_map) 58 | 59 | if args['model'] in ['FloRNN']: 60 | end_index = -1 if (args['forward_count'] == -1) else (-args['forward_count']) 61 | loss = criterion(seq[:, 1:end_index], seqdn[:, 1:end_index]) / (N * 2) 62 | else: 63 | loss = criterion(seq, seqdn) / (N * 2) 64 | 65 | loss.backward() 66 | optimizer.step() 67 | optimizer.zero_grad() 68 | 69 | if (i+1) % args['print_every'] == 0 and args['local_rank'] == 0: 70 | train_psnr = fastdvdnet_batch_psnr(seq, seqdn) 71 | log(args["log_file"], "[epoch {}][{}/{}] loss: {:1.4f} PSNR_train: {:1.4f}\n". \ 72 | format(epoch + 1, i + 1, int(args['patches_per_epoch'] // args['batch_size'] // args['world_size']), loss.item(), train_psnr)) 73 | 74 | scheduler.step() 75 | 76 | # evaluating 77 | if args['local_rank'] == 0: 78 | model.eval() 79 | psnr_val = 0 80 | for i, data in enumerate(loader_val): 81 | seq = data['seq'] 82 | 83 | torch.manual_seed(0) 84 | stdn = torch.FloatTensor([args['val_noiseL']]) 85 | noise_level_map = stdn.expand_as(seq) 86 | noise = torch.empty_like(seq).normal_(mean=0, std=args['val_noiseL']) 87 | seqn = seq + noise 88 | 89 | with torch.no_grad(): 90 | seqdn = model(seqn, noise_level_map) 91 | psnr_val += fastdvdnet_batch_psnr(seq, seqdn) 92 | 93 | psnr_val = psnr_val / len(dataset_val) 94 | log(args["log_file"], "\n[epoch %d] PSNR_val: %.4f, %0.2f hour/epoch\n\n" % (epoch + 1, psnr_val, (time.time()-start_time)/3600)) 95 | 96 | # save model 97 | save_model(args, model, optimizer, scheduler, epoch + 1) 98 | 99 | 100 | if __name__ == '__main__': 101 | parser = argparse.ArgumentParser(description="Train the denoiser") 102 | parser.add_argument("--local_rank", type=int, default=0) 103 | 104 | # Model parameters 105 | parser.add_argument("--model", type=str, default='BasicVSRPlusPlus') 106 | 107 | # Training parameters 108 | parser.add_argument("--batch_size", type=int, default=8) 109 | parser.add_argument("--world_size", type=int, default=4) 110 | parser.add_argument("--init_method", default='tcp://127.0.0.1:25000') 111 | parser.add_argument("--epochs", "--e", type=int, default=12) 112 | parser.add_argument("--milestones", nargs=1, type=int, default=[11]) 113 | parser.add_argument("--lr", type=float, default=1e-4) 114 | parser.add_argument("--print_every", type=int, default=100) 115 | parser.add_argument("--noise_ival", nargs=2, type=int, default=[0, 55]) 116 | parser.add_argument("--val_noiseL", type=float, default=20) 117 | parser.add_argument("--patch_size", "--p", type=int, default=96, help="Patch size") 118 | parser.add_argument("--patches_per_epoch", "--n", type=int, default=128000, help="Number of patches") 119 | 120 | # Paths 121 | parser.add_argument("--trainset_dir", type=str, default='/mnt/disk10T/Documents/datasets/DAVIS-2017-trainval-480p') 122 | parser.add_argument("--valset_dir", type=str, default='/mnt/disk10T/Documents/datasets/Set8') 123 | parser.add_argument("--log_dir", type=str, default="../logs/BiRNN_plusplus") 124 | argspar = parser.parse_args() 125 | 126 | argspar.log_file = os.path.join(argspar.log_dir, 'log.out') 127 | argspar.train_length = 10 128 | argspar.batch_size = argspar.batch_size // argspar.world_size 129 | 130 | # Normalize noise between [0, 1] 131 | argspar.val_noiseL /= 255. 132 | argspar.noise_ival[0] /= 255. 133 | argspar.noise_ival[1] /= 255. 134 | 135 | if argspar.local_rank == 0: 136 | if not os.path.exists(argspar.log_dir): 137 | os.makedirs(argspar.log_dir) 138 | log(argspar.log_file, "\n### Training the denoiser ###\n") 139 | log(argspar.log_file, "> Parameters:\n") 140 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()): 141 | log(argspar.log_file, '\t{}: {}\n'.format(p, v)) 142 | 143 | main(**vars(argspar)) -------------------------------------------------------------------------------- /utils/fastdvdnet_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import numpy as np 4 | import os 5 | from random import choices 6 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 7 | import torch 8 | 9 | IMAGETYPES = ('*.bmp', '*.png', '*.jpg', '*.jpeg', '*.tif') # Supported image types 10 | 11 | def fastdvdnet_batch_psnr(img, imclean, data_range=1.): 12 | r""" 13 | Computes the PSNR along the batch dimension (not pixel-wise) 14 | 15 | Args: 16 | img: a `torch.Tensor` containing the restored image 17 | imclean: a `torch.Tensor` containing the reference image 18 | data_range: The data range of the input image (distance between 19 | minimum and maximum possible values). By default, this is estimated 20 | from the image data-type. 21 | """ 22 | img_cpu = img.data.cpu().numpy().astype(np.float32) 23 | imgclean = imclean.data.cpu().numpy().astype(np.float32) 24 | psnr = 0 25 | for i in range(img_cpu.shape[0]): 26 | psnr += compare_psnr(imgclean[i, :, :, :], img_cpu[i, :, :, :], 27 | data_range=data_range) 28 | return psnr/img_cpu.shape[0] 29 | 30 | def get_imagenames(seq_dir, pattern=None): 31 | """ Get ordered list of filenames 32 | """ 33 | files = [] 34 | for typ in IMAGETYPES: 35 | files.extend(glob.glob(os.path.join(seq_dir, typ))) 36 | 37 | # filter filenames 38 | if not pattern is None: 39 | ffiltered = [f for f in files if pattern in os.path.split(f)[-1]] 40 | files = ffiltered 41 | del ffiltered 42 | 43 | # sort filenames alphabetically 44 | files.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 45 | return files 46 | 47 | def open_sequence(seq_dir, gray_mode, expand_if_needed=False, max_num_fr=85): 48 | r""" Opens a sequence of images and expands it to even sizes if necesary 49 | Args: 50 | fpath: string, path to image sequence 51 | gray_mode: boolean, True indicating if images is to be open are in grayscale mode 52 | expand_if_needed: if True, the spatial dimensions will be expanded if 53 | size is odd 54 | expand_axis0: if True, output will have a fourth dimension 55 | max_num_fr: maximum number of frames to load 56 | Returns: 57 | seq: array of dims [num_frames, C, H, W], C=1 grayscale or C=3 RGB, H and W are even. 58 | The image gets normalized gets normalized to the range [0, 1]. 59 | expanded_h: True if original dim H was odd and image got expanded in this dimension. 60 | expanded_w: True if original dim W was odd and image got expanded in this dimension. 61 | """ 62 | # Get ordered list of filenames 63 | files = get_imagenames(seq_dir) 64 | 65 | seq_list = [] 66 | # print("\tOpen sequence in folder: ", seq_dir) 67 | for fpath in files[0:max_num_fr]: 68 | 69 | img, expanded_h, expanded_w = open_image(fpath,\ 70 | gray_mode=gray_mode,\ 71 | expand_if_needed=expand_if_needed,\ 72 | expand_axis0=False) 73 | seq_list.append(img) 74 | seq = np.stack(seq_list, axis=0) 75 | return seq, expanded_h, expanded_w 76 | 77 | def open_image(fpath, gray_mode, expand_if_needed=False, expand_axis0=True, normalize_data=True): 78 | r""" Opens an image and expands it if necesary 79 | Args: 80 | fpath: string, path of image file 81 | gray_mode: boolean, True indicating if image is to be open 82 | in grayscale mode 83 | expand_if_needed: if True, the spatial dimensions will be expanded if 84 | size is odd 85 | expand_axis0: if True, output will have a fourth dimension 86 | Returns: 87 | img: image of dims NxCxHxW, N=1, C=1 grayscale or C=3 RGB, H and W are even. 88 | if expand_axis0=False, the output will have a shape CxHxW. 89 | The image gets normalized gets normalized to the range [0, 1]. 90 | expanded_h: True if original dim H was odd and image got expanded in this dimension. 91 | expanded_w: True if original dim W was odd and image got expanded in this dimension. 92 | """ 93 | if not gray_mode: 94 | # Open image as a CxHxW torch.Tensor 95 | img = cv2.imread(fpath) 96 | # from HxWxC to CxHxW, RGB image 97 | img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1) 98 | else: 99 | # from HxWxC to CxHxW grayscale image (C=1) 100 | img = cv2.imread(fpath, cv2.IMREAD_GRAYSCALE) 101 | img = np.expand_dims(img, 0) 102 | 103 | if expand_axis0: 104 | img = np.expand_dims(img, 0) 105 | 106 | # Handle odd sizes 107 | expanded_h = False 108 | expanded_w = False 109 | sh_im = img.shape 110 | if expand_if_needed: 111 | if sh_im[-2]%2 == 1: 112 | expanded_h = True 113 | if expand_axis0: 114 | img = np.concatenate((img, \ 115 | img[:, :, -1, :][:, :, np.newaxis, :]), axis=2) 116 | else: 117 | img = np.concatenate((img, \ 118 | img[:, -1, :][:, np.newaxis, :]), axis=1) 119 | 120 | 121 | if sh_im[-1]%2 == 1: 122 | expanded_w = True 123 | if expand_axis0: 124 | img = np.concatenate((img, \ 125 | img[:, :, :, -1][:, :, :, np.newaxis]), axis=3) 126 | else: 127 | img = np.concatenate((img, \ 128 | img[:, :, -1][:, :, np.newaxis]), axis=2) 129 | 130 | if normalize_data: 131 | img = normalize(img) 132 | return img, expanded_h, expanded_w 133 | 134 | def normalize(data): 135 | r"""Normalizes a unit8 image to a float32 image in the range [0, 1] 136 | 137 | Args: 138 | data: a unint8 numpy array to normalize from [0, 255] to [0, 1] 139 | """ 140 | return np.float32(data/255.) 141 | 142 | def normalize_augment(img_train): 143 | '''Normalizes and augments an input patch of dim [N, num_frames, C. H, W] in [0., 255.] to \ 144 | [N, num_frames*C. H, W] in [0., 1.]. It also returns the central frame of the temporal \ 145 | patch as a ground truth. 146 | ''' 147 | def transform(sample): 148 | # define transformations 149 | do_nothing = lambda x: x 150 | do_nothing.__name__ = 'do_nothing' 151 | flipud = lambda x: torch.flip(x, dims=[2]) 152 | flipud.__name__ = 'flipup' 153 | rot90 = lambda x: torch.rot90(x, k=1, dims=[2, 3]) 154 | rot90.__name__ = 'rot90' 155 | rot90_flipud = lambda x: torch.flip(torch.rot90(x, k=1, dims=[2, 3]), dims=[2]) 156 | rot90_flipud.__name__ = 'rot90_flipud' 157 | rot180 = lambda x: torch.rot90(x, k=2, dims=[2, 3]) 158 | rot180.__name__ = 'rot180' 159 | rot180_flipud = lambda x: torch.flip(torch.rot90(x, k=2, dims=[2, 3]), dims=[2]) 160 | rot180_flipud.__name__ = 'rot180_flipud' 161 | rot270 = lambda x: torch.rot90(x, k=3, dims=[2, 3]) 162 | rot270.__name__ = 'rot270' 163 | rot270_flipud = lambda x: torch.flip(torch.rot90(x, k=3, dims=[2, 3]), dims=[2]) 164 | rot270_flipud.__name__ = 'rot270_flipud' 165 | add_csnt = lambda x: x + torch.normal(mean=torch.zeros(x.size()[0], 1, 1, 1), \ 166 | std=(5/255.)).expand_as(x).to(x.device) 167 | add_csnt.__name__ = 'add_csnt' 168 | 169 | # define transformations and their frequency, then pick one. 170 | aug_list = [do_nothing, flipud, rot90, rot90_flipud, \ 171 | rot180, rot180_flipud, rot270, rot270_flipud, add_csnt] 172 | w_aug = [32, 12, 12, 12, 12, 12, 12, 12, 12] # one fourth chances to do_nothing 173 | transf = choices(aug_list, w_aug) 174 | 175 | # transform all images in array 176 | return transf[0](sample) 177 | 178 | N, T, C, H, W = img_train.shape 179 | # convert to [N, num_frames*C. H, W] in [0., 1.] from [N, num_frames, C. H, W] in [0., 255.] 180 | img_train = img_train.type(torch.float32).view(N, -1, H, W) / 255. 181 | 182 | # augment 183 | img_train = transform(img_train) 184 | 185 | # view back 186 | img_train = img_train.view(N, T, C, H, W) 187 | 188 | return img_train 189 | 190 | def remove_dataparallel_wrapper(state_dict): 191 | r"""Converts a DataParallel models to a normal one by removing the "module." 192 | wrapper in the module dictionary 193 | 194 | 195 | Args: 196 | state_dict: a torch.nn.DataParallel state dictionary 197 | """ 198 | from collections import OrderedDict 199 | 200 | new_state_dict = OrderedDict() 201 | for k, v in state_dict.items(): 202 | name = k[7:] # remove 'module.' of DataParallel 203 | new_state_dict[name] = v 204 | 205 | return new_state_dict 206 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | import numpy as np 4 | import os 5 | import torch 6 | 7 | def list_dir(dir, postfix=None, full_path=False): 8 | if full_path: 9 | if postfix is None: 10 | names = sorted([name for name in os.listdir(dir) if not name.startswith('.')]) 11 | return sorted([os.path.join(dir, name) for name in names]) 12 | else: 13 | names = sorted([name for name in os.listdir(dir) if (not name.startswith('.') and name.endswith(postfix))]) 14 | return sorted([os.path.join(dir, name) for name in names]) 15 | else: 16 | if postfix is None: 17 | return sorted([name for name in os.listdir(dir) if not name.startswith('.')]) 18 | else: 19 | return sorted([name for name in os.listdir(dir) if (not name.startswith('.') and name.endswith(postfix))]) 20 | 21 | def open_images_uint8(image_files): 22 | image_list = [] 23 | for image_file in image_files: 24 | image = imageio.imread(image_file).astype(np.uint8) 25 | if len(image.shape) == 3: 26 | image = np.transpose(image, (2, 0, 1)) 27 | image_list.append(image) 28 | seq = np.stack(image_list, axis=0) 29 | return seq 30 | 31 | def log(log_file, str, also_print=True): 32 | with open(log_file, 'a+') as F: 33 | F.write(str) 34 | if also_print: 35 | print(str, end='') 36 | 37 | # return pytorch image in shape 1x3xHxW 38 | def image2tensor(image_file): 39 | image = imageio.imread(image_file).astype(np.float32) / np.float32(255.0) 40 | if len(image.shape) == 3: 41 | image = np.transpose(image, (2, 0, 1)) 42 | elif len(image.shape) == 2: 43 | image = np.expand_dims(image, 0) 44 | image = np.asarray(image, dtype=np.float32) 45 | image = torch.from_numpy(image).unsqueeze(0) 46 | return image 47 | 48 | # save numpy image in shape 3xHxW 49 | def np2image(image, image_file): 50 | image = np.transpose(image, (1, 2, 0)) 51 | image = np.clip(image, 0., 1.) 52 | image = image * 255. 53 | image = image.astype(np.uint8) 54 | imageio.imwrite(image_file, image) 55 | 56 | def np2image_bgr(image, image_file): 57 | image = np.transpose(image, (1, 2, 0)) 58 | image = np.clip(image, 0., 1.) 59 | image = image * 255. 60 | image = image.astype(np.uint8) 61 | cv2.imwrite(image_file, image) 62 | 63 | # save tensor image in shape 1x3xHxW 64 | def tensor2image(image, image_file): 65 | image = image.detach().cpu().squeeze(0).numpy() 66 | np2image(image, image_file) 67 | 68 | -------------------------------------------------------------------------------- /utils/raw.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # simply convert raw seq to rgb seq for computing optical flow 4 | def demosaic(raw_seq): 5 | N, T, C, H, W = raw_seq.shape 6 | rgb_seq = torch.empty((N, T, 3, H, W), dtype=raw_seq.dtype, device=raw_seq.device) 7 | rgb_seq[:, :, 0] = raw_seq[:, :, 0] 8 | rgb_seq[:, :, 1] = (raw_seq[:, :, 1] + raw_seq[:, :, 2]) / 2 9 | rgb_seq[:, :, 2] = (raw_seq[:, :, 3]) 10 | return rgb_seq -------------------------------------------------------------------------------- /utils/ssim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.metrics import structural_similarity as compare_ssim 3 | 4 | # img: T, C, H, W, imclean: T, C, H, W 5 | def batch_ssim(img, imclean, data_range): 6 | 7 | img = img.data.cpu().numpy().astype(np.float32) 8 | img = np.transpose(img, (0, 2, 3, 1)) 9 | img_clean = imclean.data.cpu().numpy().astype(np.float32) 10 | img_clean = np.transpose(img_clean, (0, 2, 3, 1)) 11 | 12 | ssim = 0 13 | for i in range(img.shape[0]): 14 | origin_i = img_clean[i, :, :, :] 15 | denoised_i = img[i, :, :, :] 16 | ssim += compare_ssim(origin_i.astype(float), denoised_i.astype(float), multichannel=True, win_size=11, K1=0.01, 17 | K2=0.03, sigma=1.5, gaussian_weights=True, data_range=1) 18 | return ssim/img.shape[0] -------------------------------------------------------------------------------- /utils/warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def warp(x, flo): 4 | ''' 5 | warp an image/tensor (im2) back to im1, according to the optical flow 6 | x: [B, C, H, W] (im2) 7 | flo: [B, 2, H, W] (flow) 8 | ''' 9 | B, C, H, W = x.size() 10 | # mesh grid 11 | xx = torch.arange(0, W, device=x.device).view(1, -1).repeat(H, 1) 12 | yy = torch.arange(0, H, device=x.device).view(-1, 1).repeat(1, W) 13 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 14 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 15 | grid = torch.cat((xx, yy), 1).float() 16 | 17 | if x.is_cuda: 18 | grid = grid.to(x.device) 19 | vgrid = torch.autograd.Variable(grid) + flo 20 | 21 | # scale grid to [-1, 1] 22 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 23 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 24 | 25 | vgrid = vgrid.permute(0, 2, 3, 1) 26 | output = torch.nn.functional.grid_sample(x, vgrid, align_corners=True) 27 | mask = torch.autograd.Variable(torch.ones((B, C, H, W), device=x.device)) 28 | mask = torch.nn.functional.grid_sample(mask, vgrid, align_corners=True) 29 | 30 | mask[mask < 0.9999] = 0 31 | mask[mask > 0] = 1 32 | 33 | return output * mask, mask --------------------------------------------------------------------------------