├── 3rd-party-licenses.txt ├── 80mn_cifar_idxs.txt ├── LICENSE ├── README.md ├── SerraReplicationCode ├── ReferenceGlowVsDirectPng.ipynb └── ReferenceGlowVsDirectPng.py ├── cifar_indexes └── invglow ├── __init__.py ├── create_tiny.py ├── datasets.py ├── evaluate.py ├── exp.py ├── folder_locations.py ├── invertible ├── __init__.py ├── actnorm.py ├── affine.py ├── branching.py ├── categorical_mixture.py ├── conditional.py ├── coupling.py ├── distribution.py ├── expression.py ├── gaussian.py ├── graph.py ├── identity.py ├── init.py ├── inv_permute.py ├── inverse.py ├── noise.py ├── pure_model.py ├── sequential.py ├── split_merge.py ├── splitter.py └── view_as.py ├── load_data.py ├── losses.py ├── main.py ├── models ├── __init__.py ├── class_conditional.py ├── glow.py └── patch_glow.py ├── scheduler.py └── util.py /3rd-party-licenses.txt: -------------------------------------------------------------------------------- 1 | Third Party Licenses 2 | ==================== 3 | 4 | Benchmarks includes material from the projects listed below (Third Party 5 | IP). The original copyright notice and the license under which we received 6 | such Third Party IP, are set forth below. 7 | 8 | 9 | -------------------------------------------------------------------------- 10 | Overview 11 | -------------------------------------------------------------------------- 12 | 13 | Glow-PyTorch: 14 | 15 | Name: Glow-PyTorch 16 | Version: - 17 | URL: https://github.com/y0ast/Glow-PyTorch 18 | License: MIT 19 | Copyright: Copyright (c) 2019 Joost van Amersfoort 20 | Copyright (c) 2019 Yuki-Chai 21 | 22 | 23 | Lasagne: 24 | 25 | Name: Lasagne 26 | Version: - 27 | URL: https://github.com/Lasagne/Lasagne 28 | License: MIT 29 | Copyright: Copyright (c) 2014-2015 Lasagne contributors 30 | 31 | 32 | glow-pytorch: 33 | 34 | Name: glow-pytorch 35 | Version: - 36 | URL: https://github.com/rosinality/glow-pytorch/blob/master/LICENSE 37 | License: MIT 38 | Copyright: Copyright (c) 2018 Kim Seonghyeon 39 | 40 | 41 | reversible2: 42 | 43 | Name: reversible2 44 | Version: - 45 | URL: https://github.com/robintibor/reversible2 46 | License: MIT 47 | Copyright: Copyright (c) 2019 Neuromedical AI Lab Freiburg 48 | 49 | 50 | tensorflow/tensor2tensor: 51 | 52 | Name: tensorflow/tensor2tensor: 53 | Version: 2.0 54 | URL: https://github.com/tensorflow/tensor2tensor 55 | License: Apache 2.0 56 | Copyright: Copyright 2020 The TensorFlow/tensor2tensor Authors. All rights reserved. 57 | 58 | 59 | outlier-exposure: 60 | 61 | Name: outlier-exposure 62 | Version: - 63 | URL: https://github.com/hendrycks/outlier-exposure 64 | License: Apache 2.0 65 | Copyright: Copyright 2018 Dan Hendrycks 66 | 67 | 68 | -------------------------------------------------------------------------- 69 | Licenses 70 | -------------------------------------------------------------------------- 71 | 72 | a. Glow-PyTorch 73 | 74 | MIT License 75 | 76 | Copyright (c) 2019 Joost van Amersfoort 77 | 78 | Permission is hereby granted, free of charge, to any person obtaining a copy 79 | of this software and associated documentation files (the "Software"), to deal 80 | in the Software without restriction, including without limitation the rights 81 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 82 | copies of the Software, and to permit persons to whom the Software is 83 | furnished to do so, subject to the following conditions: 84 | 85 | The above copyright notice and this permission notice shall be included in all 86 | copies or substantial portions of the Software. 87 | 88 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 89 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 90 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 91 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 92 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 93 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 94 | SOFTWARE. 95 | 96 | MIT License 97 | 98 | Copyright (c) 2019 Yuki-Chai 99 | 100 | Permission is hereby granted, free of charge, to any person obtaining a copy 101 | of this software and associated documentation files (the "Software"), to deal 102 | in the Software without restriction, including without limitation the rights 103 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 104 | copies of the Software, and to permit persons to whom the Software is 105 | furnished to do so, subject to the following conditions: 106 | 107 | The above copyright notice and this permission notice shall be included in all 108 | copies or substantial portions of the Software. 109 | 110 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 111 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 112 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 113 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 114 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 115 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 116 | SOFTWARE. 117 | 118 | b. Lasagne 119 | 120 | The MIT License (MIT) 121 | 122 | Copyright (c) 2014-2015 Lasagne contributors 123 | 124 | Lasagne uses a shared copyright model: each contributor holds copyright over 125 | their contributions to Lasagne. The project versioning records all such 126 | contribution and copyright details. 127 | By contributing to the Lasagne repository through pull-request, comment, 128 | or otherwise, the contributor releases their content to the license and 129 | copyright terms herein. 130 | 131 | Permission is hereby granted, free of charge, to any person obtaining a copy 132 | of this software and associated documentation files (the "Software"), to deal 133 | in the Software without restriction, including without limitation the rights 134 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 135 | copies of the Software, and to permit persons to whom the Software is 136 | furnished to do so, subject to the following conditions: 137 | 138 | The above copyright notice and this permission notice shall be included in all 139 | copies or substantial portions of the Software. 140 | 141 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 142 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 143 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 144 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 145 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 146 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 147 | SOFTWARE. 148 | 149 | c. glow-pytorch 150 | 151 | MIT License 152 | 153 | Copyright (c) 2018 Kim Seonghyeon 154 | 155 | Permission is hereby granted, free of charge, to any person obtaining a copy 156 | of this software and associated documentation files (the "Software"), to deal 157 | in the Software without restriction, including without limitation the rights 158 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 159 | copies of the Software, and to permit persons to whom the Software is 160 | furnished to do so, subject to the following conditions: 161 | 162 | The above copyright notice and this permission notice shall be included in all 163 | copies or substantial portions of the Software. 164 | 165 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 166 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 167 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 168 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 169 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 170 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 171 | SOFTWARE. 172 | 173 | d. reversible2 174 | 175 | MIT License 176 | 177 | Copyright (c) 2019 Neuromedical AI Lab Freiburg 178 | 179 | Permission is hereby granted, free of charge, to any person obtaining a copy 180 | of this software and associated documentation files (the "Software"), to deal 181 | in the Software without restriction, including without limitation the rights 182 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 183 | copies of the Software, and to permit persons to whom the Software is 184 | furnished to do so, subject to the following conditions: 185 | 186 | The above copyright notice and this permission notice shall be included in all 187 | copies or substantial portions of the Software. 188 | 189 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 190 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 191 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 192 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 193 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 194 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 195 | SOFTWARE. 196 | 197 | e. tensorflow/tensor2tensor 198 | 199 | Apache License 200 | Version 2.0, January 2004 201 | http://www.apache.org/licenses/ 202 | 203 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 204 | 205 | 1. Definitions. 206 | 207 | "License" shall mean the terms and conditions for use, reproduction, 208 | and distribution as defined by Sections 1 through 9 of this document. 209 | 210 | "Licensor" shall mean the copyright owner or entity authorized by 211 | the copyright owner that is granting the License. 212 | 213 | "Legal Entity" shall mean the union of the acting entity and all 214 | other entities that control, are controlled by, or are under common 215 | control with that entity. For the purposes of this definition, 216 | "control" means (i) the power, direct or indirect, to cause the 217 | direction or management of such entity, whether by contract or 218 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 219 | outstanding shares, or (iii) beneficial ownership of such entity. 220 | 221 | "You" (or "Your") shall mean an individual or Legal Entity 222 | exercising permissions granted by this License. 223 | 224 | "Source" form shall mean the preferred form for making modifications, 225 | including but not limited to software source code, documentation 226 | source, and configuration files. 227 | 228 | "Object" form shall mean any form resulting from mechanical 229 | transformation or translation of a Source form, including but 230 | not limited to compiled object code, generated documentation, 231 | and conversions to other media types. 232 | 233 | "Work" shall mean the work of authorship, whether in Source or 234 | Object form, made available under the License, as indicated by a 235 | copyright notice that is included in or attached to the work 236 | (an example is provided in the Appendix below). 237 | 238 | "Derivative Works" shall mean any work, whether in Source or Object 239 | form, that is based on (or derived from) the Work and for which the 240 | editorial revisions, annotations, elaborations, or other modifications 241 | represent, as a whole, an original work of authorship. For the purposes 242 | of this License, Derivative Works shall not include works that remain 243 | separable from, or merely link (or bind by name) to the interfaces of, 244 | the Work and Derivative Works thereof. 245 | 246 | "Contribution" shall mean any work of authorship, including 247 | the original version of the Work and any modifications or additions 248 | to that Work or Derivative Works thereof, that is intentionally 249 | submitted to Licensor for inclusion in the Work by the copyright owner 250 | or by an individual or Legal Entity authorized to submit on behalf of 251 | the copyright owner. For the purposes of this definition, "submitted" 252 | means any form of electronic, verbal, or written communication sent 253 | to the Licensor or its representatives, including but not limited to 254 | communication on electronic mailing lists, source code control systems, 255 | and issue tracking systems that are managed by, or on behalf of, the 256 | Licensor for the purpose of discussing and improving the Work, but 257 | excluding communication that is conspicuously marked or otherwise 258 | designated in writing by the copyright owner as "Not a Contribution." 259 | 260 | "Contributor" shall mean Licensor and any individual or Legal Entity 261 | on behalf of whom a Contribution has been received by Licensor and 262 | subsequently incorporated within the Work. 263 | 264 | 2. Grant of Copyright License. Subject to the terms and conditions of 265 | this License, each Contributor hereby grants to You a perpetual, 266 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 267 | copyright license to reproduce, prepare Derivative Works of, 268 | publicly display, publicly perform, sublicense, and distribute the 269 | Work and such Derivative Works in Source or Object form. 270 | 271 | 3. Grant of Patent License. Subject to the terms and conditions of 272 | this License, each Contributor hereby grants to You a perpetual, 273 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 274 | (except as stated in this section) patent license to make, have made, 275 | use, offer to sell, sell, import, and otherwise transfer the Work, 276 | where such license applies only to those patent claims licensable 277 | by such Contributor that are necessarily infringed by their 278 | Contribution(s) alone or by combination of their Contribution(s) 279 | with the Work to which such Contribution(s) was submitted. If You 280 | institute patent litigation against any entity (including a 281 | cross-claim or counterclaim in a lawsuit) alleging that the Work 282 | or a Contribution incorporated within the Work constitutes direct 283 | or contributory patent infringement, then any patent licenses 284 | granted to You under this License for that Work shall terminate 285 | as of the date such litigation is filed. 286 | 287 | 4. Redistribution. You may reproduce and distribute copies of the 288 | Work or Derivative Works thereof in any medium, with or without 289 | modifications, and in Source or Object form, provided that You 290 | meet the following conditions: 291 | 292 | (a) You must give any other recipients of the Work or 293 | Derivative Works a copy of this License; and 294 | 295 | (b) You must cause any modified files to carry prominent notices 296 | stating that You changed the files; and 297 | 298 | (c) You must retain, in the Source form of any Derivative Works 299 | that You distribute, all copyright, patent, trademark, and 300 | attribution notices from the Source form of the Work, 301 | excluding those notices that do not pertain to any part of 302 | the Derivative Works; and 303 | 304 | (d) If the Work includes a "NOTICE" text file as part of its 305 | distribution, then any Derivative Works that You distribute must 306 | include a readable copy of the attribution notices contained 307 | within such NOTICE file, excluding those notices that do not 308 | pertain to any part of the Derivative Works, in at least one 309 | of the following places: within a NOTICE text file distributed 310 | as part of the Derivative Works; within the Source form or 311 | documentation, if provided along with the Derivative Works; or, 312 | within a display generated by the Derivative Works, if and 313 | wherever such third-party notices normally appear. The contents 314 | of the NOTICE file are for informational purposes only and 315 | do not modify the License. You may add Your own attribution 316 | notices within Derivative Works that You distribute, alongside 317 | or as an addendum to the NOTICE text from the Work, provided 318 | that such additional attribution notices cannot be construed 319 | as modifying the License. 320 | 321 | You may add Your own copyright statement to Your modifications and 322 | may provide additional or different license terms and conditions 323 | for use, reproduction, or distribution of Your modifications, or 324 | for any such Derivative Works as a whole, provided Your use, 325 | reproduction, and distribution of the Work otherwise complies with 326 | the conditions stated in this License. 327 | 328 | 5. Submission of Contributions. Unless You explicitly state otherwise, 329 | any Contribution intentionally submitted for inclusion in the Work 330 | by You to the Licensor shall be under the terms and conditions of 331 | this License, without any additional terms or conditions. 332 | Notwithstanding the above, nothing herein shall supersede or modify 333 | the terms of any separate license agreement you may have executed 334 | with Licensor regarding such Contributions. 335 | 336 | 6. Trademarks. This License does not grant permission to use the trade 337 | names, trademarks, service marks, or product names of the Licensor, 338 | except as required for reasonable and customary use in describing the 339 | origin of the Work and reproducing the content of the NOTICE file. 340 | 341 | 7. Disclaimer of Warranty. Unless required by applicable law or 342 | agreed to in writing, Licensor provides the Work (and each 343 | Contributor provides its Contributions) on an "AS IS" BASIS, 344 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 345 | implied, including, without limitation, any warranties or conditions 346 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 347 | PARTICULAR PURPOSE. You are solely responsible for determining the 348 | appropriateness of using or redistributing the Work and assume any 349 | risks associated with Your exercise of permissions under this License. 350 | 351 | 8. Limitation of Liability. In no event and under no legal theory, 352 | whether in tort (including negligence), contract, or otherwise, 353 | unless required by applicable law (such as deliberate and grossly 354 | negligent acts) or agreed to in writing, shall any Contributor be 355 | liable to You for damages, including any direct, indirect, special, 356 | incidental, or consequential damages of any character arising as a 357 | result of this License or out of the use or inability to use the 358 | Work (including but not limited to damages for loss of goodwill, 359 | work stoppage, computer failure or malfunction, or any and all 360 | other commercial damages or losses), even if such Contributor 361 | has been advised of the possibility of such damages. 362 | 363 | 9. Accepting Warranty or Additional Liability. While redistributing 364 | the Work or Derivative Works thereof, You may choose to offer, 365 | and charge a fee for, acceptance of support, warranty, indemnity, 366 | or other liability obligations and/or rights consistent with this 367 | License. However, in accepting such obligations, You may act only 368 | on Your own behalf and on Your sole responsibility, not on behalf 369 | of any other Contributor, and only if You agree to indemnify, 370 | defend, and hold each Contributor harmless for any liability 371 | incurred by, or claims asserted against, such Contributor by reason 372 | of your accepting any such warranty or additional liability. 373 | 374 | END OF TERMS AND CONDITIONS 375 | 376 | APPENDIX: How to apply the Apache License to your work. 377 | 378 | To apply the Apache License to your work, attach the following 379 | boilerplate notice, with the fields enclosed by brackets "[]" 380 | replaced with your own identifying information. (Don't include 381 | the brackets!) The text should be enclosed in the appropriate 382 | comment syntax for the file format. We also recommend that a 383 | file or class name and description of purpose be included on the 384 | same "printed page" as the copyright notice for easier 385 | identification within third-party archives. 386 | 387 | Copyright [yyyy] [name of copyright owner] 388 | 389 | Licensed under the Apache License, Version 2.0 (the "License"); 390 | you may not use this file except in compliance with the License. 391 | You may obtain a copy of the License at 392 | 393 | http://www.apache.org/licenses/LICENSE-2.0 394 | 395 | Unless required by applicable law or agreed to in writing, software 396 | distributed under the License is distributed on an "AS IS" BASIS, 397 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 398 | See the License for the specific language governing permissions and 399 | limitations under the License. 400 | 401 | f. outlier-exposure 402 | 403 | Apache License 404 | Version 2.0, January 2004 405 | http://www.apache.org/licenses/ 406 | 407 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 408 | 409 | 1. Definitions. 410 | 411 | "License" shall mean the terms and conditions for use, reproduction, 412 | and distribution as defined by Sections 1 through 9 of this document. 413 | 414 | "Licensor" shall mean the copyright owner or entity authorized by 415 | the copyright owner that is granting the License. 416 | 417 | "Legal Entity" shall mean the union of the acting entity and all 418 | other entities that control, are controlled by, or are under common 419 | control with that entity. For the purposes of this definition, 420 | "control" means (i) the power, direct or indirect, to cause the 421 | direction or management of such entity, whether by contract or 422 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 423 | outstanding shares, or (iii) beneficial ownership of such entity. 424 | 425 | "You" (or "Your") shall mean an individual or Legal Entity 426 | exercising permissions granted by this License. 427 | 428 | "Source" form shall mean the preferred form for making modifications, 429 | including but not limited to software source code, documentation 430 | source, and configuration files. 431 | 432 | "Object" form shall mean any form resulting from mechanical 433 | transformation or translation of a Source form, including but 434 | not limited to compiled object code, generated documentation, 435 | and conversions to other media types. 436 | 437 | "Work" shall mean the work of authorship, whether in Source or 438 | Object form, made available under the License, as indicated by a 439 | copyright notice that is included in or attached to the work 440 | (an example is provided in the Appendix below). 441 | 442 | "Derivative Works" shall mean any work, whether in Source or Object 443 | form, that is based on (or derived from) the Work and for which the 444 | editorial revisions, annotations, elaborations, or other modifications 445 | represent, as a whole, an original work of authorship. For the purposes 446 | of this License, Derivative Works shall not include works that remain 447 | separable from, or merely link (or bind by name) to the interfaces of, 448 | the Work and Derivative Works thereof. 449 | 450 | "Contribution" shall mean any work of authorship, including 451 | the original version of the Work and any modifications or additions 452 | to that Work or Derivative Works thereof, that is intentionally 453 | submitted to Licensor for inclusion in the Work by the copyright owner 454 | or by an individual or Legal Entity authorized to submit on behalf of 455 | the copyright owner. For the purposes of this definition, "submitted" 456 | means any form of electronic, verbal, or written communication sent 457 | to the Licensor or its representatives, including but not limited to 458 | communication on electronic mailing lists, source code control systems, 459 | and issue tracking systems that are managed by, or on behalf of, the 460 | Licensor for the purpose of discussing and improving the Work, but 461 | excluding communication that is conspicuously marked or otherwise 462 | designated in writing by the copyright owner as "Not a Contribution." 463 | 464 | "Contributor" shall mean Licensor and any individual or Legal Entity 465 | on behalf of whom a Contribution has been received by Licensor and 466 | subsequently incorporated within the Work. 467 | 468 | 2. Grant of Copyright License. Subject to the terms and conditions of 469 | this License, each Contributor hereby grants to You a perpetual, 470 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 471 | copyright license to reproduce, prepare Derivative Works of, 472 | publicly display, publicly perform, sublicense, and distribute the 473 | Work and such Derivative Works in Source or Object form. 474 | 475 | 3. Grant of Patent License. Subject to the terms and conditions of 476 | this License, each Contributor hereby grants to You a perpetual, 477 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 478 | (except as stated in this section) patent license to make, have made, 479 | use, offer to sell, sell, import, and otherwise transfer the Work, 480 | where such license applies only to those patent claims licensable 481 | by such Contributor that are necessarily infringed by their 482 | Contribution(s) alone or by combination of their Contribution(s) 483 | with the Work to which such Contribution(s) was submitted. If You 484 | institute patent litigation against any entity (including a 485 | cross-claim or counterclaim in a lawsuit) alleging that the Work 486 | or a Contribution incorporated within the Work constitutes direct 487 | or contributory patent infringement, then any patent licenses 488 | granted to You under this License for that Work shall terminate 489 | as of the date such litigation is filed. 490 | 491 | 4. Redistribution. You may reproduce and distribute copies of the 492 | Work or Derivative Works thereof in any medium, with or without 493 | modifications, and in Source or Object form, provided that You 494 | meet the following conditions: 495 | 496 | (a) You must give any other recipients of the Work or 497 | Derivative Works a copy of this License; and 498 | 499 | (b) You must cause any modified files to carry prominent notices 500 | stating that You changed the files; and 501 | 502 | (c) You must retain, in the Source form of any Derivative Works 503 | that You distribute, all copyright, patent, trademark, and 504 | attribution notices from the Source form of the Work, 505 | excluding those notices that do not pertain to any part of 506 | the Derivative Works; and 507 | 508 | (d) If the Work includes a "NOTICE" text file as part of its 509 | distribution, then any Derivative Works that You distribute must 510 | include a readable copy of the attribution notices contained 511 | within such NOTICE file, excluding those notices that do not 512 | pertain to any part of the Derivative Works, in at least one 513 | of the following places: within a NOTICE text file distributed 514 | as part of the Derivative Works; within the Source form or 515 | documentation, if provided along with the Derivative Works; or, 516 | within a display generated by the Derivative Works, if and 517 | wherever such third-party notices normally appear. The contents 518 | of the NOTICE file are for informational purposes only and 519 | do not modify the License. You may add Your own attribution 520 | notices within Derivative Works that You distribute, alongside 521 | or as an addendum to the NOTICE text from the Work, provided 522 | that such additional attribution notices cannot be construed 523 | as modifying the License. 524 | 525 | You may add Your own copyright statement to Your modifications and 526 | may provide additional or different license terms and conditions 527 | for use, reproduction, or distribution of Your modifications, or 528 | for any such Derivative Works as a whole, provided Your use, 529 | reproduction, and distribution of the Work otherwise complies with 530 | the conditions stated in this License. 531 | 532 | 5. Submission of Contributions. Unless You explicitly state otherwise, 533 | any Contribution intentionally submitted for inclusion in the Work 534 | by You to the Licensor shall be under the terms and conditions of 535 | this License, without any additional terms or conditions. 536 | Notwithstanding the above, nothing herein shall supersede or modify 537 | the terms of any separate license agreement you may have executed 538 | with Licensor regarding such Contributions. 539 | 540 | 6. Trademarks. This License does not grant permission to use the trade 541 | names, trademarks, service marks, or product names of the Licensor, 542 | except as required for reasonable and customary use in describing the 543 | origin of the Work and reproducing the content of the NOTICE file. 544 | 545 | 7. Disclaimer of Warranty. Unless required by applicable law or 546 | agreed to in writing, Licensor provides the Work (and each 547 | Contributor provides its Contributions) on an "AS IS" BASIS, 548 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 549 | implied, including, without limitation, any warranties or conditions 550 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 551 | PARTICULAR PURPOSE. You are solely responsible for determining the 552 | appropriateness of using or redistributing the Work and assume any 553 | risks associated with Your exercise of permissions under this License. 554 | 555 | 8. Limitation of Liability. In no event and under no legal theory, 556 | whether in tort (including negligence), contract, or otherwise, 557 | unless required by applicable law (such as deliberate and grossly 558 | negligent acts) or agreed to in writing, shall any Contributor be 559 | liable to You for damages, including any direct, indirect, special, 560 | incidental, or consequential damages of any character arising as a 561 | result of this License or out of the use or inability to use the 562 | Work (including but not limited to damages for loss of goodwill, 563 | work stoppage, computer failure or malfunction, or any and all 564 | other commercial damages or losses), even if such Contributor 565 | has been advised of the possibility of such damages. 566 | 567 | 9. Accepting Warranty or Additional Liability. While redistributing 568 | the Work or Derivative Works thereof, You may choose to offer, 569 | and charge a fee for, acceptance of support, warranty, indemnity, 570 | or other liability obligations and/or rights consistent with this 571 | License. However, in accepting such obligations, You may act only 572 | on Your own behalf and on Your sole responsibility, not on behalf 573 | of any other Contributor, and only if You agree to indemnify, 574 | defend, and hold each Contributor harmless for any liability 575 | incurred by, or claims asserted against, such Contributor by reason 576 | of your accepting any such warranty or additional liability. 577 | 578 | END OF TERMS AND CONDITIONS 579 | 580 | APPENDIX: How to apply the Apache License to your work. 581 | 582 | To apply the Apache License to your work, attach the following 583 | boilerplate notice, with the fields enclosed by brackets "[]" 584 | replaced with your own identifying information. (Don't include 585 | the brackets!) The text should be enclosed in the appropriate 586 | comment syntax for the file format. We also recommend that a 587 | file or class name and description of purpose be included on the 588 | same "printed page" as the copyright notice for easier 589 | identification within third-party archives. 590 | 591 | Copyright [yyyy] [name of copyright owner] 592 | 593 | Licensed under the Apache License, Version 2.0 (the "License"); 594 | you may not use this file except in compliance with the License. 595 | You may obtain a copy of the License at 596 | 597 | http://www.apache.org/licenses/LICENSE-2.0 598 | 599 | Unless required by applicable law or agreed to in writing, software 600 | distributed under the License is distributed on an "AS IS" BASIS, 601 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 602 | See the License for the specific language governing permissions and 603 | limitations under the License. 604 | 605 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hierarchical_anomaly_detection 2 | 3 | Pytorch implementation of the NeurIPS 2020 paper [Understanding anomaly detection with deep invertible networks through hierarchies of distributions and features](https://proceedings.neurips.cc/paper/2020/hash/f106b7f99d2cb30c3db1c3cc0fde9ccb-Abstract.html). The code allows the users to reproduce and extend the results reported in the study. Please cite the above paper when reporting, reproducing or extending the results. 4 | 5 | ## Purpose of the project 6 | 7 | This software is a research prototype, solely developed for and published as part of the publication. It will neither be maintained nor monitored in any way. 8 | 9 | ## Requirements 10 | 11 | This is a Python3 codebase. 12 | 13 | You will need some libraries: 14 | - Pytorch 15 | - ignite 16 | - tensorboardX 17 | - numpy 18 | - scipy 19 | - scikit-learn 20 | - torchvision 21 | - tqdm 22 | - opencv (for Serra replication) 23 | 24 | You also need to add this folder to your python path 25 | 26 | ## Data 27 | 28 | You first have to set folder locations in invglow/folder_locations.py and download 80 Million Tiny Images, LSUN etc. 29 | 30 | Also copy the supplied files 80mn_cifar_idxs.txt and cifar_indexes file to your tiny images folder (only needed if you want to test excluding cifar from tiny which we did not do in the main manuscript) 31 | 32 | Then you need to run python invglow/create_tiny.py to create the tiny dataset 33 | 34 | ## Structure 35 | 36 | invglow folder contains code for invertible network experiments. 37 | 38 | main.py shows some examples of how code should be used to obtain results in manuscript 39 | 40 | main.py shows one example of how code is run, you first will need to create a tiny model, and then use saved model this folder to further finetune on other datasets, similar to the invertible network main.py logic 41 | 42 | ## Pretrained models 43 | 44 | We provide some pretrained models at: 45 | 46 | https://osf.io/ces72/?view_only=cc58b057ac084d25862b2f5f7fc056df 47 | 48 | ## License 49 | 50 | hierarchical_anomaly_detection is open-sourced under the AGPL-3.0 license. See the [LICENSE](LICENSE) file for details. 51 | 52 | For a list of other open source components included in hierarchical_anomaly_detection, see the file [3rd-party-licenses.txt](3rd-party-licenses.txt). 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /SerraReplicationCode/ReferenceGlowVsDirectPng.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Installation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "You need:\n", 15 | "\n", 16 | "* PyTorch\n", 17 | "* Torchvision\n", 18 | "* Scikit-Learn\n", 19 | "* OpenCV\n", 20 | "* Clone https://github.com/y0ast/Glow-PyTorch/tree/181daaffcd0f3561f08c32d5b3846874bcc0481a" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "## Define your folders" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 1, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import os\n", 37 | "\n", 38 | "# Set your folder where you cloned the glow-repo linked above\n", 39 | "glow_code_folder = os.path.join(os.environ['HOME'], 'code/glow-do-deep/glow_do_deep/')\n", 40 | "\n", 41 | "# Set your folder where you downloaded pretrained glow model from \n", 42 | "# https://github.com/y0ast/Glow-PyTorch\n", 43 | "# http://www.cs.ox.ac.uk/people/joost.vanamersfoort/glow.zip\n", 44 | "output_folder = os.path.join(os.environ['HOME'], 'code/glow-do-deep/glow/')\n", 45 | "\n", 46 | "\n", 47 | "# Set here path to your PyTorch-CIFAR10/SVHN datasets\n", 48 | "cifar10_path = os.path.join(os.environ['HOME'], 'data/pytorch-datasets/data/CIFAR10/')\n", 49 | "svhn_path = os.path.join(os.environ['HOME'], 'data/pytorch-datasets/data/SVHN/')\n" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## Some imports" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stderr", 66 | "output_type": "stream", 67 | "text": [ 68 | "..anonymized/ipykernel_launcher.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", 69 | " # This is added back by InteractiveShellApp.init_path()\n" 70 | ] 71 | } 72 | ], 73 | "source": [ 74 | "import torch\n", 75 | "from torchvision import datasets\n", 76 | "import numpy as np\n", 77 | "torch.backends.cudnn.benchmark = True\n", 78 | "\n", 79 | "import json\n", 80 | "\n", 81 | "# Load tqdm if available for progress bar\n", 82 | "# otherwise just no progress bar\n", 83 | "try:\n", 84 | " from tqdm.autonotebook import tqdm\n", 85 | "except ModuleNotFoundError:\n", 86 | " def tqdm(x):\n", 87 | " return x\n" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "## Load pretrained model" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 3, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "{'K': 32, 'L': 3, 'LU_decomposed': True, 'actnorm_scale': 1.0, 'augment': True, 'batch_size': 64, 'cuda': True, 'dataroot': './', 'dataset': 'cifar10', 'download': False, 'epochs': 1500, 'eval_batch_size': 512, 'flow_coupling': 'affine', 'flow_permutation': 'invconv', 'fresh': True, 'hidden_channels': 512, 'learn_top': True, 'lr': 0.0005, 'max_grad_clip': 0, 'max_grad_norm': 0, 'n_init_batches': 8, 'n_workers': 6, 'output_dir': 'output/', 'saved_model': '', 'saved_optimizer': '', 'seed': 0, 'warmup_steps': 4000, 'y_condition': False, 'y_weight': 0.01}\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "os.chdir(glow_code_folder)\n", 112 | "from model import Glow\n", 113 | "model_name = 'glow_affine_coupling.pt'\n", 114 | "\n", 115 | "with open(output_folder + 'hparams.json') as json_file: \n", 116 | " hparams = json.load(json_file)\n", 117 | " \n", 118 | "print(hparams)\n", 119 | "image_shape = (32,32,3)\n", 120 | "num_classes = 10\n", 121 | "model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'],\n", 122 | " hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes,\n", 123 | " hparams['learn_top'], hparams['y_condition'])\n", 124 | "\n", 125 | "model.load_state_dict(torch.load(output_folder + model_name))\n", 126 | "model.set_actnorm_init()\n", 127 | "\n", 128 | "device = torch.device(\"cuda\")\n", 129 | "model = model.to(device)\n", 130 | "\n", 131 | "model = model.eval()" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "## Load datasets\n", 139 | "\n", 140 | "Set `download=True` in case you don't have them yet." 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 4, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "\n", 150 | "test_svhn = datasets.SVHN(\n", 151 | " svhn_path,\n", 152 | " split='test',\n", 153 | " download=False)\n", 154 | "\n", 155 | "\n", 156 | "test_cifar10 = datasets.CIFAR10(\n", 157 | " cifar10_path,\n", 158 | " train=False,\n", 159 | " download=False)\n", 160 | "\n", 161 | "pytorch_datasets = dict(test_cifar10=test_cifar10,\n", 162 | " test_svhn=test_svhn)\n", 163 | "\n", 164 | "np_arrays = dict()\n", 165 | "loaders = dict()\n", 166 | "for name, dataset in pytorch_datasets.items():\n", 167 | " np_arr = np.stack([np.array(x) for x,y in dataset])\n", 168 | " np_arrays[name] = np_arr\n", 169 | " # Ensure we are working on exactly same data.\n", 170 | " loader = torch.utils.data.DataLoader(\n", 171 | " torch.utils.data.TensorDataset(torch.Tensor(np_arr)),\n", 172 | " batch_size=512, drop_last=False)\n", 173 | " loaders[name] = loader\n", 174 | "\n", 175 | " " 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "## Compute PNG BPDs" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 5, 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "data": { 192 | "application/vnd.jupyter.widget-view+json": { 193 | "model_id": "782ba2dc77244423b208b996a8830e04", 194 | "version_major": 2, 195 | "version_minor": 0 196 | }, 197 | "text/plain": [ 198 | "HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))" 199 | ] 200 | }, 201 | "metadata": {}, 202 | "output_type": "display_data" 203 | }, 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "\n" 209 | ] 210 | }, 211 | { 212 | "data": { 213 | "application/vnd.jupyter.widget-view+json": { 214 | "model_id": "6071a32f4184443b90ae9c8ee110c81e", 215 | "version_major": 2, 216 | "version_minor": 0 217 | }, 218 | "text/plain": [ 219 | "HBox(children=(IntProgress(value=0, max=26032), HTML(value='')))" 220 | ] 221 | }, 222 | "metadata": {}, 223 | "output_type": "display_data" 224 | }, 225 | { 226 | "name": "stdout", 227 | "output_type": "stream", 228 | "text": [ 229 | "\n" 230 | ] 231 | } 232 | ], 233 | "source": [ 234 | "import cv2\n", 235 | "def create_png_bpds(np_im_arr,):\n", 236 | " all_bpds = []\n", 237 | " for i_file, a_x in enumerate(tqdm(np_im_arr)):\n", 238 | " # This code was written using an author reply to our mails\n", 239 | " # Use highest compression level (9)\n", 240 | " img_encoded = cv2.imencode('.png', a_x, [int(cv2.IMWRITE_PNG_COMPRESSION),9])[1]\n", 241 | " assert img_encoded.shape[1] == 1\n", 242 | " all_bpds.append((len(img_encoded) * 8)/np.prod(a_x.shape))\n", 243 | " return all_bpds\n", 244 | "\n", 245 | "\n", 246 | "png_bpds = dict([\n", 247 | " (name, create_png_bpds(np_im_arr))\n", 248 | " for name, np_im_arr in np_arrays.items()])" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": {}, 254 | "source": [ 255 | "## Compute BPDs of Glow Model" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 6, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "def preprocess(x, ):\n", 265 | " # Preprocess from tensor with:\n", 266 | " # dim ordering: B,H,W,C\n", 267 | " # values: 0-255 \n", 268 | " # \n", 269 | " # to tensor with\n", 270 | " # dim ordering: B,C,H,W\n", 271 | " # values: -0.5 to +0.5\n", 272 | " # Follows:\n", 273 | " # https://github.com/tensorflow/tensor2tensor/blob/e48cf23c505565fd63378286d9722a1632f4bef7/tensor2tensor/models/research/glow.py#L78\n", 274 | " x = x.permute(0,3,1,2)\n", 275 | " n_bits = 8\n", 276 | " n_bins = 2**n_bits\n", 277 | " x = x / n_bins - 0.5\n", 278 | " return x.cuda()" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 7, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "def compute_glow_bpds(model, loader):\n", 288 | " all_bpds = []\n", 289 | " for x, in tqdm(loader):\n", 290 | " with torch.no_grad():\n", 291 | " preproced_x = preprocess(x)\n", 292 | " _, bpd, _ = model(preproced_x)\n", 293 | " all_bpds.append(bpd.cpu().numpy())\n", 294 | "\n", 295 | " all_bpds = np.concatenate(all_bpds)\n", 296 | " return all_bpds" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 8, 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "data": { 306 | "application/vnd.jupyter.widget-view+json": { 307 | "model_id": "001047b95bd6487896f89d8f9ed858c9", 308 | "version_major": 2, 309 | "version_minor": 0 310 | }, 311 | "text/plain": [ 312 | "HBox(children=(IntProgress(value=0, max=20), HTML(value='')))" 313 | ] 314 | }, 315 | "metadata": {}, 316 | "output_type": "display_data" 317 | }, 318 | { 319 | "name": "stdout", 320 | "output_type": "stream", 321 | "text": [ 322 | "\n" 323 | ] 324 | }, 325 | { 326 | "data": { 327 | "application/vnd.jupyter.widget-view+json": { 328 | "model_id": "fa513b78876b422aa834e3395b095576", 329 | "version_major": 2, 330 | "version_minor": 0 331 | }, 332 | "text/plain": [ 333 | "HBox(children=(IntProgress(value=0, max=51), HTML(value='')))" 334 | ] 335 | }, 336 | "metadata": {}, 337 | "output_type": "display_data" 338 | }, 339 | { 340 | "name": "stdout", 341 | "output_type": "stream", 342 | "text": [ 343 | "\n" 344 | ] 345 | } 346 | ], 347 | "source": [ 348 | "glow_bpds = dict([\n", 349 | " (name, compute_glow_bpds(model, loader))\n", 350 | " for name, loader in loaders.items()])" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": {}, 356 | "source": [ 357 | "## Results" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 9, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "from sklearn.metrics import roc_auc_score\n", 367 | "def compute_auc_for_s_scores(scores_ood, scores_itd):\n", 368 | " # Assumes scores a should be higher\n", 369 | " auc = roc_auc_score(\n", 370 | " np.concatenate((np.ones_like(scores_ood),\n", 371 | " np.zeros_like(scores_itd)),\n", 372 | " axis=0),\n", 373 | " np.concatenate((scores_ood,\n", 374 | " scores_itd,),\n", 375 | " axis=0))\n", 376 | " return auc" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "We reach substantially different values for S-Score (78.4% vs 95.0%), see https://arxiv.org/pdf/1909.11480.pdf Supplementary D, p.14, Table 6." 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 10, 389 | "metadata": {}, 390 | "outputs": [ 391 | { 392 | "data": { 393 | "text/plain": [ 394 | "0.7837171385218193" 395 | ] 396 | }, 397 | "execution_count": 10, 398 | "metadata": {}, 399 | "output_type": "execute_result" 400 | } 401 | ], 402 | "source": [ 403 | "s_score_cifar10 = glow_bpds['test_cifar10'] - png_bpds['test_cifar10']\n", 404 | "s_score_svhn = glow_bpds['test_svhn'] - png_bpds['test_svhn']\n", 405 | "\n", 406 | "compute_auc_for_s_scores(s_score_svhn, s_score_cifar10)" 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": {}, 412 | "source": [ 413 | "We reach similar numbers for PNG only (7.7% in paper vs 7.9% here)" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 11, 419 | "metadata": {}, 420 | "outputs": [ 421 | { 422 | "data": { 423 | "text/plain": [ 424 | "0.07867029233251382" 425 | ] 426 | }, 427 | "execution_count": 11, 428 | "metadata": {}, 429 | "output_type": "execute_result" 430 | } 431 | ], 432 | "source": [ 433 | "# We reach si\n", 434 | "compute_auc_for_s_scores(png_bpds['test_svhn'], png_bpds['test_cifar10'])" 435 | ] 436 | } 437 | ], 438 | "metadata": { 439 | "kernelspec": { 440 | "display_name": "Python 3", 441 | "language": "python", 442 | "name": "python3" 443 | }, 444 | "language_info": { 445 | "codemirror_mode": { 446 | "name": "ipython", 447 | "version": 3 448 | }, 449 | "file_extension": ".py", 450 | "mimetype": "text/x-python", 451 | "name": "python", 452 | "nbconvert_exporter": "python", 453 | "pygments_lexer": "ipython3", 454 | "version": "3.7.5" 455 | } 456 | }, 457 | "nbformat": 4, 458 | "nbformat_minor": 4 459 | } 460 | -------------------------------------------------------------------------------- /SerraReplicationCode/ReferenceGlowVsDirectPng.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # Installation 5 | 6 | # You need: 7 | # 8 | # * PyTorch 9 | # * Torchvision 10 | # * Scikit-Learn 11 | # * OpenCV 12 | # * Clone https://github.com/y0ast/Glow-PyTorch/tree/181daaffcd0f3561f08c32d5b3846874bcc0481a 13 | 14 | # ## Define your folders 15 | 16 | # In[1]: 17 | 18 | 19 | import os 20 | 21 | # Set your folder where you cloned the glow-repo linked above 22 | glow_code_folder = os.path.join(os.environ['HOME'], 'code/glow-do-deep/glow_do_deep/') 23 | 24 | # Set your folder where you downloaded pretrained glow model from 25 | # https://github.com/y0ast/Glow-PyTorch 26 | # http://www.cs.ox.ac.uk/people/joost.vanamersfoort/glow.zip 27 | output_folder = os.path.join(os.environ['HOME'], 'code/glow-do-deep/glow/') 28 | 29 | 30 | # Set here path to your PyTorch-CIFAR10/SVHN datasets 31 | cifar10_path = os.path.join(os.environ['HOME'], 'data/pytorch-datasets/data/CIFAR10/') 32 | svhn_path = os.path.join(os.environ['HOME'], 'data/pytorch-datasets/data/SVHN/') 33 | 34 | 35 | # ## Some imports 36 | 37 | # In[2]: 38 | 39 | 40 | import torch 41 | from torchvision import datasets 42 | import numpy as np 43 | torch.backends.cudnn.benchmark = True 44 | 45 | import json 46 | 47 | # Load tqdm if available for progress bar 48 | # otherwise just no progress bar 49 | try: 50 | from tqdm.autonotebook import tqdm 51 | except ModuleNotFoundError: 52 | def tqdm(x): 53 | return x 54 | 55 | 56 | # ## Load pretrained model 57 | 58 | # In[3]: 59 | 60 | 61 | os.chdir(glow_code_folder) 62 | from model import Glow 63 | model_name = 'glow_affine_coupling.pt' 64 | 65 | with open(output_folder + 'hparams.json') as json_file: 66 | hparams = json.load(json_file) 67 | 68 | print(hparams) 69 | image_shape = (32,32,3) 70 | num_classes = 10 71 | model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], 72 | hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes, 73 | hparams['learn_top'], hparams['y_condition']) 74 | 75 | model.load_state_dict(torch.load(output_folder + model_name)) 76 | model.set_actnorm_init() 77 | 78 | device = torch.device("cuda") 79 | model = model.to(device) 80 | 81 | model = model.eval() 82 | 83 | 84 | # ## Load datasets 85 | # 86 | # Set `download=True` in case you don't have them yet. 87 | 88 | # In[4]: 89 | 90 | 91 | 92 | test_svhn = datasets.SVHN( 93 | svhn_path, 94 | split='test', 95 | download=False) 96 | 97 | 98 | test_cifar10 = datasets.CIFAR10( 99 | cifar10_path, 100 | train=False, 101 | download=False) 102 | 103 | pytorch_datasets = dict(test_cifar10=test_cifar10, 104 | test_svhn=test_svhn) 105 | 106 | np_arrays = dict() 107 | loaders = dict() 108 | for name, dataset in pytorch_datasets.items(): 109 | np_arr = np.stack([np.array(x) for x,y in dataset]) 110 | np_arrays[name] = np_arr 111 | # Ensure we are working on exactly same data. 112 | loader = torch.utils.data.DataLoader( 113 | torch.utils.data.TensorDataset(torch.Tensor(np_arr)), 114 | batch_size=512, drop_last=False) 115 | loaders[name] = loader 116 | 117 | 118 | 119 | 120 | # ## Compute PNG BPDs 121 | 122 | # In[5]: 123 | 124 | 125 | import cv2 126 | def create_png_bpds(np_im_arr,): 127 | all_bpds = [] 128 | for i_file, a_x in enumerate(tqdm(np_im_arr)): 129 | # This code was written using an author reply to our mails 130 | # Use highest compression level (9) 131 | img_encoded = cv2.imencode('.png', a_x, [int(cv2.IMWRITE_PNG_COMPRESSION),9])[1] 132 | assert img_encoded.shape[1] == 1 133 | all_bpds.append((len(img_encoded) * 8)/np.prod(a_x.shape)) 134 | return all_bpds 135 | 136 | 137 | png_bpds = dict([ 138 | (name, create_png_bpds(np_im_arr)) 139 | for name, np_im_arr in np_arrays.items()]) 140 | 141 | 142 | # ## Compute BPDs of Glow Model 143 | 144 | # In[6]: 145 | 146 | 147 | def preprocess(x, ): 148 | # Preprocess from tensor with: 149 | # dim ordering: B,H,W,C 150 | # values: 0-255 151 | # 152 | # to tensor with 153 | # dim ordering: B,C,H,W 154 | # values: -0.5 to +0.5 155 | # Follows: 156 | # https://github.com/tensorflow/tensor2tensor/blob/e48cf23c505565fd63378286d9722a1632f4bef7/tensor2tensor/models/research/glow.py#L78 157 | x = x.permute(0,3,1,2) 158 | n_bits = 8 159 | n_bins = 2**n_bits 160 | x = x / n_bins - 0.5 161 | return x.cuda() 162 | 163 | 164 | # In[7]: 165 | 166 | 167 | def compute_glow_bpds(model, loader): 168 | all_bpds = [] 169 | for x, in tqdm(loader): 170 | with torch.no_grad(): 171 | preproced_x = preprocess(x) 172 | _, bpd, _ = model(preproced_x) 173 | all_bpds.append(bpd.cpu().numpy()) 174 | 175 | all_bpds = np.concatenate(all_bpds) 176 | return all_bpds 177 | 178 | 179 | # In[8]: 180 | 181 | 182 | glow_bpds = dict([ 183 | (name, compute_glow_bpds(model, loader)) 184 | for name, loader in loaders.items()]) 185 | 186 | 187 | # ## Results 188 | 189 | # In[9]: 190 | 191 | 192 | from sklearn.metrics import roc_auc_score 193 | def compute_auc_for_s_scores(scores_ood, scores_itd): 194 | # Assumes scores a should be higher 195 | auc = roc_auc_score( 196 | np.concatenate((np.ones_like(scores_ood), 197 | np.zeros_like(scores_itd)), 198 | axis=0), 199 | np.concatenate((scores_ood, 200 | scores_itd,), 201 | axis=0)) 202 | return auc 203 | 204 | 205 | # We reach substantially different values for S-Score (78.4% vs 95.0%), see https://arxiv.org/pdf/1909.11480.pdf Supplementary D, p.14, Table 6. 206 | 207 | # In[10]: 208 | 209 | 210 | s_score_cifar10 = glow_bpds['test_cifar10'] - png_bpds['test_cifar10'] 211 | s_score_svhn = glow_bpds['test_svhn'] - png_bpds['test_svhn'] 212 | 213 | compute_auc_for_s_scores(s_score_svhn, s_score_cifar10) 214 | 215 | 216 | # We reach similar numbers for PNG only (7.7% in paper vs 7.9% here) 217 | 218 | # In[11]: 219 | 220 | 221 | # We reach si 222 | compute_auc_for_s_scores(png_bpds['test_svhn'], png_bpds['test_cifar10']) 223 | 224 | -------------------------------------------------------------------------------- /invglow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boschresearch/hierarchical_anomaly_detection/ca2f1d84615c2ef140a74f4e1515352abff9e938/invglow/__init__.py -------------------------------------------------------------------------------- /invglow/create_tiny.py: -------------------------------------------------------------------------------- 1 | from invglow import folder_locations 2 | import numpy as np 3 | from scipy.io import loadmat 4 | from tqdm import tqdm 5 | import os 6 | from collections import OrderedDict 7 | 8 | 9 | 10 | data_file = open(os.path.join(folder_locations.tiny_data, 'tiny_images.bin'), "rb") 11 | 12 | def load_image(idx): 13 | data_file.seek(idx * 3072) 14 | data = data_file.read(3072) 15 | return np.frombuffer(data, dtype='uint8').reshape(32, 32, 3, order="F") 16 | 17 | cifar10_idxs = [int(l) - 1 for l in open(os.path.join(folder_locations.tiny_data, '80mn_cifar_idxs.txt'), 'r').readlines()] 18 | cifar100_idxs = [int(l) - 1 for l in open(os.path.join(folder_locations.tiny_data, 'cifar_indexes'), 'r').readlines() 19 | if int(l) != 0] 20 | 21 | 22 | 23 | n_total_images = 79302017 24 | metadata = loadmat(os.path.join(folder_locations.tiny_data, 'tiny_index.mat')) 25 | words = [a[0] for a in metadata['word'][0]] 26 | num_imgs = metadata['num_imgs'][0] 27 | offset = metadata['offset'][0] 28 | w_to_n = OrderedDict(zip(words, num_imgs)) 29 | w_to_o = OrderedDict(zip(words, offset)) 30 | number_indices = list(range(w_to_o['number'], w_to_o['number'] + w_to_n['offset'])) 31 | 32 | batch_starts = np.arange(0, n_total_images, 64)[:-1] # last one is an incomplete batch, drop it 33 | 34 | rng = np.random.RandomState(20191203) 35 | i_rand_starts = rng.choice(batch_starts, len(batch_starts), replace=False) 36 | 37 | this_i_rand_starts = i_rand_starts[:24415] 38 | 39 | 40 | 41 | folder = os.path.join(folder_locations.tiny_data, 'chunks/') 42 | for i_start in tqdm(this_i_rand_starts): 43 | arrs = [load_image(i) for i in range(i_start, i_start+64)] 44 | np.save(os.path.join(folder, f"{i_start:08d}_{i_start+64:08d}.all.npy"), np.stack(arrs)) 45 | 46 | # Recheck 47 | 48 | rng = np.random.RandomState(20191203) 49 | i_rand_starts = rng.choice(batch_starts, len(batch_starts), replace=False) 50 | 51 | this_i_rand_starts = i_rand_starts[:24415] 52 | 53 | for i_start in tqdm(this_i_rand_starts): 54 | filename = os.path.join(folder, f"{i_start:08d}_{i_start+64:08d}.all.npy") 55 | assert os.path.exists(filename) 56 | 57 | 58 | # no-cifar set 59 | rng = np.random.RandomState(20191203) 60 | i_rand_starts = rng.choice(batch_starts, len(batch_starts), replace=False) 61 | 62 | all_cifar_idxs = np.sort(cifar10_idxs + cifar100_idxs) 63 | set_cifar = set(all_cifar_idxs) 64 | i_selected_rand_starts = [] 65 | for i_start in tqdm(i_rand_starts): 66 | if not any([i in set_cifar for i in range(i_start, i_start+ 64)]): 67 | i_selected_rand_starts.append(i_start) 68 | this_i_rand_starts = i_selected_rand_starts[:24415] 69 | folder = os.path.join(folder_locations.tiny_data, 'chunks/') 70 | for i_start in tqdm(this_i_rand_starts): 71 | arrs = [load_image(i) for i in range(i_start, i_start+64)] 72 | np.save(os.path.join(folder, f"{i_start:08d}_{i_start+64:08d}.exclude_cifar.npy"), np.stack(arrs)) 73 | -------------------------------------------------------------------------------- /invglow/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import torch as th 6 | from sklearn.metrics import roc_auc_score, average_precision_score 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | 10 | from invglow.datasets import load_tiny_imagenet, load_train_test, LSUN, load_celeb_a 11 | from invglow.invertible.graph import IntermediateResultsNode 12 | from invglow.invertible.graph import get_nodes_by_names 13 | from invglow.models.class_conditional import \ 14 | convert_class_model_for_multi_scale_nll 15 | from invglow.util import var_to_np, set_random_seeds 16 | from invglow.datasets import PreprocessedLoader 17 | from invglow.invertible.expression import Expression 18 | from invglow import folder_locations 19 | 20 | log = logging.getLogger(__name__) 21 | 22 | 23 | def get_nlls(loader, wanted_nodes, node_names): 24 | rs = [] 25 | with th.no_grad(): 26 | for x, y in tqdm(loader): 27 | outs = IntermediateResultsNode(wanted_nodes)( 28 | x.cuda(),fixed=dict(y=None)) 29 | lps = outs[1] 30 | fixed_lps = [] 31 | for lp in lps: 32 | if len(lp.shape) > 1: 33 | assert len(lp.shape) == 2 34 | n_components = lp.shape[1] 35 | lp = th.logsumexp(lp, dim=1) - np.log(n_components) 36 | fixed_lps.append(lp) 37 | node_to_lps = dict(zip(node_names, fixed_lps)) 38 | lp0 = node_to_lps['m0-flow-0'] / 2 + node_to_lps['m0-dist-0'] 39 | lp1 = node_to_lps['m0-flow-1'] / 2 + node_to_lps['m0-dist-1'] - \ 40 | node_to_lps['m0-flow-0'] / 2 41 | lp2 = node_to_lps['m0-dist-2'] - node_to_lps['m0-flow-1'] / 2 42 | lpz0 = node_to_lps['m0-dist-0'] - node_to_lps['m0-act-0'] 43 | lpz1 = node_to_lps['m0-dist-1'] - node_to_lps['m0-act-1'] 44 | lpz2 = node_to_lps['m0-dist-2'] - node_to_lps['m0-act-2'] 45 | lp0 = lp0.cpu().numpy() 46 | lp1 = lp1.cpu().numpy() 47 | lp2 = lp2.cpu().numpy() 48 | lpz0 = lpz0.cpu().numpy() 49 | lpz1 = lpz1.cpu().numpy() 50 | lpz2 = lpz2.cpu().numpy() 51 | lprob = lp0 + lp1 + lp2 52 | lprobz = lpz0+lpz1+lpz2 53 | bpd = np.log2(256) - ((lprob / np.log(2)) / np.prod(x.shape[1:])) 54 | rs.append(dict(lp0=lp0, lp1=lp1, lp2=lp2, lprob=lprob, bpd=bpd, 55 | lpz0=lpz0, lpz1=lpz1, lpz2=lpz2, lprobz=lprobz)) 56 | full_r = {} 57 | for key in rs[0].keys(): 58 | full_r[key] = np.concatenate([r[key] for r in rs]) 59 | return full_r 60 | 61 | 62 | def get_nlls_only_final(loader, model): 63 | rs = [] 64 | with th.no_grad(): 65 | for x, y in tqdm(loader): 66 | _, lp = model(x.cuda(), fixed=dict(y=None)) 67 | lprob = var_to_np(lp) 68 | bpd = np.log2(256) - ((lprob / np.log(2)) / np.prod(x.shape[1:])) 69 | rs.append(dict(lprob=lprob, bpd=bpd,)) 70 | full_r = {} 71 | for key in rs[0].keys(): 72 | full_r[key] = np.concatenate([r[key] for r in rs]) 73 | return full_r 74 | 75 | 76 | 77 | def get_rgb_loaders(first_n=None): 78 | train_cifar10, test_cifar10 = load_train_test('cifar10', shuffle_train=False, 79 | drop_last_train=False, 80 | batch_size=512, 81 | eval_batch_size=512, 82 | n_workers=6, 83 | first_n=first_n, 84 | augment=False, 85 | exclude_cifar_from_tiny=None,) 86 | 87 | train_svhn, test_svhn = load_train_test('svhn', shuffle_train=False, 88 | drop_last_train=False, 89 | batch_size=512, eval_batch_size=512, 90 | n_workers=6, 91 | first_n=first_n, 92 | augment=False, 93 | exclude_cifar_from_tiny=None,) 94 | 95 | categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', 96 | 'conference_room', 'dining_room', 'kitchen', 97 | 'living_room', 'restaurant', 'tower'] 98 | 99 | def final_preproc_lsun(x): 100 | # make same as cifar tiny etc. 101 | return (x * 255 / 256) - 0.5 102 | 103 | lsun_set = LSUN(folder_locations.lsun_data, 104 | classes=[c + '_val' for c in categories], 105 | transform=transforms.Compose([ 106 | transforms.Resize(32), 107 | transforms.CenterCrop(32), 108 | transforms.ToTensor(), 109 | final_preproc_lsun])) 110 | 111 | test_lsun = th.utils.data.DataLoader(lsun_set, batch_size=512, 112 | num_workers=0) 113 | train_cifar100, test_cifar100 = load_train_test('cifar100', 114 | shuffle_train=False, 115 | drop_last_train=False, 116 | batch_size=512, 117 | eval_batch_size=512, 118 | n_workers=6, 119 | first_n=first_n, 120 | augment=False, 121 | exclude_cifar_from_tiny=None,) 122 | #train_tiny, test_tiny = load_train_test( 123 | # 'tiny', 124 | # shuffle_train=False, 125 | # drop_last_train=False, 126 | # batch_size=512, eval_batch_size=512, 127 | # n_workers=6, 128 | # first_n=120000 if first_n is None else first_n, 129 | # augment=False, 130 | # exclude_cifar_from_tiny=False, 131 | # shuffle_tiny_chunks=False) 132 | train_celeba, test_celeba = load_celeb_a(shuffle_train=False, 133 | drop_last_train=False, 134 | batch_size=512, 135 | eval_batch_size=512, 136 | n_workers=6, 137 | first_n=first_n, ) 138 | train_tiny_imagenet, test_tiny_imagenet = load_tiny_imagenet( 139 | shuffle_train=False, drop_last_train=False, 140 | batch_size=512, eval_batch_size=512, 141 | n_workers=6, 142 | first_n=first_n, ) 143 | 144 | 145 | 146 | loaders = dict( 147 | train_cifar10=train_cifar10, 148 | test_cifar10=test_cifar10, 149 | train_svhn=train_svhn, 150 | test_svhn=test_svhn, 151 | #train_tiny=train_tiny, 152 | #test_tiny=test_tiny, 153 | test_lsun=test_lsun, 154 | train_cifar100=train_cifar100, 155 | test_cifar100=test_cifar100, 156 | train_celeba=train_celeba, 157 | test_celeba=test_celeba, 158 | train_tiny_imagenet=train_tiny_imagenet, 159 | test_tiny_imagenet=test_tiny_imagenet, 160 | ) 161 | return loaders 162 | 163 | 164 | def get_grey_loaders(first_n): 165 | train_mnist, test_mnist = load_train_test('mnist', 166 | shuffle_train=False, 167 | drop_last_train=False, 168 | batch_size=512, 169 | eval_batch_size=512, 170 | n_workers=6, 171 | first_n=first_n, 172 | augment=False, 173 | exclude_cifar_from_tiny=None,) 174 | train_fashion_mnist, test_fashion_mnist = load_train_test('fashion-mnist', 175 | shuffle_train=False, 176 | drop_last_train=False, 177 | batch_size=512, 178 | eval_batch_size=512, 179 | n_workers=6, 180 | first_n=first_n, 181 | augment=False, 182 | exclude_cifar_from_tiny=None,) 183 | 184 | #train_tiny, test_tiny = load_train_test( 185 | # 'tiny', 186 | # shuffle_train=False, 187 | # drop_last_train=False, 188 | # batch_size=512, eval_batch_size=512, 189 | # n_workers=6, 190 | # first_n=120000 if first_n is None else first_n, 191 | # augment=False, 192 | # exclude_cifar_from_tiny=False, 193 | # shuffle_tiny_chunks=False, 194 | # tiny_grey=True) 195 | 196 | 197 | loaders = dict( 198 | train_mnist=train_mnist, 199 | test_mnist=test_mnist, 200 | #train_tiny=train_tiny, 201 | #test_tiny=test_tiny, 202 | train_fashion_mnist=train_fashion_mnist, 203 | test_fashion_mnist=test_fashion_mnist, 204 | ) 205 | return loaders 206 | 207 | 208 | def set_non_finite_to(arr, val): 209 | arr = arr.copy() 210 | arr[~np.isfinite(arr)] = val 211 | return arr 212 | 213 | 214 | def evaluate_without_noise(fine_model, base_model, on_top_class, 215 | first_n, noise_factor, 216 | in_dist_name, rgb_or_grey, 217 | only_full_nll, ): 218 | fine_results = _evaluate_without_noise( 219 | fine_model, on_top_class=on_top_class, 220 | first_n=first_n, noise_factor=noise_factor, 221 | rgb_or_grey=rgb_or_grey, 222 | only_full_nll=only_full_nll, ) 223 | base_results = _evaluate_without_noise( 224 | base_model, on_top_class=False, 225 | first_n=first_n, noise_factor=noise_factor, 226 | rgb_or_grey=rgb_or_grey, 227 | only_full_nll=only_full_nll, ) 228 | in_dist_diff = set_non_finite_to( 229 | fine_results['test_' + in_dist_name]['lprob'], 230 | -3000000) - set_non_finite_to( 231 | base_results['test_' + in_dist_name]['lprob'], 232 | -3000000) 233 | if not only_full_nll: 234 | in_dist_4x4 = set_non_finite_to( 235 | fine_results['test_' + in_dist_name]['lp2'], -3000000) 236 | 237 | if rgb_or_grey == 'rgb': 238 | ood_sets = ['cifar10', 'cifar100', 'svhn', 'lsun', 'celeba', 'tiny_imagenet'] 239 | else: 240 | ood_sets = ['fashion_mnist', 'mnist'] 241 | ood_sets = [s for s in ood_sets if s != in_dist_name] 242 | for ood_set in ood_sets: 243 | folds = ('test',) 244 | if ood_set == 'celeba': 245 | folds = ('train', 'test',) 246 | ood_diff = np.concatenate([set_non_finite_to( 247 | fine_results[f'{fold}_{ood_set}']['lprob'], 248 | -3000000) - set_non_finite_to( 249 | base_results[f'{fold}_{ood_set}']['lprob'], 250 | -3000000) for fold in folds]) 251 | auc = compute_auc_for_scores(ood_diff, in_dist_diff) 252 | print(f"{ood_set}: {auc:.1%} ratio AUC") 253 | if not only_full_nll: 254 | ood_4x4 = np.concatenate([set_non_finite_to( 255 | fine_results[f'{fold}_{ood_set}']['lp2'], -3000000) 256 | for fold in folds]) 257 | auc_4x4 = compute_auc_for_scores(ood_4x4, in_dist_4x4) 258 | print(f"{ood_set}: {auc_4x4:.1%} 4x4 AUC") 259 | 260 | 261 | def _evaluate_without_noise(model, on_top_class, 262 | first_n, noise_factor, rgb_or_grey, 263 | only_full_nll, ): 264 | assert rgb_or_grey in ["rgb", "grey"] 265 | if rgb_or_grey == 'rgb': 266 | loaders = get_rgb_loaders(first_n=first_n) 267 | else: 268 | assert rgb_or_grey == 'grey' 269 | loaders = get_grey_loaders(first_n=first_n) 270 | 271 | if on_top_class: 272 | model = convert_class_model_for_multi_scale_nll(model) 273 | 274 | if not only_full_nll: 275 | node_names = ('m0-flow-0', 'm0-act-0', 'm0-dist-0', 276 | 'm0-flow-1', 'm0-act-1', 'm0-dist-1', 277 | 'm0-flow-2', 'm0-act-2', 'm0-dist-2') 278 | try: 279 | wanted_nodes = get_nodes_by_names(model, *node_names) 280 | except: 281 | wanted_nodes = get_nodes_by_names(model.module, *node_names) 282 | 283 | loaders_to_results = {} 284 | for set_name, loader in loaders.items(): 285 | set_random_seeds(20191120, True) 286 | # add half of noise interval 287 | loader = PreprocessedLoader(loader, 288 | Expression(lambda x: x + noise_factor / 2), 289 | to_cuda=True) 290 | if not only_full_nll: 291 | print(set_name) 292 | result = get_nlls(loader, wanted_nodes, node_names) 293 | else: 294 | result = get_nlls_only_final(loader, model) 295 | loaders_to_results[set_name] = result 296 | return loaders_to_results 297 | 298 | 299 | def compute_func_for_sets(name_to_train_test_loaders, model_dist, func): 300 | results_per_set = OrderedDict() 301 | for setname, train_loader, test_loader in name_to_train_test_loaders: 302 | for name, loader in (('train', train_loader), ('test', test_loader)): 303 | if loader is None: continue 304 | results = func(loader, model_dist) 305 | results_per_set[name + '_' + setname] = results 306 | return results_per_set 307 | 308 | 309 | def get_log_dets_probs_for_set(loader, model_dist): 310 | with th.no_grad(): 311 | n_examples = sum([len(x) for x,y in loader]) 312 | n_components = len(model_dist.dist.class_means) 313 | all_log_probs = np.ones((n_examples,n_components)) * np.nan 314 | all_log_dets = np.ones((n_examples,)) * np.nan 315 | i_example = 0 316 | for x,y in loader: 317 | x = x.cuda() 318 | out, log_det = model_dist.model(x) 319 | log_probs_per_class = model_dist.dist.log_probs_per_class(out) 320 | all_log_probs[i_example:i_example + len(x)] = var_to_np(log_probs_per_class) 321 | all_log_dets[i_example:i_example + len(x)] = var_to_np(log_det) 322 | i_example += len(x) 323 | assert not np.any(np.isnan(all_log_probs)) 324 | assert not np.any(np.isnan(all_log_dets)) 325 | return all_log_probs, all_log_dets 326 | 327 | 328 | def get_in_diffs_per_set(loader, model_dist): 329 | n_examples = sum([len(x) for x, y in loader]) 330 | all_diffs = np.ones(n_examples, dtype=np.float32) * np.nan 331 | i_example = 0 332 | for x, y in loader: 333 | out = model_dist.model(x.cuda())[0] 334 | out_perturbation = th.rand_like(out) 335 | out_perturbation = 0.01 * out_perturbation / th.norm(out_perturbation, 336 | p=2, dim=1, 337 | keepdim=True) 338 | out_perturbed = out + out_perturbation 339 | inverted_perturbed = model_dist.model.invert(out_perturbed)[0] 340 | in_diffs = th.norm((x.cuda() - inverted_perturbed).view(x.shape[0], -1), 341 | dim=1, p=2) 342 | assert not np.any(np.isnan(var_to_np(in_diffs))) 343 | all_diffs[i_example:i_example + len(x)] = var_to_np(in_diffs) 344 | i_example += len(x) 345 | assert i_example == len(all_diffs) 346 | assert not np.any(np.isnan(all_diffs)), "nan diff exists" 347 | 348 | return all_diffs 349 | 350 | 351 | def get_out_diffs_per_set(loader, model_dist): 352 | n_examples = sum([len(x) for x, y in loader]) 353 | all_diffs = np.ones(n_examples, dtype=np.float32) * np.nan 354 | i_example = 0 355 | for x, y in loader: 356 | x = x.cuda() 357 | in_perturbation = th.rand_like(x) 358 | in_perturbation = 0.01 * in_perturbation / th.norm( 359 | in_perturbation.view(x.shape[0],-1), 360 | p=2, dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) 361 | in_perturbed = x + in_perturbation 362 | out = model_dist.model(x)[0] 363 | out_perturbed = model_dist.model(in_perturbed)[0] 364 | out_diffs = th.norm((out - out_perturbed).view(x.shape[0], -1), 365 | dim=1, p=2) 366 | assert not np.any(np.isnan(var_to_np(out_diffs))) 367 | all_diffs[i_example:i_example + len(x)] = var_to_np(out_diffs) 368 | i_example += len(x) 369 | assert i_example == len(all_diffs) 370 | assert not np.any(np.isnan(all_diffs)), "nan diff exists" 371 | return all_diffs 372 | 373 | 374 | def compute_auc_for_scores(scores_a, scores_b): 375 | auc = roc_auc_score( 376 | np.concatenate((np.zeros_like(scores_a), 377 | np.ones_like(scores_b)), 378 | axis=0), 379 | np.concatenate((scores_a, 380 | scores_b,), 381 | axis=0)) 382 | return auc 383 | 384 | def compute_aupr_for_scores(scores_a, scores_b): 385 | auc = average_precision_score( 386 | np.concatenate((np.zeros_like(scores_a), 387 | np.ones_like(scores_b)), 388 | axis=0), 389 | np.concatenate((scores_a, 390 | scores_b,), 391 | axis=0)) 392 | return auc 393 | 394 | 395 | def collect_for_dataloader(dataloader, step_fn, show_tqdm=False): 396 | with th.no_grad(): 397 | all_outs = [] 398 | if show_tqdm: 399 | dataloader = tqdm(dataloader) 400 | for batch in dataloader: 401 | outs = step_fn(*batch) 402 | all_outs.append(outs) 403 | return all_outs 404 | 405 | 406 | def collect_for_loaders(name_to_loader,step_fn, show_tqdm=False): 407 | results = {} 408 | for name in name_to_loader:# 409 | result = collect_for_dataloader(name_to_loader[name], step_fn, 410 | show_tqdm=show_tqdm) 411 | try: 412 | result = np.concatenate(result) 413 | except: 414 | pass 415 | results[name] = result 416 | return results 417 | 418 | 419 | def collect_log_dets(loader, model, node_names, use_y=False): 420 | results_model = IntermediateResultsNode(get_nodes_by_names(model, *node_names)) 421 | def get_log_dets(x,y): 422 | # return in examples x modules logic 423 | if use_y: 424 | this_y = y 425 | else: 426 | this_y = None 427 | return np.array([var_to_np(logdet) for logdet in results_model( 428 | x,fixed=dict(y=this_y))[1]]).T 429 | log_dets = collect_for_dataloader(loader, get_log_dets, show_tqdm=True) 430 | log_dets_per_node = np.concatenate(log_dets, axis=0).T#np.array(log_dets).reshape(-1, log_dets[0].shape[-1]).T 431 | # now modules x examples 432 | name_to_log_det = dict(zip(node_names, log_dets_per_node)) 433 | return name_to_log_det 434 | 435 | def collect_log_dets_for_loaders(loaders, model, node_names, use_y=False): 436 | loaders_to_log_dets = dict([(name, collect_log_dets(loader, model, node_names, 437 | use_y=use_y)) 438 | for name, loader in loaders.items()]) 439 | return loaders_to_log_dets 440 | 441 | 442 | def compute_bpds(dataloader, model, use_y, n_batches=None, show_tqdm=True): 443 | 444 | bpds = [] 445 | if show_tqdm: 446 | dataloader = tqdm(dataloader) 447 | for i_batch, (x, y) in enumerate(dataloader): 448 | if not use_y: 449 | y = None 450 | fixed = dict(y=y) 451 | with th.no_grad(): 452 | n_dims = np.prod(x.shape[1:]) 453 | nll = -(model(x, fixed=fixed)[1] - np.log(256) * n_dims) 454 | bpd = nll / (n_dims * np.log(2)) 455 | bpds.append(bpd) 456 | if n_batches is not None and i_batch >= (n_batches - 1): 457 | break 458 | 459 | return th.cat(bpds).cpu() 460 | 461 | 462 | def identity(x): 463 | return x 464 | 465 | 466 | def collect_outputs(loader, model, process_fn=identity, use_y=False): 467 | all_outputs = [] 468 | with th.no_grad(): 469 | for batch in loader: 470 | x = batch[0] 471 | if use_y: 472 | y = batch[1] 473 | else: 474 | y = None 475 | outputs = model(x, fixed=dict(y=y)) 476 | outputs = process_fn(outputs) 477 | all_outputs.append(outputs) 478 | return all_outputs 479 | -------------------------------------------------------------------------------- /invglow/exp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path 3 | from copy import deepcopy 4 | 5 | import numpy as np 6 | import torch as th 7 | from ignite.engine import Engine, Events 8 | from ignite.handlers import ModelCheckpoint 9 | from tensorboardX import SummaryWriter 10 | 11 | from invglow.load_data import load_data 12 | from invglow.evaluate import compute_bpds, compute_auc_for_scores 13 | from invglow.invertible.categorical_mixture import InvertibleClassConditional 14 | from invglow.invertible.branching import CatChans 15 | from invglow.invertible.distribution import MergeLogDets 16 | from invglow.invertible.graph import CatChansNode 17 | from invglow.invertible.graph import Node 18 | from invglow.invertible.graph import get_nodes_by_names 19 | from invglow.invertible.init import init_all_modules 20 | from invglow.invertible.sequential import InvertibleSequential 21 | from invglow.losses import nll_class_loss 22 | from invglow.models.class_conditional import latent_model 23 | from invglow.models.glow import create_glow_model 24 | from invglow.models.patch_glow import create_patch_glow_model 25 | from invglow.scheduler import ScheduledOptimizer 26 | from invglow.util import check_gradients_clear, step_and_clear_gradients 27 | from invglow.util import grads_all_finite 28 | from invglow.util import np_to_var 29 | from invglow.util import set_random_seeds, var_to_np 30 | 31 | log = logging.getLogger(__name__) 32 | 33 | 34 | class BaseFineIndependent(object): 35 | """Compute NLLs from base/general model and fine/specific model, 36 | recompute in case some NLLs are not finite due to numerical instability.""" 37 | def get_nlls(self, model, base_model, x, y, prev_valid_mask=None): 38 | # remember which nlls were finite (not inf or nan) 39 | # then recompute forward for only those examples 40 | valid_mask = np_to_var([True] * len(x), device=x.device) 41 | n_dims = np.prod(x.shape[1:]) 42 | if base_model is not None: 43 | with th.no_grad(): 44 | base_nll = -(base_model( 45 | x, fixed=dict(y=y))[1] - np.log(256) * n_dims) 46 | mask = th.isfinite(base_nll.detach()) 47 | valid_mask = mask & valid_mask 48 | else: 49 | base_nll = None 50 | nll = -(model(x, fixed=dict(y=y))[1] - np.log(256) * n_dims) 51 | # deal with full label multi class pred case 52 | if len(nll.shape) == 2: 53 | nll_for_mask = th.sum(nll, dim=1).detach() 54 | else: 55 | nll_for_mask = nll.detach() 56 | mask = th.isfinite(nll_for_mask.detach()) 57 | valid_mask = mask & valid_mask 58 | n_valid = th.sum(valid_mask) 59 | if (n_valid < len(x)) and (n_valid > (len(x) // 4)): 60 | # n valid too small don't handle to prevent too small batch, 61 | # catch later nans on backward 62 | del nll 63 | del base_nll 64 | del mask 65 | x = x[valid_mask] 66 | if y is not None: 67 | y = y[valid_mask] 68 | if prev_valid_mask is None: 69 | prev_valid_mask = valid_mask 70 | return self.get_nlls(model, base_model, x, y, prev_valid_mask=prev_valid_mask) 71 | if prev_valid_mask is None: 72 | prev_valid_mask = valid_mask 73 | return dict(base_nll=base_nll, fine_nll=nll, valid_mask=prev_valid_mask) 74 | 75 | 76 | def apply_inlier_losses( 77 | model, base_model, x,y, nll_computer, 78 | add_full_label_loss, 79 | temperature, 80 | weight, 81 | outlier_batches, 82 | outlier_loss,): 83 | # Need to compute outputs for all classes for the full label loss 84 | y_for_outs = None if add_full_label_loss else y 85 | model_outs = nll_computer.get_nlls(model, base_model, x, y_for_outs) 86 | if len(model_outs['fine_nll']) < len(y): 87 | y = y[model_outs['valid_mask']] 88 | model_outs['masked_y'] = y 89 | apply_inlier_loss_from_outs( 90 | model_outs, y, 91 | add_full_label_loss, 92 | temperature=temperature, weight=weight, 93 | outlier_loss=outlier_loss,) 94 | return model_outs 95 | 96 | 97 | def apply_inlier_loss_from_outs( 98 | model_outs, 99 | y, 100 | add_full_label_loss, 101 | temperature, 102 | weight, 103 | outlier_loss): 104 | nll = model_outs['fine_nll'] 105 | total_loss = 0 106 | if add_full_label_loss: 107 | if len(y) == len(model_outs['base_nll']): 108 | assert len(nll.shape) == 2 109 | # only correct_labels 110 | nll_for_loss = th.gather( 111 | nll, dim=1, index=th.argmax(y, dim=1).unsqueeze(1)).squeeze(1) 112 | assert outlier_loss == 'class' 113 | # Now compute wrong class loss, sum over all wrong classes 114 | wrong_class_loss = nll_class_loss( 115 | model_outs['base_nll'].unsqueeze(1), 116 | model_outs['fine_nll'], 117 | target_val=0, temperature=temperature, 118 | weight=weight, reduction='none') 119 | # Mask out correct class 120 | wrong_class_loss = wrong_class_loss * (1 - y.type_as(wrong_class_loss)) 121 | wrong_class_loss = th.mean(wrong_class_loss) 122 | total_loss = total_loss + wrong_class_loss 123 | else: 124 | log.warning("Did not apply label loss because of nonfinite outputs") 125 | return 126 | else: 127 | nll_for_loss = nll 128 | assert len(nll_for_loss.shape) == 1 129 | nll_loss = th.mean(nll_for_loss) 130 | total_loss = total_loss + nll_loss 131 | total_loss.backward() 132 | 133 | 134 | def apply_outlier_losses( 135 | model, base_model, inlier_results, 136 | ood_x, ood_y, 137 | nll_computer, temperature, weight, outlier_loss): 138 | outlier_results = nll_computer.get_nlls(model, base_model, ood_x, ood_y) 139 | apply_outlier_losses_from_outs(inlier_results, outlier_results, 140 | outlier_loss, temperature=temperature, 141 | weight=weight, 142 | n_dims=np.prod(ood_x.shape[1:])) 143 | return outlier_results 144 | 145 | 146 | def apply_outlier_losses_from_outs( 147 | inlier_results, 148 | outlier_results, 149 | outlier_loss, 150 | temperature, 151 | weight, 152 | n_dims): 153 | if outlier_loss == 'margin': 154 | n_min = min(len(inlier_results['fine_nll']), len(outlier_results['fine_nll'])) 155 | diff = th.nn.functional.relu( 156 | n_dims + inlier_results['fine_nll'][:n_min].detach() - 157 | outlier_results['fine_nll'][:n_min]) 158 | ood_loss = th.mean(diff) 159 | elif outlier_loss == 'class': 160 | assert outlier_loss == 'class' 161 | ood_loss = nll_class_loss( 162 | outlier_results['base_nll'], outlier_results['fine_nll'], 163 | target_val=0, temperature=temperature, weight=weight, 164 | reduction='mean') 165 | ood_loss.backward() 166 | 167 | 168 | 169 | def run_exp(dataset, 170 | first_n, 171 | lr, 172 | weight_decay, 173 | np_th_seed, 174 | debug, 175 | output_dir, 176 | n_epochs, 177 | exclude_cifar_from_tiny, 178 | saved_base_model_path, 179 | saved_model_path, 180 | saved_optimizer_path, 181 | reinit, 182 | base_set_name, 183 | outlier_weight, 184 | outlier_loss, 185 | ood_set_name, 186 | outlier_temperature, 187 | noise_factor, 188 | K, 189 | flow_coupling, 190 | init_class_model, 191 | on_top_class_model_name, 192 | batch_size, 193 | add_full_label_loss, 194 | augment, 195 | tiny_grey, 196 | warmup_steps, 197 | local_patches, 198 | block_type, 199 | flow_permutation, 200 | LU_decomposed, 201 | ): 202 | hidden_channels=512 203 | L=3 204 | if debug: 205 | first_n = 512 206 | n_epochs = 3 207 | if dataset == 'tiny': 208 | # pretrain a bit longer 209 | first_n = 5120 210 | n_epochs = 10 211 | 212 | set_random_seeds(np_th_seed, True) 213 | 214 | log.info("Loading data...") 215 | loaders, base_train_loader = load_data( 216 | dataset=dataset, 217 | first_n=first_n, 218 | exclude_cifar_from_tiny=exclude_cifar_from_tiny, 219 | base_set_name=base_set_name, 220 | ood_set_name=ood_set_name, 221 | noise_factor=noise_factor, 222 | batch_size=batch_size, 223 | augment=augment, 224 | tiny_grey=tiny_grey, 225 | ) 226 | 227 | n_chans = next(loaders['test'].__iter__())[0].shape[1] 228 | 229 | if saved_model_path is None: 230 | log.info("Creating model...") 231 | if not local_patches: 232 | model = create_glow_model( 233 | hidden_channels=hidden_channels, 234 | K=K, 235 | L=L, 236 | flow_permutation=flow_permutation, 237 | flow_coupling=flow_coupling, 238 | LU_decomposed=LU_decomposed, 239 | n_chans=n_chans, 240 | block_type=block_type, 241 | ) 242 | else: 243 | model = create_patch_glow_model( 244 | hidden_channels=hidden_channels, 245 | K=K, 246 | flow_permutation=flow_permutation, 247 | flow_coupling=flow_coupling, 248 | LU_decomposed=LU_decomposed, 249 | n_chans=n_chans, 250 | ) 251 | model = model.cuda(); 252 | 253 | if saved_base_model_path is not None: 254 | log.info("Loading base model...") 255 | base_model = th.load(saved_base_model_path) 256 | init_all_modules(base_model, None) 257 | else: 258 | base_model = None 259 | 260 | if saved_model_path is not None: 261 | log.info("Loading pretrained model...") 262 | model = th.load(saved_model_path) 263 | if on_top_class_model_name is not None: 264 | # Check if we have a on top class model already contained 265 | if hasattr(model, 'sequential') and len(list(model.sequential.children())) == 2: 266 | log.info("Extracting on top class model...") 267 | model_log_det_node = model.sequential[0] 268 | class_model = model.sequential[1].module 269 | model = InvertibleSequential( 270 | model_log_det_node, MergeLogDets(class_model)) 271 | class_model_loaded = True 272 | else: 273 | class_model_loaded = False 274 | else: 275 | class_model_loaded = False 276 | 277 | 278 | if (on_top_class_model_name is not None) and (not class_model_loaded): 279 | # remove references to previous dist node 280 | model_log_act_nodes = get_nodes_by_names( 281 | model, 'm0-act-0', 'm0-act-1', 'm0-act-2') 282 | for a in model_log_act_nodes: 283 | a.next = [] 284 | model_log_det_node = CatChansNode( 285 | model_log_act_nodes, 286 | notify_prev_nodes=True) 287 | 288 | if on_top_class_model_name == 'latent': 289 | top_single_class_model = latent_model(n_chans) 290 | if dataset in ['cifar10', 'svhn', 'fashion-mnist', 'mnist']: 291 | n_classes = 10 292 | else: 293 | assert dataset == 'cifar100' 294 | n_classes = 100 295 | 296 | i_classes = list(range(n_classes)) 297 | class_model = InvertibleClassConditional( 298 | [Node(deepcopy(top_single_class_model), CatChans()) for _ in 299 | i_classes], 300 | i_classes) 301 | del top_single_class_model 302 | class_model.cuda(); 303 | 304 | if init_class_model: 305 | from itertools import islice 306 | with th.no_grad(): 307 | init_x_y = [(model_log_det_node(x)[0], y) for x, y in 308 | islice(loaders['train'], 10)] 309 | init_x = th.cat([xy[0] for xy in init_x_y], dim=0) 310 | init_y = th.cat([xy[1] for xy in init_x_y], dim=0) 311 | init_all_modules(class_model, (init_x, init_y), use_y=True) 312 | else: 313 | init_all_modules(class_model, None) 314 | 315 | model = InvertibleSequential(model_log_det_node, MergeLogDets(class_model)) 316 | model.cuda(); 317 | 318 | if add_full_label_loss: 319 | train_inlier_model = InvertibleSequential(model_log_det_node, 320 | class_model) 321 | train_inlier_model.cuda(); 322 | else: 323 | train_inlier_model = model 324 | 325 | if (saved_model_path is None) or reinit: 326 | init_all_modules(model, loaders['train'], use_y=False) 327 | else: 328 | init_all_modules(model, None) 329 | 330 | optimizer = th.optim.Adamax( 331 | [p for p in model.parameters() if p.requires_grad], 332 | lr=lr, weight_decay=weight_decay) 333 | 334 | if saved_optimizer_path is not None: 335 | optimizer.load_state_dict(th.load(saved_optimizer_path)) 336 | 337 | if (warmup_steps is not None) and (warmup_steps > 0): 338 | lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup_steps) # noqa 339 | scheduler = th.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 340 | optimizer = ScheduledOptimizer(scheduler, optimizer) 341 | 342 | 343 | if (ood_set_name is not None): 344 | log.info("Compute BPDs of base model...") 345 | # Remember bpds of base model 346 | base_model_bpds = {} 347 | assert base_model is not None 348 | for loader_name, loader in loaders.items(): 349 | if (base_model is not None): 350 | bpds = compute_bpds(loader, base_model, use_y=False, 351 | show_tqdm=False) 352 | base_model_bpds[loader_name] = var_to_np(bpds) 353 | mean_bpd = th.mean(bpds).item() 354 | print(f"base_{loader_name} BPD: {mean_bpd:.2f}") 355 | 356 | if outlier_loss is not None: 357 | train_base_model = base_model 358 | else: 359 | train_base_model = None 360 | 361 | nll_computer = BaseFineIndependent() 362 | train_loader = deepcopy(loaders['train']) 363 | 364 | def get_outlier_batches(): 365 | outlier_batches = [] 366 | x, _ = next(base_train_loader.__iter__()) 367 | outlier_batches.append((x, None)) 368 | return outlier_batches 369 | 370 | # Only need to compute base nll on inliers if needed 371 | # (full label loss or 372 | if add_full_label_loss: 373 | train_base_model_inlier = train_base_model 374 | else: 375 | train_base_model_inlier = None 376 | 377 | 378 | def train(engine, batch): 379 | x, y = batch 380 | check_gradients_clear(optimizer) 381 | model.train() 382 | if outlier_loss is not None: 383 | outlier_batches = get_outlier_batches() 384 | else: 385 | outlier_batches = None 386 | inlier_results = apply_inlier_losses( 387 | train_inlier_model, 388 | train_base_model_inlier, 389 | x, 390 | y, 391 | nll_computer, 392 | add_full_label_loss=add_full_label_loss, 393 | temperature=outlier_temperature, 394 | weight=outlier_weight, 395 | outlier_loss=outlier_loss, 396 | outlier_batches=outlier_batches,) 397 | 398 | if outlier_loss is not None: 399 | for out_x, out_y in outlier_batches: 400 | apply_outlier_losses(model, train_base_model, 401 | inlier_results, out_x, out_y, 402 | nll_computer, 403 | outlier_loss=outlier_loss, 404 | temperature=outlier_temperature, 405 | weight=outlier_weight,) 406 | if grads_all_finite(optimizer): 407 | step_and_clear_gradients(optimizer) 408 | else: 409 | log.warning("NaNs or Infs in grad! Not all grads finite") 410 | optimizer.zero_grad() 411 | n_dims = np.prod(x.shape[1:]) 412 | bpd = th.mean(inlier_results['fine_nll']).item() / (n_dims * np.log(2)) 413 | return dict(bpd=bpd) 414 | 415 | eval_model = model 416 | 417 | def evaluate(engine, ): 418 | eval_model.eval() 419 | results = {} 420 | print(f"Epoch {engine.state.epoch:d}") 421 | all_bpds_per_set = {} 422 | for loader_name, loader in loaders.items(): 423 | bpds = compute_bpds(loader, eval_model, use_y=False, 424 | show_tqdm=False) 425 | # some stabilizations for later evaluation 426 | # bpds[~np.isnan(bpds)] = np.nanmax(bpds) 427 | # bpds = np.clip(bpds, -100000,100000) 428 | mean_bpd = th.mean(bpds).item() 429 | result_key_name = f"{loader_name:s}" 430 | print(f"{result_key_name} BPD: {mean_bpd:.2f}") 431 | all_bpds_per_set[f"{result_key_name}"] = var_to_np(bpds) 432 | results[f"{result_key_name}_bpd"] = mean_bpd 433 | writer.add_scalar(f"{result_key_name}_bpd", mean_bpd, 434 | engine.state.epoch) 435 | 436 | if ood_set_name is not None: 437 | # AUC computation 438 | for fold in ('train', 'test'): 439 | itd_diffs = all_bpds_per_set[f"{fold}"] - base_model_bpds[ 440 | f"{fold}"] 441 | itd_diffs[~np.isfinite(itd_diffs)] = 300000 442 | 443 | ood_sets = ['ood_test'] 444 | if tiny_grey == False: 445 | ood_sets.extend(["ood_cifar", "lsun"]) 446 | for ood_name in ood_sets: 447 | ood_diffs = all_bpds_per_set[ood_name] - base_model_bpds[ 448 | ood_name] 449 | # set to very high numbers in cas of not finite 450 | ood_diffs[~np.isfinite(ood_diffs)] = 300000 451 | auc = compute_auc_for_scores( 452 | itd_diffs[np.isfinite(itd_diffs)], 453 | ood_diffs[np.isfinite(ood_diffs)]) * 100 454 | results[f"{fold}_vs_{ood_name}_auc"] = auc 455 | print(f"{fold}_vs_{ood_name} AUC: {auc:.1f} %") 456 | writer.add_scalar(f"{fold}_vs_{ood_name}_auc", auc, 457 | engine.state.epoch) 458 | 459 | engine.state.results = results 460 | writer.flush() 461 | if ((engine.state.epoch % max(n_epochs // 5, 1) == 0) 462 | or engine.state.epoch == n_epochs) and (not debug): 463 | model_path = os.path.join(output_dir, 464 | f"{engine.state.epoch:d}_model.th") 465 | th.save(model, open(model_path, 'wb')) 466 | 467 | writer = SummaryWriter(output_dir) 468 | trainer = Engine(train) 469 | trainer.add_event_handler(Events.STARTED, evaluate) 470 | trainer.add_event_handler(Events.EPOCH_COMPLETED, evaluate) 471 | if not debug: 472 | checkpoint_state_dicts_handler = ModelCheckpoint(output_dir, 'state_dicts', 473 | save_interval=1, 474 | n_saved=1, 475 | require_empty=False, 476 | save_as_state_dict=True) 477 | 478 | models = dict(model=model) 479 | optimizers = dict(optimizer=optimizer) 480 | trainer.add_event_handler(Events.EPOCH_COMPLETED, 481 | checkpoint_state_dicts_handler, 482 | {**models, **optimizers}) 483 | trainer.run(train_loader, n_epochs) 484 | return trainer, model 485 | 486 | -------------------------------------------------------------------------------- /invglow/folder_locations.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Just set your paths here explicitly, no need 4 | # to use system environment as we did, this just made it easier for us 5 | 6 | pytorch_data = os.environ['pytorch_data'] 7 | # Also copy the 80mn_cifar_idxs.txt and cifar_indexes file there 8 | tiny_data = os.environ['tiny_data'] 9 | lsun_data = os.environ['lsun_data'] 10 | # only necessary for MRI experiment: 11 | brats_data = os.environ['brats_data'] 12 | # only necessary for additional OOD dataset evaluation: 13 | celeba_data = os.environ['celeba_data'] 14 | tiny_imagenet_data = os.environ['tiny_imagenet_data'] -------------------------------------------------------------------------------- /invglow/invertible/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boschresearch/hierarchical_anomaly_detection/ca2f1d84615c2ef140a74f4e1515352abff9e938/invglow/invertible/__init__.py -------------------------------------------------------------------------------- /invglow/invertible/actnorm.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | import numpy as np 4 | 5 | def inverse_elu(y): 6 | mask = y > 1 7 | x = th.zeros_like(y) 8 | x.data[mask] = y.data[mask] - 1 9 | x.data[1-mask] = th.log(y.data[1-mask]) 10 | return x 11 | 12 | class ActNorm(nn.Module): 13 | def __init__(self, in_channel, scale_fn, eps=1e-8, verbose_init=True, 14 | init_eps=None): 15 | super().__init__() 16 | 17 | self.loc = nn.Parameter(th.zeros(1, in_channel, 1, 1)) 18 | self.log_scale = nn.Parameter(th.zeros(1, in_channel, 1, 1)) 19 | 20 | self.initialize_this_forward = False 21 | self.initialized = False 22 | self.scale_fn = scale_fn 23 | self.eps = eps 24 | self.verbose_init = verbose_init 25 | if init_eps is None: 26 | if scale_fn == 'exp': 27 | self.init_eps = 1e-6 28 | else: 29 | assert scale_fn == 'elu' 30 | self.init_eps = 1e-1 31 | else: 32 | self.init_eps = init_eps 33 | 34 | def initialize(self, x): 35 | with th.no_grad(): 36 | flatten = x.permute(1, 0, 2, 3).contiguous().view(x.shape[1], -1) 37 | mean = ( 38 | flatten.mean(1) 39 | .unsqueeze(1) 40 | .unsqueeze(2) 41 | .unsqueeze(3) 42 | .permute(1, 0, 2, 3) 43 | ) 44 | std = ( 45 | flatten.std(1) 46 | .unsqueeze(1) 47 | .unsqueeze(2) 48 | .unsqueeze(3) 49 | .permute(1, 0, 2, 3) 50 | ) 51 | self.loc.data.copy_(-mean) 52 | if self.scale_fn == 'exp': 53 | self.log_scale.data.copy_(th.log(1 / th.clamp_min(std, self.init_eps))) 54 | elif self.scale_fn == 'elu': 55 | self.log_scale.data.copy_(inverse_elu(1 / th.clamp_min(std, self.init_eps))) 56 | else: 57 | assert False 58 | 59 | if self.scale_fn == 'exp': 60 | multipliers = th.exp(self.log_scale.squeeze()) 61 | elif self.scale_fn == 'elu': 62 | multipliers = th.nn.functional.elu(self.log_scale) + 1 63 | if self.verbose_init: 64 | print(f"Multiplier init to (log10) " 65 | f"min: {np.log10(th.min(multipliers).item()):3.0f} " 66 | f"max: {np.log10(th.max(multipliers).item()):3.0f} " 67 | f"mean: {np.log10(th.mean(multipliers).item()):3.0f}") 68 | 69 | def forward(self, x, fixed=None): 70 | was_2d = False 71 | if len (x.shape) == 2: 72 | was_2d = True 73 | x = x.unsqueeze(-1).unsqueeze(-1) 74 | _, _, height, width = x.shape 75 | 76 | if not self.initialized: 77 | assert self.initialize_this_forward, ( 78 | "Please first initialize by setting initialize_this_forward to True" 79 | " and forwarding appropriate data") 80 | if self.initialize_this_forward: 81 | self.initialize(x) 82 | self.initialized = True 83 | self.initialize_this_forward = False 84 | 85 | scale, log_det_px = self.scale_and_logdet_per_pixel() 86 | y = scale * (x + self.loc) 87 | if was_2d: 88 | y = y.squeeze(-1).squeeze(-1) 89 | 90 | logdet = height * width * log_det_px 91 | logdet = logdet.repeat(len( 92 | x)) 93 | 94 | return y, logdet 95 | 96 | def scale_and_logdet_per_pixel(self): 97 | if self.scale_fn == 'exp': 98 | scale = th.exp(self.log_scale) + self.eps 99 | if self.eps == 0: 100 | logdet = th.sum(self.log_scale) 101 | else: 102 | logdet = th.sum(th.log(scale)) 103 | elif self.scale_fn == 'elu': 104 | scale = th.nn.functional.elu(self.log_scale) + 1 + self.eps 105 | logdet = th.sum(th.log(scale)) 106 | else: 107 | assert False 108 | 109 | return scale, logdet 110 | 111 | def invert(self, y, fixed=None): 112 | was_2d = False 113 | if len (y.shape) == 2: 114 | was_2d = True 115 | y = y.unsqueeze(-1).unsqueeze(-1) 116 | _, _, height, width = y.shape 117 | scale, log_det_px = self.scale_and_logdet_per_pixel() 118 | x = y / scale - self.loc 119 | logdet = height * width * log_det_px 120 | if was_2d: 121 | x = x.squeeze(-1).squeeze(-1) 122 | # repeat per example in batch 123 | logdet = logdet.repeat(len( 124 | x)) 125 | return x, logdet 126 | 127 | 128 | def init_act_norm(net, trainloader, n_batches=10, uni_noise_factor=1/255.0): 129 | if trainloader is not None: 130 | all_x = [] 131 | for i_batch, (x, y) in enumerate(trainloader): 132 | all_x.append(x) 133 | if i_batch >= n_batches: 134 | break 135 | 136 | init_x = th.cat(all_x, dim=0) 137 | init_x = init_x.cuda() 138 | init_x = init_x + th.rand_like(init_x) * uni_noise_factor 139 | 140 | for m in net.modules(): 141 | if hasattr(m, 'initialize_this_forward'): 142 | m.initialize_this_forward = True 143 | 144 | _ = net(init_x) 145 | else: 146 | for m in net.modules(): 147 | if hasattr(m, 'initialize_this_forward'): 148 | m.initialized = True 149 | 150 | 151 | class PureActNorm(nn.Module): 152 | def __init__(self, in_channel,): 153 | super().__init__() 154 | self.loc = nn.Parameter(th.zeros(in_channel)) 155 | self.scale = nn.Parameter(th.zeros(in_channel)) 156 | self.initialize_this_forward = False 157 | self.initialized = False 158 | 159 | def forward(self, x): 160 | if not self.initialized: 161 | assert self.initialize_this_forward, ( 162 | "Please first initialize by setting initialize_this_forward to True" 163 | " and forwarding appropriate data") 164 | if self.initialize_this_forward: 165 | self.initialize(x) 166 | self.initialized = True 167 | self.initialize_this_forward = False 168 | 169 | loc = self.loc.unsqueeze(0) 170 | scale = self.scale.unsqueeze(0) 171 | if len(x.shape) == 4: 172 | loc = loc.unsqueeze(2).unsqueeze(3) 173 | scale = scale.unsqueeze(2).unsqueeze(3) 174 | y = scale * (x + loc) 175 | return y 176 | 177 | def initialize(self, x): 178 | with th.no_grad(): 179 | flatten = x.transpose(0,1).contiguous().view(x.shape[1], -1) 180 | mean = ( 181 | flatten.mean(1) 182 | ) 183 | std = ( 184 | flatten.std(1) 185 | ) 186 | self.loc.data.copy_(-mean) 187 | self.scale.data.copy_(1 / (std + 1e-4)) 188 | print("Multiplier initialized to \n", self.scale.squeeze()) 189 | -------------------------------------------------------------------------------- /invglow/invertible/affine.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | 5 | class AffineCoefs(nn.Module): 6 | def __init__(self, module, splitter): 7 | super().__init__() 8 | self.module = module 9 | self.splitter = splitter 10 | 11 | def forward(self, x): 12 | coefs = self.module(x) 13 | add, raw_scale = self.splitter.split(coefs) 14 | return (add, raw_scale) 15 | 16 | class AdditiveCoefs(nn.Module): 17 | def __init__(self, module, ): 18 | super().__init__() 19 | self.module = module 20 | 21 | def forward(self, x): 22 | add = self.module(x) 23 | raw_scale = None 24 | return (add, raw_scale) 25 | 26 | 27 | class AffineModifier(nn.Module): 28 | def __init__(self, sigmoid_or_exp_scale, add_first, eps, ): 29 | super().__init__() 30 | self.sigmoid_or_exp_scale = sigmoid_or_exp_scale 31 | self.add_first = add_first 32 | self.eps = eps 33 | 34 | def forward(self, x2, coefs): 35 | add, raw_scale = coefs 36 | 37 | if raw_scale is not None: 38 | if self.sigmoid_or_exp_scale == 'sigmoid': 39 | s = th.sigmoid(raw_scale + 2.) + self.eps 40 | else: 41 | assert self.sigmoid_or_exp_scale == 'exp' 42 | s = th.exp(raw_scale) + self.eps 43 | logdet = th.sum(th.log(s).view(s.shape[0], -1), 1) 44 | else: 45 | logdet = 0 46 | 47 | if self.add_first and (add is not None): 48 | x2 = x2 + add 49 | if raw_scale is not None: 50 | x2 = x2 * s 51 | if (not self.add_first) and (add is not None): 52 | x2 = x2 + add 53 | return x2, logdet 54 | 55 | def invert(self, x2, coefs): 56 | add, raw_scale = coefs 57 | if raw_scale is not None: 58 | if self.sigmoid_or_exp_scale == 'sigmoid': 59 | s = th.sigmoid(raw_scale + 2) + self.eps 60 | else: 61 | assert self.sigmoid_or_exp_scale == 'exp' 62 | s = th.exp(raw_scale) + self.eps 63 | logdet = th.sum(th.log(s).view(s.shape[0], -1), 1) 64 | else: 65 | logdet = 0 66 | 67 | if (not self.add_first) and (add is not None): 68 | x2 = x2 - add 69 | 70 | if raw_scale is not None: 71 | x2 = x2 / s 72 | 73 | if (self.add_first) and (add is not None): 74 | x2 = x2 - add 75 | return x2, logdet 76 | 77 | 78 | class AffineBlock(th.nn.Module): 79 | def __init__(self, FA, FM, single_affine_block, 80 | split_merger, 81 | sigmoid_or_exp_scale=None, 82 | eps=1e-2, 83 | condition_merger=None, 84 | add_first=None): 85 | super().__init__() 86 | if add_first is None: 87 | print("warning add first is None, setting to False!!") 88 | add_first = False 89 | # first G before F, only to have consistent ordering of 90 | # parameter list compared to other code 91 | self.FA = FA 92 | self.FM = FM 93 | self.split_merger = split_merger 94 | self.single_affine_block = single_affine_block 95 | if self.single_affine_block: 96 | assert self.FM is None 97 | if (self.FM is not None) or self.single_affine_block: 98 | assert sigmoid_or_exp_scale is not None 99 | else: 100 | assert sigmoid_or_exp_scale is None 101 | self.sigmoid_or_exp_scale = sigmoid_or_exp_scale 102 | self.eps = eps 103 | self.condition_merger = condition_merger 104 | self.accepts_condition = (condition_merger is not None) 105 | self.add_first = add_first 106 | 107 | def forward(self, x, condition=None): 108 | logdet = 0 109 | x1, x2 = self.split_merger.split(x) 110 | y1 = x1 111 | y2 = x2 112 | if condition is not None: 113 | assert self.accepts_condition 114 | y2 = self.condition_merger(y2,condition) 115 | 116 | raw_scale_F = None 117 | add_F = None 118 | if self.single_affine_block: 119 | add_F, raw_scale_F = th.chunk(self.FA(y2), 2, dim=1) 120 | #h = self.FA(y2) 121 | #add_F, raw_scale_F = h[:,0::2], h[:,1::2] 122 | else: 123 | if self.FA is not None: 124 | add_F = self.FA(y2) 125 | if self.FM is not None: 126 | raw_scale_F = self.FM(y2) 127 | 128 | if raw_scale_F is not None: 129 | if self.sigmoid_or_exp_scale == 'sigmoid': 130 | s = th.sigmoid(raw_scale_F + 2.) + self.eps 131 | else: 132 | assert self.sigmoid_or_exp_scale == 'exp' 133 | s = th.exp(raw_scale_F) + self.eps 134 | logdet = logdet + th.sum(th.log(s).view(s.shape[0], -1), 1) 135 | 136 | if self.add_first and (add_F is not None): 137 | y1 = y1 + add_F 138 | if raw_scale_F is not None: 139 | y1 = y1 * s 140 | if (not self.add_first) and (add_F is not None): 141 | y1 = y1 + add_F 142 | 143 | 144 | # x2 should be unchanged!! 145 | y = self.split_merger.merge(y1,x2) 146 | return y , logdet 147 | 148 | def invert(self, y, condition=None): 149 | y1, y2 = self.split_merger.split(y) 150 | x1 = y1 151 | x2 = y2 152 | if condition is not None: 153 | assert self.accepts_condition 154 | x2 = self.condition_merger(x2,condition) 155 | logdet = 0 156 | 157 | raw_scale_F = None 158 | add_F = None 159 | if self.single_affine_block: 160 | add_F, raw_scale_F = th.chunk(self.FA(x2), 2, dim=1) 161 | else: 162 | if self.FA is not None: 163 | add_F = self.FA(x2) 164 | if self.FM is not None: 165 | raw_scale_F = self.FM(x2) 166 | 167 | if (not self.add_first) and (add_F is not None): 168 | x1 = x1 - add_F 169 | if raw_scale_F is not None: 170 | if self.sigmoid_or_exp_scale == 'sigmoid': 171 | s = th.sigmoid(raw_scale_F + 2) + self.eps 172 | else: 173 | assert self.sigmoid_or_exp_scale == 'exp' 174 | s = th.exp(raw_scale_F) + self.eps 175 | logdet = logdet + th.sum(th.log(s).view(s.shape[0], -1), 1) 176 | 177 | if raw_scale_F is not None: 178 | x1 = x1 / s 179 | 180 | if (self.add_first) and (add_F is not None): 181 | x1 = x1 - add_F 182 | 183 | # y2 should be unchanged!! 184 | x = self.split_merger.merge(x1,y2) 185 | return x, logdet 186 | 187 | 188 | class AdditiveBlock(AffineBlock): 189 | def __init__(self, FA, eps=0): 190 | super(AdditiveBlock, self).__init__( 191 | FA=FA, FM=None, eps=eps) 192 | 193 | class MultiplicativeBlock(AffineBlock): 194 | def __init__(self, FM, eps=0): 195 | super(AdditiveBlock, self).__init__( 196 | FA=None, FM=FM, eps=eps) 197 | -------------------------------------------------------------------------------- /invglow/invertible/branching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | from torch import nn 4 | 5 | 6 | class ChunkChans(nn.Module): 7 | def __init__(self, n_parts): 8 | super(ChunkChans, self).__init__() 9 | self.n_parts = n_parts 10 | 11 | def forward(self, x, fixed=None): 12 | xs = th.chunk(x, chunks=self.n_parts, dim=1, ) 13 | # for debug 14 | self.my_x_sizes = [x.size() for x in xs] 15 | return xs, 0 16 | 17 | def invert(self, y, fixed=None): 18 | y = th.cat(y, dim=1) 19 | return y, 0 20 | 21 | 22 | class SwitchX1X2(nn.Module): 23 | def forward(self, x): 24 | x1, x2 = th.chunk(x, 2, dim=1) 25 | return th.cat([x2, x1], dim=1) 26 | 27 | def invert(self, y): 28 | return self.forward(y) 29 | 30 | 31 | class ChunkByIndex(nn.Module): 32 | def __init__(self, index): 33 | super(ChunkByIndex, self).__init__() 34 | self.index = index 35 | 36 | def forward(self, x, fixed=None): 37 | xs = [x[:, :self.index], x[:,self.index:]] 38 | return xs, 0 39 | 40 | def invert(self, y, fixed=None): 41 | y = th.cat(y, dim=1) 42 | return y, 0 43 | 44 | class ChunkByIndices(nn.Module): 45 | def __init__(self, indices): 46 | super().__init__() 47 | self.indices = tuple(indices) 48 | 49 | def forward(self, x, fixed=None): 50 | indices = (0,) + self.indices + (x.shape[1],) 51 | xs = [x[:, start:stop] 52 | for start, stop in zip(indices[:-1], indices[1:])] 53 | return xs, 0 54 | 55 | def invert(self, y, fixed=None): 56 | y = th.cat(y, dim=1) 57 | return y, 0 58 | 59 | 60 | class CatChans(nn.Module): 61 | def __init__(self,): 62 | super().__init__() 63 | self.n_chans = None 64 | 65 | def forward(self, xs, fixed=None): 66 | n_chans = tuple([a_x.size()[1] for a_x in xs]) 67 | if self.n_chans is None: 68 | self.n_chans = n_chans 69 | else: 70 | assert n_chans == self.n_chans 71 | return th.cat(xs, dim=1), 0 72 | 73 | def invert(self, ys, fixed=None): 74 | assert self.n_chans is not None, "please do forward first" 75 | if ys is not None: 76 | xs = [] 77 | bounds = np.insert(np.cumsum(self.n_chans), 0, 0) 78 | for i_b in range(len(bounds) - 1): 79 | xs.append(ys[:, bounds[i_b]:bounds[i_b + 1]]) 80 | else: 81 | xs = [None] * len(self.n_chans) 82 | return xs, 0 83 | -------------------------------------------------------------------------------- /invglow/invertible/categorical_mixture.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch as th 3 | 4 | class InvertibleClassConditional(nn.Module): 5 | def __init__(self, modules, i_classes): 6 | super().__init__() 7 | self.module_list = nn.ModuleList(modules) 8 | self.i_classes = i_classes 9 | 10 | def forward(self, x, fixed): 11 | return self.compute(x, fixed=fixed, mode='forward') 12 | 13 | def invert(self, y, fixed): 14 | return self.compute(y, fixed=fixed, mode='invert') 15 | 16 | def compute(self, x, fixed, mode): 17 | y = fixed['y'] 18 | if y is None: # just compute all mixture components for all examples 19 | if hasattr(x, 'shape'): # if it was list already 20 | # then no need to split up again, we already duplicated it... 21 | xs = [x] * len(self.module_list) 22 | else: 23 | xs = x 24 | else: 25 | masks = [y[:, i_class] == 1 for i_class in self.i_classes] 26 | xs = [x[m] for m in masks] 27 | 28 | outs = [] 29 | log_dets = [] 30 | assert len(xs) == len(self.module_list) 31 | for a_x, module in zip(xs, self.module_list): 32 | if len(a_x) > 0: 33 | if mode == 'forward': 34 | this_out, this_log_det = module(a_x, 35 | fixed=fixed) 36 | o_shape = this_out.shape # used below 37 | else: 38 | assert mode == 'invert' 39 | this_out, this_log_det = module.invert(a_x, 40 | fixed=fixed) 41 | else: 42 | this_out, this_log_det = None, None 43 | outs.append(this_out) 44 | log_dets.append(this_log_det) 45 | 46 | if y is None: 47 | return outs, th.stack(log_dets, dim=-1) 48 | else: 49 | assert len(outs) == len(masks) == len(log_dets), ( 50 | f"n_outs: {len(outs)}, n_masks: {len(masks)}, n_dets: {len(log_dets)}") 51 | outs_full = th.zeros(len(x), *o_shape[1:], dtype=x.dtype, 52 | device=x.device) 53 | log_dets_full = th.zeros(len(x), dtype=x.dtype, device=x.device) 54 | for out, log_det, mask in zip(outs, log_dets, masks): 55 | if out is not None: 56 | #outs_full[mask] = outs_full[mask] + out 57 | #log_dets_full[mask] = log_dets_full[mask] + log_det 58 | log_dets_full = log_dets_full.masked_scatter(mask, log_det) 59 | while len(mask.shape) < len(outs_full.shape): 60 | mask = mask.unsqueeze(-1).repeat( 61 | (1,) * len(mask.shape) + ( 62 | outs_full.shape[len(mask.shape)],)) 63 | outs_full = outs_full.masked_scatter(mask, out) 64 | 65 | # counts = th.zeros(len(self.i_classes), dtype=th.int64) 66 | # y_label = th.argmax(y, dim=1) 67 | # all_outs = [] 68 | # all_log_dets = [] 69 | # for i in range(len(x)): 70 | # i_class = y_label[i] 71 | # i_in_class = counts[i_class] 72 | # all_outs.append(outs[i_class][i_in_class]) 73 | # all_log_dets.append(log_dets[i_class][i_in_class]) 74 | # counts[i_class] += 1 75 | # 76 | # outs_full = th.stack(all_outs, axis=0) 77 | # log_dets_full = th.stack(all_log_dets, axis=0) 78 | # 79 | 80 | # outs_full = th.cat(outs, axis=0) 81 | # log_dets_full = th.cat(log_dets, axis=0) 82 | return outs_full, log_dets_full 83 | -------------------------------------------------------------------------------- /invglow/invertible/conditional.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | class CatChansMerger(nn.Module): 5 | def __init__(self, cond_preproc=None): 6 | super(CatChansMerger, self).__init__() 7 | self.cond_preproc = cond_preproc 8 | 9 | def forward(self, x,cond, fixed=None): 10 | if self.cond_preproc is not None: 11 | cond_processed = self.cond_preproc(cond) 12 | else: 13 | cond_processed = cond 14 | return th.cat((x, cond_processed), dim=1) 15 | 16 | 17 | class ConditionTransformWrapper(nn.Module): 18 | def __init__(self, module, cond_preproc): 19 | super().__init__() 20 | self.module = module 21 | self.cond_preproc = cond_preproc 22 | self.accepts_condition=True 23 | 24 | def forward(self, x, condition, fixed=None): 25 | cond_processed = self.cond_preproc(condition) 26 | return self.module(x, condition=cond_processed, fixed=fixed) 27 | 28 | def invert(self, y, condition, fixed=None): 29 | cond_processed = self.cond_preproc(condition) 30 | return self.module.invert(y, condition=cond_processed, fixed=fixed) 31 | 32 | 33 | class ApplyAndCat(nn.Module): 34 | """Apply different modules to different inputs. 35 | First module will be applied to first input, etc. 36 | So this module expects to receive a list of inputs 37 | in the forward.""" 38 | 39 | def __init__(self, *modules): 40 | super().__init__() 41 | self.module_list = nn.ModuleList(modules) 42 | 43 | def forward(self, xs): 44 | assert len(xs) == len(self.module_list), ( 45 | f"{len(xs)} xs and {len(self.module_list)} modules") 46 | 47 | ys = [m(x) for m, x in zip(self.module_list, xs)] 48 | return th.cat(ys, dim=1) 49 | -------------------------------------------------------------------------------- /invglow/invertible/coupling.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class CouplingLayer(nn.Module): 5 | def __init__(self, split_merger, coef_extractor, 6 | modifier, 7 | condition_merger=None, 8 | ): 9 | super().__init__() 10 | self.split_merger = split_merger 11 | self.coef_extractor = coef_extractor 12 | self.modifier = modifier 13 | self.condition_merger = condition_merger 14 | self.accepts_condition = (condition_merger is not None) 15 | 16 | def forward(self, x, condition=None, fixed=None): 17 | x1, x2 = self.split_merger.split(x) 18 | y1 = x1 19 | y2 = x2 20 | if condition is not None: 21 | assert self.accepts_condition 22 | y2 = self.condition_merger(y2, condition) 23 | coefs = self.coef_extractor(y2) 24 | y1, log_det = self.modifier(y1, coefs) 25 | y = self.split_merger.merge(y1, x2) # x2 should be unchanged 26 | return y, log_det 27 | 28 | def invert(self, y, condition=None, fixed=None): 29 | y1, y2 = self.split_merger.split(y) 30 | x1 = y1 31 | x2 = y2 32 | if condition is not None: 33 | assert self.accepts_condition 34 | x2 = self.condition_merger(x2, condition) 35 | coefs = self.coef_extractor(x2) 36 | x1, log_det = self.modifier.invert(x1, coefs) 37 | x = self.split_merger.merge(x1, y2) # y2 should be unchanged 38 | return x, log_det -------------------------------------------------------------------------------- /invglow/invertible/distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | from invglow.invertible.gaussian import get_gauss_samples, \ 4 | get_mixture_gaussian_log_probs 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from invglow.invertible.gaussian import get_gaussian_log_probs 9 | 10 | 11 | class MergeLogDets(nn.Module): 12 | def __init__(self, module): 13 | super().__init__() 14 | self.module = module 15 | 16 | def forward(self, x, fixed): 17 | y, logdets = self.module(x, fixed=fixed) 18 | if fixed['y'] is None: 19 | n_components = logdets.shape[1] 20 | logdets = th.logsumexp(logdets, dim=1) - np.log(n_components) 21 | return y, logdets 22 | 23 | def invert(self, y, fixed): 24 | return self.module.invert(y, fixed=fixed) 25 | 26 | 27 | 28 | class PerClass(nn.Module): 29 | def __init__(self, dist): 30 | super().__init__() 31 | self.dist = dist 32 | 33 | def forward(self, x, fixed=None): 34 | logdet = self.dist.log_probs_per_class(x) 35 | if hasattr(fixed, '__getitem__') and 'y' in fixed and fixed['y'] is not None: 36 | y = fixed['y'] 37 | logdet = logdet.gather( 38 | dim=1, index=y.argmax(dim=1, keepdim=True)).squeeze(1) 39 | return x, logdet 40 | 41 | def invert(self, y, fixed=None): 42 | if hasattr(fixed, '__getitem__') and 'y' in fixed: 43 | assert fixed['y'] == None, "other not implemented" 44 | if y is None: 45 | assert 'n_samples' in fixed 46 | y = self.dist.get_unlabeled_samples(fixed['n_samples'], 47 | std_factor=1) 48 | logdet = self.dist.log_probs_per_class(y) 49 | return y, logdet 50 | 51 | 52 | class Unlabeled(nn.Module): 53 | def __init__(self, dist): 54 | super().__init__() 55 | self.dist = dist 56 | 57 | def forward(self, x, fixed=None): 58 | logdet = self.dist.log_prob_unlabeled(x) 59 | return x, logdet 60 | 61 | def invert(self, y, fixed=None): 62 | if y is None: 63 | assert 'n_samples' in fixed 64 | y = self.dist.get_unlabeled_samples(fixed['n_samples'], 65 | std_factor=1) 66 | logdet = self.dist.log_prob_unlabeled(y) 67 | return y, logdet 68 | 69 | 70 | class ZeroDist(nn.Module): 71 | def log_prob_unlabeled(self, x): 72 | return 0 73 | 74 | 75 | class NClassIndependentDist(nn.Module): 76 | def __init__(self, n_classes=None, n_dims=None, optimize_mean_std=True, truncate_to=None, 77 | means=None, log_stds=None): 78 | super().__init__() 79 | if means is not None: 80 | assert log_stds is not None 81 | self.class_means = means 82 | self.class_log_stds = log_stds 83 | else: 84 | if optimize_mean_std: 85 | self.class_means = nn.Parameter( 86 | th.zeros(n_classes, n_dims, requires_grad=True)) 87 | self.class_log_stds = nn.Parameter( 88 | th.zeros(n_classes, n_dims, requires_grad=True)) 89 | else: 90 | self.register_buffer('class_means', th.zeros(n_classes, n_dims,)) 91 | self.register_buffer('class_log_stds', th.zeros(n_classes, n_dims, )) 92 | 93 | self.truncate_to = truncate_to 94 | 95 | def forward(self, x, fixed=None): 96 | fixed = fixed or {} 97 | logdet = self.log_probs_per_class(x, sum_dims=fixed.get('sum_dims', True)) 98 | if 'y' in fixed and fixed['y'] is not None: 99 | y = fixed['y'] 100 | if y.ndim > 1: 101 | # assume one hot encoding 102 | y = y.argmax(dim=1, keepdim=True) 103 | else: 104 | y = y.unsqueeze(1) 105 | logdet = logdet.gather( 106 | dim=1, index=y).squeeze(1) 107 | return x, logdet 108 | 109 | def invert(self, y, fixed=None): 110 | if y is None: 111 | assert 'n_samples' in fixed 112 | if hasattr(fixed, '__getitem__') and 'y' in fixed: 113 | i_class = fixed['y'] 114 | assert isinstance(i_class, int) 115 | y = self.get_samples(i_class, fixed['n_samples'], std_factor=1) 116 | logdet = self.log_prob_class(i_class, y) 117 | else: 118 | y = self.get_unlabeled_samples(fixed['n_samples'], 119 | std_factor=1) 120 | logdet = self.log_probs_per_class(y) 121 | 122 | else: 123 | if hasattr(fixed, '__getitem__') and 'y' in fixed: 124 | assert fixed['y'] is None, "not implemented" 125 | logdet = self.log_probs_per_class(y) 126 | return y, logdet 127 | 128 | def get_mean_std(self, i_class): 129 | cur_mean, cur_log_std = self.get_mean_log_std(i_class) 130 | return cur_mean, th.exp(cur_log_std) 131 | 132 | def get_mean_log_std(self, i_class): 133 | cur_mean = self.class_means[i_class] 134 | cur_log_std = self.class_log_stds[i_class] 135 | return cur_mean, cur_log_std 136 | 137 | def get_samples(self, i_class, n_samples, std_factor=1): 138 | cur_mean, cur_std = self.get_mean_std(i_class) 139 | samples = get_gauss_samples( 140 | n_samples, cur_mean, cur_std * std_factor, 141 | truncate_to=self.truncate_to 142 | ) 143 | return samples 144 | 145 | def get_unlabeled_samples(self, n_samples, std_factor=1): 146 | choices = np.random.choice(range(len(self.class_means)), 147 | size=n_samples,) 148 | bincounts = np.bincount(choices) 149 | all_samples = th.cat([self.get_samples( 150 | i_mixture, bincounts[i_mixture], std_factor=std_factor) 151 | for i_mixture in np.flatnonzero(bincounts)], dim=0) 152 | return all_samples 153 | 154 | def change_to_other_class(self, outs, i_class_from, i_class_to, eps=1e-6): 155 | mean_from, std_from = self.get_mean_std(i_class_from) 156 | mean_to, std_to = self.get_mean_std(i_class_to) 157 | normed = (outs - mean_from.unsqueeze(0)) / (std_from.unsqueeze(0) + eps) 158 | transformed = (normed * std_to.unsqueeze(0)) + mean_to.unsqueeze(0) 159 | return transformed 160 | 161 | def log_prob_class(self, i_class, outs, clamp_max_sigma=None): 162 | mean, log_std = self.get_mean_log_std(i_class) 163 | log_probs = get_gaussian_log_probs(mean, log_std, outs, 164 | clamp_max_sigma=clamp_max_sigma) 165 | return log_probs 166 | 167 | def log_probs_per_class(self, y, clamp_max_sigma=None, sum_dims=True): 168 | log_probs = get_mixture_gaussian_log_probs( 169 | self.class_means, self.class_log_stds, y, 170 | clamp_max_sigma=clamp_max_sigma, sum_dims=sum_dims) 171 | return log_probs 172 | 173 | def log_probs_per_weighted_class(self, y, clamp_max_sigma=None): 174 | n_classes = len(self.class_means) 175 | log_probs = get_mixture_gaussian_log_probs( 176 | self.class_means, self.class_log_stds, y, 177 | clamp_max_sigma=clamp_max_sigma) - np.log(n_classes) 178 | return log_probs 179 | 180 | def log_prob_unlabeled(self, outs, clamp_max_sigma=None): 181 | weighted_log_probs = self.log_probs_per_weighted_class( 182 | outs, clamp_max_sigma=clamp_max_sigma) 183 | return th.logsumexp(weighted_log_probs, dim=-1) 184 | 185 | def set_mean_std(self, i_class, mean, std): 186 | if mean is not None: 187 | self.class_means.data[i_class] = mean.data 188 | if std is not None: 189 | self.class_log_stds.data[i_class] = th.log(std).data 190 | 191 | def log_softmax(self, outs): 192 | log_probs = self.log_probs_per_weighted_class( 193 | outs, clamp_max_sigma=None) 194 | log_softmaxed = F.log_softmax(log_probs, dim=-1) 195 | return log_softmaxed 196 | -------------------------------------------------------------------------------- /invglow/invertible/expression.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class Expression(nn.Module): 4 | def __init__(self, expression_fn): 5 | super().__init__() 6 | self.expression_fn = expression_fn 7 | 8 | def forward(self, x): 9 | return self.expression_fn(x) -------------------------------------------------------------------------------- /invglow/invertible/gaussian.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | 4 | # For truncated logic see: 5 | # https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12 6 | # torch.fmod(torch.randn(size),2) 7 | def get_gauss_samples(n_samples, mean, std, truncate_to=None): 8 | if mean.is_cuda: 9 | orig_samples = th.cuda.FloatTensor(n_samples, len(mean)).normal_(0, 1) 10 | else: 11 | orig_samples = th.FloatTensor(n_samples, len(mean)).normal_(0, 1) 12 | if truncate_to is not None: 13 | orig_samples = th.fmod(orig_samples, truncate_to) 14 | orig_samples = th.autograd.Variable(orig_samples) 15 | samples = (orig_samples * std.unsqueeze(0)) + mean.unsqueeze(0) 16 | return samples 17 | 18 | 19 | def get_mixture_gaussian_log_probs(means, log_stds, outs, sum_dims:bool=True, 20 | clamp_max_sigma=None): 21 | """ 22 | Returns #examples x #mixture components 23 | """ 24 | demeaned = outs.unsqueeze(1) - means.unsqueeze(0) 25 | 26 | if clamp_max_sigma is not None: 27 | # unsqueeze over batch dim 28 | clamp_vals = (th.exp(log_stds.unsqueeze(0)) * clamp_max_sigma) 29 | # with straight through gradient estimation 30 | clamped = th.max(th.min(demeaned, clamp_vals), -clamp_vals).detach() + ( 31 | demeaned - demeaned.detach()) 32 | else: 33 | clamped = demeaned 34 | 35 | unnormed_log_probs = -(clamped ** 2) / (2 * (th.exp(log_stds.unsqueeze(0)) ** 2)) 36 | log_probs = unnormed_log_probs - np.log(np.sqrt(2 * np.pi)) - log_stds.unsqueeze(0) 37 | if sum_dims: 38 | log_probs = th.sum(log_probs, dim=2) 39 | return log_probs 40 | 41 | 42 | def get_gaussian_log_probs(mean, log_std, outs, sum_dims:bool=True, clamp_max_sigma=None): 43 | demeaned = outs - mean.unsqueeze(0) 44 | if clamp_max_sigma is not None: 45 | # unsqueeze over batch dim 46 | clamp_vals = (th.exp(log_std) * clamp_max_sigma).unsqueeze(0) 47 | # with straight through gradient estimation 48 | clamped = th.max(th.min(demeaned, clamp_vals), -clamp_vals).detach() + ( 49 | demeaned - demeaned.detach()) 50 | else: 51 | clamped = demeaned 52 | 53 | unnormed_log_probs = -(clamped ** 2) / (2 * (th.exp(log_std) ** 2)) 54 | log_probs = unnormed_log_probs - np.log(np.sqrt(2 * np.pi)) - log_std 55 | if sum_dims: 56 | log_probs = th.sum(log_probs, dim=1) 57 | return log_probs 58 | 59 | -------------------------------------------------------------------------------- /invglow/invertible/graph.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | import numpy as np 4 | import logging 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | 9 | class AbstractNode(nn.Module): 10 | def __init__(self, prev, module, notify_prev_nodes=True, **tags): 11 | super().__init__() 12 | # Always make into List 13 | self.change_prev(prev, notify_prev_nodes=notify_prev_nodes) 14 | self.module = module 15 | self.cur_out = None 16 | self.cur_out_log_det = None 17 | self.cur_in = None 18 | self.cur_in_log_det = None 19 | self.tags = tags 20 | 21 | def change_prev(self, prev, notify_prev_nodes): 22 | if prev is not None: 23 | if not hasattr(prev, "__len__"): 24 | prev = [prev] 25 | prev = nn.ModuleList(prev) 26 | self.prev = prev 27 | self.next = [] 28 | if self.prev is not None and notify_prev_nodes: 29 | for p in self.prev: 30 | p.register_next(self) 31 | 32 | def register_next(self, next_module): 33 | self.next.append(next_module) 34 | 35 | def forward(self, x, fixed=None): 36 | assert self.cur_out is None, "Please remove cur out before forward" 37 | out, log_det = self._forward(x, fixed=fixed) 38 | self.remove_cur_in() 39 | self.remove_cur_out() 40 | return out, log_det 41 | 42 | def _forward(self, x, fixed=None): 43 | if self.cur_out is None: 44 | # Collect incoming results 45 | if self.prev is not None: 46 | xs, prev_log_dets = list(zip(*[ 47 | p._forward(x, fixed=fixed) 48 | for p in self.prev])) 49 | # if has condition node, than here make forward of x already 50 | if hasattr(self, 'condition_nodes'): 51 | for c in self.condition_nodes: 52 | _ = c._forward(x, fixed=fixed) # (just ignore output) 53 | else: 54 | xs = [x] 55 | prev_log_dets = [0] 56 | 57 | y, logdet = self._forward_myself(prev_log_dets, *xs, 58 | fixed=fixed) 59 | self.cur_out = y 60 | self.cur_out_log_det = logdet 61 | return self.cur_out, self.cur_out_log_det 62 | 63 | def invert(self, x, fixed=None): 64 | # determine starting module 65 | # ps are predecessors 66 | starting_m = self.find_starting_node() 67 | inverted = starting_m._invert(x, fixed=fixed) 68 | self.remove_cur_in() 69 | self.remove_cur_out() 70 | return inverted 71 | 72 | def find_starting_node(self): 73 | cur_ps = [self] 74 | starting_m = None 75 | while starting_m is None: 76 | new_cur_ps = [] 77 | for p in cur_ps: 78 | if p.prev is None: 79 | starting_m = p 80 | break 81 | else: 82 | new_cur_ps.extend(p.prev) 83 | cur_ps = new_cur_ps 84 | # log.debug("Starting Node" + str(starting_m)) 85 | return starting_m 86 | 87 | def _invert(self, y, fixed=None): 88 | if self.cur_in is None: 89 | # Collect incoming results 90 | 91 | if len(self.next) > 0: 92 | ys = [] 93 | log_dets = [] 94 | for n in self.next: 95 | this_y, this_log_det = n._invert(y, fixed=fixed) 96 | # Only take those ys belonging to you 97 | if len(this_y) > 1 and len(n.prev) > 1: 98 | assert len(this_y) == len(n.prev) 99 | filtered_y = [] 100 | for p, a_y in zip(n.prev, this_y): 101 | if p == self: 102 | filtered_y.append(a_y) 103 | this_y = filtered_y 104 | if len(this_y) == 1: 105 | this_y = this_y[0] 106 | ys.append(this_y) 107 | log_dets.append(this_log_det) 108 | 109 | # If has condition node, than here make invert of y already 110 | 111 | if hasattr(self, 'condition_nodes'): 112 | for c in self.condition_nodes: 113 | _ = c._invert(y, fixed=fixed) # (just ignore output) 114 | 115 | # Try to automatically correct ordering in case 116 | # next nodes are select nodes 117 | next_class_names = [n.__class__.__name__ for n in self.next] 118 | if all([n == 'SelectNode' for n in next_class_names]): 119 | indices = [n.index for n in self.next] 120 | assert np.array_equal(sorted(indices), range(len(ys))) 121 | ys = [ys[indices.index(i)] for i in range(len(ys))] 122 | 123 | if (len(ys) == 1) and (self.__class__.__name__ != 'SelectNode' or ( 124 | not ('no_squeeze' in self.tags and self.tags[ 125 | 'no_squeeze'] == True) 126 | )): 127 | ys = ys[0] 128 | 129 | next_log_dets = log_dets 130 | else: 131 | ys = y 132 | next_log_dets = [0] 133 | # log.debug("Now inverting " + str(self.__class__.__name__)) 134 | # if 'name' in self.tags: 135 | # log.debug("name: " + self.tags['name']) 136 | # if self.module is not None: 137 | # log.debug("module: " + str(self.module.__class__.__name__)) 138 | # log.debug("len(self.next) " + str(len(self.next))) 139 | # log.debug("len(ys) " + str(len(ys))) 140 | 141 | x, log_det = self._invert_myself(next_log_dets, ys, fixed=fixed) 142 | # Now we save cur out for conditional 143 | # WARNING: THIS ONLY WORKS IF THE CONDITIONAL NODE ITSELF 144 | # IS PART OF THE COMPUTATION GRAPH OF THE RESULT _WITHOUT_ 145 | # THE CONDITIONAL NODE PART, SO IT MUST BE USED SOMEWHERE 146 | # ELSE 147 | # OTHERWISE THIS CODE NEEDS TO BE ADAPTED SMARTLY 148 | self.cur_out = ys # possibly necessary for conditional 149 | self.cur_in = x 150 | self.cur_in_log_det = log_det 151 | return self.cur_in, self.cur_in_log_det 152 | 153 | 154 | def remove_cur_out(self,): 155 | if self.prev is not None: 156 | for p in self.prev: 157 | p.remove_cur_out() 158 | if hasattr(self, 'condition_nodes'): 159 | for c in self.condition_nodes: 160 | c.remove_cur_out() 161 | # not sure if necessary 162 | c.remove_cur_in() 163 | self.cur_out = None 164 | self.cur_out_log_det = None 165 | 166 | def remove_cur_in(self,): 167 | if self.prev is not None: 168 | for p in self.prev: 169 | p.remove_cur_in() 170 | if hasattr(self, 'condition_nodes'): 171 | for c in self.condition_nodes: 172 | c.remove_cur_out() 173 | # not sure if necessary 174 | c.remove_cur_in() 175 | self.cur_in = None 176 | self.cur_in_log_det = None 177 | 178 | def remove_cur_in_out(self): 179 | self.remove_cur_in() 180 | self.remove_cur_out() 181 | 182 | 183 | class Node(AbstractNode): 184 | def _forward_myself(self, prev_log_dets, *xs, fixed=None): 185 | y, logdet = self.module(*xs, fixed=fixed) 186 | prev_sum = sum(prev_log_dets) 187 | if hasattr(logdet, 'shape') and hasattr(prev_sum, 'shape'): 188 | if len(logdet.shape) > 1 and len(prev_sum.shape) > 1: 189 | if logdet.shape[1] == 1 and prev_sum.shape[1] > 1: 190 | logdet = logdet.squeeze(1).unsqueeze(1) 191 | if logdet.shape[1] > 1 and prev_sum.shape[1] == 1: 192 | prev_sum = prev_sum.squeeze(1).unsqueeze(1) 193 | if len(prev_sum.shape) == 1 and len(logdet.shape) == 2: 194 | prev_sum = prev_sum.unsqueeze(1) 195 | if len(prev_sum.shape) == 2 and len(logdet.shape) == 1: 196 | logdet = logdet.unsqueeze(1) 197 | 198 | new_log_det = prev_sum + logdet 199 | return y, new_log_det 200 | 201 | def _invert_myself(self, next_log_dets, ys, fixed=None): 202 | # hacky fix 203 | for i_y in range(len(ys)): 204 | if isinstance(ys[i_y], tuple): 205 | ys[i_y] = ys[i_y][0] 206 | 207 | x, log_det = self.module.invert(ys, fixed=fixed) 208 | return x, sum(next_log_dets) + log_det 209 | 210 | 211 | class SelectNode(AbstractNode): 212 | def __init__(self, prev, index, **tags): 213 | super().__init__(prev, None, notify_prev_nodes=True, **tags) 214 | self.index = index 215 | 216 | def _forward_myself(self, prev_log_dets, *xs, fixed=None): 217 | # don't understand reason for next two lines 218 | assert len(xs) == 1 219 | xs = xs[0] 220 | n_parts = len(xs) 221 | assert n_parts > self.index 222 | return xs[self.index], sum(prev_log_dets) / n_parts 223 | 224 | def _invert_myself(self, next_log_dets, ys, fixed=None): 225 | return ys, sum(next_log_dets) 226 | 227 | 228 | class CatChansNode(AbstractNode): 229 | def __init__(self, prev, notify_prev_nodes=True, **tags): 230 | self.n_chans = None 231 | super(CatChansNode, self).__init__(prev, None, 232 | notify_prev_nodes=notify_prev_nodes, 233 | **tags) 234 | 235 | def _forward_myself(self, prev_log_dets, *xs, fixed=None): 236 | n_chans = tuple([a_x.size()[1] for a_x in xs]) 237 | if self.n_chans is None: 238 | self.n_chans = n_chans 239 | else: 240 | assert n_chans == self.n_chans 241 | return th.cat(xs, dim=1), sum(prev_log_dets) 242 | 243 | def _invert_myself(self, next_log_dets, ys, fixed=None): 244 | if self.n_chans is None: 245 | n_parts = len(self.prev) 246 | xs = th.chunk(ys, chunks=n_parts, dim=1, ) 247 | self.n_chans = tuple([a_x.size()[1] for a_x in xs]) 248 | else: 249 | xs = [] 250 | bounds = np.insert(np.cumsum(self.n_chans), 0, 0) 251 | for i_b in range(len(bounds) - 1): 252 | xs.append(ys[:, bounds[i_b]:bounds[i_b + 1]]) 253 | return xs, sum(next_log_dets) / len(xs) 254 | 255 | 256 | class ConditionalNode(AbstractNode): 257 | def __init__(self, prev, module, condition_nodes, **tags): 258 | super().__init__(prev, module, notify_prev_nodes=True, 259 | **tags) 260 | assert any([hasattr(m, 'accepts_condition') and m.accepts_condition 261 | for m in module.modules()]) 262 | if not hasattr(condition_nodes, '__len__'): 263 | condition_nodes = [condition_nodes] 264 | self.condition_nodes = condition_nodes 265 | 266 | def get_condition(self): 267 | for c in self.condition_nodes: 268 | assert c.cur_out is not None 269 | condition = [c.cur_out for c in self.condition_nodes] 270 | if len(condition) == 1: 271 | condition = condition[0] 272 | return condition 273 | 274 | def _forward_myself(self, prev_log_dets, *xs, fixed=None): 275 | condition = self.get_condition() 276 | y, logdet = self.module( 277 | *xs, condition=condition, fixed=fixed) 278 | return y, sum(prev_log_dets) + logdet 279 | 280 | def _invert_myself(self, next_log_dets, ys, fixed=None): 281 | condition = self.get_condition() 282 | x, log_det = self.module.invert( 283 | ys, condition=condition, fixed=fixed) 284 | return x, sum(next_log_dets) + log_det 285 | 286 | 287 | class IntermediateResultsNode(AbstractNode): 288 | def __init__(self, prev, **tags): 289 | self.n_chans = None 290 | super().__init__(prev, None, notify_prev_nodes=False, **tags) 291 | 292 | def _forward_myself(self, prev_log_dets, *xs, fixed=None): 293 | return xs, prev_log_dets 294 | 295 | def _invert_myself(self, next_log_dets, ys, fixed=None): 296 | return ys, next_log_dets 297 | 298 | 299 | class CatAsListNode(AbstractNode): 300 | def __init__(self, prev, notify_prev_nodes=True, **tags): 301 | super().__init__(prev, None, notify_prev_nodes=notify_prev_nodes, **tags) 302 | 303 | def _forward_myself(self, prev_log_dets, *xs, fixed=None): 304 | max_len_shape = max([len(l.shape) for l in prev_log_dets]) 305 | new_prev_log_dets = [] 306 | for p in prev_log_dets: 307 | if hasattr(p, 'shape') and len(p.shape) < max_len_shape: 308 | p = p.unsqueeze(1) 309 | new_prev_log_dets.append(p) 310 | return xs, sum(new_prev_log_dets) 311 | 312 | def _invert_myself(self, next_log_dets, ys, fixed=None): 313 | # log.debug("Inverting cat as list node NOW,,,,,") 314 | # log.debug("in cat as list len(ys)" + str(len(ys))) 315 | return ys, sum(next_log_dets) / len(ys) 316 | 317 | 318 | class MergeLogDetsNode(AbstractNode): 319 | def __init__(self, prev, notify_prev_nodes=True, **tags): 320 | super().__init__(prev, None, notify_prev_nodes=notify_prev_nodes, 321 | **tags) 322 | 323 | def _forward_myself(self, prev_log_dets, *xs, fixed=None): 324 | if fixed['y'] is None: 325 | if not hasattr(prev_log_dets, 'shape'): 326 | prev_log_dets = prev_log_dets[0] 327 | n_components = prev_log_dets.shape[1] 328 | logdets = th.logsumexp(prev_log_dets, dim=1) - np.log(n_components) 329 | 330 | 331 | return xs, logdets 332 | 333 | def _invert_myself(self, next_log_dets, ys, fixed=None): 334 | raise NotImplementedError("Check if you can just pass it through") 335 | return ys, next_log_dets 336 | 337 | 338 | def get_all_nodes(final_node): 339 | cur_ps = [final_node] 340 | all_nodes = [] 341 | while len(cur_ps) > 0: 342 | new_cur_ps = [] 343 | for p in cur_ps: 344 | if p.prev is not None: 345 | new_cur_ps.extend(p.prev) 346 | if p not in all_nodes: 347 | all_nodes.append(p) 348 | cur_ps = new_cur_ps 349 | return all_nodes[::-1] 350 | 351 | 352 | def get_nodes_by_tags(full_model, **tags): 353 | nodes = [] 354 | for n in get_all_nodes(full_model): 355 | put_inside = True 356 | for tag in tags: 357 | if (tag not in n.tags) or n.tags[tag] != tags[tag]: 358 | put_inside = False 359 | if put_inside: 360 | nodes.append(n) 361 | return nodes 362 | 363 | 364 | def get_nodes_by_names(full_model, *names): 365 | name_to_node = dict() 366 | for n in get_all_nodes(full_model): 367 | if 'name' in n.tags and n.tags['name'] in names: 368 | name_to_node[n.tags['name']] = n 369 | nodes = [name_to_node[name] for name in names] 370 | return nodes 371 | -------------------------------------------------------------------------------- /invglow/invertible/identity.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Identity(nn.Module): 5 | def forward(self, x, fixed=None): 6 | return x, 0 7 | 8 | def invert(self, y, fixed=None): 9 | return y, 0 10 | -------------------------------------------------------------------------------- /invglow/invertible/init.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | def init_all_modules(net, trainloader, n_batches=10, use_y=False, 4 | verbose_init=True): 5 | with th.no_grad(): 6 | if trainloader is not None: 7 | if hasattr(trainloader, 'shape') or ( 8 | hasattr(trainloader, '__getitem__') and 9 | hasattr(trainloader[0], 'shape')): 10 | # then it is/ they are tensors! 11 | if use_y: 12 | init_x = trainloader[0] 13 | init_y = trainloader[1] 14 | else: 15 | init_x = trainloader 16 | assert not use_y 17 | init_y = None 18 | else: 19 | all_x = [] 20 | all_y = [] 21 | for i_batch, batch, in enumerate(trainloader): 22 | x,y = batch[:2] 23 | all_x.append(x) 24 | all_y.append(y) 25 | if i_batch >= n_batches: 26 | break 27 | 28 | init_x = th.cat(all_x, dim=0) 29 | init_x = init_x.cuda() 30 | if use_y: 31 | init_y = th.cat(all_y) 32 | else: 33 | init_y = None 34 | 35 | for m in net.modules(): 36 | if hasattr(m, 'initialize_this_forward'): 37 | m.initialize_this_forward = True 38 | m.verbose_init = verbose_init 39 | 40 | _ = net(init_x,fixed=dict(y=init_y)) 41 | else: 42 | for m in net.modules(): 43 | if hasattr(m, 'initialize_this_forward'): 44 | m.initialized = True 45 | 46 | 47 | def prepare_init(model, verbose_init=True): 48 | for m in model.modules(): 49 | if hasattr(m, 'initialize_this_forward'): 50 | m.initialize_this_forward = True 51 | m.verbose_init = verbose_init 52 | -------------------------------------------------------------------------------- /invglow/invertible/inv_permute.py: -------------------------------------------------------------------------------- 1 | # Partly from https://github.com/rosinality/glow-pytorch/ 2 | 3 | import numpy as np 4 | 5 | from scipy import linalg as la 6 | import torch.nn.functional as F 7 | import torch as th 8 | from torch import nn 9 | 10 | 11 | class InvPermute(nn.Module): 12 | def __init__(self, in_channel, fixed, use_lu): 13 | super().__init__() 14 | self.use_lu = use_lu 15 | self.fixed = fixed 16 | if not use_lu: 17 | weight = th.randn(in_channel, in_channel) 18 | q, _ = th.qr(weight) 19 | weight = q 20 | if fixed: 21 | self.register_buffer('weight', weight.data) 22 | self.register_buffer('weight_inverse', weight.data.inverse()) 23 | self.register_buffer('fixed_log_det', 24 | th.slogdet(self.weight.double())[1].float()) 25 | else: 26 | self.weight = nn.Parameter(weight) 27 | if use_lu: 28 | assert not fixed 29 | #weight = np.random.randn(in_channel, in_channel) 30 | weight = th.randn(in_channel, in_channel) 31 | #q, _ = la.qr(weight) 32 | q, _ = th.qr(weight) 33 | 34 | # w_p, w_l, w_u = la.lu(q.astype(np.float32)) 35 | w_p, w_l, w_u = th.lu_unpack(*th.lu(q)) 36 | 37 | #w_s = np.diag(w_u) 38 | w_s = th.diag(w_u) 39 | #w_u = np.triu(w_u, 1) 40 | w_u = th.triu(w_u, 1) 41 | #u_mask = np.triu(np.ones_like(w_u), 1) 42 | u_mask = th.triu(th.ones_like(w_u), 1) 43 | #l_mask = u_mask.T 44 | l_mask = u_mask.t() 45 | 46 | #w_p = th.from_numpy(w_p) 47 | #w_l = th.from_numpy(w_l) 48 | #w_s = th.from_numpy(w_s) 49 | #w_u = th.from_numpy(w_u) 50 | 51 | self.register_buffer('w_p', w_p) 52 | self.register_buffer('u_mask', u_mask) 53 | self.register_buffer('l_mask', l_mask) 54 | self.register_buffer('s_sign', th.sign(w_s)) 55 | self.register_buffer('l_eye', th.eye(l_mask.shape[0])) 56 | self.w_l = nn.Parameter(w_l) 57 | self.w_s = nn.Parameter(th.log(th.abs(w_s))) 58 | self.w_u = nn.Parameter(w_u) 59 | 60 | def reset_to_identity(self): 61 | def eye_like(w): 62 | return th.eye( 63 | len(w), device=w.device, 64 | dtype=w.dtype) 65 | if self.use_lu: 66 | self.w_p.data.copy_(eye_like(self.w_p)) 67 | self.s_sign.data.copy_(th.ones_like((self.s_sign))) 68 | self.w_l.data.copy_(eye_like(self.w_l)) 69 | self.w_s.data.copy_(th.ones_like((self.w_s))) 70 | self.w_u.data.zero_() 71 | 72 | else: 73 | self.weight.data.copy_(eye_like(self.weight)) 74 | if self.fixed: 75 | self.weight_inverse.data.copy_(eye_like(self.weight)) 76 | self.fixed_log_det.copy_(th.zeros_like(self.weight[0,0])) 77 | 78 | def forward(self, x, fixed=None): 79 | weight = self.calc_weight() 80 | if len(x.shape) == 2: 81 | y = F.linear(x, weight) 82 | else: 83 | assert len(x.shape) == 4 84 | y = F.conv2d(x, weight.unsqueeze(2).unsqueeze(3)) 85 | 86 | logdet = self.compute_log_det(x.shape) 87 | return y, logdet 88 | 89 | def calc_weight(self): 90 | if self.use_lu: 91 | weight = ( 92 | self.w_p 93 | @ (self.w_l * self.l_mask + self.l_eye) 94 | @ ((self.w_u * self.u_mask) + th.diag( 95 | self.s_sign * th.exp(self.w_s))) 96 | ) 97 | else: 98 | weight = self.weight 99 | return weight 100 | 101 | def invert(self, y, fixed=None): 102 | if self.fixed: 103 | weight_inverse = self.weight_inverse 104 | else: 105 | weight = self.calc_weight() 106 | weight_inverse = weight.inverse() 107 | if len(y.shape) == 2: 108 | x = F.linear(y, weight_inverse) 109 | else: 110 | assert len(y.shape) == 4 111 | x = F.conv2d(y, weight_inverse.unsqueeze(2).unsqueeze(3)) 112 | logdet = self.compute_log_det(x.shape) 113 | return x, logdet 114 | 115 | def compute_log_det(self, x_shape): 116 | logdet = self.compute_log_det_per_px() 117 | if len(x_shape) == 4: 118 | _, _, height, width = x_shape 119 | logdet = logdet * height * width 120 | else: 121 | assert len(x_shape) == 2 122 | return logdet 123 | 124 | def compute_log_det_per_px(self): 125 | if self.fixed: 126 | logdet = self.fixed_log_det 127 | else: 128 | if self.use_lu: 129 | logdet = th.sum(self.w_s) 130 | else: 131 | logdet = th.slogdet(self.weight.double())[1].float() 132 | return logdet 133 | 134 | class Shuffle(nn.Module): 135 | def __init__(self, in_channel): 136 | super().__init__() 137 | indices = th.randperm(in_channel) 138 | invert_inds = th.sort(indices)[1] 139 | self.register_buffer('indices', indices) 140 | self.register_buffer('invert_inds', invert_inds) 141 | 142 | def forward(self, x, fixed=None): 143 | assert x.shape[1] == len(self.indices) 144 | y = x[:, self.indices] 145 | return y,0 146 | 147 | def invert(self, y, fixed=None): 148 | assert y.shape[1] == len(self.indices) 149 | x = y[:, self.invert_inds] 150 | return x,0 -------------------------------------------------------------------------------- /invglow/invertible/inverse.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Inverse(nn.Module): 5 | def __init__(self, module): 6 | super(Inverse, self).__init__() 7 | self.module = module 8 | 9 | def forward(self, *args, **kwargs): 10 | return self.module.invert(*args, **kwargs) 11 | 12 | def invert(self, *args, **kwargs): 13 | return self.module.forward(*args, **kwargs) 14 | -------------------------------------------------------------------------------- /invglow/invertible/noise.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | import math 4 | 5 | 6 | class UniNoise(nn.Module): 7 | def __init__(self, noise_level=1 / 255.0, center=False): 8 | super(UniNoise, self).__init__() 9 | self.noise_level = noise_level 10 | self.center = center 11 | 12 | def forward(self, x, fixed=None): 13 | noise = th.rand_like(x) 14 | if self.center: 15 | noise = noise - 0.5 16 | noise = noise * self.noise_level 17 | 18 | return x + noise, 0 19 | 20 | def invert(self, y): 21 | # can't undo 22 | return y, 0 23 | 24 | 25 | class UniformBins(nn.Module): 26 | def __init__(self, n_bins): 27 | super().__init__() 28 | self.n_bins = n_bins 29 | 30 | def forward(self, x): 31 | x = x + th.zeros_like(x).uniform_(0, 1.0 / self.n_bins) 32 | log_det = self.compute_log_det(x) 33 | return x, log_det 34 | 35 | def invert(self, y): 36 | # can't undo 37 | log_det = self.compute_log_det(y) 38 | return y, log_det 39 | 40 | def compute_log_det(self, x): 41 | b, c, h, w = x.size() 42 | chw = c * h * w 43 | log_det = -math.log(self.n_bins) * chw * th.ones(b, device=x.device) 44 | return log_det 45 | 46 | 47 | class GaussianNoise(nn.Module): 48 | def __init__(self, noise_factor=None, means=None, stds=None,): 49 | super().__init__() 50 | assert (noise_factor is None) != ( 51 | ((means is None) or (stds is None)) 52 | ) 53 | assert (means is None) == (stds is None) 54 | if means is not None: 55 | self.register_buffer('means', means) 56 | self.register_buffer('stds', stds) 57 | self.noise_factor = None 58 | else: 59 | assert noise_factor is not None 60 | self.noise_factor = noise_factor 61 | 62 | def forward(self, x): 63 | if self.noise_factor is not None: 64 | return x + (th.randn_like(x) * self.noise_factor) 65 | else: 66 | return x + (th.randn_like(x) * self.stds.unsqueeze( 67 | 0)) + self.means.unsqueeze(0) 68 | 69 | 70 | class GaussianNoiseGates(nn.Module): 71 | def __init__(self, n_dims): 72 | super().__init__() 73 | self.gates = nn.Parameter(th.ones(n_dims).fill_(2)) 74 | 75 | def forward(self, x): 76 | alphas = th.sigmoid(self.gates) 77 | rands = th.randn_like(x) 78 | expanded_alphas = alphas.unsqueeze(0) 79 | while len(expanded_alphas.shape) < len(x.shape): 80 | expanded_alphas = expanded_alphas.unsqueeze(-1) 81 | y = expanded_alphas * x + (1 - expanded_alphas) * rands 82 | return y, 0 83 | 84 | def invert(self, y): 85 | return y, 0 # cannot undo noise 86 | -------------------------------------------------------------------------------- /invglow/invertible/pure_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class ModelThrowAwayLogDet(nn.Module): 5 | def __init__(self, model): 6 | super(ModelThrowAwayLogDet, self).__init__() 7 | self.model = model 8 | 9 | def forward(self, x): 10 | x, logdet = self.model(x) 11 | return x 12 | 13 | def invert(self, y): 14 | x, logdet = self.model.invert(y) 15 | return x 16 | 17 | #Alias 18 | NoLogDet = ModelThrowAwayLogDet -------------------------------------------------------------------------------- /invglow/invertible/sequential.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from inspect import signature 3 | 4 | 5 | def get_arg_names(func, excludes=('kwargs',)): 6 | """ 7 | from https://github.com/Lasagne/Lasagne/blob/a61b76fd991f84c50acdb7bea02118899b5fefe1/lasagne/utils.py 8 | 9 | Inspects a callable and returns a list of all optional keyword arguments. 10 | Parameters 11 | ---------- 12 | func : callable 13 | The callable to inspect 14 | Returns 15 | ------- 16 | args : list of str 17 | Names of all arguments of `func` 18 | """ 19 | params = signature(func).parameters 20 | return [p.name for p in params.values() if p.name not in excludes] 21 | 22 | 23 | class InvertibleSequential(nn.Module): 24 | def __init__(self, *modules): 25 | super().__init__() 26 | self.sequential = nn.Sequential(*modules) 27 | # just always true, in case any submodule, 28 | # including possibly later added ones accepts the condition 29 | self.accepts_condition = True 30 | 31 | def forward(self, x, condition=None, fixed=None): 32 | sum_logdet = 0 33 | for child in self.sequential.children(): 34 | #needed_kwargs = get_arg_names(child) 35 | #needed_kwargs = needed_kwargs[1:] 36 | #print("sequential py", child.__class__.__name__) 37 | 38 | if condition is not None and hasattr( 39 | child, 'accepts_condition') and child.accepts_condition: 40 | x, logdet = child(x, condition, fixed=fixed) 41 | else: 42 | x, logdet = child(x, fixed=fixed) 43 | 44 | if hasattr(logdet, 'shape') and hasattr(sum_logdet, 'shape'): 45 | if len(logdet.shape) > 1 and len(sum_logdet.shape) > 1: 46 | if logdet.shape[1] == 1 and sum_logdet.shape[1] > 1: 47 | logdet = logdet.squeeze(1).unsqueeze(1) 48 | if logdet.shape[1] > 1 and sum_logdet.shape[1] == 1: 49 | sum_logdet = sum_logdet.squeeze(1).unsqueeze(1) 50 | if len(sum_logdet.shape) == 1 and len(logdet.shape) == 2: 51 | sum_logdet = sum_logdet.unsqueeze(1) 52 | if len(sum_logdet.shape) == 2 and len(logdet.shape) == 1: 53 | logdet = logdet.unsqueeze(1) 54 | sum_logdet = logdet + sum_logdet 55 | return x, sum_logdet 56 | 57 | def invert(self, y, condition=None, fixed=None): 58 | sum_logdet = 0 59 | for child in reversed(list(self.sequential.children())): 60 | assert hasattr(child, 'invert'), ( 61 | "Class {:s} has no method invert".format( 62 | child.__class__.__name__)) 63 | if condition is not None and hasattr( 64 | child, 'accepts_condition') and child.accepts_condition: 65 | y, logdet = child.invert(y, condition, 66 | fixed=fixed) 67 | else: 68 | y, logdet = child.invert(y, 69 | fixed=fixed) 70 | if hasattr(logdet, "shape") and hasattr(sum_logdet, "shape"): 71 | if logdet.ndim < sum_logdet.ndim: 72 | logdet = logdet.unsqueeze(-1) 73 | if logdet.ndim > sum_logdet.ndim: 74 | sum_logdet = sum_logdet.unsqueeze(0) 75 | sum_logdet = logdet + sum_logdet 76 | return y, sum_logdet 77 | -------------------------------------------------------------------------------- /invglow/invertible/split_merge.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | 4 | class ChunkChansIn2(object): 5 | def __init__(self, swap_dims): 6 | self.swap_dims = swap_dims 7 | 8 | def split(self, x): 9 | n_chans = x.size()[1] 10 | assert n_chans % 2 == 0 11 | if self.swap_dims: 12 | x1 = x[:, n_chans // 2:] 13 | x2 = x[:, :n_chans // 2] 14 | else: 15 | x1 = x[:, :n_chans // 2] 16 | x2 = x[:, n_chans // 2:] 17 | return x1, x2 18 | 19 | def merge(self, y1, x2): 20 | if self.swap_dims: 21 | y = th.cat((x2, y1), dim=1) 22 | else: 23 | y = th.cat((y1, x2), dim=1) 24 | return y 25 | 26 | 27 | class ChansFraction(object): 28 | def __init__(self, swap_dims, n_unchanged=None, fraction_unchanged=None): 29 | assert (n_unchanged is None) != (fraction_unchanged is None), ( 30 | "Supply one of n_unchanged or fraction_unchanged") 31 | self.n_unchanged = n_unchanged 32 | self.fraction_unchanged = fraction_unchanged 33 | self.swap_dims = swap_dims 34 | 35 | def split(self, x): 36 | n_chans = x.size()[1] 37 | if self.n_unchanged is not None: 38 | n_unchanged = self.n_unchanged 39 | else: 40 | n_unchanged = int(np.round(n_chans * self.fraction_unchanged)) 41 | assert n_unchanged > 0 42 | assert n_unchanged < n_chans 43 | if self.swap_dims: 44 | x1 = x[:, n_unchanged:] 45 | x2 = x[:, :n_unchanged] 46 | 47 | else: 48 | x1 = x[:, :-n_unchanged] 49 | x2 = x[:, -n_unchanged:] 50 | return x1, x2 51 | 52 | def merge(self, y1, x2): 53 | if self.swap_dims: 54 | y = th.cat((x2, y1), dim=1) 55 | else: 56 | y = th.cat((y1, x2), dim=1) 57 | return y 58 | 59 | 60 | class EverySecondChan(object): 61 | def split(self, x): 62 | x1 = x[:,0::2] 63 | x2 = x[:,1::2] 64 | return x1, x2 65 | 66 | def merge(self, y1, x2): 67 | # see also https://discuss.pytorch.org/t/how-to-interleave-two-tensors-along-certain-dimension/11332/4 68 | full_shape = (y1.shape[0], y1.shape[1] + x2.shape[1]) + y1.shape[2:] 69 | y = th.stack((y1,x2), dim=2).view(full_shape) 70 | return y -------------------------------------------------------------------------------- /invglow/invertible/splitter.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | 4 | 5 | # from https://github.com/y0ast/Glow-PyTorch 6 | def squeeze2d(input, factor): 7 | if factor == 1: 8 | return input 9 | 10 | B, C, H, W = input.size() 11 | 12 | assert H % factor == 0 and W % factor == 0, "H or W modulo factor is not 0" 13 | 14 | x = input.view(B, C, H // factor, factor, W // factor, factor) 15 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 16 | x = x.view(B, C * factor * factor, H // factor, W // factor) 17 | 18 | return x 19 | 20 | # from https://github.com/y0ast/Glow-PyTorch 21 | def unsqueeze2d(input, factor): 22 | if factor == 1: 23 | return input 24 | 25 | factor2 = factor ** 2 26 | 27 | B, C, H, W = input.size() 28 | 29 | assert C % (factor2) == 0, "C module factor squared is not 0" 30 | 31 | x = input.view(B, C // factor2, factor, factor, H, W) 32 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 33 | x = x.view(B, C // (factor2), H * factor, W * factor) 34 | 35 | return x 36 | 37 | 38 | class SubsampleSplitter(th.nn.Module): 39 | def __init__(self, stride, chunk_chans_first=True, checkerboard=False, 40 | cat_at_end=True, via_reshape=False): 41 | super(SubsampleSplitter, self).__init__() 42 | if not hasattr(stride, '__len__'): 43 | stride = (stride, stride) 44 | self.stride = stride 45 | self.chunk_chans_first = chunk_chans_first 46 | self.checkerboard = checkerboard 47 | self.cat_at_end = cat_at_end 48 | self.via_reshape = via_reshape 49 | if checkerboard: 50 | assert stride[0] == 2 51 | assert stride[1] == 2 52 | if self.via_reshape: 53 | assert stride[0] == stride[1] 54 | 55 | def forward(self, x, fixed=None): 56 | # Chunk chans first to ensure that each of the two streams in the 57 | # reversible network will see a subsampled version of the whole input 58 | # (in case the preceding blocks would not alter the input) 59 | # and not one half of the input 60 | if self.via_reshape: 61 | y = squeeze2d(x, self.stride[0]) 62 | return y, 0 63 | else: 64 | new_x = [] 65 | if self.chunk_chans_first: 66 | xs = th.chunk(x, 2, dim=1) 67 | else: 68 | xs = [x] 69 | for one_x in xs: 70 | if not self.checkerboard: 71 | for i_stride in range(self.stride[0]): 72 | for j_stride in range(self.stride[1]): 73 | new_x.append( 74 | one_x[:, :, i_stride::self.stride[0], 75 | j_stride::self.stride[1]]) 76 | else: 77 | new_x.append(one_x[:,:,0::2,0::2]) 78 | new_x.append(one_x[:,:,1::2,1::2]) 79 | new_x.append(one_x[:,:,0::2,1::2]) 80 | new_x.append(one_x[:,:,1::2,0::2]) 81 | 82 | if self.cat_at_end: 83 | new_x = th.cat(new_x, dim=1) 84 | return new_x, 0 #logdet 85 | 86 | 87 | def invert(self, features, fixed=None): 88 | if self.via_reshape: 89 | x = unsqueeze2d(features, self.stride[0]) 90 | return x, 0 91 | else: 92 | # after splitting the input into two along channel dimension if possible 93 | # for i_stride in range(self.stride): 94 | # for j_stride in range(self.stride): 95 | # new_x.append(one_x[:,:,i_stride::self.stride, j_stride::self.stride]) 96 | if self.cat_at_end: 97 | n_all_chans_before = features.size()[1] // ( 98 | self.stride[0] * self.stride[1]) 99 | else: 100 | n_all_chans_before = sum([f.shape[1] for f in features]) // ( 101 | self.stride[0] * self.stride[1]) 102 | 103 | # if there was only one chan before, chunk had no effect 104 | if self.chunk_chans_first and (n_all_chans_before > 1): 105 | if self.cat_at_end: 106 | chan_features = th.chunk(features, 2, dim=1) 107 | else: 108 | chan_features = [features[: len(features) // 2], 109 | features[len(features) // 2:]] 110 | else: 111 | chan_features = [features] 112 | all_previous_features = [] 113 | for one_chan_features in chan_features: 114 | if self.cat_at_end: 115 | n_examples = one_chan_features.size()[0] 116 | n_chans = one_chan_features.size()[1] // ( 117 | self.stride[0] * self.stride[1]) 118 | n_0 = one_chan_features.size()[2] * self.stride[0] 119 | n_1 = one_chan_features.size()[3] * self.stride[1] 120 | else: 121 | n_examples = one_chan_features[0].size()[0] 122 | n_chans = sum([f.shape[1] for f in one_chan_features]) // ( 123 | self.stride[0] * self.stride[1]) 124 | n_0 = int( 125 | np.mean([f.size()[2] for f in one_chan_features]) * 126 | self.stride[0]) 127 | n_1 = int( 128 | np.mean([f.size()[3] for f in one_chan_features]) * 129 | self.stride[0]) 130 | 131 | previous_features = th.zeros( 132 | n_examples, 133 | n_chans, 134 | n_0, 135 | n_1, 136 | device=features[0].device) 137 | 138 | n_chans_before = previous_features.size()[1] 139 | cur_chan = 0 140 | if not self.checkerboard: 141 | for i_stride in range(self.stride[0]): 142 | for j_stride in range(self.stride[1]): 143 | if self.cat_at_end: 144 | previous_features[:, :, i_stride::self.stride[0], 145 | j_stride::self.stride[1]] = ( 146 | one_chan_features[:, 147 | cur_chan * n_chans_before: 148 | cur_chan * n_chans_before + n_chans_before]) 149 | else: 150 | previous_features[:, :, i_stride::self.stride[0], 151 | j_stride::self.stride[1]] = one_chan_features[cur_chan] 152 | cur_chan += 1 153 | else: 154 | # Manually go through 4 checkerboard positions 155 | assert self.stride[0] == 2 156 | assert self.stride[1] == 2 157 | if self.cat_at_end: 158 | previous_features[:, :, 0::2, 0::2] = ( 159 | one_chan_features[:, 160 | 0 * n_chans_before:0 * n_chans_before + n_chans_before]) 161 | previous_features[:, :, 1::2, 1::2] = ( 162 | one_chan_features[:, 163 | 1 * n_chans_before:1 * n_chans_before + n_chans_before]) 164 | previous_features[:, :, 0::2, 1::2] = ( 165 | one_chan_features[:, 166 | 2 * n_chans_before:2 * n_chans_before + n_chans_before]) 167 | previous_features[:, :, 1::2, 0::2] = ( 168 | one_chan_features[:, 169 | 3 * n_chans_before:3 * n_chans_before + n_chans_before]) 170 | else: 171 | previous_features[:, :, 0::2, 0::2] = one_chan_features[0] 172 | previous_features[:, :, 1::2, 1::2] = one_chan_features[1] 173 | previous_features[:, :, 0::2, 1::2] = one_chan_features[2] 174 | previous_features[:, :, 1::2, 0::2] = one_chan_features[3] 175 | all_previous_features.append(previous_features) 176 | features = th.cat(all_previous_features, dim=1) 177 | return features, 0 178 | 179 | def __repr__(self): 180 | return ("SubsampleSplitter(stride={:s}, chunk_chans_first={:s}, " 181 | "checkerboard={:s})").format(str(self.stride), 182 | str(self.chunk_chans_first), 183 | str(self.checkerboard)) 184 | 185 | -------------------------------------------------------------------------------- /invglow/invertible/view_as.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | 4 | 5 | class ViewAs(th.nn.Module): 6 | def __init__(self, dims_before, dims_after): 7 | super().__init__() 8 | self.dims_before = dims_before 9 | self.dims_after = dims_after 10 | 11 | def forward(self, x, fixed=None): 12 | for i_dim in range(len(x.size())): 13 | expected = self.dims_before[i_dim] 14 | if expected != -1: 15 | assert x.size()[i_dim] == expected, ( 16 | "Expected size {:s}, Actual: {:s}".format( 17 | str(self.dims_before), str(x.size())) 18 | ) 19 | return x.view(self.dims_after), 0 20 | 21 | def invert(self, features, fixed=None): 22 | for i_dim in range(len(features.size())): 23 | expected = self.dims_after[i_dim] 24 | if expected != -1: 25 | assert features.size()[i_dim] == expected, ( 26 | "Expected size {:s}, Actual: {:s}".format( 27 | str(self.dims_after), str(features.size())) 28 | ) 29 | features = features.view(self.dims_before) 30 | return features, 0 31 | 32 | def __repr__(self): 33 | return "ViewAs({:s}, {:s})".format( 34 | str(self.dims_before), str(self.dims_after)) 35 | 36 | 37 | class Flatten2d(th.nn.Module): 38 | def __init__(self, ): 39 | super().__init__() 40 | self.dims_before = None 41 | 42 | def forward(self, x, fixed=None): 43 | self.dims_before = x.size() 44 | y = x.view(x.size()[0], -1) 45 | return y, 0 46 | 47 | def invert(self, features, fixed=None): 48 | assert self.dims_before is not None, ( 49 | "Please call forward first") 50 | features = features.view(-1, 51 | *self.dims_before[1:]) 52 | return features, 0 53 | 54 | def __repr__(self): 55 | return "Flatten2d({:s}".format( 56 | str(self.dims_before)) 57 | 58 | 59 | class Flatten2dAndCat(th.nn.Module): 60 | def __init__(self, ): 61 | super().__init__() 62 | self.dims_before = None 63 | 64 | def forward(self, x, fixed=None): 65 | self.dims_before = [a_x.shape for a_x in x] 66 | y = th.cat([a_x.contiguous().view(a_x.size()[0], -1) for a_x in x], dim=1) 67 | return y, 0 68 | 69 | def invert(self, features, fixed=None): 70 | assert self.dims_before is not None, ( 71 | "Please call forward first") 72 | xs = [] 73 | i_start = 0 74 | for shape in self.dims_before: 75 | n_len = int(np.prod(shape[1:])) 76 | part_f = features[:, i_start:i_start+n_len] 77 | xs.append(part_f.view(-1,*shape[1:])) 78 | i_start += n_len 79 | return xs, 0 80 | 81 | def __repr__(self): 82 | return "Flatten2dAndCat({:s})".format( 83 | str(self.dims_before)) 84 | -------------------------------------------------------------------------------- /invglow/load_data.py: -------------------------------------------------------------------------------- 1 | from invglow.datasets import load_train_test, PreprocessedLoader 2 | import torch as th 3 | 4 | from invglow.invertible.noise import UniNoise 5 | from invglow.invertible.pure_model import NoLogDet 6 | from invglow.datasets import LSUN 7 | from torchvision import transforms 8 | from invglow import folder_locations 9 | 10 | def load_data( 11 | dataset, 12 | first_n, 13 | exclude_cifar_from_tiny, 14 | base_set_name, 15 | ood_set_name, 16 | noise_factor, 17 | augment, 18 | batch_size=64, 19 | eval_batch_size=512, 20 | shuffle_train=True, 21 | drop_last_train=True, 22 | tiny_grey=False, 23 | ): 24 | n_workers = 6 25 | 26 | base_first_n = first_n 27 | train_loader, test_loader = load_train_test( 28 | dataset, 29 | shuffle_train=shuffle_train, 30 | drop_last_train=drop_last_train, 31 | batch_size=batch_size, 32 | eval_batch_size=eval_batch_size, 33 | n_workers=n_workers, 34 | first_n=first_n, 35 | augment=augment, 36 | exclude_cifar_from_tiny=exclude_cifar_from_tiny, 37 | tiny_grey=tiny_grey) 38 | 39 | if base_set_name is not None: 40 | base_train_loader, base_test_loader = load_train_test( 41 | base_set_name, 42 | shuffle_train=shuffle_train, 43 | drop_last_train=drop_last_train, 44 | batch_size=batch_size, 45 | eval_batch_size=eval_batch_size, 46 | n_workers=n_workers, 47 | first_n=base_first_n, 48 | augment=augment, 49 | exclude_cifar_from_tiny=exclude_cifar_from_tiny, 50 | tiny_grey=tiny_grey) 51 | else: 52 | base_train_loader = None 53 | if ood_set_name is not None: 54 | ood_train_loader, ood_test_loader = load_train_test( 55 | ood_set_name, 56 | shuffle_train=shuffle_train, 57 | drop_last_train=drop_last_train, 58 | batch_size=batch_size, 59 | eval_batch_size=eval_batch_size, 60 | n_workers=n_workers, 61 | first_n=first_n, 62 | augment=augment, 63 | exclude_cifar_from_tiny=exclude_cifar_from_tiny, 64 | tiny_grey=tiny_grey) 65 | 66 | if tiny_grey == False: 67 | categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', 68 | 'conference_room', 'dining_room', 'kitchen', 69 | 'living_room', 'restaurant', 'tower'] 70 | 71 | def final_preproc_lsun(x): 72 | # make same as cifar tiny etc. 73 | return (x * 255 / 256) - 0.5 74 | 75 | lsun_set = LSUN(folder_locations.lsun_data, 76 | classes=[c + '_val' for c in categories], 77 | transform=transforms.Compose([ 78 | transforms.Resize(32), 79 | transforms.CenterCrop(32), 80 | transforms.ToTensor(), 81 | final_preproc_lsun])) 82 | test_lsun = th.utils.data.DataLoader(lsun_set, batch_size=512, 83 | num_workers=0) 84 | def preproced(loader): 85 | return PreprocessedLoader(loader, NoLogDet(UniNoise(noise_factor)), 86 | to_cuda=True) 87 | loaders = dict( 88 | train=preproced(train_loader), 89 | test=preproced(test_loader), 90 | ) 91 | if ood_set_name is not None: 92 | loaders['ood_test'] = preproced(ood_test_loader) 93 | if tiny_grey == False: 94 | loaders['lsun'] = preproced(test_lsun) 95 | if dataset == 'cifar10': 96 | other_cifar_name = 'cifar100' 97 | else: 98 | other_cifar_name = 'cifar10' 99 | _, other_cifar_test_loader = load_train_test( 100 | other_cifar_name, 101 | shuffle_train=shuffle_train, 102 | drop_last_train=drop_last_train, 103 | batch_size=batch_size, 104 | eval_batch_size=eval_batch_size, 105 | n_workers=n_workers, 106 | first_n=first_n, 107 | augment=augment, 108 | exclude_cifar_from_tiny=exclude_cifar_from_tiny, 109 | tiny_grey=tiny_grey) 110 | loaders['ood_cifar'] = preproced(other_cifar_test_loader) 111 | 112 | 113 | if base_set_name is not None: 114 | base_train_loader = preproced(base_train_loader) 115 | 116 | return loaders, base_train_loader 117 | -------------------------------------------------------------------------------- /invglow/losses.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn.functional as F 3 | 4 | def nll_class_loss(base_nll, fine_nll, target_val, 5 | temperature, weight, reduction): 6 | assert target_val in [0, 1] 7 | ldiff = -(fine_nll - base_nll.detach()) 8 | return nll_diff_loss(ldiff, target_val=target_val, 9 | temperature=temperature, 10 | weight=weight, 11 | reduction=reduction) 12 | 13 | 14 | def nll_diff_loss(lp_diff, target_val, temperature, weight, 15 | reduction): 16 | class_loss = F.binary_cross_entropy_with_logits( 17 | lp_diff / temperature, th.zeros_like(lp_diff) + target_val, 18 | reduction=reduction) 19 | class_loss = weight * class_loss 20 | return class_loss -------------------------------------------------------------------------------- /invglow/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch as th 4 | 5 | from invglow.exp import run_exp 6 | from invglow.evaluate import evaluate_without_noise 7 | 8 | log = logging.getLogger(__name__) 9 | 10 | th.backends.cudnn.benchmark = True 11 | default_args = dict( 12 | lr=5e-4, 13 | weight_decay=5e-5, 14 | first_n=None, 15 | exclude_cifar_from_tiny=False, 16 | output_dir='.', 17 | reinit=False, 18 | saved_optimizer_path=None, 19 | noise_factor=1/256.0, 20 | flow_coupling='affine', 21 | init_class_model = False, 22 | batch_size=64, 23 | augment=True, 24 | warmup_steps=None, 25 | n_epochs=250, 26 | np_th_seed=20200610 27 | ) 28 | 29 | # to run properly, set debug to False 30 | debug = True 31 | 32 | dataset = 'tiny' 33 | base_set_name = 'tiny' 34 | tiny_grey = False 35 | saved_base_model_path = None 36 | saved_model_path = None 37 | outlier_loss = None 38 | outlier_weight = None 39 | outlier_temperature = None 40 | ood_set_name = None 41 | add_full_label_loss = False 42 | on_top_class_model_name = None 43 | K = 32 44 | local_patches = False 45 | block_type = 'conv' 46 | flow_permutation = 'invconv' 47 | LU_decomposed=True 48 | 49 | 50 | ## Pretrain Tiny 51 | 52 | 53 | trainer, model = run_exp( 54 | dataset=dataset, 55 | debug=debug, 56 | saved_base_model_path=saved_base_model_path, 57 | saved_model_path=saved_model_path, 58 | base_set_name=base_set_name, 59 | outlier_weight=outlier_weight, 60 | outlier_loss=outlier_loss, 61 | ood_set_name=ood_set_name, 62 | outlier_temperature=outlier_temperature, 63 | K=K, 64 | on_top_class_model_name=on_top_class_model_name, 65 | add_full_label_loss=add_full_label_loss, 66 | tiny_grey=tiny_grey, 67 | local_patches=local_patches, 68 | block_type=block_type, 69 | flow_permutation=flow_permutation, 70 | LU_decomposed=LU_decomposed, 71 | **default_args) 72 | 73 | 74 | saved_base_model_path = './tiny_model.th' 75 | th.save(model, saved_base_model_path) 76 | del trainer, model 77 | 78 | ## Finetune 79 | 80 | dataset = 'cifar10' 81 | base_set_name = 'tiny' 82 | tiny_grey = False 83 | saved_base_model_path = './tiny_model.th' 84 | saved_model_path = './tiny_model.th' 85 | outlier_loss = None 86 | outlier_temperature = None 87 | outlier_weight = None 88 | ood_set_name = 'svhn' # just for eval 89 | add_full_label_loss = False 90 | on_top_class_model_name = None 91 | K = 32 92 | local_patches = False 93 | block_type = 'conv' 94 | flow_permutation = 'invconv' 95 | LU_decomposed = True 96 | 97 | trainer, model = run_exp( 98 | dataset=dataset, 99 | debug=debug, 100 | saved_base_model_path=saved_base_model_path, 101 | saved_model_path=saved_model_path, 102 | base_set_name=base_set_name, 103 | outlier_weight=outlier_weight, 104 | outlier_loss=outlier_loss, 105 | ood_set_name=ood_set_name, 106 | outlier_temperature=outlier_temperature, 107 | K=K, 108 | on_top_class_model_name=on_top_class_model_name, 109 | add_full_label_loss=add_full_label_loss, 110 | tiny_grey=tiny_grey, 111 | local_patches=local_patches, 112 | block_type=block_type, 113 | flow_permutation=flow_permutation, 114 | LU_decomposed=LU_decomposed, 115 | **default_args) 116 | # We create some helper function to show how the evaluation wihtout noise works, 117 | # please look inside to see how it works, 118 | # this function was not used directly for the manuscript 119 | # but should yield same results, unless new bugs were introduced :) 120 | 121 | base_model = th.load(saved_base_model_path) 122 | 123 | evaluate_without_noise(model, base_model, on_top_class=False, 124 | first_n=512, # set to none for proper eval 125 | noise_factor=1/256.0, 126 | in_dist_name='cifar10', 127 | rgb_or_grey='rgb', 128 | only_full_nll=False) 129 | del trainer, model 130 | 131 | 132 | ## Finetune with outlier loss 133 | 134 | dataset = 'cifar10' 135 | base_set_name = 'tiny' 136 | tiny_grey = False 137 | saved_base_model_path = './tiny_model.th' 138 | saved_model_path = './tiny_model.th' 139 | outlier_loss = 'class' 140 | outlier_weight = 6000 141 | outlier_temperature = 1000 142 | ood_set_name = 'svhn' # just for eval 143 | add_full_label_loss = False 144 | on_top_class_model_name = None 145 | K = 32 146 | local_patches = False 147 | block_type = 'conv' 148 | flow_permutation = 'invconv' 149 | LU_decomposed = True 150 | 151 | trainer, model = run_exp( 152 | dataset=dataset, 153 | debug=debug, 154 | saved_base_model_path=saved_base_model_path, 155 | saved_model_path=saved_model_path, 156 | base_set_name=base_set_name, 157 | outlier_weight=outlier_weight, 158 | outlier_loss=outlier_loss, 159 | ood_set_name=ood_set_name, 160 | outlier_temperature=outlier_temperature, 161 | K=K, 162 | on_top_class_model_name=on_top_class_model_name, 163 | add_full_label_loss=add_full_label_loss, 164 | tiny_grey=tiny_grey, 165 | local_patches=local_patches, 166 | block_type=block_type, 167 | flow_permutation=flow_permutation, 168 | LU_decomposed=LU_decomposed, 169 | **default_args) 170 | 171 | 172 | ## Finetune with outlier loss supervised 173 | 174 | del trainer, model 175 | 176 | dataset = 'cifar10' 177 | base_set_name = 'tiny' 178 | tiny_grey = False 179 | saved_base_model_path = './tiny_model.th' 180 | saved_model_path = './tiny_model.th' 181 | outlier_loss = 'class' 182 | outlier_weight = 6000 183 | outlier_temperature = 1000 184 | ood_set_name = 'svhn' # just for eval 185 | add_full_label_loss = True 186 | on_top_class_model_name = 'latent' 187 | K = 32 188 | local_patches = False 189 | block_type = 'conv' 190 | flow_permutation = 'invconv' 191 | LU_decomposed = True 192 | 193 | trainer, model = run_exp( 194 | dataset=dataset, 195 | debug=debug, 196 | saved_base_model_path=saved_base_model_path, 197 | saved_model_path=saved_model_path, 198 | base_set_name=base_set_name, 199 | outlier_weight=outlier_weight, 200 | outlier_loss=outlier_loss, 201 | ood_set_name=ood_set_name, 202 | outlier_temperature=outlier_temperature, 203 | K=K, 204 | on_top_class_model_name=on_top_class_model_name, 205 | add_full_label_loss=add_full_label_loss, 206 | tiny_grey=tiny_grey, 207 | local_patches=local_patches, 208 | block_type=block_type, 209 | flow_permutation=flow_permutation, 210 | LU_decomposed=LU_decomposed, 211 | **default_args) 212 | 213 | ## Local model 214 | 215 | del trainer, model 216 | 217 | dataset = 'cifar10' 218 | base_set_name = None 219 | tiny_grey = False 220 | saved_base_model_path = None 221 | saved_model_path = None 222 | outlier_loss = None 223 | outlier_weight = None 224 | outlier_temperature = None 225 | ood_set_name = None 226 | add_full_label_loss = False 227 | on_top_class_model_name = None 228 | K = 32 229 | local_patches = True 230 | block_type = 'conv' 231 | flow_permutation = 'invconv' 232 | LU_decomposed = True 233 | 234 | trainer, model = run_exp( 235 | dataset=dataset, 236 | debug=debug, 237 | saved_base_model_path=saved_base_model_path, 238 | saved_model_path=saved_model_path, 239 | base_set_name=base_set_name, 240 | outlier_weight=outlier_weight, 241 | outlier_loss=outlier_loss, 242 | ood_set_name=ood_set_name, 243 | outlier_temperature=outlier_temperature, 244 | K=K, 245 | on_top_class_model_name=on_top_class_model_name, 246 | add_full_label_loss=add_full_label_loss, 247 | tiny_grey=tiny_grey, 248 | local_patches=local_patches, 249 | block_type=block_type, 250 | flow_permutation=flow_permutation, 251 | LU_decomposed=LU_decomposed, 252 | **default_args) 253 | 254 | ## Fully Connected Model 255 | 256 | del trainer, model 257 | 258 | dataset = 'cifar10' 259 | base_set_name = None 260 | tiny_grey = False 261 | saved_base_model_path = None 262 | saved_model_path = None 263 | outlier_loss = None 264 | outlier_weight = None 265 | outlier_temperature = None 266 | ood_set_name = None 267 | add_full_label_loss = False 268 | on_top_class_model_name = None 269 | K = 8 270 | local_patches = False 271 | block_type = 'dense' 272 | flow_permutation = 'invconvfixed' 273 | LU_decomposed = False 274 | 275 | trainer, model = run_exp( 276 | dataset=dataset, 277 | debug=debug, 278 | saved_base_model_path=saved_base_model_path, 279 | saved_model_path=saved_model_path, 280 | base_set_name=base_set_name, 281 | outlier_weight=outlier_weight, 282 | outlier_loss=outlier_loss, 283 | ood_set_name=ood_set_name, 284 | outlier_temperature=outlier_temperature, 285 | K=K, 286 | on_top_class_model_name=on_top_class_model_name, 287 | add_full_label_loss=add_full_label_loss, 288 | tiny_grey=tiny_grey, 289 | local_patches=local_patches, 290 | block_type=block_type, 291 | flow_permutation=flow_permutation, 292 | LU_decomposed=LU_decomposed, 293 | **default_args) -------------------------------------------------------------------------------- /invglow/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boschresearch/hierarchical_anomaly_detection/ca2f1d84615c2ef140a74f4e1515352abff9e938/invglow/models/__init__.py -------------------------------------------------------------------------------- /invglow/models/class_conditional.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from invglow.invertible.actnorm import ActNorm 3 | from invglow.invertible.branching import ChunkByIndices 4 | from invglow.invertible.distribution import Unlabeled, NClassIndependentDist 5 | from invglow.invertible.graph import CatAsListNode 6 | from invglow.invertible.graph import Node, SelectNode 7 | from invglow.invertible.sequential import InvertibleSequential 8 | from invglow.invertible.view_as import Flatten2d 9 | 10 | log = logging.getLogger('__name__') 11 | 12 | def latent_model(n_chans): 13 | n_dims_per_scale = [n_chans*2*16*16, n_chans*4*8*8, n_chans*16*4*4] 14 | rechunker = ChunkByIndices((n_dims_per_scale[0], sum(n_dims_per_scale[:2]))) 15 | nd_in_split_again = Node(None, rechunker) 16 | dist_nodes = [] 17 | for i_scale in range(3): 18 | n_dims_scale = n_dims_per_scale[i_scale] 19 | nd_in_class = SelectNode(nd_in_split_again, i_scale) 20 | act_class = InvertibleSequential( 21 | Flatten2d(), 22 | ActNorm(n_dims_scale, scale_fn='exp')) 23 | dist_class = Unlabeled( 24 | NClassIndependentDist(1, n_dims_scale)) 25 | nd_act_class = Node(nd_in_class,act_class) 26 | nd_dist_class = Node(nd_act_class, dist_class) 27 | dist_nodes.append(nd_dist_class) 28 | top_model = CatAsListNode(dist_nodes) 29 | return top_model 30 | 31 | 32 | 33 | def convert_class_model_for_multi_scale_nll(model): 34 | log.warning("Please be aware that these results are not mathematically correct") 35 | pre_class_nodes_per_scale = model.sequential[0].prev 36 | for p in pre_class_nodes_per_scale: 37 | p.next = [] 38 | 39 | class_model = model.sequential[1].module 40 | per_scale_act_dist_mods = [[], [], []] 41 | for single_model in class_model.module_list: 42 | per_scale_nodes = single_model.prev[0].prev 43 | for i_scale, per_scale_node in enumerate(per_scale_nodes): 44 | act_mod = per_scale_node.prev[0].module 45 | dist_mod = per_scale_node.module 46 | per_scale_act_dist_mods[i_scale].append( 47 | InvertibleSequential(act_mod, dist_mod)) 48 | 49 | act_dist_nodes_per_scale = [ 50 | MergeLogDetsNode(Node( 51 | pre_class_nodes_per_scale[i_scale], 52 | InvertibleClassConditional( 53 | per_scale_act_dist_mods[i_scale], 54 | i_classes=list(range(len(per_scale_act_dist_mods[i_scale])))), 55 | name=f'm0-dist-{i_scale}')) 56 | for i_scale in range(3)] 57 | model = CatAsListNode(act_dist_nodes_per_scale) 58 | return model -------------------------------------------------------------------------------- /invglow/models/glow.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch as th 3 | 4 | from invglow.invertible.actnorm import ActNorm 5 | from invglow.invertible.affine import AffineCoefs, AffineModifier, AdditiveCoefs 6 | from invglow.invertible.branching import ChunkChans, ChunkByIndices 7 | from invglow.invertible.coupling import CouplingLayer 8 | from invglow.invertible.distribution import Unlabeled, NClassIndependentDist 9 | from invglow.invertible.graph import CatChansNode 10 | from invglow.invertible.graph import Node, SelectNode, CatAsListNode 11 | from invglow.invertible.graph import get_nodes_by_names 12 | from invglow.invertible.identity import Identity 13 | from invglow.invertible.inv_permute import InvPermute, Shuffle 14 | from invglow.invertible.sequential import InvertibleSequential 15 | from invglow.invertible.split_merge import ChunkChansIn2, EverySecondChan 16 | from invglow.invertible.splitter import SubsampleSplitter 17 | from invglow.invertible.view_as import Flatten2d, ViewAs 18 | 19 | 20 | def convert_glow_to_pre_dist_model(model): 21 | model_log_act_nodes = get_nodes_by_names( 22 | model, 'm0-act-0', 'm0-act-1', 'm0-act-2') 23 | for a in model_log_act_nodes: 24 | a.next = [] 25 | model_log_det_node = CatChansNode( 26 | model_log_act_nodes, 27 | notify_prev_nodes=True) 28 | return model_log_det_node 29 | 30 | 31 | def split_glow_into_pre_dist_and_dist(model): 32 | # remove references to previous dist node 33 | model_log_act_nodes = get_nodes_by_names( 34 | model, 'm0-act-0', 'm0-act-1', 'm0-act-2') 35 | for a in model_log_act_nodes: 36 | a.next = [] 37 | model_log_det_node = CatChansNode( 38 | model_log_act_nodes, 39 | notify_prev_nodes=True) 40 | 41 | model_dist_nodes = get_nodes_by_names(model, 'm0-dist-0', 42 | 'm0-dist-1', 43 | 'm0-dist-2') 44 | rechunker = ChunkByIndices((6 * 16 * 16, 6 * 16 * 16 + 48 * 4 * 4)) 45 | nd_in_split_for_dist = Node(None, rechunker) 46 | dist_node = CatChansNode( 47 | [Node(SelectNode(nd_in_split_for_dist, i), 48 | model_dist_nodes[i].module) 49 | for i in range(len(model_dist_nodes))]) 50 | return model_log_det_node, dist_node 51 | 52 | 53 | def create_glow_model( 54 | hidden_channels, 55 | K, 56 | L, 57 | flow_permutation, 58 | flow_coupling, 59 | LU_decomposed, 60 | n_chans, 61 | block_type='conv', 62 | use_act_norm=True 63 | 64 | ): 65 | image_shape = (32, 32, n_chans) 66 | 67 | H, W, C = image_shape 68 | flows_per_scale = [] 69 | act_norms_per_scale = [] 70 | dists_per_scale = [] 71 | for i in range(L): 72 | 73 | C, H, W = C * 4, H // 2, W // 2 74 | 75 | splitter = SubsampleSplitter( 76 | 2, via_reshape=True, chunk_chans_first=True, checkerboard=False, 77 | cat_at_end=True) 78 | 79 | if block_type == 'dense': 80 | pre_flow_layers = [Flatten2d()] 81 | in_channels = C * H * W 82 | else: 83 | assert block_type == 'conv' 84 | pre_flow_layers = [] 85 | in_channels = C 86 | 87 | flow_layers = [flow_block(in_channels=in_channels, 88 | hidden_channels=hidden_channels, 89 | flow_permutation=flow_permutation, 90 | flow_coupling=flow_coupling, 91 | LU_decomposed=LU_decomposed, 92 | cond_channels=0, 93 | cond_merger=None, 94 | block_type=block_type, 95 | use_act_norm=use_act_norm) for _ in range(K)] 96 | 97 | if block_type == 'dense': 98 | post_flow_layers = [ViewAs((-1, C * H * W), (-1, C, H, W))] 99 | else: 100 | assert block_type == 'conv' 101 | post_flow_layers = [] 102 | flow_layers = pre_flow_layers + flow_layers + post_flow_layers 103 | flow_this_scale = InvertibleSequential(splitter, *flow_layers) 104 | flows_per_scale.append(flow_this_scale) 105 | 106 | if i < L - 1: 107 | # there will be a chunking here 108 | C = C // 2 109 | # act norms for distribution (mean/std as actnorm isntead of integrated 110 | # into dist) 111 | act_norms_per_scale.append(InvertibleSequential(Flatten2d(), 112 | ActNorm((C * H * W), 113 | scale_fn='exp'))) 114 | dists_per_scale.append(Unlabeled( 115 | NClassIndependentDist(1, C * H * W, optimize_mean_std=False))) 116 | 117 | assert len(flows_per_scale) == 3 118 | 119 | nd_1_o = Node(None, flows_per_scale[0], name='m0-flow-0') 120 | nd_1_ab = Node(nd_1_o, ChunkChans(2)) 121 | nd_1_a = SelectNode(nd_1_ab, 0) 122 | nd_1_an = Node(nd_1_a, act_norms_per_scale[0], name='m0-act-0') 123 | nd_1_ad = Node(nd_1_an, dists_per_scale[0], name='m0-dist-0') 124 | 125 | nd_1_b = SelectNode(nd_1_ab, 1, name='m0-in-flow-1') 126 | nd_2_o = Node(nd_1_b, flows_per_scale[1], name='m0-flow-1') 127 | nd_2_ab = Node(nd_2_o, ChunkChans(2), ) 128 | 129 | nd_2_a = SelectNode(nd_2_ab, 0, ) 130 | nd_2_an = Node(nd_2_a, act_norms_per_scale[1], name='m0-act-1') 131 | nd_2_ad = Node(nd_2_an, dists_per_scale[1], name='m0-dist-1') 132 | 133 | nd_2_b = SelectNode(nd_2_ab, 1, name='m0-in-flow-2') 134 | nd_3_o = Node(nd_2_b, flows_per_scale[2], name='m0-flow-2') 135 | nd_3_n = Node(nd_3_o, act_norms_per_scale[2], name='m0-act-2') 136 | nd_3_d = Node(nd_3_n, dists_per_scale[2], name='m0-dist-2') 137 | 138 | model = CatAsListNode([nd_1_ad, nd_2_ad, nd_3_d], 139 | name='m0-full') # cahnged to pre-full 140 | return model 141 | 142 | 143 | 144 | 145 | class Conv2dZeros(nn.Module): 146 | def __init__(self, in_channels, out_channels, 147 | kernel_size=(3, 3), stride=(1, 1), 148 | padding="same", logscale_factor=3): 149 | super().__init__() 150 | 151 | if padding == "same": 152 | padding = compute_same_pad(kernel_size, stride) 153 | elif padding == "valid": 154 | padding = 0 155 | 156 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 157 | padding) 158 | 159 | self.conv.weight.data.zero_() 160 | self.conv.bias.data.zero_() 161 | 162 | self.logscale_factor = logscale_factor 163 | self.logs = nn.Parameter(th.zeros(out_channels, 1, 1)) 164 | 165 | def forward(self, input): 166 | output = self.conv(input) 167 | return output * th.exp(self.logs * self.logscale_factor) 168 | 169 | 170 | 171 | class Conv2d(nn.Module): 172 | def __init__(self, in_channels, out_channels, 173 | kernel_size=(3, 3), stride=(1, 1), 174 | padding="same", do_actnorm=True, weight_std=0.05): 175 | super().__init__() 176 | 177 | if padding == "same": 178 | padding = compute_same_pad(kernel_size, stride) 179 | elif padding == "valid": 180 | padding = 0 181 | 182 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 183 | padding, bias=(not do_actnorm)) 184 | 185 | # init weight with std 186 | self.conv.weight.data.normal_(mean=0.0, std=weight_std) 187 | 188 | if not do_actnorm: 189 | self.conv.bias.data.zero_() 190 | else: 191 | self.actnorm = ActNorm(out_channels, scale_fn='exp', eps=0) 192 | 193 | self.do_actnorm = do_actnorm 194 | 195 | def forward(self, input): 196 | x = self.conv(input) 197 | if self.do_actnorm: 198 | x, _ = self.actnorm(x) 199 | return x 200 | 201 | 202 | def get_conv_block(in_channels, out_channels, hidden_channels, nonlin_name): 203 | assert nonlin_name in ['elu', 'relu'] 204 | nonlin = {'elu': nn.ELU(inplace=False), 205 | 'relu': nn.ReLU(inplace=False)}[nonlin_name] 206 | block = nn.Sequential(Conv2d(in_channels, hidden_channels), 207 | nonlin, 208 | Conv2d(hidden_channels, hidden_channels, 209 | kernel_size=(1, 1)), 210 | nonlin, 211 | Conv2dZeros(hidden_channels, out_channels)) 212 | return block 213 | 214 | 215 | def get_dense_block(in_channels, out_channels, hidden_channels, nonlin_name): 216 | nonlin = {'elu': nn.ELU(inplace=False), 217 | 'relu': nn.ReLU(inplace=False)}[nonlin_name] 218 | block = nn.Sequential(nn.Linear(in_channels, hidden_channels), 219 | nonlin, 220 | nn.Linear(hidden_channels, hidden_channels), 221 | nonlin, 222 | nn.Linear(hidden_channels, out_channels)) 223 | return block 224 | 225 | 226 | def flow_block(in_channels, hidden_channels, 227 | flow_permutation, flow_coupling, LU_decomposed, 228 | cond_channels, cond_merger, block_type, use_act_norm, 229 | nonlin_name='relu'): 230 | if use_act_norm: 231 | actnorm = ActNorm(in_channels, scale_fn='exp', eps=0) 232 | # 2. permute 233 | if flow_permutation == "invconv": 234 | flow_permutation = InvPermute( 235 | in_channels, fixed=False, use_lu=LU_decomposed) 236 | elif flow_permutation == 'invconvfixed': 237 | flow_permutation = InvPermute(in_channels, 238 | fixed=True, 239 | use_lu=LU_decomposed) 240 | elif flow_permutation == "identity": 241 | flow_permutation = Identity() 242 | else: 243 | assert flow_permutation == 'shuffle' 244 | flow_permutation = Shuffle(in_channels) 245 | 246 | if flow_coupling == "additive": 247 | out_channels = in_channels // 2 248 | else: 249 | out_channels = in_channels 250 | 251 | if type(block_type) is str: 252 | if block_type == 'conv': 253 | block_fn = get_conv_block 254 | else: 255 | assert block_type == 'dense' 256 | block_fn = get_dense_block 257 | else: 258 | block_fn = block_type 259 | 260 | block = block_fn(in_channels // 2 + cond_channels, 261 | out_channels, 262 | hidden_channels, 263 | nonlin_name=nonlin_name) 264 | 265 | if flow_coupling == "additive": 266 | coupling = CouplingLayer( 267 | ChunkChansIn2(swap_dims=True), 268 | AdditiveCoefs(block,), 269 | AffineModifier(sigmoid_or_exp_scale='sigmoid', 270 | eps=0, add_first=True, ), 271 | condition_merger=cond_merger 272 | ) 273 | elif flow_coupling == "affine": 274 | coupling = CouplingLayer( 275 | ChunkChansIn2(swap_dims=True), 276 | AffineCoefs(block, EverySecondChan()), 277 | AffineModifier(sigmoid_or_exp_scale='sigmoid', eps=0, add_first=True), 278 | condition_merger=cond_merger, 279 | ) 280 | else: 281 | assert False, f"unknown flow_coupling {flow_coupling}" 282 | if use_act_norm: 283 | sequential = InvertibleSequential(actnorm, flow_permutation, coupling) 284 | else: 285 | sequential = InvertibleSequential(flow_permutation, coupling) 286 | return sequential 287 | 288 | 289 | 290 | def compute_same_pad(kernel_size, stride): 291 | if isinstance(kernel_size, int): 292 | kernel_size = [kernel_size] 293 | 294 | if isinstance(stride, int): 295 | stride = [stride] 296 | 297 | assert len(stride) == len(kernel_size),\ 298 | "Pass kernel size and stride both as int, or both as equal length iterable" 299 | 300 | return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] -------------------------------------------------------------------------------- /invglow/models/patch_glow.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | from invglow.invertible.actnorm import ActNorm 4 | from invglow.invertible.distribution import Unlabeled, NClassIndependentDist 5 | from invglow.invertible.sequential import InvertibleSequential 6 | from invglow.invertible.splitter import SubsampleSplitter 7 | from invglow.invertible.view_as import Flatten2d 8 | from invglow.models.glow import flow_block 9 | 10 | 11 | def unfold_patches_batch(x, size): 12 | unfolded = th.nn.functional.unfold( 13 | x, 14 | (size, size), stride=(size, size)) 15 | patches = unfolded.reshape(x.shape[0],x.shape[1],size,size,-1) 16 | patches_batch = patches.permute(0,4,1,2,3).reshape(-1, *patches.shape[1:-1]) 17 | return patches_batch 18 | 19 | def fold_to_images_batch(x, n_orig_x, image_size, patch_size, ): 20 | patches = x.reshape( 21 | n_orig_x,-1,x.shape[1], patch_size, patch_size).permute(0,2,3,4,1) 22 | unfolded = patches.reshape(patches.shape[0],-1,patches.shape[-1]) 23 | folded = th.nn.functional.fold( 24 | unfolded, (image_size,image_size), 25 | (patch_size,patch_size), stride=(patch_size,patch_size)) 26 | return folded 27 | 28 | 29 | class WrapForPatches(nn.Module): 30 | def __init__(self, model, patch_size): 31 | super().__init__() 32 | self.model = model 33 | self.patch_size = patch_size 34 | 35 | def forward(self, x, fixed=None): 36 | patches = unfold_patches_batch(x, self.patch_size) 37 | out, lp = self.model(patches) 38 | image_lp = th.sum(lp.reshape(x.shape[0], -1, *lp.shape[1:]), dim=1) 39 | return out, image_lp 40 | 41 | def invert(self, z, fixed=None): 42 | raise ValueError("not implemented") 43 | 44 | def create_patch_glow_model( 45 | hidden_channels, 46 | K, 47 | flow_permutation, 48 | flow_coupling, 49 | LU_decomposed, 50 | n_chans, 51 | use_act_norm=True): 52 | C = n_chans * 4 53 | H = 4 54 | W = 4 55 | 56 | splitter = SubsampleSplitter( 57 | 2, via_reshape=True, chunk_chans_first=True, checkerboard=False, 58 | cat_at_end=True) 59 | 60 | flow_layers = [flow_block(in_channels=C, 61 | hidden_channels=hidden_channels, 62 | flow_permutation=flow_permutation, 63 | flow_coupling=flow_coupling, 64 | LU_decomposed=LU_decomposed, 65 | cond_channels=0, 66 | cond_merger=None, 67 | block_type="conv", 68 | use_act_norm=use_act_norm) for _ in range(K)] 69 | flow_this_scale = InvertibleSequential(splitter, *flow_layers) 70 | flow_this_scale.cuda(); 71 | act_norm = InvertibleSequential( 72 | Flatten2d(), 73 | ActNorm((C * H * W), 74 | scale_fn='exp')) 75 | dist = Unlabeled( 76 | NClassIndependentDist(1, C * H * W, optimize_mean_std=False)) 77 | model = InvertibleSequential(flow_this_scale, act_norm, dist) 78 | model = WrapForPatches(model, 8) 79 | return model -------------------------------------------------------------------------------- /invglow/scheduler.py: -------------------------------------------------------------------------------- 1 | class ScheduledOptimizer(object): 2 | def __init__(self, scheduler, optimizer): 3 | self.scheduler = scheduler 4 | self.optimizer = optimizer 5 | 6 | 7 | def step(self): 8 | self.optimizer.step() 9 | self.scheduler.step() 10 | 11 | def zero_grad(self): 12 | self.optimizer.zero_grad() 13 | 14 | @property 15 | def param_groups(self): 16 | return self.optimizer.param_groups 17 | 18 | def add_param_group(self, *args, **kwargs): 19 | self.optimizer.add_param_group(*args, **kwargs) 20 | 21 | 22 | def state_dict(self,): 23 | return self.optimizer.state_dict() 24 | 25 | def load_state_dict(self, *args, **kwargs): 26 | return self.optimizer.load_state_dict(*args, **kwargs) 27 | 28 | -------------------------------------------------------------------------------- /invglow/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | import logging 4 | import numbers 5 | import torch as th 6 | import numpy as np 7 | 8 | log = logging.getLogger(__name__) 9 | 10 | def step_and_clear_gradients(optimizer): 11 | optimizer.step() 12 | optimizer.zero_grad() 13 | 14 | 15 | def check_gradients_clear(optimizer): 16 | for g in optimizer.param_groups: 17 | for p in g['params']: 18 | assert p.grad is None or th.all(p.grad.data == 0).item(), ( 19 | "Gradient not none or zero!") 20 | 21 | 22 | def grads_all_finite(optimizer): 23 | for g in optimizer.param_groups: 24 | for p in g['params']: 25 | if p.grad is None: 26 | log.warning("Gradient was none on check of finite grads") 27 | elif not th.all(th.isfinite(p.grad)).item(): 28 | return False 29 | return True 30 | 31 | 32 | def clip_to_finite_max(arr, ): 33 | arr = deepcopy(arr) 34 | arr[np.isnan(arr)] = np.nanmax(arr) 35 | arr[~np.isfinite(arr)] = np.max(arr[np.isfinite(arr)]) 36 | return arr 37 | 38 | 39 | def enforce_2d(outs): 40 | while len(outs.size()) > 2: 41 | n_dims = len(outs.size()) 42 | outs = outs.squeeze(2) 43 | assert len(outs.size()) == n_dims - 1 44 | return outs 45 | 46 | 47 | def view_2d(outs): 48 | return outs.view(outs.size()[0], -1) 49 | 50 | 51 | def ensure_on_same_device(*variables): 52 | any_cuda = np.any([v.is_cuda for v in variables]) 53 | if any_cuda: 54 | variables = [ensure_cuda(v) for v in variables] 55 | return variables 56 | 57 | 58 | def ensure_cuda(v): 59 | if not v.is_cuda: 60 | v = v.cuda() 61 | return v 62 | 63 | 64 | 65 | def log_sum_exp(value, dim=None, keepdim=False): 66 | # https://github.com/pytorch/pytorch/issues/2591#issuecomment-338980717 67 | """Numerically stable implementation of the operation 68 | 69 | value.exp().sum(dim, keepdim).log() 70 | """ 71 | # TODO: torch.max(value, dim=None) threw an error at time of writing 72 | if dim is not None: 73 | m, _ = th.max(value, dim=dim, keepdim=True) 74 | value0 = value - m 75 | if keepdim is False: 76 | m = m.squeeze(dim) 77 | return m + th.log(th.sum(th.exp(value0), 78 | dim=dim, keepdim=keepdim)) 79 | else: 80 | m = th.max(value) 81 | sum_exp = th.sum(th.exp(value - m)) 82 | return m + th.log(sum_exp) 83 | 84 | 85 | def set_random_seeds(seed, cuda): 86 | """ 87 | Set seeds for python random module numpy.random and torch. 88 | 89 | Parameters 90 | ---------- 91 | seed: int 92 | Random seed. 93 | cuda: bool 94 | Whether to set cuda seed with torch. 95 | 96 | """ 97 | random.seed(seed) 98 | th.manual_seed(seed) 99 | if cuda: 100 | th.cuda.manual_seed_all(seed) 101 | np.random.seed(seed) 102 | 103 | 104 | def np_to_var(X, requires_grad=False, dtype=None, pin_memory=False, 105 | **tensor_kwargs): 106 | """ 107 | Convenience function to transform numpy array to `torch.Tensor`. 108 | 109 | Converts `X` to ndarray using asarray if necessary. 110 | 111 | Parameters 112 | ---------- 113 | X: ndarray or list or number 114 | Input arrays 115 | requires_grad: bool 116 | passed on to Variable constructor 117 | dtype: numpy dtype, optional 118 | var_kwargs: 119 | passed on to Variable constructor 120 | 121 | Returns 122 | ------- 123 | var: `torch.Tensor` 124 | """ 125 | if not hasattr(X, '__len__'): 126 | X = [X] 127 | X = np.asarray(X) 128 | if dtype is not None: 129 | X = X.astype(dtype) 130 | X_tensor = th.tensor(X, requires_grad=requires_grad, **tensor_kwargs) 131 | if pin_memory: 132 | X_tensor = X_tensor.pin_memory() 133 | return X_tensor 134 | 135 | def var_to_np(var): 136 | """Convenience function to transform `torch.Tensor` to numpy 137 | array. 138 | 139 | Should work both for CPU and GPU.""" 140 | if hasattr(var, 'cpu'): 141 | return var.cpu().data.numpy() 142 | else: 143 | # might happen that you get just a number, in that case nothing to do 144 | assert isinstance(var, numbers.Number) 145 | return var 146 | 147 | 148 | def interpolate_nans_in_df(df): 149 | df = df.copy() 150 | for row in df: 151 | series = np.array(df[row]) 152 | mask = np.isnan(series) 153 | series[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), 154 | series[~mask]) 155 | df.loc[:, row] = series 156 | return df 157 | 158 | 159 | def flatten_2d(a): 160 | return a.view(len(a), -1) 161 | --------------------------------------------------------------------------------