├── LICENSE ├── README.md ├── images ├── deblur_obs.png ├── deblur_sample_1.png ├── deblur_sample_2.png ├── deblur_y_or.png ├── inpainting_obs.png ├── inpainting_sample_1.png ├── inpainting_sample_2.png ├── inpainting_y_or.png ├── sr_obs.png ├── sr_sample_1.png ├── sr_sample_2.png └── sr_y_or.png ├── mcg_diff ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── particle_filter.cpython-311.pyc │ ├── sgm.cpython-311.pyc │ └── utils.cpython-311.pyc ├── particle_filter.py ├── sgm.py └── utils.py ├── requirements.txt ├── requirements_scripts.txt ├── scripts ├── __init__.py ├── configs │ ├── config.yaml │ ├── dataset │ │ ├── bedroom.yaml │ │ ├── cats.yaml │ │ ├── celeb.yaml │ │ ├── churches.yaml │ │ ├── cifar_10.yaml │ │ ├── flowers.yaml │ │ └── mnist.yaml │ ├── diffusion │ │ ├── ddim_10.yaml │ │ ├── ddim_100.yaml │ │ └── ddim_250.yaml │ ├── mcg_diff │ │ ├── colorization.yaml │ │ ├── default.yaml │ │ ├── empty.yaml │ │ └── outpainting.yaml │ └── task │ │ ├── colorization.yaml │ │ ├── deblur_2d.yaml │ │ ├── inpainting.yaml │ │ ├── motion_blur.yaml │ │ ├── outpainting.yaml │ │ └── super_resolution.yaml ├── hugging_faces_models.py ├── inverse_problems_operators.py └── viz_gaussian.py ├── setup.py └── tests └── test_particle_filter.py /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MCG-Diff: Monte Carlo guided diffusion for Bayesian linear inverse problems 2 | 3 | This repository contains the code of the algorithm proposed in https://arxiv.org/abs/2308.07983 and accepted for oral presentation at ICLR 2024. 4 | 5 | This repository can be installed as a python package by running 6 | `pip install -I .` on the root folder of this git. 7 | 8 | You can try this algorithm either with the available scripts in this project or in the benchopts, with comparison to other algorithms at https://github.com/gabrielvc/benchopts_inverse_problem_diffusion_prior/tree/master 9 | 10 | 11 | ## Results using the hugging face models on CelebA 12 | 13 | 14 | The following table was produced by running the `scripts/hugging_faces_models.py` with the configurations described below. 15 | The configuration for MCG DIFF is defined in `scripts/configs/mcg_diff/default.yaml`. Running this script took 3 minutes to generate each image, but 16 | this can be made faster by using parallelization. 17 | 18 | | Original image | Observation | Sample | Sample | Changes to `config.yaml` | 19 | |--------------------------------------| ---- | ---- | ---- |----------------------------------| 20 | | ![image](images/deblur_y_or.png) | ![image](images/deblur_obs.png)| ![image](images/deblur_sample_1.png)| ![image](images/deblur_sample_2.png)| `seed = 32`, `task = deblur_2d` | 21 | | ![image](images/inpainting_y_or.png) | ![image](images/inpainting_obs.png)| ![image](images/inpainting_sample_1.png)| ![image](images/inpainting_sample_2.png)| `seed = 15`, `task = inpainting` | 22 | | ![image](images/sr_y_or.png) | ![image](images/sr_obs.png)| ![image](images/sr_sample_1.png)| ![image](images/sr_sample_2.png)| `seed = 10`, `task = sr` | 23 | -------------------------------------------------------------------------------- /images/deblur_obs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/deblur_obs.png -------------------------------------------------------------------------------- /images/deblur_sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/deblur_sample_1.png -------------------------------------------------------------------------------- /images/deblur_sample_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/deblur_sample_2.png -------------------------------------------------------------------------------- /images/deblur_y_or.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/deblur_y_or.png -------------------------------------------------------------------------------- /images/inpainting_obs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/inpainting_obs.png -------------------------------------------------------------------------------- /images/inpainting_sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/inpainting_sample_1.png -------------------------------------------------------------------------------- /images/inpainting_sample_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/inpainting_sample_2.png -------------------------------------------------------------------------------- /images/inpainting_y_or.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/inpainting_y_or.png -------------------------------------------------------------------------------- /images/sr_obs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/sr_obs.png -------------------------------------------------------------------------------- /images/sr_sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/sr_sample_1.png -------------------------------------------------------------------------------- /images/sr_sample_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/sr_sample_2.png -------------------------------------------------------------------------------- /images/sr_y_or.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/images/sr_y_or.png -------------------------------------------------------------------------------- /mcg_diff/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/mcg_diff/__init__.py -------------------------------------------------------------------------------- /mcg_diff/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/mcg_diff/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /mcg_diff/__pycache__/particle_filter.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/mcg_diff/__pycache__/particle_filter.cpython-311.pyc -------------------------------------------------------------------------------- /mcg_diff/__pycache__/sgm.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/mcg_diff/__pycache__/sgm.cpython-311.pyc -------------------------------------------------------------------------------- /mcg_diff/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/mcg_diff/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /mcg_diff/particle_filter.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import torch 4 | from torch.distributions import Categorical 5 | 6 | from mcg_diff.sgm import ScoreModel, generate_coefficients_ddim 7 | from mcg_diff.utils import get_taus_from_singular_values 8 | 9 | 10 | def predict(score_model: ScoreModel, 11 | particles: torch.Tensor, 12 | t: float, 13 | t_prev: float, 14 | eta: float, 15 | n_samples_per_gpu: int = 1) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 16 | noise, coeff_sample, coeff_score = generate_coefficients_ddim( 17 | alphas_cumprod=score_model.alphas_cumprod.to(particles.device), 18 | time_step=t, 19 | prev_time_step=t_prev, 20 | eta=eta 21 | ) 22 | if hasattr(score_model.net, 'device_ids'): 23 | batch_size = n_samples_per_gpu * len(score_model.net.device_ids) 24 | epsilon_predicted = [] 25 | n_batches = particles.shape[0] // batch_size + int(particles.shape[0] % batch_size > 0) 26 | for batch_idx in range(n_batches): 27 | epsilon_predicted.append(score_model.net(particles[batch_size*batch_idx:(batch_idx+1)*batch_size], t).cpu()) 28 | epsilon_predicted = torch.cat(epsilon_predicted, dim=0).to(particles.device) 29 | else: 30 | epsilon_predicted = score_model.net(particles, t).to(particles.device) 31 | mean = coeff_sample * particles + coeff_score * epsilon_predicted.to(particles.device) 32 | 33 | return mean, noise, epsilon_predicted 34 | 35 | 36 | def gauss_loglik(x, mean, diag_std): 37 | return - 1/2 * (torch.linalg.norm((x - mean[None, :]) / diag_std[None].clip(1e-10, 1e10), dim=-1)**2) 38 | 39 | 40 | def mcg_diff( 41 | initial_particles: torch.Tensor, 42 | observation: torch.Tensor, 43 | score_model: ScoreModel, 44 | coordinates_mask: torch.Tensor, 45 | timesteps: torch.Tensor, 46 | likelihood_diagonal: torch.Tensor, 47 | var_observation: float, 48 | eta: float = 1, 49 | n_samples_per_gpu_inference: int = 16, 50 | gaussian_var: float = 1e-4 51 | ) -> Tuple[torch.Tensor, torch.Tensor]: 52 | ''' 53 | MCG Diff algorithm, as described in https://arxiv.org/abs/2308.07983 54 | :param initial_particles: The initial particles for the algorithm 55 | :param observation: The observation from which we want to sample from the associated posterior 56 | :param score_model: The score model, containing the score function as well as the alphas_cumprod (VP framework) 57 | :param coordinates_mask: A mask containing true if the coordinate is observed (corresponds to an observation) 58 | :param timesteps: The timesteps to be used for the diffusion generation 59 | :param likelihood_diagonal: The elements of S, such that s_i x_i + var_observations * epsilon_i = y_i 60 | :param var_observation: the observation variance. 61 | :param eta: DDIM parameter 62 | :param n_samples_per_gpu_inference: 63 | :param gaussian_var: Corresponds to Kappa in https://arxiv.org/abs/2308.07983 64 | :return: Samples and Log weights. 65 | ''' 66 | #Initialization 67 | n_particles, dim = initial_particles.shape 68 | alphas_cumprod = score_model.alphas_cumprod.to(initial_particles.device) 69 | particles = initial_particles 70 | taus, taus_indices = get_taus_from_singular_values(alphas_cumprod=alphas_cumprod, 71 | timesteps=timesteps, 72 | singular_values=likelihood_diagonal, 73 | var=var_observation) 74 | 75 | coordinates_in_state = torch.where(coordinates_mask == 1)[0] 76 | always_free_coordinates = torch.where(coordinates_mask == 0)[0] 77 | rescaled_observations = ((alphas_cumprod[taus]**.5)*observation / likelihood_diagonal) 78 | 79 | #Splitting timesteps at after Tau_1 and before tau_1 80 | filtering_timesteps = timesteps[taus_indices.min().item():] 81 | propagation_timesteps = timesteps[:taus_indices.min().item()+1] 82 | 83 | pbar = enumerate(zip(filtering_timesteps.tolist()[1:][::-1], 84 | filtering_timesteps.tolist()[:-1][::-1])) 85 | 86 | for i, (t, t_prev) in pbar: 87 | predicted_mean, predicted_noise, eps = predict(score_model=score_model, 88 | particles=particles, 89 | t=t, 90 | t_prev=t_prev, 91 | eta=eta, 92 | n_samples_per_gpu=n_samples_per_gpu_inference) 93 | active_coordinates_in_obs = torch.where(t_prev >= taus)[0] 94 | previously_active_coordinates_in_obs = torch.where(t >= taus)[0] 95 | active_coordinates_in_x = coordinates_in_state[active_coordinates_in_obs] 96 | inactive_coordinates_in_x = torch.cat((coordinates_in_state[t_prev < taus], always_free_coordinates), dim=0) 97 | previously_active_coordinates_in_x = coordinates_in_state[previously_active_coordinates_in_obs] 98 | 99 | #Calculation of weights 100 | previous_log_likelihood = gauss_loglik( 101 | x=particles[:, previously_active_coordinates_in_x], 102 | mean=rescaled_observations[previously_active_coordinates_in_obs] * (alphas_cumprod[t] / alphas_cumprod[taus[previously_active_coordinates_in_obs]])**.5, 103 | diag_std=(1 - (1 - gaussian_var) * (alphas_cumprod[t] / alphas_cumprod[taus[previously_active_coordinates_in_obs]]))**.5) 104 | log_integration_constant = gauss_loglik( 105 | x=predicted_mean[:, active_coordinates_in_x], 106 | mean=rescaled_observations[active_coordinates_in_obs] * ((alphas_cumprod[t_prev] / alphas_cumprod[taus[active_coordinates_in_obs]])**.5), 107 | diag_std=(predicted_noise ** 2 + 1 - (1 - gaussian_var)*(alphas_cumprod[t_prev] / alphas_cumprod[taus[active_coordinates_in_obs]]))**.5 108 | ) 109 | log_weights = log_integration_constant - previous_log_likelihood 110 | 111 | #Ancestor sampling 112 | ancestors = Categorical(logits=log_weights, validate_args=False).sample((n_particles,)) 113 | #Update 114 | z = torch.randn_like(particles) 115 | Kprev = (predicted_noise**2 / (predicted_noise**2 + 1 - (1 - gaussian_var)*(alphas_cumprod[t_prev] / alphas_cumprod[taus[active_coordinates_in_obs]])).clip(1e-10, 1e10)) 116 | new_particles = particles.clone() 117 | new_particles[:, inactive_coordinates_in_x] = z[:, inactive_coordinates_in_x] * predicted_noise + predicted_mean[ancestors][:, inactive_coordinates_in_x] 118 | new_particles[:, active_coordinates_in_x] = Kprev * rescaled_observations[active_coordinates_in_obs][None,:] * ((alphas_cumprod[t_prev] / alphas_cumprod[taus[active_coordinates_in_obs]])**.5) + \ 119 | (1 - Kprev)*predicted_mean[ancestors][:, active_coordinates_in_x] + \ 120 | ((1 - (1 - gaussian_var)*(alphas_cumprod[t_prev] / alphas_cumprod[taus[active_coordinates_in_obs]]))*Kprev)**.5 * z[:, active_coordinates_in_x] 121 | 122 | particles = new_particles 123 | 124 | t = filtering_timesteps[0] 125 | previously_active_coordinates_in_obs = torch.where(t >= taus)[0] 126 | previously_active_coordinates_in_x = coordinates_in_state[previously_active_coordinates_in_obs] 127 | previous_log_likelihood = gauss_loglik( 128 | x=particles[:, previously_active_coordinates_in_x], 129 | mean=rescaled_observations[previously_active_coordinates_in_obs] * ( 130 | alphas_cumprod[t] / alphas_cumprod[taus[previously_active_coordinates_in_obs]]) ** .5, 131 | diag_std=(1 - (1 - gaussian_var) * (alphas_cumprod[t] / alphas_cumprod[taus[previously_active_coordinates_in_obs]]))**.5) 132 | if len(propagation_timesteps) > 1: 133 | # If Tau_1 > 0 we still have to propagate using the diffusion between tau_1 and 0 134 | pbar = enumerate(zip(propagation_timesteps.tolist()[1:][::-1], 135 | propagation_timesteps.tolist()[:-1][::-1])) 136 | 137 | for i, (t, t_prev) in pbar: 138 | predicted_mean, predicted_noise, eps = predict(score_model=score_model, 139 | particles=particles, 140 | t=t, 141 | t_prev=t_prev, 142 | eta=eta, 143 | n_samples_per_gpu=n_samples_per_gpu_inference) 144 | z = torch.randn_like(particles) 145 | particles = z * predicted_noise + predicted_mean 146 | log_likelihood = gauss_loglik(x=likelihood_diagonal[None, :]*particles[:, coordinates_in_state], 147 | mean=observation, 148 | diag_std=(torch.ones_like(observation)*var_observation)**.5) 149 | log_weights = log_likelihood - previous_log_likelihood 150 | else: 151 | log_weights = -previous_log_likelihood 152 | 153 | return particles, log_weights -------------------------------------------------------------------------------- /mcg_diff/sgm.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from math import log 3 | from typing import List, Tuple 4 | 5 | import torch 6 | from torch import device 7 | 8 | 9 | @dataclass 10 | class ScoreModel: 11 | net: torch.nn.Module 12 | alphas_cumprod: torch.tensor 13 | device: device 14 | 15 | def to(self, device): 16 | self.model = self.net.to(device) 17 | self.alphas_cumprod = self.alphas_cumprod.to(device) 18 | self.device=device 19 | 20 | def cpu(self): 21 | self.to('cpu') 22 | self.device = device('cpu') 23 | 24 | def cuda(self): 25 | self.to('cuda:0') 26 | self.device = device('cuda:0') 27 | 28 | 29 | def generate_coefficients_ddim( 30 | alphas_cumprod, 31 | time_step, 32 | prev_time_step, 33 | eta): 34 | alphas_cumprod_t_1 = alphas_cumprod[prev_time_step] if prev_time_step >= 0 else 1 35 | alphas_cumprod_t = alphas_cumprod[time_step] 36 | 37 | noise = eta * (((1 - alphas_cumprod_t_1) / (1 - alphas_cumprod_t)) * (1 - alphas_cumprod_t / alphas_cumprod_t_1)) ** .5 38 | 39 | coeff_sample = (alphas_cumprod_t_1 / alphas_cumprod_t) ** .5 40 | coeff_score = ((1 - alphas_cumprod_t_1 - noise ** 2) ** .5) - coeff_sample * ((1 - alphas_cumprod_t)**.5) 41 | 42 | return noise, coeff_sample, coeff_score 43 | 44 | 45 | def ddim_marginal_logprob( 46 | x0: torch.Tensor, 47 | alphas_cumprod: List[float], 48 | timesteps: List[int], 49 | score_model: ScoreModel, 50 | n_samples: int, 51 | eta: float = 1) -> torch.Tensor: 52 | """ 53 | Computes the log marginal of x0 sampled from ddim. 54 | 55 | steps: 1- sample a path from the real backward process 56 | conditionned on x0, see eq. (7) 57 | and compute its logprob 58 | 2- compute the logprob of the same path under the ddim path log_prob 59 | 60 | output: 61 | :log_weights: log ratio, which corresponds to the estimate of the log marginal when 62 | one sample is used 63 | :bwd: forward samples of DDIM, conditionned on the real x0 64 | """ 65 | dim_range = tuple(range(2,x0.dim() + 1)) 66 | alpha_T = alphas_cumprod[-1] 67 | noise_sample = torch.randn((n_samples, *x0.shape)) 68 | x = (alpha_T ** .5) * x0 + (1 - alpha_T) ** .5 * noise_sample 69 | log_weights = ((noise_sample ** 2).sum(dim_range) / 2) - (x**2).sum(dim_range) / 2 70 | for prev_time_step, time_step in tqdm.tqdm(zip(timesteps[1:], 71 | timesteps[:-1])): 72 | alphas_cumprod_t_1 = alphas_cumprod[prev_time_step] if prev_time_step >= 0 else 1 73 | alphas_cumprod_t = alphas_cumprod[time_step] 74 | noise_std, coeff_sample, coeff_score = generate_coefficients_ddim( 75 | alphas_cumprod=score_model.alphas_cumprod, 76 | time_step=time_step, 77 | prev_time_step=prev_time_step, 78 | eta=eta 79 | ) 80 | epsilon_predicted = score_model.net(x, time_step) 81 | mean = coeff_sample * x + coeff_score * epsilon_predicted 82 | if prev_time_step != 0: 83 | x = (alphas_cumprod_t_1 ** .5) * x0 \ 84 | + (1 - alphas_cumprod_t_1 - noise_std ** 2)**.5 \ 85 | * (x - (alphas_cumprod_t ** .5) * x0) / ((1 - alphas_cumprod_t) ** .5) 86 | noise_sample = torch.randn_like(x) 87 | x += noise_std * noise_sample 88 | log_prob_ddim = - ((x - mean)**2).sum(dim_range) / (2 * noise_std**2) 89 | log_prob_fwd_ddim = - (noise_sample ** 2).sum(dim_range) / 2 90 | log_weights += log_prob_ddim - log_prob_fwd_ddim 91 | else: 92 | log_prob_ddim = - ((x0 - mean)**2).sum(dim_range) / (2 * noise_std**2) 93 | log_weights += log_prob_ddim 94 | return log_weights.logsumexp(0) - log(n_samples) 95 | 96 | 97 | def ddim_parameters(x: torch.Tensor, 98 | score_model: ScoreModel, 99 | t: float, 100 | t_prev: float, 101 | eta: float,) -> Tuple[torch.Tensor, torch.Tensor]: 102 | noise, coeff_sample, coeff_score = generate_coefficients_ddim( 103 | alphas_cumprod=score_model.alphas_cumprod.to(x.device), 104 | time_step=t, 105 | prev_time_step=t_prev, 106 | eta=eta 107 | ) 108 | epsilon_predicted = score_model.net(x, t) 109 | mean = coeff_sample * x + coeff_score * epsilon_predicted.to(x.device) 110 | 111 | return mean, noise 112 | 113 | def ddim_sampling(initial_noise_sample: torch.Tensor, 114 | timesteps: List[int], 115 | score_model: ScoreModel, 116 | eta: float = 1) -> torch.Tensor: 117 | ''' 118 | This function implements the (subsampled) generation from https://arxiv.org/pdf/2010.02502.pdf (eqs 9,10, 12) 119 | :param initial_noise_sample: Initial "noise" 120 | :param timesteps: List containing the timesteps. Should start by 999 and end by 0 121 | :param score_model: The score model 122 | :param eta: the parameter eta from https://arxiv.org/pdf/2010.02502.pdf (eq 16) 123 | :return: 124 | ''' 125 | sample = initial_noise_sample 126 | for prev_time_step, time_step in zip(timesteps[1:], 127 | timesteps[:-1]): 128 | mean, noise = ddim_parameters(x=sample, 129 | score_model=score_model, 130 | t=time_step, 131 | t_prev=prev_time_step, 132 | eta=eta) 133 | sample = mean + noise * torch.randn_like(mean) 134 | return sample 135 | 136 | def ddim_trajectory(initial_noise_sample: torch.Tensor, 137 | timesteps: List[int], 138 | score_model: ScoreModel, 139 | eta: float = 1) -> torch.Tensor: 140 | ''' 141 | This function implements the (subsampled) generation from https://arxiv.org/pdf/2010.02502.pdf (eqs 9,10, 12) 142 | :param initial_noise_sample: Initial "noise" 143 | :param timesteps: List containing the timesteps. Should start by 999 and end by 0 144 | :param score_model: The score model 145 | :param eta: the parameter eta from https://arxiv.org/pdf/2010.02502.pdf (eq 16) 146 | :return: 147 | ''' 148 | sample = initial_noise_sample 149 | samples = sample.unsqueeze(0) 150 | for prev_time_step, time_step in zip(timesteps[1:], 151 | timesteps[:-1]): 152 | mean, noise = ddim_parameters(x=sample, 153 | score_model=score_model, 154 | t=time_step, 155 | t_prev=prev_time_step, 156 | eta=eta) 157 | sample = mean + noise * torch.randn_like(mean) 158 | samples = torch.cat([samples, sample.unsqueeze(0)]) 159 | return samples -------------------------------------------------------------------------------- /mcg_diff/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch.distributions import MultivariateNormal, Normal 5 | 6 | 7 | 8 | def id_like(A): 9 | return torch.sparse_coo_tensor(torch.stack((torch.arange(A.shape[1], device=A.device),) * 2, 10 | dim=0), 11 | torch.ones(A.shape[1], 12 | device=A.device), 13 | (A.shape[1], A.shape[1])) 14 | 15 | 16 | def batch_mm_sparse(A_sparse, 17 | x): 18 | prod = A_sparse @ x.reshape(x.shape[0], -1) 19 | return prod.reshape(A_sparse.shape[0], *x.shape[1:]) 20 | 21 | 22 | def generate_inpainting(anchor_left_top: torch.Tensor, 23 | sizes: torch.Tensor, 24 | original_shape: Tuple[int, int, int]): 25 | ''' 26 | 27 | :param anchor_left_top: 28 | :param sizes: 29 | :param original_shape: (x, y, n_channels) 30 | :return: 31 | ''' 32 | A_per_channel = torch.eye(original_shape[0] * original_shape[1]) 33 | mask = torch.ones(original_shape[:2]) 34 | mask[anchor_left_top[0]:anchor_left_top[0] + sizes[0], :][:, anchor_left_top[1]:anchor_left_top[1] + sizes[1]] = 0 35 | return A_per_channel[mask.flatten()==1, :], A_per_channel[mask.flatten() == 0], mask 36 | 37 | 38 | class NetReparametrized(torch.nn.Module): 39 | 40 | def __init__(self, 41 | base_score_module: torch.nn.Module, 42 | orthogonal_transformation: torch.Tensor): 43 | super().__init__() 44 | self.base_score_module = base_score_module 45 | self.orthogonal_transformation = orthogonal_transformation 46 | 47 | def forward(self, x, *args): 48 | x_back_to_basis = (self.orthogonal_transformation.T @ x.T).T 49 | score = self.base_score_module(x_back_to_basis, *args) 50 | return (self.orthogonal_transformation @ score.T).T 51 | 52 | 53 | def build_extended_svd(A: torch.tensor): 54 | U, d, V = torch.linalg.svd(A, full_matrices=True) 55 | coordinate_mask = torch.ones_like(V[0]) 56 | coordinate_mask[len(d):] = 0 57 | return U, d, coordinate_mask, V 58 | 59 | 60 | def gaussian_posterior(y, 61 | likelihood_A, 62 | likelihood_bias, 63 | likelihood_precision, 64 | prior_loc, 65 | prior_covar): 66 | prior_precision_matrix = torch.linalg.inv(prior_covar) 67 | posterior_precision_matrix = prior_precision_matrix + likelihood_A.T @ likelihood_precision @ likelihood_A 68 | posterior_covariance_matrix = torch.linalg.inv(posterior_precision_matrix) 69 | posterior_mean = posterior_covariance_matrix @ (likelihood_A.T @ likelihood_precision @ (y - likelihood_bias) + prior_precision_matrix @ prior_loc) 70 | try: 71 | posterior_covariance_matrix = (posterior_covariance_matrix + posterior_covariance_matrix.T) / 2 72 | return MultivariateNormal(loc=posterior_mean, covariance_matrix=posterior_covariance_matrix) 73 | except ValueError: 74 | u, s, v = torch.linalg.svd(posterior_covariance_matrix, full_matrices=False) 75 | s = s.clip(1e-12, 1e6).real 76 | posterior_covariance_matrix = u.real @ torch.diag_embed(s) @ v.real 77 | posterior_covariance_matrix = (posterior_covariance_matrix + posterior_covariance_matrix.T) / 2 78 | return MultivariateNormal(loc=posterior_mean, covariance_matrix=posterior_covariance_matrix) 79 | 80 | 81 | def gaussian_posterior_batch(y, 82 | likelihood_A, 83 | likelihood_bias, 84 | likelihood_precision, 85 | prior_loc, 86 | prior_covar): 87 | prior_precision_matrix = torch.linalg.inv(prior_covar) 88 | posterior_precision_matrix = prior_precision_matrix + likelihood_A.T @ likelihood_precision @ likelihood_A 89 | posterior_covariance_matrix = torch.linalg.inv(posterior_precision_matrix) 90 | posterior_mean = (posterior_covariance_matrix @ (likelihood_A.T @ (likelihood_precision @ (y[None, ] - likelihood_bias).T) + (prior_precision_matrix @ prior_loc.T))).T 91 | try: 92 | posterior_covariance_matrix = (posterior_covariance_matrix + posterior_covariance_matrix.T) / 2 93 | return MultivariateNormal(loc=posterior_mean, covariance_matrix=posterior_covariance_matrix.unsqueeze(0).repeat(posterior_mean.shape[0], 1, 1)) 94 | except ValueError: 95 | u, s, v = torch.linalg.svd(posterior_covariance_matrix, full_matrices=False) 96 | s = s.clip(1e-6, 1e6).real 97 | posterior_covariance_matrix = u.real @ torch.diag_embed(s) @ v.real 98 | posterior_covariance_matrix = (posterior_covariance_matrix + posterior_covariance_matrix.T) / 2 99 | return MultivariateNormal(loc=posterior_mean, covariance_matrix=posterior_covariance_matrix.unsqueeze(0).repeat(posterior_mean.shape[0], 1, 1)) 100 | 101 | 102 | def gaussian_posterior_batch_diagonal(y, 103 | likelihood_A, 104 | likelihood_bias, 105 | likelihood_precision_diag, 106 | prior_loc, 107 | prior_covar_diag): 108 | prior_precision_diag = 1 / prior_covar_diag 109 | posterior_precision_diag = prior_precision_diag.clone() 110 | posterior_precision_diag[likelihood_A != 0] += (likelihood_A[likelihood_A != 0]**2) * likelihood_precision_diag 111 | posterior_covariance_diag = 1 / posterior_precision_diag 112 | mean_residue = y - likelihood_bias 113 | mean_projected_residue = torch.zeros_like(prior_loc[0]) 114 | mean_projected_residue[likelihood_A != 0] = likelihood_A[likelihood_A != 0] * likelihood_precision_diag * mean_residue 115 | mean_prior = prior_precision_diag[None, :] * prior_loc 116 | posterior_mean = posterior_covariance_diag[None, :] * (mean_projected_residue[None, :] + mean_prior) 117 | return Normal(loc=posterior_mean, scale=posterior_covariance_diag.unsqueeze(0).repeat(posterior_mean.shape[0], 1)**.5) 118 | 119 | 120 | def get_taus_from_var(alphas_cumprod, timesteps, var_observations): 121 | distances = (var_observations[:, None] - ((1 - alphas_cumprod[timesteps]) / (alphas_cumprod[timesteps]))[None, :]) 122 | distances[distances > 0] = torch.inf 123 | taus_indices = distances.abs().argmin(dim=1) 124 | taus = timesteps[taus_indices] 125 | return taus, taus_indices 126 | 127 | 128 | def get_taus_from_singular_values(alphas_cumprod, timesteps, singular_values, var): 129 | distances = (var * alphas_cumprod[None, timesteps] - (1 - alphas_cumprod)[None, timesteps] * singular_values[:, None]**2) 130 | distances = distances * (var > 0) 131 | taus_indices = distances.abs().argmin(dim=1) 132 | taus = timesteps[taus_indices] 133 | return taus, taus_indices 134 | 135 | 136 | def get_optimal_timesteps_from_singular_values(alphas_cumprod, singular_value, n_timesteps, var, jump=1, mode='equal'): 137 | distances = torch.unique(var * alphas_cumprod[None, :] - (1 - alphas_cumprod)[None, :] * singular_value[:, None]**2) 138 | optimal_distances = sorted(list(set((distances.abs().argmin(dim=-1, keepdims=True)).tolist())), key=lambda x: x) 139 | if 0 == optimal_distances[0]: 140 | optimal_distances = optimal_distances[1:] 141 | timesteps = [0] 142 | start_index = 0 143 | start_cumprod = alphas_cumprod[0]**.5 144 | end = torch.where(alphas_cumprod**.5 < 1e-2)[0][0].item() 145 | target_increase = (alphas_cumprod[start_index]**.5 - alphas_cumprod[end]**.5) / (n_timesteps - 1 - len(optimal_distances)) 146 | last_value = start_cumprod 147 | for i in range(start_index + 1, end): 148 | if last_value - alphas_cumprod[i]**.5 >= target_increase: 149 | timesteps.append(i) 150 | last_value = alphas_cumprod[i]**.5 151 | elif i in optimal_distances: 152 | timesteps.append(i) 153 | last_value = alphas_cumprod[i]**.5 154 | timesteps += torch.ceil(torch.linspace(timesteps[-1], len(alphas_cumprod) - 1, n_timesteps - len(timesteps) + 1)).tolist()[1:] 155 | return torch.tensor(timesteps).long() 156 | 157 | 158 | def get_posterior_distribution_from_dist(x, dist, measure, operator, sigma_y): 159 | x = x['x'] 160 | return -dist.log_prob(x) + 0.5 * (torch.linalg.norm((operator @ x - measure)/sigma_y)**2) 161 | 162 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio -------------------------------------------------------------------------------- /requirements_scripts.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | diffusers 3 | hydra-core 4 | omegaconf -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - dataset: celeb 4 | - diffusion: ddim_100 5 | - mcg_diff: default 6 | - task: super_resolution 7 | 8 | save_folder: /mnt/data/gabriel/mcg_diff/images 9 | seed: 10 10 | plot: true 11 | save_fig: false 12 | save_data: false -------------------------------------------------------------------------------- /scripts/configs/dataset/bedroom.yaml: -------------------------------------------------------------------------------- 1 | hf_model_tag: google/ddpm-ema-bedroom-256 2 | N_MAX_GPU_MCG_DIFF: 64 3 | N_MAX_GPU_DDRM: 64 4 | N_MAX_GPU_DPS: 16 5 | -------------------------------------------------------------------------------- /scripts/configs/dataset/cats.yaml: -------------------------------------------------------------------------------- 1 | hf_model_tag: samwit/ddpm-afhq-cats-128 2 | N_MAX_GPU_MCG_DIFF: 256 3 | N_MAX_GPU_DDRM: 256 4 | N_MAX_GPU_DPS: 50 -------------------------------------------------------------------------------- /scripts/configs/dataset/celeb.yaml: -------------------------------------------------------------------------------- 1 | hf_model_tag: google/ddpm-ema-celebahq-256 2 | N_MAX_GPU_MCG_DIFF: 64 3 | N_MAX_GPU_DDRM: 64 4 | N_MAX_GPU_DPS: 6 5 | -------------------------------------------------------------------------------- /scripts/configs/dataset/churches.yaml: -------------------------------------------------------------------------------- 1 | hf_model_tag: google/ddpm-ema-church-256 2 | N_MAX_GPU_MCG_DIFF: 64 3 | N_MAX_GPU_DDRM: 64 4 | N_MAX_GPU_DPS: 6 5 | -------------------------------------------------------------------------------- /scripts/configs/dataset/cifar_10.yaml: -------------------------------------------------------------------------------- 1 | hf_model_tag: google/ddpm-cifar10-32 2 | N_MAX_GPU_MCG_DIFF: 4096 3 | N_MAX_GPU_DDRM: 4096 4 | N_MAX_GPU_DPS: 400 5 | -------------------------------------------------------------------------------- /scripts/configs/dataset/flowers.yaml: -------------------------------------------------------------------------------- 1 | hf_model_tag: anton-l/ddpm-ema-flowers-64 2 | N_MAX_GPU_MCG_DIFF: 2000 3 | N_MAX_GPU_DDRM: 2000 4 | N_MAX_GPU_DPS: 200 5 | -------------------------------------------------------------------------------- /scripts/configs/dataset/mnist.yaml: -------------------------------------------------------------------------------- 1 | hf_model_tag: nabdan/mnist_20_epoch 2 | N_MAX_GPU_MCG_DIFF: 8912 3 | N_MAX_GPU_DDRM: 8912 4 | N_MAX_GPU_DPS: 800 5 | -------------------------------------------------------------------------------- /scripts/configs/diffusion/ddim_10.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/scripts/configs/diffusion/ddim_10.yaml -------------------------------------------------------------------------------- /scripts/configs/diffusion/ddim_100.yaml: -------------------------------------------------------------------------------- 1 | n_steps: 100 2 | eta: 1 -------------------------------------------------------------------------------- /scripts/configs/diffusion/ddim_250.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/scripts/configs/diffusion/ddim_250.yaml -------------------------------------------------------------------------------- /scripts/configs/mcg_diff/colorization.yaml: -------------------------------------------------------------------------------- 1 | N_total: 100 2 | N_particles: 1024 3 | gaussian_var: 1e-4 -------------------------------------------------------------------------------- /scripts/configs/mcg_diff/default.yaml: -------------------------------------------------------------------------------- 1 | N_total: 4 2 | N_particles: 64 3 | gaussian_var: 1e-4 -------------------------------------------------------------------------------- /scripts/configs/mcg_diff/empty.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gabrielvc/mcg_diff/498a8830998c3c84c7c6cb6bf78e7d79dc99bd62/scripts/configs/mcg_diff/empty.yaml -------------------------------------------------------------------------------- /scripts/configs/mcg_diff/outpainting.yaml: -------------------------------------------------------------------------------- 1 | N_total: 100000 2 | N_particles: 10000 3 | gaussian_var: 1e-4 -------------------------------------------------------------------------------- /scripts/configs/task/colorization.yaml: -------------------------------------------------------------------------------- 1 | name: colorization 2 | sigma_y: 0 -------------------------------------------------------------------------------- /scripts/configs/task/deblur_2d.yaml: -------------------------------------------------------------------------------- 1 | name: deblur_2d 2 | sigma_y: 0.1 3 | kernel_size: 0.2 4 | kernel_std: 0.04 -------------------------------------------------------------------------------- /scripts/configs/task/inpainting.yaml: -------------------------------------------------------------------------------- 1 | name: inpainting 2 | center: [0.5, 0.5] 3 | width: 0.3 4 | height: 0.3 5 | sigma_y: 0 -------------------------------------------------------------------------------- /scripts/configs/task/motion_blur.yaml: -------------------------------------------------------------------------------- 1 | name: motion_blur 2 | sigma_y: 0.1 3 | kernel_size: 0.2 -------------------------------------------------------------------------------- /scripts/configs/task/outpainting.yaml: -------------------------------------------------------------------------------- 1 | name: outpainting 2 | center: [0.5, 0.5] 3 | width: 0.6 4 | height: 0.6 5 | sigma_y: 0 -------------------------------------------------------------------------------- /scripts/configs/task/super_resolution.yaml: -------------------------------------------------------------------------------- 1 | name: super_resolution 2 | sigma_y: 0.0 3 | ratio: 16 -------------------------------------------------------------------------------- /scripts/hugging_faces_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | from diffusers import DDPMPipeline 5 | from scripts.inverse_problems_operators import Deblurring2D, SuperResolution, Inpainting, Colorization 6 | import torch 7 | import numpy as np 8 | from mcg_diff.particle_filter import mcg_diff, ScoreModel 9 | import matplotlib.pyplot as plt 10 | import hydra 11 | from omegaconf import DictConfig, OmegaConf 12 | from pathlib import Path 13 | import os 14 | import PIL.Image 15 | import tqdm 16 | 17 | 18 | def display_sample(sample): 19 | image_processed = sample.cpu().permute(1, 2, 0) 20 | image_processed = (image_processed + 1.0) * 127.5 21 | image_processed = image_processed.numpy().astype(np.uint8) 22 | 23 | image_pil = PIL.Image.fromarray(image_processed) 24 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 25 | ax.imshow(image_pil) 26 | #.title(f"Image at step {i}") 27 | fig.subplots_adjust(left=0, right=1, bottom=0, top=1) 28 | return fig 29 | 30 | 31 | def display_black_and_white(img): 32 | fig, ax = plt.subplots(1, 1, figsize=(5, 5)) 33 | fig.subplots_adjust(top=1, bottom=0, left=0, right=1) 34 | ax.imshow(img) 35 | return fig 36 | 37 | 38 | def find_furthest_particles_in_clound(particles, N=None): 39 | N = particles.shape[0] 40 | dist_matrix = torch.cdist(particles.reshape(N, -1), particles.reshape(N, -1), p=2) 41 | return (dist_matrix==torch.max(dist_matrix)).nonzero()[0] 42 | 43 | 44 | class EpsilonNetSVD(torch.nn.Module): 45 | 46 | def __init__(self, H_funcs, unet, dim): 47 | super().__init__() 48 | self.unet = unet 49 | self.H_funcs = H_funcs 50 | self.dim = dim 51 | 52 | def forward(self, x, t): 53 | x_normal_basis = self.H_funcs.V(x).reshape(-1, *self.dim) 54 | #x_normal_basis = x.reshape(-1, 1, 28, 28) 55 | t_emb = torch.tensor(t).to(x.device)#.repeat(x.shape[0]).to(x.device) 56 | eps = self.unet(x_normal_basis, t_emb).sample 57 | #eps_svd_basis = eps.reshape(x.shape[0], -1) 58 | #eps = eps - .5 59 | eps_svd_basis = self.H_funcs.Vt(eps, for_H=False) 60 | return eps_svd_basis 61 | 62 | 63 | def load_hf_model(config_hf): 64 | pipeline = DDPMPipeline.from_pretrained(config_hf.hf_model_tag).to('cuda:0') 65 | all_images = pipeline(batch_size=1) 66 | image = all_images.images[0] 67 | x_origin = ((torch.tensor(np.array(image)).type(torch.FloatTensor).cuda() - 127.5) / 127.5) 68 | 69 | D_OR = x_origin.shape 70 | if len(D_OR) == 2: 71 | D_OR = (1, ) + D_OR 72 | x_origin = x_origin.reshape(*D_OR) 73 | else: 74 | D_OR = D_OR[::-1] 75 | x_origin = x_origin.permute(2, 0, 1) 76 | D_FLAT = math.prod(D_OR) 77 | return pipeline, x_origin, D_OR, D_FLAT 78 | 79 | 80 | def plot(x): 81 | if x.shape[0] == 1: 82 | fig = display_black_and_white(x[0].cpu()) 83 | else: 84 | fig = display_sample(x.cpu()) 85 | return fig 86 | 87 | 88 | def load_operator(task_cfg, D_OR, x_origin): 89 | sigma_y = task_cfg.sigma_y 90 | if task_cfg.name == 'deblur_2d': 91 | kernel_size = math.ceil(D_OR[2] * task_cfg.kernel_size) * (3 // D_OR[0]) 92 | sigma = math.ceil(D_OR[2] * task_cfg.kernel_std) 93 | pdf = lambda x: torch.exp(-0.5 * (x / sigma) ** 2) 94 | kernel1 = pdf(torch.arange(-kernel_size, kernel_size + 1)).cuda() 95 | kernel2 = pdf(torch.arange(-kernel_size, kernel_size + 1)).cuda() 96 | kernel1 = kernel1 / kernel1.sum() 97 | kernel2 = kernel2 / kernel2.sum() 98 | 99 | H_funcs = Deblurring2D(kernel1, 100 | kernel2, 101 | D_OR[0], 102 | D_OR[1], 0) 103 | 104 | 105 | y_0_origin = H_funcs.H(x_origin[None, ...]) 106 | y_0_origin = y_0_origin.reshape(*D_OR) 107 | y_0 = y_0_origin + sigma_y * torch.randn_like(y_0_origin) 108 | y_0_img = y_0 109 | diag = H_funcs.singulars() 110 | coordinates_mask = diag != 0 111 | U_t_y_0 = H_funcs.Ut(y_0[None, ...]).flatten()[coordinates_mask].cpu() 112 | diag = diag[coordinates_mask].cpu() 113 | D_OBS = D_OR 114 | 115 | elif task_cfg.name == 'super_resolution': 116 | ratio = task_cfg.ratio 117 | H_funcs = SuperResolution(channels=D_OR[0], img_dim=D_OR[2], ratio=ratio, device='cuda:0') 118 | D_OBS = (D_OR[0], int(D_OR[1] / ratio), int(D_OR[2] / ratio)) 119 | y_0_origin = H_funcs.H(x_origin[None, ...]) 120 | y_0_origin = y_0_origin.reshape(*D_OBS) 121 | y_0 = (y_0_origin + sigma_y * torch.randn_like(y_0_origin)).clip(-1., 1.) 122 | y_0_img = y_0 123 | 124 | U_t_y_0 = H_funcs.Ut(y_0[None, ...]).flatten().cpu() 125 | diag = H_funcs.singulars() 126 | coordinates_mask = diag != 0 127 | coordinates_mask = torch.cat( 128 | (coordinates_mask, torch.tensor([0] * (torch.tensor(D_OR).prod() - len(coordinates_mask))).cuda())) 129 | 130 | elif task_cfg.name == 'outpainting': 131 | center, width, height = task_cfg.center, task_cfg.width, task_cfg.height 132 | range_width = (math.floor((center[0] - width / 2)*D_OR[1]), math.ceil((center[0] + width / 2)*D_OR[1])) 133 | range_height = (math.floor((center[1] - height / 2)*D_OR[2]), math.ceil((center[1] + width / 2)*D_OR[2])) 134 | mask = torch.ones(*D_OR[1:]) 135 | mask[range_width[0]: range_width[1], range_height[0]:range_height[1]] = 0 136 | missing_r = torch.nonzero(mask.flatten()).long().reshape(-1) * 3 137 | missing_g = missing_r + 1 138 | missing_b = missing_g + 1 139 | missing = torch.cat([missing_r, missing_g, missing_b], dim=0) 140 | 141 | H_funcs = Inpainting(channels=D_OR[0], img_dim=D_OR[1], missing_indices=missing, device=x_origin.device) 142 | y_0_origin = H_funcs.H(x_origin[None, ...]) 143 | y_0 = (y_0_origin + sigma_y * torch.randn_like(y_0_origin)).clip(-1., 1.) 144 | y_0_img = -torch.ones(math.prod(D_OR), device=y_0.device) 145 | y_0_img[:y_0.shape[-1]] = y_0[0] 146 | y_0_img = H_funcs.V(y_0_img[None, ...]) 147 | y_0_img = y_0_img.reshape(*D_OR) 148 | U_t_y_0 = H_funcs.Ut(y_0[None, ...]).flatten().cpu() 149 | diag = H_funcs.singulars() 150 | coordinates_mask = torch.isin(torch.arange(math.prod(D_OR), 151 | device=H_funcs.kept_indices.device), 152 | torch.arange(H_funcs.kept_indices.shape[0], 153 | device=H_funcs.kept_indices.device)) 154 | D_OBS = (math.prod(D_OR) - len(missing),) 155 | elif task_cfg.name == 'inpainting': 156 | center, width, height = task_cfg.center, task_cfg.width, task_cfg.height 157 | range_width = (math.floor((center[0] - width / 2)*D_OR[1]), math.ceil((center[0] + width / 2)*D_OR[1])) 158 | range_height = (math.floor((center[1] - height / 2)*D_OR[2]), math.ceil((center[1] + width / 2)*D_OR[2])) 159 | mask = torch.zeros(*D_OR[1:]) 160 | mask[range_width[0]: range_width[1], range_height[0]:range_height[1]] = 1 161 | missing_r = torch.nonzero(mask.flatten()).long().reshape(-1) * 3 162 | missing_g = missing_r + 1 163 | missing_b = missing_g + 1 164 | missing = torch.cat([missing_r, missing_g, missing_b], dim=0) 165 | 166 | H_funcs = Inpainting(channels=D_OR[0], img_dim=D_OR[1], missing_indices=missing, device=x_origin.device) 167 | y_0_origin = H_funcs.H(x_origin[None, ...]) 168 | y_0 = (y_0_origin + sigma_y * torch.randn_like(y_0_origin)).clip(-1., 1.) 169 | y_0_img = -torch.ones(math.prod(D_OR), device=y_0.device) 170 | y_0_img[:y_0.shape[-1]] = y_0[0] 171 | y_0_img = H_funcs.V(y_0_img[None, ...]) 172 | y_0_img = y_0_img.reshape(*D_OR) 173 | U_t_y_0 = H_funcs.Ut(y_0[None, ...]).flatten().cpu() 174 | diag = H_funcs.singulars() 175 | coordinates_mask = torch.isin(torch.arange(math.prod(D_OR), 176 | device=H_funcs.kept_indices.device), 177 | torch.arange(H_funcs.kept_indices.shape[0], 178 | device=H_funcs.kept_indices.device)) 179 | D_OBS = (math.prod(D_OR) - len(missing),) 180 | elif task_cfg.name == 'colorization': 181 | 182 | H_funcs = Colorization(D_OR[1], x_origin.device) 183 | 184 | y_0_origin = H_funcs.H(x_origin[None, ...]) 185 | y_0 = y_0_origin + sigma_y * torch.randn_like(y_0_origin) 186 | y_0_img = H_funcs.H_pinv(y_0_origin).reshape(D_OR) 187 | diag = H_funcs.singulars() 188 | coordinates_mask = diag != 0 189 | U_t_y_0 = H_funcs.Ut(y_0[None, ...]).flatten()[coordinates_mask].cpu() 190 | diag = diag[coordinates_mask].cpu() 191 | coordinates_mask = torch.cat( 192 | (coordinates_mask, torch.tensor([0] * (torch.tensor(D_OR).prod() - len(coordinates_mask))).cuda())) 193 | D_OBS = (y_0.shape[-1],) 194 | else: 195 | raise NotImplementedError 196 | 197 | return H_funcs, y_0, y_0_origin, y_0_img, U_t_y_0, diag, coordinates_mask, D_OBS 198 | 199 | 200 | def run_mcg_diff(mcg_diff_config, score_model, n_max_gpu, dim, U_t_y_0, diag, coordinates_mask, sigma_y, timesteps, eta, H_funcs): 201 | total_N = mcg_diff_config.N_total 202 | #batch_size = n_max_gpu // mcg_diff_config.N_particles 203 | n_particles = mcg_diff_config.N_particles 204 | n_batch = total_N #// batch_size 205 | def _run(initial_particles): 206 | particles, weights = mcg_diff( 207 | initial_particles=initial_particles.cpu(), 208 | observation=U_t_y_0, 209 | likelihood_diagonal=diag.cpu(), 210 | score_model=score_model, 211 | coordinates_mask=coordinates_mask.cpu(), 212 | var_observation=sigma_y ** 2, 213 | timesteps=timesteps.cpu(), 214 | eta=eta, 215 | n_samples_per_gpu_inference=n_max_gpu, 216 | gaussian_var=mcg_diff_config.gaussian_var 217 | ) 218 | particle = particles[torch.distributions.Categorical(logits=weights, validate_args=True).sample((1,))[0]] 219 | return particle 220 | 221 | run_fn = _run # would like to do vmap(_run) 222 | particles_mcg_diff = [] 223 | for j in tqdm.tqdm(enumerate(range(n_batch)), desc="MCG-DIFF"): 224 | batch_initial_particles = torch.randn(size=(n_particles, dim)) 225 | particles = run_fn(batch_initial_particles)[None] 226 | H_funcs = H_funcs.to("cpu") 227 | particles = H_funcs.V(particles).clip(-1, 1) 228 | H_funcs = H_funcs.to("cuda:0") 229 | particles_mcg_diff.append(particles) 230 | particles_mcg_diff = torch.concat(particles_mcg_diff, dim=0) 231 | return particles_mcg_diff 232 | 233 | 234 | @hydra.main(version_base=None, config_path="configs/", config_name="config") 235 | def main(cfg: DictConfig) -> None: 236 | print(OmegaConf.to_yaml(cfg)) 237 | OmegaConf.set_struct(cfg, False) 238 | full_path_images = os.path.join(cfg.save_folder, 239 | cfg.task.name, 240 | cfg.dataset.hf_model_tag.replace('-', '_').replace('/','_'), 241 | str(cfg.seed), 242 | 'images') 243 | full_path_data = os.path.join(cfg.save_folder, 244 | cfg.task.name, 245 | cfg.dataset.hf_model_tag.replace('-', '_').replace('/','_'), 246 | str(cfg.seed), 247 | 'data') 248 | Path(full_path_images).mkdir(parents=True, exist_ok=True) 249 | Path(full_path_data).mkdir(parents=True, exist_ok=True) 250 | torch.manual_seed(cfg.seed) 251 | # Loading HF model 252 | pipeline, x_origin, D_OR, D_FLAT = load_hf_model(cfg.dataset) 253 | fig = plot(x_origin) 254 | if cfg.plot: 255 | fig.show() 256 | if cfg.save_fig: 257 | fig.savefig(f'{full_path_images}/sample.pdf') 258 | plt.close(fig) 259 | 260 | H_funcs, y_0, y_0_origin, y_0_img, U_t_y_0, diag, coordinates_mask, D_OBS = load_operator(task_cfg=cfg.task, 261 | D_OR=D_OR, 262 | x_origin=x_origin) 263 | 264 | fig = plot(y_0_img) 265 | if cfg.plot: 266 | fig.show() 267 | if cfg.save_fig: 268 | fig.savefig(f'{full_path_images}/measure.pdf') 269 | plt.close(fig) 270 | 271 | 272 | #Diffusion stuff 273 | alphas_cumprod = pipeline.scheduler.alphas_cumprod.cuda().clip(1e-6, 1) 274 | timesteps = torch.linspace(0, 999, cfg.diffusion.n_steps).long().cuda() 275 | eta = cfg.diffusion.eta 276 | 277 | model = pipeline.unet 278 | model = model.requires_grad_(False) 279 | model = model.eval() 280 | 281 | ## MCG_DIFF 282 | particles_mcg_diff = run_mcg_diff( 283 | mcg_diff_config=cfg.mcg_diff, 284 | n_max_gpu=cfg.dataset.N_MAX_GPU_MCG_DIFF, 285 | dim=D_FLAT, 286 | U_t_y_0=U_t_y_0, 287 | diag=diag, 288 | coordinates_mask=coordinates_mask==1, 289 | sigma_y=cfg.task.sigma_y, 290 | timesteps=timesteps, 291 | eta=eta, 292 | H_funcs=H_funcs, 293 | score_model=ScoreModel(net=torch.nn.DataParallel(EpsilonNetSVD(H_funcs, model, dim=D_OR).requires_grad_(False)), 294 | alphas_cumprod=alphas_cumprod, 295 | device='cuda:0'), 296 | ) 297 | particles_mcg_diff = particles_mcg_diff.reshape(-1, *D_OR) 298 | 299 | furthest = find_furthest_particles_in_clound(particles_mcg_diff) 300 | for i, particle in enumerate(particles_mcg_diff[furthest]): 301 | fig = plot(particle) 302 | if cfg.plot: 303 | fig.show() 304 | if cfg.save_fig: 305 | fig.savefig(f'{full_path_images}/furthest_{i}_mcg_diff.pdf') 306 | plt.close(fig) 307 | if cfg.save_data: 308 | np.save(file=f'{full_path_data}/particles_mcg_diff.npy', 309 | arr=particles_mcg_diff.cpu().numpy()) 310 | 311 | 312 | 313 | if cfg.save_data: 314 | np.save(file=f'{full_path_data}/noisy_obs.npy', arr=y_0.cpu().numpy()) 315 | np.save(file=f'{full_path_data}/sample.npy', arr=x_origin.cpu().numpy()) 316 | np.save(file=f'{full_path_data}/noiseless_obs.npy', arr=y_0_origin.cpu().numpy()) 317 | 318 | 319 | if __name__ == '__main__': 320 | main() 321 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /scripts/inverse_problems_operators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class H_functions(torch.nn.Module): 5 | """ 6 | A class replacing the SVD of a matrix H, perhaps efficiently. 7 | All input vectors are of shape (Batch, ...). 8 | All output vectors are of shape (Batch, DataDimension). 9 | """ 10 | 11 | def __init__(self): 12 | super(H_functions, self).__init__() 13 | 14 | def V(self, vec): 15 | """ 16 | Multiplies the input vector by V 17 | """ 18 | raise NotImplementedError() 19 | 20 | def Vt(self, vec, for_H=True): 21 | """ 22 | Multiplies the input vector by V transposed 23 | """ 24 | raise NotImplementedError() 25 | 26 | def U(self, vec): 27 | """ 28 | Multiplies the input vector by U 29 | """ 30 | raise NotImplementedError() 31 | 32 | def Ut(self, vec): 33 | """ 34 | Multiplies the input vector by U transposed 35 | """ 36 | raise NotImplementedError() 37 | 38 | def singulars(self): 39 | """ 40 | Returns a vector containing the singular values. The shape of the vector should be the same as the smaller dimension (like U) 41 | """ 42 | raise NotImplementedError() 43 | 44 | def add_zeros(self, vec): 45 | """ 46 | Adds trailing zeros to turn a vector from the small dimension (U) to the big dimension (V) 47 | """ 48 | raise NotImplementedError() 49 | 50 | def H(self, vec): 51 | """ 52 | Multiplies the input vector by H 53 | """ 54 | temp = self.Vt(vec) 55 | singulars = self.singulars() 56 | return self.U(singulars * temp[:, :singulars.shape[0]]) 57 | 58 | def Ht(self, vec): 59 | """ 60 | Multiplies the input vector by H transposed 61 | """ 62 | temp = self.Ut(vec) 63 | singulars = self.singulars() 64 | return self.V(self.add_zeros(singulars * temp[:, :singulars.shape[0]])) 65 | 66 | def H_pinv(self, vec): 67 | """ 68 | Multiplies the input vector by the pseudo inverse of H 69 | """ 70 | temp = self.Ut(vec) 71 | singulars = self.singulars() 72 | temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] / singulars 73 | return self.V(self.add_zeros(temp)) 74 | 75 | 76 | # a memory inefficient implementation for any general degradation H 77 | class GeneralH(H_functions): 78 | def mat_by_vec(self, M, v): 79 | vshape = v.shape[1] 80 | if len(v.shape) > 2: vshape = vshape * v.shape[2] 81 | if len(v.shape) > 3: vshape = vshape * v.shape[3] 82 | return torch.matmul(M, v.view(v.shape[0], vshape, 83 | 1)).view(v.shape[0], M.shape[0]) 84 | 85 | def __init__(self, H): 86 | self._U, self._singulars, self._V = torch.svd(H, some=False) 87 | self._Vt = self._V.transpose(0, 1) 88 | self._Ut = self._U.transpose(0, 1) 89 | 90 | ZERO = 1e-3 91 | self._singulars[self._singulars < ZERO] = 0 92 | print(len([x.item() for x in self._singulars if x == 0])) 93 | 94 | def V(self, vec): 95 | return self.mat_by_vec(self._V, vec.clone()) 96 | 97 | def Vt(self, vec, for_H=True): 98 | return self.mat_by_vec(self._Vt, vec.clone()) 99 | 100 | def U(self, vec): 101 | return self.mat_by_vec(self._U, vec.clone()) 102 | 103 | def Ut(self, vec): 104 | return self.mat_by_vec(self._Ut, vec.clone()) 105 | 106 | def singulars(self): 107 | return self._singulars 108 | 109 | def add_zeros(self, vec): 110 | out = torch.zeros(vec.shape[0], self._V.shape[0], device=vec.device) 111 | out[:, :self._U.shape[0]] = vec.clone().reshape(vec.shape[0], -1) 112 | return out 113 | 114 | 115 | # Inpainting 116 | class Inpainting(H_functions): 117 | def __init__(self, channels, img_dim, missing_indices, device): 118 | super(Inpainting, self).__init__() 119 | self.channels = channels 120 | self.img_dim = img_dim 121 | self._singulars = torch.nn.Parameter(torch.ones(channels * img_dim ** 2 - missing_indices.shape[0]).to(device), 122 | requires_grad=False) 123 | self.missing_indices = torch.nn.Parameter(missing_indices, requires_grad=False) 124 | self.kept_indices = torch.nn.Parameter( 125 | torch.Tensor([i for i in range(channels * img_dim ** 2) if i not in missing_indices]).to(device).long(), 126 | requires_grad=False) 127 | 128 | def V(self, vec): 129 | temp = vec.clone().reshape(vec.shape[0], -1) 130 | out = torch.zeros_like(temp) 131 | out[:, self.kept_indices] = temp[:, :self.kept_indices.shape[0]] 132 | out[:, self.missing_indices] = temp[:, self.kept_indices.shape[0]:] 133 | return out.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1) 134 | 135 | def Vt(self, vec, for_H=True): 136 | temp = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1) 137 | out = torch.zeros_like(temp) 138 | out[:, :self.kept_indices.shape[0]] = temp[:, self.kept_indices] 139 | out[:, self.kept_indices.shape[0]:] = temp[:, self.missing_indices] 140 | return out 141 | 142 | def U(self, vec): 143 | return vec.clone().reshape(vec.shape[0], -1) 144 | 145 | def Ut(self, vec): 146 | return vec.clone().reshape(vec.shape[0], -1) 147 | 148 | def singulars(self): 149 | return self._singulars 150 | 151 | def add_zeros(self, vec): 152 | temp = torch.zeros((vec.shape[0], self.channels * self.img_dim ** 2), device=vec.device) 153 | reshaped = vec.clone().reshape(vec.shape[0], -1) 154 | temp[:, :reshaped.shape[1]] = reshaped 155 | return temp 156 | 157 | 158 | # Denoising 159 | class Denoising(H_functions): 160 | def __init__(self, channels, img_dim, device): 161 | self._singulars = torch.ones(channels * img_dim ** 2, device=device) 162 | 163 | def V(self, vec): 164 | return vec.clone().reshape(vec.shape[0], -1) 165 | 166 | def Vt(self, vec, for_H=True): 167 | return vec.clone().reshape(vec.shape[0], -1) 168 | 169 | def U(self, vec): 170 | return vec.clone().reshape(vec.shape[0], -1) 171 | 172 | def Ut(self, vec): 173 | return vec.clone().reshape(vec.shape[0], -1) 174 | 175 | def singulars(self): 176 | return self._singulars 177 | 178 | def add_zeros(self, vec): 179 | return vec.clone().reshape(vec.shape[0], -1) 180 | 181 | 182 | # Super Resolution 183 | class SuperResolution(H_functions): 184 | def __init__(self, channels, img_dim, ratio, device): # ratio = 2 or 4 185 | super(SuperResolution, self).__init__() 186 | assert img_dim % ratio == 0 187 | self.img_dim = img_dim 188 | self.channels = channels 189 | self.y_dim = img_dim // ratio 190 | self.ratio = ratio 191 | H = torch.Tensor([[1 / ratio ** 2] * ratio ** 2]).to(device) 192 | self.U_small, self.singulars_small, self.V_small = torch.svd(H, some=False) 193 | self.U_small = torch.nn.Parameter(self.U_small, requires_grad=False) 194 | self.V_small = torch.nn.Parameter(self.V_small, requires_grad=False) 195 | self.singulars_small = torch.nn.Parameter(self.singulars_small, requires_grad=False) 196 | self.Vt_small = torch.nn.Parameter(self.V_small.transpose(0, 1), requires_grad=False) 197 | 198 | def V(self, vec): 199 | # reorder the vector back into patches (because singulars are ordered descendingly) 200 | temp = vec.clone().reshape(vec.shape[0], -1) 201 | patches = torch.zeros(vec.shape[0], self.channels, self.y_dim ** 2, self.ratio ** 2, device=vec.device) 202 | patches[:, :, :, 0] = temp[:, :self.channels * self.y_dim ** 2].view(vec.shape[0], self.channels, -1) 203 | for idx in range(self.ratio ** 2 - 1): 204 | patches[:, :, :, idx + 1] = temp[:, (self.channels * self.y_dim ** 2 + idx)::self.ratio ** 2 - 1].view( 205 | vec.shape[0], self.channels, -1) 206 | # multiply each patch by the small V 207 | patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio ** 2, 1)).reshape(vec.shape[0], 208 | self.channels, -1, 209 | self.ratio ** 2) 210 | # repatch the patches into an image 211 | patches_orig = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio) 212 | recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous() 213 | recon = recon.reshape(vec.shape[0], self.channels * self.img_dim ** 2) 214 | return recon 215 | 216 | def Vt(self, vec, for_H=True): 217 | # extract flattened patches 218 | patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim) 219 | patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 220 | unfold_shape = patches.shape 221 | patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2) 222 | # multiply each by the small V transposed 223 | patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio ** 2, 1)).reshape(vec.shape[0], 224 | self.channels, -1, 225 | self.ratio ** 2) 226 | # reorder the vector to have the first entry first (because singulars are ordered descendingly) 227 | recon = torch.zeros(vec.shape[0], self.channels * self.img_dim ** 2, device=vec.device) 228 | recon[:, :self.channels * self.y_dim ** 2] = patches[:, :, :, 0].view(vec.shape[0], 229 | self.channels * self.y_dim ** 2) 230 | for idx in range(self.ratio ** 2 - 1): 231 | recon[:, (self.channels * self.y_dim ** 2 + idx)::self.ratio ** 2 - 1] = patches[:, :, :, idx + 1].view( 232 | vec.shape[0], self.channels * self.y_dim ** 2) 233 | return recon 234 | 235 | def U(self, vec): 236 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 237 | 238 | def Ut(self, vec): # U is 1x1, so U^T = U 239 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 240 | 241 | def singulars(self): 242 | return self.singulars_small.repeat(self.channels * self.y_dim ** 2) 243 | 244 | def add_zeros(self, vec): 245 | reshaped = vec.clone().reshape(vec.shape[0], -1) 246 | temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio ** 2), device=vec.device) 247 | temp[:, :reshaped.shape[1]] = reshaped 248 | return temp 249 | 250 | 251 | # Colorization 252 | class Colorization(H_functions): 253 | def __init__(self, img_dim, device): 254 | super(Colorization, self).__init__() 255 | self.channels = 3 256 | self.img_dim = img_dim 257 | # Do the SVD for the per-pixel matrix 258 | H = torch.nn.Parameter(torch.Tensor([[0.3333, 0.3333, 0.3333]]), requires_grad=False).to(device) 259 | self.U_small, self.singulars_small, self.V_small = torch.svd(H, some=False) 260 | self.Vt_small = self.V_small.transpose(0, 1) 261 | self.Vt_small = torch.nn.Parameter(self.Vt_small, requires_grad=False) 262 | self.V_small = torch.nn.Parameter(self.V_small, requires_grad=False) 263 | self.singulars_small = torch.nn.Parameter(self.singulars_small, requires_grad=False) 264 | self.U_small = torch.nn.Parameter(self.U_small, requires_grad=False) 265 | 266 | def V(self, vec): 267 | # get the needles 268 | needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) # shape: B, WH, C' 269 | # multiply each needle by the small V 270 | needles = torch.matmul(self.V_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, 271 | self.channels) # shape: B, WH, C 272 | # permute back to vector representation 273 | recon = needles.permute(0, 2, 1) # shape: B, C, WH 274 | return recon.reshape(vec.shape[0], -1) 275 | 276 | def Vt(self, vec, for_H=True): 277 | # get the needles 278 | needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) # shape: B, WH, C 279 | # multiply each needle by the small V transposed 280 | needles = torch.matmul(self.Vt_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, 281 | self.channels) # shape: B, WH, C' 282 | # reorder the vector so that the first entry of each needle is at the top 283 | recon = needles.permute(0, 2, 1).reshape(vec.shape[0], -1) 284 | return recon 285 | 286 | def U(self, vec): 287 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 288 | 289 | def Ut(self, vec): # U is 1x1, so U^T = U 290 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 291 | 292 | def singulars(self): 293 | return self.singulars_small.repeat(self.img_dim ** 2) 294 | 295 | def add_zeros(self, vec): 296 | reshaped = vec.clone().reshape(vec.shape[0], -1) 297 | temp = torch.zeros((vec.shape[0], self.channels * self.img_dim ** 2), device=vec.device) 298 | temp[:, :self.img_dim ** 2] = reshaped 299 | return temp 300 | 301 | 302 | # Walsh-Hadamard Compressive Sensing 303 | class WalshHadamardCS(H_functions): 304 | def fwht(self, vec): # the Fast Walsh Hadamard Transform is the same as its inverse 305 | a = vec.reshape(vec.shape[0], self.channels, self.img_dim ** 2) 306 | h = 1 307 | while h < self.img_dim ** 2: 308 | a = a.reshape(vec.shape[0], self.channels, -1, h * 2) 309 | b = a.clone() 310 | a[:, :, :, :h] = b[:, :, :, :h] + b[:, :, :, h:2 * h] 311 | a[:, :, :, h:2 * h] = b[:, :, :, :h] - b[:, :, :, h:2 * h] 312 | h *= 2 313 | a = a.reshape(vec.shape[0], self.channels, self.img_dim ** 2) / self.img_dim 314 | return a 315 | 316 | def __init__(self, channels, img_dim, ratio, perm, device): 317 | self.channels = channels 318 | self.img_dim = img_dim 319 | self.ratio = ratio 320 | self.perm = perm 321 | self._singulars = torch.ones(channels * img_dim ** 2 // ratio, device=device) 322 | 323 | def V(self, vec): 324 | temp = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device) 325 | temp[:, :, self.perm] = vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1) 326 | return self.fwht(temp).reshape(vec.shape[0], -1) 327 | 328 | def Vt(self, vec, for_H=True): 329 | return self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1) 330 | 331 | def U(self, vec): 332 | return vec.clone().reshape(vec.shape[0], -1) 333 | 334 | def Ut(self, vec): 335 | return vec.clone().reshape(vec.shape[0], -1) 336 | 337 | def singulars(self): 338 | return self._singulars 339 | 340 | def add_zeros(self, vec): 341 | out = torch.zeros(vec.shape[0], self.channels * self.img_dim ** 2, device=vec.device) 342 | out[:, :self.channels * self.img_dim ** 2 // self.ratio] = vec.clone().reshape(vec.shape[0], -1) 343 | return out 344 | 345 | 346 | # Convolution-based super-resolution 347 | class SRConv(H_functions): 348 | def mat_by_img(self, M, v, dim): 349 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, dim, 350 | dim)).reshape(v.shape[0], self.channels, M.shape[0], dim) 351 | 352 | def img_by_mat(self, v, M, dim): 353 | return torch.matmul(v.reshape(v.shape[0] * self.channels, dim, 354 | dim), M).reshape(v.shape[0], self.channels, dim, M.shape[1]) 355 | 356 | def __init__(self, kernel, channels, img_dim, device, stride=1): 357 | self.img_dim = img_dim 358 | self.channels = channels 359 | self.ratio = stride 360 | small_dim = img_dim // stride 361 | self.small_dim = small_dim 362 | # build 1D conv matrix 363 | H_small = torch.zeros(small_dim, img_dim, device=device) 364 | for i in range(stride // 2, img_dim + stride // 2, stride): 365 | for j in range(i - kernel.shape[0] // 2, i + kernel.shape[0] // 2): 366 | j_effective = j 367 | # reflective padding 368 | if j_effective < 0: j_effective = -j_effective - 1 369 | if j_effective >= img_dim: j_effective = (img_dim - 1) - (j_effective - img_dim) 370 | # matrix building 371 | H_small[i // stride, j_effective] += kernel[j - i + kernel.shape[0] // 2] 372 | # get the svd of the 1D conv 373 | self.U_small, self.singulars_small, self.V_small = torch.svd(H_small, some=False) 374 | ZERO = 3e-2 375 | self.singulars_small[self.singulars_small < ZERO] = 0 376 | # calculate the singular values of the big matrix 377 | self._singulars = torch.matmul(self.singulars_small.reshape(small_dim, 1), 378 | self.singulars_small.reshape(1, small_dim)).reshape(small_dim ** 2) 379 | # permutation for matching the singular values. See P_1 in Appendix D.5. 380 | self._perm = torch.Tensor([self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim)] + \ 381 | [self.img_dim * i + j for i in range(self.small_dim) for j in 382 | range(self.small_dim, self.img_dim)]).to(device).long() 383 | 384 | def V(self, vec): 385 | # invert the permutation 386 | temp = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device) 387 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim ** 2, self.channels)[:, 388 | :self._perm.shape[0], :] 389 | temp[:, self._perm.shape[0]:, :] = vec.clone().reshape(vec.shape[0], self.img_dim ** 2, self.channels)[:, 390 | self._perm.shape[0]:, :] 391 | temp = temp.permute(0, 2, 1) 392 | # multiply the image by V from the left and by V^T from the right 393 | out = self.mat_by_img(self.V_small, temp, self.img_dim) 394 | out = self.img_by_mat(out, self.V_small.transpose(0, 1), self.img_dim).reshape(vec.shape[0], -1) 395 | return out 396 | 397 | def Vt(self, vec, for_H=True): 398 | # multiply the image by V^T from the left and by V from the right 399 | temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone(), self.img_dim) 400 | temp = self.img_by_mat(temp, self.V_small, self.img_dim).reshape(vec.shape[0], self.channels, -1) 401 | # permute the entries 402 | temp[:, :, :self._perm.shape[0]] = temp[:, :, self._perm] 403 | temp = temp.permute(0, 2, 1) 404 | return temp.reshape(vec.shape[0], -1) 405 | 406 | def U(self, vec): 407 | # invert the permutation 408 | temp = torch.zeros(vec.shape[0], self.small_dim ** 2, self.channels, device=vec.device) 409 | temp[:, :self.small_dim ** 2, :] = vec.clone().reshape(vec.shape[0], self.small_dim ** 2, self.channels) 410 | temp = temp.permute(0, 2, 1) 411 | # multiply the image by U from the left and by U^T from the right 412 | out = self.mat_by_img(self.U_small, temp, self.small_dim) 413 | out = self.img_by_mat(out, self.U_small.transpose(0, 1), self.small_dim).reshape(vec.shape[0], -1) 414 | return out 415 | 416 | def Ut(self, vec): 417 | # multiply the image by U^T from the left and by U from the right 418 | temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone(), self.small_dim) 419 | temp = self.img_by_mat(temp, self.U_small, self.small_dim).reshape(vec.shape[0], self.channels, -1) 420 | # permute the entries 421 | temp = temp.permute(0, 2, 1) 422 | return temp.reshape(vec.shape[0], -1) 423 | 424 | def singulars(self): 425 | return self._singulars.repeat_interleave(3).reshape(-1) 426 | 427 | def add_zeros(self, vec): 428 | reshaped = vec.clone().reshape(vec.shape[0], -1) 429 | temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio ** 2), device=vec.device) 430 | temp[:, :reshaped.shape[1]] = reshaped 431 | return temp 432 | 433 | 434 | # Deblurring 435 | class Deblurring(H_functions): 436 | def mat_by_img(self, M, v): 437 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim, 438 | self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim) 439 | 440 | def img_by_mat(self, v, M): 441 | return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim, 442 | self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1]) 443 | 444 | def __init__(self, kernel, channels, img_dim, device, ZERO=3e-2): 445 | self.img_dim = img_dim 446 | self.channels = channels 447 | # build 1D conv matrix 448 | H_small = torch.zeros(img_dim, img_dim, device=device) 449 | for i in range(img_dim): 450 | for j in range(i - kernel.shape[0] // 2, i + kernel.shape[0] // 2): 451 | if j < 0 or j >= img_dim: continue 452 | H_small[i, j] = kernel[j - i + kernel.shape[0] // 2] 453 | # get the svd of the 1D conv 454 | self.U_small, self.singulars_small, self.V_small = torch.svd(H_small, some=False) 455 | self.U_small = torch.nn.Parameter(self.U_small) 456 | self.singulars_small = torch.nn.Parameter(self.singulars_small) 457 | self.V_small = torch.nn.Parameter(self.V_small) 458 | # ZERO = 3e-2 459 | self.singulars_small[self.singulars_small < ZERO] = 0 460 | # calculate the singular values of the big matrix 461 | self._singulars = torch.nn.Parameter( 462 | torch.matmul(self.singulars_small.reshape(img_dim, 1), self.singulars_small.reshape(1, img_dim)).reshape( 463 | img_dim ** 2)) 464 | # sort the big matrix singulars and save the permutation 465 | self._singulars, self._perm = self._singulars.sort(descending=True) # , stable=True) 466 | self._singulars = torch.nn.Parameter(self._singulars) 467 | self._perm = torch.nn.Parameter(self._perm) 468 | 469 | def V(self, vec): 470 | # invert the permutation 471 | temp = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device) 472 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim ** 2, self.channels) 473 | temp = temp.permute(0, 2, 1) 474 | # multiply the image by V from the left and by V^T from the right 475 | out = self.mat_by_img(self.V_small, temp) 476 | out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1) 477 | return out 478 | 479 | def Vt(self, vec, for_H=True): 480 | # multiply the image by V^T from the left and by V from the right 481 | temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone()) 482 | temp = self.img_by_mat(temp, self.V_small).reshape(vec.shape[0], self.channels, -1) 483 | # permute the entries according to the singular values 484 | temp = temp[:, :, self._perm].permute(0, 2, 1) 485 | return temp.reshape(vec.shape[0], -1) 486 | 487 | def U(self, vec): 488 | # invert the permutation 489 | temp = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device) 490 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim ** 2, self.channels) 491 | temp = temp.permute(0, 2, 1) 492 | # multiply the image by U from the left and by U^T from the right 493 | out = self.mat_by_img(self.U_small, temp) 494 | out = self.img_by_mat(out, self.U_small.transpose(0, 1)).reshape(vec.shape[0], -1) 495 | return out 496 | 497 | def Ut(self, vec): 498 | # multiply the image by U^T from the left and by U from the right 499 | temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone()) 500 | temp = self.img_by_mat(temp, self.U_small).reshape(vec.shape[0], self.channels, -1) 501 | # permute the entries according to the singular values 502 | temp = temp[:, :, self._perm].permute(0, 2, 1) 503 | return temp.reshape(vec.shape[0], -1) 504 | 505 | def singulars(self): 506 | return self._singulars.repeat(1, 3).reshape(-1) 507 | 508 | def add_zeros(self, vec): 509 | return vec.clone().reshape(vec.shape[0], -1) 510 | 511 | 512 | # Anisotropic Deblurring 513 | class Deblurring2D(H_functions): 514 | def mat_by_img(self, M, v): 515 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim, 516 | self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim) 517 | 518 | def img_by_mat(self, v, M): 519 | return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim, 520 | self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1]) 521 | 522 | def __init__(self, kernel1, kernel2, channels, img_dim, device): 523 | super(Deblurring2D, self).__init__() 524 | self.img_dim = img_dim 525 | self.channels = channels 526 | # build 1D conv matrix - kernel1 527 | H_small1 = torch.zeros(img_dim, img_dim, device=device) 528 | for i in range(img_dim): 529 | for j in range(i - kernel1.shape[0] // 2, i + kernel1.shape[0] // 2): 530 | if j < 0 or j >= img_dim: continue 531 | H_small1[i, j] = kernel1[j - i + kernel1.shape[0] // 2] 532 | # build 1D conv matrix - kernel2 533 | H_small2 = torch.zeros(img_dim, img_dim, device=device) 534 | for i in range(img_dim): 535 | for j in range(i - kernel2.shape[0] // 2, i + kernel2.shape[0] // 2): 536 | if j < 0 or j >= img_dim: continue 537 | H_small2[i, j] = kernel2[j - i + kernel2.shape[0] // 2] 538 | # get the svd of the 1D conv 539 | self.U_small1, self.singulars_small1, self.V_small1 = torch.svd(H_small1, some=False) 540 | self.U_small2, self.singulars_small2, self.V_small2 = torch.svd(H_small2, some=False) 541 | ZERO = 3e-2 542 | self.singulars_small1[self.singulars_small1 < ZERO] = 0 543 | self.singulars_small2[self.singulars_small2 < ZERO] = 0 544 | 545 | self.U_small1, self.U_small2 = torch.nn.Parameter(self.U_small1, requires_grad=False), torch.nn.Parameter( 546 | self.U_small2, requires_grad=False) 547 | self.singulars_small1 = torch.nn.Parameter(self.singulars_small1, requires_grad=False) 548 | self.singulars_small2 = torch.nn.Parameter(self.singulars_small2, requires_grad=False) 549 | self.V_small1 = torch.nn.Parameter(self.V_small1, requires_grad=False) 550 | self.V_small2 = torch.nn.Parameter(self.V_small2, requires_grad=False) 551 | 552 | # calculate the singular values of the big matrix 553 | self._singulars = torch.matmul(self.singulars_small1.reshape(img_dim, 1), 554 | self.singulars_small2.reshape(1, img_dim)).reshape(img_dim ** 2) 555 | # sort the big matrix singulars and save the permutation 556 | self._singulars, self._perm = self._singulars.sort(descending=True) # , stable=True) 557 | self._singulars = torch.nn.Parameter(self._singulars, requires_grad=False) 558 | self._perm = torch.nn.Parameter(self._perm, requires_grad=False) 559 | 560 | def V(self, vec): 561 | # invert the permutation 562 | temp = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device) 563 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim ** 2, self.channels) 564 | temp = temp.permute(0, 2, 1) 565 | # multiply the image by V from the left and by V^T from the right 566 | out = self.mat_by_img(self.V_small1, temp) 567 | out = self.img_by_mat(out, self.V_small2.transpose(0, 1)).reshape(vec.shape[0], -1) 568 | return out 569 | 570 | def Vt(self, vec, for_H=True): 571 | # multiply the image by V^T from the left and by V from the right 572 | temp = self.mat_by_img(self.V_small1.transpose(0, 1), vec.clone()) 573 | temp = self.img_by_mat(temp, self.V_small2).reshape(vec.shape[0], self.channels, -1) 574 | # permute the entries according to the singular values 575 | temp = temp[:, :, self._perm].permute(0, 2, 1) 576 | return temp.reshape(vec.shape[0], -1) 577 | 578 | def U(self, vec): 579 | # invert the permutation 580 | temp = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device) 581 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim ** 2, self.channels) 582 | temp = temp.permute(0, 2, 1) 583 | # multiply the image by U from the left and by U^T from the right 584 | out = self.mat_by_img(self.U_small1, temp) 585 | out = self.img_by_mat(out, self.U_small2.transpose(0, 1)).reshape(vec.shape[0], -1) 586 | return out 587 | 588 | def Ut(self, vec): 589 | # multiply the image by U^T from the left and by U from the right 590 | temp = self.mat_by_img(self.U_small1.transpose(0, 1), vec.clone()) 591 | temp = self.img_by_mat(temp, self.U_small2).reshape(vec.shape[0], self.channels, -1) 592 | # permute the entries according to the singular values 593 | temp = temp[:, :, self._perm].permute(0, 2, 1) 594 | return temp.reshape(vec.shape[0], -1) 595 | 596 | def singulars(self): 597 | return self._singulars.repeat(1, self.channels).reshape(-1) 598 | 599 | def add_zeros(self, vec): 600 | return vec.clone().reshape(vec.shape[0], -1) -------------------------------------------------------------------------------- /scripts/viz_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | from mcg_diff.particle_filter import mcg_diff 4 | from mcg_diff.sgm import ScoreModel 5 | from mcg_diff.utils import NetReparametrized, get_optimal_timesteps_from_singular_values 6 | 7 | def ou_mixt(alpha_t, means, dim, weights): 8 | cat = torch.distributions.Categorical(weights, validate_args=False) 9 | 10 | ou_norm = torch.distributions.MultivariateNormal( 11 | torch.vstack(tuple((alpha_t**.5) * m for m in means)), 12 | torch.eye(dim).repeat(len(means), 1, 1), validate_args=False) 13 | return torch.distributions.MixtureSameFamily(cat, ou_norm, validate_args=False) 14 | 15 | 16 | def get_posterior(obs, prior, A, Sigma_y): 17 | modified_means = [] 18 | modified_covars = [] 19 | weights = [] 20 | precision = torch.linalg.inv(Sigma_y) 21 | for loc, cov, weight in zip(prior.component_distribution.loc, 22 | prior.component_distribution.covariance_matrix, 23 | prior.mixture_distribution.probs): 24 | new_dist = gaussian_posterior(obs, 25 | A, 26 | torch.zeros_like(obs), 27 | precision, 28 | loc, 29 | cov) 30 | modified_means.append(new_dist.loc) 31 | modified_covars.append(new_dist.covariance_matrix) 32 | prior_x = torch.distributions.MultivariateNormal(loc=loc, covariance_matrix=cov) 33 | residue = obs - A @ new_dist.loc 34 | log_constant = -(residue[None, :] @ precision @ residue[:, None]) / 2 + \ 35 | prior_x.log_prob(new_dist.loc) - \ 36 | new_dist.log_prob(new_dist.loc) 37 | weights.append(torch.log(weight).item() + log_constant) 38 | weights = torch.tensor(weights) 39 | weights = weights - torch.logsumexp(weights, dim=0) 40 | cat = torch.distributions.Categorical(logits=weights) 41 | ou_norm = torch.distributions.MultivariateNormal(loc=torch.stack(modified_means, dim=0), 42 | covariance_matrix=torch.stack(modified_covars, dim=0)) 43 | return torch.distributions.MixtureSameFamily(cat, ou_norm) 44 | 45 | 46 | def gaussian_posterior(y, 47 | likelihood_A, 48 | likelihood_bias, 49 | likelihood_precision, 50 | prior_loc, 51 | prior_covar): 52 | prior_precision_matrix = torch.linalg.inv(prior_covar) 53 | posterior_precision_matrix = prior_precision_matrix + likelihood_A.T @ likelihood_precision @ likelihood_A 54 | posterior_covariance_matrix = torch.linalg.inv(posterior_precision_matrix) 55 | posterior_mean = posterior_covariance_matrix @ (likelihood_A.T @ likelihood_precision @ (y - likelihood_bias) + prior_precision_matrix @ prior_loc) 56 | try: 57 | posterior_covariance_matrix = (posterior_covariance_matrix + posterior_covariance_matrix.T) / 2 58 | return torch.distributions.MultivariateNormal(loc=posterior_mean, covariance_matrix=posterior_covariance_matrix, validate_args=False) 59 | except ValueError: 60 | u, s, v = torch.linalg.svd(posterior_covariance_matrix, full_matrices=False) 61 | s = s.clip(1e-12, 1e6).real 62 | posterior_covariance_matrix = u.real @ torch.diag_embed(s) @ v.real 63 | posterior_covariance_matrix = (posterior_covariance_matrix + posterior_covariance_matrix.T) / 2 64 | return torch.distributions.MultivariateNormal(loc=posterior_mean, covariance_matrix=posterior_covariance_matrix, validate_args=False) 65 | 66 | 67 | def build_extended_svd(A: torch.tensor): 68 | U, d, V = torch.linalg.svd(A, full_matrices=True) 69 | coordinate_mask = torch.ones_like(V[0]) 70 | coordinate_mask[len(d):] = 0 71 | return U, d, coordinate_mask, V 72 | 73 | 74 | def generate_measurement_equations(dim, dim_y, mixt): 75 | A = torch.randn((dim_y, dim)) 76 | 77 | u, diag, coordinate_mask, v = build_extended_svd(A) 78 | diag = torch.sort(torch.rand_like(diag), descending=True).values 79 | 80 | A = u @ (torch.diag(diag) @ v[coordinate_mask == 1, :]) 81 | init_sample = mixt.sample() 82 | std = (torch.rand((1,)))[0]* max(diag) 83 | var_observations = std**2 84 | 85 | init_obs = A @ init_sample 86 | init_obs += torch.randn_like(init_obs) * std 87 | return A, var_observations, init_obs 88 | 89 | random_state = 10 90 | n_samples = 1000 91 | dims = (1, 8) 92 | torch.manual_seed(random_state) 93 | n_samples = n_samples 94 | dim_y, dim_x = dims 95 | # setup of the inverse problem 96 | means = [] 97 | for i in range(-2, 3): 98 | means += [torch.tensor([-8. * i, -8. * j] * (dim_x // 2)) for j in range(-2, 3)] 99 | weights = torch.randn(len(means)) ** 2 100 | weights = weights / weights.sum() 101 | ou_mixt_fun = partial(ou_mixt, 102 | means=means, 103 | dim=dim_x, 104 | weights=weights) 105 | 106 | mixt = ou_mixt_fun(1) 107 | 108 | A, var_observations, init_obs = generate_measurement_equations(dim_x, dim_y, mixt) 109 | posterior = get_posterior(init_obs, mixt, A, torch.eye(dim_y)*var_observations) 110 | target_samples = posterior.sample((n_samples,)) 111 | betas = torch.linspace(.02, 1e-4, steps=999) 112 | alphas_cumprod = torch.cumprod(torch.tensor([1, ] + [1 - beta for beta in betas]), dim=0) 113 | 114 | 115 | observation = init_obs 116 | forward_operator = A 117 | observation_noise = var_observations 118 | score_network = lambda x, alpha_t: torch.func.grad(lambda y: ou_mixt_fun(alpha_t).log_prob(y).sum())(x) 119 | reference_samples = target_samples 120 | alphas_cumprod = alphas_cumprod 121 | 122 | u, diag, coordinate_mask, v = build_extended_svd(forward_operator) 123 | score_model = ScoreModel(NetReparametrized( 124 | base_score_module=lambda x, t: - score_network(x, alphas_cumprod[t]) * ((1 - alphas_cumprod[t]) ** .5), 125 | orthogonal_transformation=v), 126 | alphas_cumprod=alphas_cumprod, 127 | device='cpu') 128 | 129 | n_steps = 100 130 | adapted_timesteps = get_optimal_timesteps_from_singular_values(alphas_cumprod=alphas_cumprod, 131 | singular_value=diag, 132 | n_timesteps=n_steps, 133 | var=observation_noise, 134 | mode='else') 135 | 136 | 137 | def mcg_diff_fun(initial_samples): 138 | samples, log_weights = mcg_diff( 139 | initial_particles=initial_samples, 140 | observation=(u.T @ observation), 141 | score_model=score_model, 142 | likelihood_diagonal=diag, 143 | coordinates_mask=coordinate_mask.bool(), 144 | var_observation=observation_noise, 145 | timesteps=adapted_timesteps, 146 | eta=1, 147 | gaussian_var=1e-8, 148 | ) 149 | print(log_weights) 150 | return v.T @ \ 151 | samples[torch.distributions.Categorical(logits=log_weights, validate_args=False).sample(sample_shape=(1,))][0] 152 | 153 | 154 | sampler = mcg_diff_fun 155 | dim_y, dim_x = forward_operator.shape 156 | n_samples = n_samples 157 | 158 | n_particles = 128 159 | initial_samples = torch.randn(size=(n_samples, n_particles, dim_x)) 160 | samples = torch.func.vmap(sampler, in_dims=(0,), randomness='different')(initial_samples) 161 | reference_samples = posterior.sample((n_samples,)) 162 | 163 | import matplotlib.pyplot as plt 164 | plt.scatter(*reference_samples[:, :2].T, label="Posterior", alpha=.3) 165 | plt.scatter(*samples[:, :2].T, label="mcg_diff", alpha=.4) 166 | plt.xlim(-20, 20) 167 | plt.ylim(-20, 20) 168 | plt.legend() 169 | plt.show() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='mcg_diff', 5 | version='0.0', 6 | packages=['mcg_diff'], 7 | url='', 8 | license='Apache 2.0', 9 | author='G. Cardoso and Y. Janati', 10 | author_email='gabriel.victorino-cardoso@polytechnique.edu', 11 | description='' 12 | ) 13 | -------------------------------------------------------------------------------- /tests/test_particle_filter.py: -------------------------------------------------------------------------------- 1 | from mcg_diff.particle_filter import mcg_diff 2 | from mcg_diff.sgm import ScoreModel 3 | from functools import partial 4 | import torch 5 | 6 | 7 | def test_particle_filter_inpainting(): 8 | beta_min = 0.1 9 | beta_max = 30 10 | beta_d = beta_max - beta_min 11 | t=torch.linspace(0, 1, steps=1000) 12 | alphas_cumprod = torch.exp(-.5*(beta_max-beta_min)*(t**2) - beta_min*t) 13 | timesteps = torch.arange(0, 1001, 10) 14 | timesteps[-1] -= 1 15 | samples, lw = mcg_diff( 16 | initial_particles=torch.randn(size=(100, 2)), 17 | observation=torch.tensor([0.,]), 18 | var_observation=0., 19 | score_model=ScoreModel( 20 | net=lambda x, t: ((1 - alphas_cumprod[t])**.5)*x, 21 | alphas_cumprod=alphas_cumprod, 22 | device='cpu' 23 | ), 24 | likelihood_diagonal=torch.tensor([1.,]), 25 | coordinates_mask=torch.tensor([True, False]), 26 | timesteps=timesteps, 27 | gaussian_var=1e-6, 28 | ) 29 | assert samples.shape == (100, 2) 30 | assert (samples[:, 0]**2).max() < 1e-5 31 | assert lw.shape == (100,) 32 | 33 | 34 | def test_particle_filter_noisy(): 35 | beta_min = 0.1 36 | beta_max = 30 37 | beta_d = beta_max - beta_min 38 | t=torch.linspace(0, 1, steps=1000) 39 | alphas_cumprod = torch.exp(-.5*(beta_max-beta_min)*(t**2) - beta_min*t) 40 | timesteps = torch.arange(0, 1001, 10) 41 | timesteps[-1] -= 1 42 | samples, lw = mcg_diff( 43 | initial_particles=torch.randn(size=(100, 2)), 44 | observation=torch.tensor([0.,]), 45 | var_observation=(1 - alphas_cumprod[timesteps[1]]).item(), 46 | score_model=ScoreModel( 47 | net=lambda x, t: ((1 - alphas_cumprod[t])**.5)*x, 48 | alphas_cumprod=alphas_cumprod, 49 | device='cpu' 50 | ), 51 | likelihood_diagonal=torch.tensor([1.,]), 52 | coordinates_mask=torch.tensor([True, False]), 53 | timesteps=timesteps, 54 | ) 55 | assert samples.shape == (100, 2) 56 | assert lw.shape == (100,) 57 | 58 | 59 | def test_vmap_particle_filter_inpainting(): 60 | beta_min = 0.1 61 | beta_max = 30 62 | beta_d = beta_max - beta_min 63 | t=torch.linspace(0, 1, steps=1000) 64 | alphas_cumprod = torch.exp(-.5*(beta_max-beta_min)*(t**2) - beta_min*t) 65 | timesteps = torch.arange(0, 1001, 10) 66 | timesteps[-1] -= 1 67 | samples, lw = torch.func.vmap(mcg_diff, in_dims=(0,), randomness='different')( 68 | torch.randn(size=(10, 100, 2)), 69 | observation=torch.tensor([0., ]), 70 | var_observation=0., 71 | score_model=ScoreModel( 72 | net=lambda x, t: ((1 - alphas_cumprod[t]) ** .5) * x, 73 | alphas_cumprod=alphas_cumprod, 74 | device='cpu' 75 | ), 76 | likelihood_diagonal=torch.tensor([1., ]), 77 | coordinates_mask=torch.tensor([True, False]), 78 | timesteps=timesteps, 79 | gaussian_var=1e-6 80 | ) 81 | assert samples.shape == (10, 100, 2) 82 | assert (samples[:, 0]**2).max() < 1e-5 83 | assert lw.shape == (10,100,) --------------------------------------------------------------------------------