├── EDA.ipynb ├── LICENSE ├── README.md ├── model └── README.md └── ndv ├── dataloader.ipynb ├── model.ipynb ├── modules ├── dataloader.py └── model.py ├── training-1cycle.ipynb ├── training.ipynb ├── utils ├── notebook2script.py └── run_notebook.py └── valid.txt /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D-MRI-brain-tumor-segmentation-using-autoencoder-regularization 2 | Pytorch implementation of the paper by Myronenko A. (https://arxiv.org/abs/1810.11654) 3 | -------------------------------------------------------------------------------- /model/README.md: -------------------------------------------------------------------------------- 1 | # 3D-MRI-brain-tumor-segmentation-using-autoencoder-regularization 2 | 3 | Pytorch model is saved under current folder. 4 | -------------------------------------------------------------------------------- /ndv/model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "b6KgSfpahUZ1" 8 | }, 9 | "source": [ 10 | "# Setup" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": { 17 | "colab": {}, 18 | "colab_type": "code", 19 | "id": "G75A7ey4hUZ3" 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "import torch\n", 24 | "from fastai.callbacks import *\n", 25 | "from fastai.vision import *" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "colab_type": "text", 32 | "id": "C77ffh0whUZ7" 33 | }, 34 | "source": [ 35 | "## GPU " 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 0, 41 | "metadata": { 42 | "colab": { 43 | "base_uri": "https://localhost:8080/", 44 | "height": 34 45 | }, 46 | "colab_type": "code", 47 | "id": "syfiMzschUZ8", 48 | "outputId": "a555f3d9-91cb-43e1-8302-a037fbb5efe9" 49 | }, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/plain": [ 54 | "True" 55 | ] 56 | }, 57 | "execution_count": 2, 58 | "metadata": { 59 | "tags": [] 60 | }, 61 | "output_type": "execute_result" 62 | } 63 | ], 64 | "source": [ 65 | "# Check GPU availablity\n", 66 | "torch.cuda.is_available()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 0, 72 | "metadata": { 73 | "colab": { 74 | "base_uri": "https://localhost:8080/", 75 | "height": 34 76 | }, 77 | "colab_type": "code", 78 | "id": "TOznfYKmhUaB", 79 | "outputId": "6b060a7f-72ae-43c1-c0b0-e521b770ffad" 80 | }, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "1" 86 | ] 87 | }, 88 | "execution_count": 4, 89 | "metadata": { 90 | "tags": [] 91 | }, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "# Check mounted GPU devices\n", 97 | "torch.cuda.device_count()" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 0, 103 | "metadata": { 104 | "colab": { 105 | "base_uri": "https://localhost:8080/", 106 | "height": 34 107 | }, 108 | "colab_type": "code", 109 | "id": "BfqqvsjvhUaF", 110 | "outputId": "5afac64a-4654-41df-ed6d-6bc968cc8ef7" 111 | }, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "text/plain": [ 116 | "0" 117 | ] 118 | }, 119 | "execution_count": 16, 120 | "metadata": { 121 | "tags": [] 122 | }, 123 | "output_type": "execute_result" 124 | } 125 | ], 126 | "source": [ 127 | "# Current device you're using\n", 128 | "# * 0-indexed *\n", 129 | "torch.cuda.current_device()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 2, 135 | "metadata": { 136 | "code_folding": [], 137 | "colab": { 138 | "base_uri": "https://localhost:8080/", 139 | "height": 289 140 | }, 141 | "colab_type": "code", 142 | "id": "hIcrVEJWhUaK", 143 | "outputId": "52eb889d-49ef-433b-9528-0ac9b514dc76" 144 | }, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "Wed Jun 10 22:31:28 2020 \n", 151 | "+-----------------------------------------------------------------------------+\n", 152 | "| NVIDIA-SMI 418.67 Driver Version: 418.67 CUDA Version: 10.1 |\n", 153 | "|-------------------------------+----------------------+----------------------+\n", 154 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 155 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 156 | "|===============================+======================+======================|\n", 157 | "| 0 Tesla V100-PCIE... On | 00000000:00:06.0 Off | 0 |\n", 158 | "| N/A 36C P0 41W / 250W | 4990MiB / 32480MiB | 0% Default |\n", 159 | "+-------------------------------+----------------------+----------------------+\n", 160 | "| 1 Tesla V100-PCIE... On | 00000000:00:07.0 Off | 0 |\n", 161 | "| N/A 37C P0 43W / 250W | 936MiB / 32480MiB | 0% Default |\n", 162 | "+-------------------------------+----------------------+----------------------+\n", 163 | "| 2 Tesla V100-PCIE... On | 00000000:00:08.0 Off | 0 |\n", 164 | "| N/A 34C P0 39W / 250W | 936MiB / 32480MiB | 0% Default |\n", 165 | "+-------------------------------+----------------------+----------------------+\n", 166 | " \n", 167 | "+-----------------------------------------------------------------------------+\n", 168 | "| Processes: GPU Memory |\n", 169 | "| GPU PID Type Process name Usage |\n", 170 | "|=============================================================================|\n", 171 | "| 0 5744 C ...u/anaconda3/envs/pytorch_p36/bin/python 941MiB |\n", 172 | "| 0 8881 C ...u/anaconda3/envs/pytorch_p36/bin/python 941MiB |\n", 173 | "+-----------------------------------------------------------------------------+\n" 174 | ] 175 | } 176 | ], 177 | "source": [ 178 | "# Check workloads of your GPU(s)\n", 179 | "!nvidia-smi" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 4, 185 | "metadata": { 186 | "colab": {}, 187 | "colab_type": "code", 188 | "id": "2BmajeAXhUaO" 189 | }, 190 | "outputs": [], 191 | "source": [ 192 | "# Reset your current device (if necessary)\n", 193 | "torch.cuda.set_device(1)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 3, 199 | "metadata": { 200 | "colab": {}, 201 | "colab_type": "code", 202 | "id": "tOVolw1UhUaS", 203 | "outputId": "eab8d203-31ee-41b2-822c-3a0abb058fe5" 204 | }, 205 | "outputs": [ 206 | { 207 | "data": { 208 | "text/plain": [ 209 | "1" 210 | ] 211 | }, 212 | "execution_count": 3, 213 | "metadata": {}, 214 | "output_type": "execute_result" 215 | } 216 | ], 217 | "source": [ 218 | "# Check change's been made\n", 219 | "torch.cuda.current_device()\n" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 5, 225 | "metadata": { 226 | "colab": { 227 | "base_uri": "https://localhost:8080/", 228 | "height": 34 229 | }, 230 | "colab_type": "code", 231 | "id": "wXALsotzhUaY", 232 | "outputId": "c2118aaa-bf6c-4d23-db4b-bfc0667ed5a2" 233 | }, 234 | "outputs": [ 235 | { 236 | "data": { 237 | "text/plain": [ 238 | "'Tesla V100-PCIE-32GB'" 239 | ] 240 | }, 241 | "execution_count": 5, 242 | "metadata": {}, 243 | "output_type": "execute_result" 244 | } 245 | ], 246 | "source": [ 247 | "# Check name of your device\n", 248 | "torch.cuda.get_device_name()" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": { 254 | "colab_type": "text", 255 | "id": "CSmixVUShUac" 256 | }, 257 | "source": [ 258 | "# Model Prototyping" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": { 264 | "colab_type": "text", 265 | "heading_collapsed": true, 266 | "id": "btBiurUPxx7t" 267 | }, 268 | "source": [ 269 | "## Helpers " 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 5, 275 | "metadata": { 276 | "code_folding": [ 277 | 0 278 | ], 279 | "colab": {}, 280 | "colab_type": "code", 281 | "hidden": true, 282 | "id": "jIljYlKohUad" 283 | }, 284 | "outputs": [], 285 | "source": [ 286 | "def conv_block(c_in, c_out, ks, num_groups=None, **conv_kwargs):\n", 287 | " \"A sequence of modules composed of Group Norm, ReLU and Conv3d in order\"\n", 288 | " if not num_groups : num_groups = int(c_in/2) if c_in%2 == 0 else None\n", 289 | " return nn.Sequential(nn.GroupNorm(num_groups, c_in),\n", 290 | " nn.ReLU(),\n", 291 | " nn.Conv3d(c_in, c_out, ks, **conv_kwargs))" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 6, 297 | "metadata": { 298 | "code_folding": [ 299 | 0 300 | ], 301 | "colab": {}, 302 | "colab_type": "code", 303 | "hidden": true, 304 | "id": "fLbUaZT7hUag" 305 | }, 306 | "outputs": [], 307 | "source": [ 308 | "def reslike_block(nf, num_groups=None, bottle_neck:bool=False, **conv_kwargs):\n", 309 | " \"A ResNet-like block with the GroupNorm normalization providing optional bottle-neck functionality\"\n", 310 | " nf_inner = nf / 2 if bottle_neck else nf\n", 311 | " return SequentialEx(conv_block(num_groups=num_groups, c_in=nf, c_out=nf_inner, ks=3, stride=1, padding=1, **conv_kwargs),\n", 312 | " conv_block(num_groups=num_groups, c_in=nf_inner, c_out=nf, ks=3, stride=1, padding=1, **conv_kwargs),\n", 313 | " MergeLayer())" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 7, 319 | "metadata": { 320 | "code_folding": [ 321 | 0 322 | ], 323 | "colab": {}, 324 | "colab_type": "code", 325 | "hidden": true, 326 | "id": "BKNJoGdsgbk2" 327 | }, 328 | "outputs": [], 329 | "source": [ 330 | "def upsize(c_in, c_out, ks=1, scale=2):\n", 331 | " \"Reduce the number of features by 2 using Conv with kernel size 1x1x1 and double the spatial dimension using 3D trilinear upsampling\"\n", 332 | " return nn.Sequential(nn.Conv3d(c_in, c_out, ks),\n", 333 | " nn.Upsample(scale_factor=scale, mode='trilinear'))" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 8, 339 | "metadata": { 340 | "code_folding": [ 341 | 0 342 | ], 343 | "colab": {}, 344 | "colab_type": "code", 345 | "hidden": true, 346 | "id": "6CjiBTnT8LFF" 347 | }, 348 | "outputs": [], 349 | "source": [ 350 | "def hook_debug(module, input, output):\n", 351 | " \"\"\"\n", 352 | " Print out what's been hooked usually for debugging purpose\n", 353 | " ----------------------------------------------------------\n", 354 | " Example:\n", 355 | " Hooks(ms, hook_debug, is_forward=True, detach=False)\n", 356 | " \n", 357 | " \"\"\"\n", 358 | " print('Hooking ' + module.__class__.__name__)\n", 359 | " print('output size:', output.data.size())\n", 360 | " return output" 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": { 366 | "colab_type": "text", 367 | "id": "MY131WWbx3nN" 368 | }, 369 | "source": [ 370 | "## Encoder Part" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 9, 376 | "metadata": { 377 | "code_folding": [ 378 | 0 379 | ], 380 | "colab": {}, 381 | "colab_type": "code", 382 | "id": "f_8ynHTavdvL" 383 | }, 384 | "outputs": [], 385 | "source": [ 386 | "class Encoder(nn.Module):\n", 387 | " \"Encoder part\"\n", 388 | " def __init__(self):\n", 389 | " super().__init__()\n", 390 | " self.conv1 = nn.Conv3d(4, 32, 3, stride=1, padding=1) \n", 391 | " self.res_block1 = reslike_block(32, num_groups=8)\n", 392 | " self.conv_block1 = conv_block(32, 64, 3, num_groups=8, stride=2, padding=1)\n", 393 | " self.res_block2 = reslike_block(64, num_groups=8)\n", 394 | " self.conv_block2 = conv_block(64, 64, 3, num_groups=8, stride=1, padding=1)\n", 395 | " self.res_block3 = reslike_block(64, num_groups=8)\n", 396 | " self.conv_block3 = conv_block(64, 128, 3, num_groups=8, stride=2, padding=1)\n", 397 | " self.res_block4 = reslike_block(128, num_groups=8)\n", 398 | " self.conv_block4 = conv_block(128, 128, 3, num_groups=8, stride=1, padding=1)\n", 399 | " self.res_block5 = reslike_block(128, num_groups=8)\n", 400 | " self.conv_block5 = conv_block(128, 256, 3, num_groups=8, stride=2, padding=1)\n", 401 | " self.res_block6 = reslike_block(256, num_groups=8)\n", 402 | " self.conv_block6 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n", 403 | " self.res_block7 = reslike_block(256, num_groups=8)\n", 404 | " self.conv_block7 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n", 405 | " self.res_block8 = reslike_block(256, num_groups=8)\n", 406 | " self.conv_block8 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n", 407 | " self.res_block9 = reslike_block(256, num_groups=8)\n", 408 | " \n", 409 | " def forward(self, x):\n", 410 | " x = self.conv1(x) # Output size: (1, 32, 160, 192, 128)\n", 411 | " x = self.res_block1(x) # Output size: (1, 32, 160, 192, 128)\n", 412 | " x = self.conv_block1(x) # Output size: (1, 64, 80, 96, 64)\n", 413 | " x = self.res_block2(x) # Output size: (1, 64, 80, 96, 64)\n", 414 | " x = self.conv_block2(x) # Output size: (1, 64, 80, 96, 64)\n", 415 | " x = self.res_block3(x) # Output size: (1, 64, 80, 96, 64)\n", 416 | " x = self.conv_block3(x) # Output size: (1, 128, 40, 48, 32)\n", 417 | " x = self.res_block4(x) # Output size: (1, 128, 40, 48, 32)\n", 418 | " x = self.conv_block4(x) # Output size: (1, 128, 40, 48, 32)\n", 419 | " x = self.res_block5(x) # Output size: (1, 128, 40, 48, 32)\n", 420 | " x = self.conv_block5(x) # Output size: (1, 256, 20, 24, 16)\n", 421 | " x = self.res_block6(x) # Output size: (1, 256, 20, 24, 16)\n", 422 | " x = self.conv_block6(x) # Output size: (1, 256, 20, 24, 16)\n", 423 | " x = self.res_block7(x) # Output size: (1, 256, 20, 24, 16)\n", 424 | " x = self.conv_block7(x) # Output size: (1, 256, 20, 24, 16)\n", 425 | " x = self.res_block8(x) # Output size: (1, 256, 20, 24, 16)\n", 426 | " x = self.conv_block8(x) # Output size: (1, 256, 20, 24, 16)\n", 427 | " x = self.res_block9(x) # Output size: (1, 256, 20, 24, 16)\n", 428 | " return x" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": 23, 434 | "metadata": { 435 | "code_folding": [ 436 | 0 437 | ], 438 | "colab": {}, 439 | "colab_type": "code", 440 | "id": "bgPmjWCq6n1F" 441 | }, 442 | "outputs": [], 443 | "source": [ 444 | "########## Sanity-check ############\n", 445 | "# input = torch.randn(1, 4, 160, 192, 128)\n", 446 | "# input = input.cuda()\n", 447 | "# encoder = Encoder()\n", 448 | "# encoder.cuda()\n", 449 | "# ms = [encoder.res_block1, encoder.res_block3, encoder.res_block5]\n", 450 | "# hooks = Hooks(ms, hook_debug, is_forward=True, detach=False)\n", 451 | "# output = encoder(input)" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": { 457 | "colab_type": "text", 458 | "id": "BKmlf1qY74Fx" 459 | }, 460 | "source": [ 461 | "## Decoder Part" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 10, 467 | "metadata": { 468 | "code_folding": [ 469 | 0 470 | ], 471 | "colab": {}, 472 | "colab_type": "code", 473 | "id": "p9jCdQAeBTch" 474 | }, 475 | "outputs": [], 476 | "source": [ 477 | "class Decoder(nn.Module):\n", 478 | " \"Decoder Part\"\n", 479 | " def __init__(self):\n", 480 | " super().__init__()\n", 481 | " self.upsize1 = upsize(256, 128)\n", 482 | " self.reslike1 = reslike_block(128, num_groups=8)\n", 483 | " self.upsize2 = upsize(128, 64)\n", 484 | " self.reslike2 = reslike_block(64, num_groups=8)\n", 485 | " self.upsize3 = upsize(64, 32)\n", 486 | " self.reslike3 = reslike_block(32, num_groups=8)\n", 487 | " self.conv1 = nn.Conv3d(32, 3, 1) \n", 488 | " self.sigmoid1 = torch.nn.Sigmoid()\n", 489 | "\n", 490 | " def forward(self, x):\n", 491 | " x = self.upsize1(x) # Output size: (1, 128, 40, 48, 32)\n", 492 | " x = x + hooks.stored[2] # Output size: (1, 128, 40, 48, 32)\n", 493 | " x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32)\n", 494 | " x = self.upsize2(x) # Output size: (1, 64, 80, 96, 64)\n", 495 | " x = x + hooks.stored[1] # Output size: (1, 64, 80, 96, 64)\n", 496 | " x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64)\n", 497 | " x = self.upsize3(x) # Output size: (1, 32, 160, 192, 128)\n", 498 | " x = x + hooks.stored[0] # Output size: (1, 32, 160, 192, 128)\n", 499 | " x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128)\n", 500 | " x = self.conv1(x) # Output size: (1, 3, 160, 192, 128)\n", 501 | " x = self.sigmoid1(x) # Output size: (1, 3, 160, 192, 128)\n", 502 | " return x" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 0, 508 | "metadata": { 509 | "code_folding": [ 510 | 0 511 | ], 512 | "colab": {}, 513 | "colab_type": "code", 514 | "id": "54LhlCx7hOt6" 515 | }, 516 | "outputs": [], 517 | "source": [ 518 | "############ Sanity-check ############\n", 519 | "# input = torch.randn(1, 256, 20, 24, 16)\n", 520 | "# input = input.cuda()\n", 521 | "# decoder = Decoder()\n", 522 | "# decoder.cuda()\n", 523 | "# output = decoder(input)\n", 524 | "# output.shape" 525 | ] 526 | }, 527 | { 528 | "cell_type": "markdown", 529 | "metadata": { 530 | "colab_type": "text", 531 | "id": "Sq9kLEFbx8sF" 532 | }, 533 | "source": [ 534 | "## VAE Part" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 11, 540 | "metadata": { 541 | "code_folding": [], 542 | "colab": {}, 543 | "colab_type": "code", 544 | "id": "KEpqknq3hUaq" 545 | }, 546 | "outputs": [], 547 | "source": [ 548 | "class VAEEncoder(nn.Module):\n", 549 | " \"Variational auto-encoder encoder part\"\n", 550 | " def __init__(self, latent_dim:int=128):\n", 551 | " super().__init__()\n", 552 | " self.latent_dim = latent_dim\n", 553 | " self.conv_block = conv_block(256, 16, 3, num_groups=8, stride=2, padding=1)\n", 554 | " self.linear1 = nn.Linear(60, 1)\n", 555 | " \n", 556 | " # Assumed latent variable's probability density function parameters\n", 557 | " self.z_mean = nn.Linear(256, latent_dim)\n", 558 | " self.z_log_var = nn.Linear(256, latent_dim)\n", 559 | " #TODO: It should work with or without GPU\n", 560 | " self.epsilon = torch.randn(1, latent_dim, device='cuda')\n", 561 | " \n", 562 | " def forward(self, x):\n", 563 | " x = self.conv_block(x) # Output size: (1, 16, 10, 12, 8) \n", 564 | " x = x.view(256, -1) # Output size: (256, 60) \n", 565 | " x = self.linear1(x) # Output size: (256, 1)\n", 566 | " x = x.view(1, 256) # Output size: (1, 256) \n", 567 | " z_mean = self.z_mean(x) # Output size: (1, 128)\n", 568 | " z_var = self.z_log_var(x).exp() # Output size: (1, 128) \n", 569 | " \n", 570 | " return z_mean + z_var * self.epsilon # Output size: (1, 128) " 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": 11, 576 | "metadata": { 577 | "code_folding": [ 578 | 0 579 | ], 580 | "colab": { 581 | "base_uri": "https://localhost:8080/", 582 | "height": 34 583 | }, 584 | "colab_type": "code", 585 | "id": "ll26pBm9tj7-", 586 | "outputId": "f1e9300e-8e79-4c66-8d0e-6897ce6b7f80" 587 | }, 588 | "outputs": [], 589 | "source": [ 590 | "############ Sanity-check ############\n", 591 | "# input = torch.randn(1, 256, 20, 24, 16)\n", 592 | "# input = input.cuda()\n", 593 | "# vae_encoder = VAEEncoder(latent_dim=128)\n", 594 | "# vae_encoder.cuda()\n", 595 | "# output = vae_encoder(output)\n", 596 | "# output.shape" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 12, 602 | "metadata": { 603 | "code_folding": [ 604 | 0 605 | ], 606 | "colab": {}, 607 | "colab_type": "code", 608 | "id": "tl4tYTaXe1qw" 609 | }, 610 | "outputs": [], 611 | "source": [ 612 | "class VAEDecoder(nn.Module):\n", 613 | " \"Variational auto-encoder decoder part\"\n", 614 | " def __init__(self):\n", 615 | " super().__init__()\n", 616 | " self.linear1 = nn.Linear(128, 256*60)\n", 617 | " self.relu1 = nn.ReLU()\n", 618 | " self.upsize1 = upsize(16, 256)\n", 619 | " self.upsize2 = upsize(256, 128)\n", 620 | " self.reslike1 = reslike_block(128, num_groups=8)\n", 621 | " self.upsize3 = upsize(128, 64)\n", 622 | " self.reslike2 = reslike_block(64, num_groups=8)\n", 623 | " self.upsize4 = upsize(64, 32)\n", 624 | " self.reslike3 = reslike_block(32, num_groups=8)\n", 625 | " self.conv1 = nn.Conv3d(32, 4, 1)\n", 626 | " \n", 627 | " def forward(self, x):\n", 628 | " x = self.linear1(x) # Output size: (1, 256*60) \n", 629 | " x = self.relu1(x) # Output size: (1, 256*60)\n", 630 | " x = x.view(1, 16, 10, 12, 8) # Output size: (1, 16, 10, 12, 8)\n", 631 | " x = self.upsize1(x) # Output size: (1, 256, 20, 24, 16)\n", 632 | " x = self.upsize2(x) # Output size: (1, 128, 40, 48, 32)\n", 633 | " x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32)\n", 634 | " x = self.upsize3(x) # Output size: (1, 64, 80, 96, 64)\n", 635 | " x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64)\n", 636 | " x = self.upsize4(x) # Output size: (1, 32, 160, 192, 128)\n", 637 | " x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128)\n", 638 | " x = self.conv1(x) # Output size: (1, 4, 160, 192, 128) \n", 639 | " return x" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": 0, 645 | "metadata": { 646 | "code_folding": [ 647 | 0 648 | ], 649 | "colab": {}, 650 | "colab_type": "code", 651 | "id": "RrusoNDpzPOk" 652 | }, 653 | "outputs": [], 654 | "source": [ 655 | "############ Sanity-check ############\n", 656 | "# input = torch.randn(1, 128)\n", 657 | "# input = input.cuda()\n", 658 | "# vae_decoder = VAEDecoder()\n", 659 | "# vae_decoder.cuda()\n", 660 | "# vae_decoder(output).shape" 661 | ] 662 | }, 663 | { 664 | "cell_type": "markdown", 665 | "metadata": { 666 | "colab_type": "text", 667 | "id": "dtLzCKAOEn6c" 668 | }, 669 | "source": [ 670 | "## AutoUNet" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": 13, 676 | "metadata": { 677 | "code_folding": [], 678 | "colab": {}, 679 | "colab_type": "code", 680 | "id": "9lhVuR2QExrp" 681 | }, 682 | "outputs": [], 683 | "source": [ 684 | "class AutoUNet(nn.Module):\n", 685 | " \"3D U-Net using autoencoder regularization\"\n", 686 | " def __init__(self):\n", 687 | " super().__init__()\n", 688 | " self.encoder = Encoder()\n", 689 | " self.decoder = Decoder()\n", 690 | " self.vencoder = VAEEncoder(latent_dim=128)\n", 691 | " self.vdecoder = VAEDecoder()\n", 692 | "\n", 693 | " def forward(self, input):\n", 694 | " interm_res = self.encoder(input)\n", 695 | " top_res = self.decoder(interm_res) # Output size: (1, 3, 160, 192, 128)\n", 696 | " bottom_res = self.vdecoder(self.vencoder(interm_res)) # Output size: (1, 4, 160, 192, 128)\n", 697 | " return top_res, bottom_res" 698 | ] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "execution_count": null, 703 | "metadata": { 704 | "code_folding": [], 705 | "scrolled": true 706 | }, 707 | "outputs": [], 708 | "source": [ 709 | "############ Sanity-check ############\n", 710 | "input = torch.randn(1, 4, 160, 192, 128)\n", 711 | "input = input.cuda()\n", 712 | "model = AutoUNet()\n", 713 | "model.cuda()\n", 714 | "\n", 715 | "ms = [model.encoder.res_block1, \n", 716 | " model.encoder.res_block3, \n", 717 | " model.encoder.res_block5, \n", 718 | " model.vencoder.z_mean, \n", 719 | " model.vencoder.z_log_var]\n", 720 | "\n", 721 | "hooks = hook_outputs(ms, detach=False, grad=False) #check: overwrite for each iteration?\n", 722 | "#hooks = Hooks(ms, hook_debug, is_forward=True, detach=False)\n", 723 | "\n", 724 | "output = model(input)" 725 | ] 726 | }, 727 | { 728 | "cell_type": "markdown", 729 | "metadata": { 730 | "colab_type": "text", 731 | "id": "ZSPf7atqhOuG" 732 | }, 733 | "source": [ 734 | "## Custom Loss " 735 | ] 736 | }, 737 | { 738 | "cell_type": "code", 739 | "execution_count": null, 740 | "metadata": { 741 | "code_folding": [], 742 | "colab": { 743 | "base_uri": "https://localhost:8080/", 744 | "height": 85 745 | }, 746 | "colab_type": "code", 747 | "id": "OQ4vfaR-L9Wz", 748 | "outputId": "cd5cb780-4027-4e12-e0de-07485713db38", 749 | "scrolled": false 750 | }, 751 | "outputs": [], 752 | "source": [ 753 | "# Set the global variables\n", 754 | "_, C, H, W, D = [input.shape[i] for i in range(len(input.shape))]\n", 755 | "c = output[0].shape[1]\n", 756 | "\n", 757 | "print(\"Channels:\", C)\n", 758 | "print(\"Height:\", H)\n", 759 | "print(\"Width:\", W)\n", 760 | "print(\"Depth:\", D)\n", 761 | "print(\"The Number Of Labels:\", c)" 762 | ] 763 | }, 764 | { 765 | "cell_type": "code", 766 | "execution_count": 0, 767 | "metadata": { 768 | "code_folding": [], 769 | "colab": {}, 770 | "colab_type": "code", 771 | "id": "j7cmXkIvhOuI" 772 | }, 773 | "outputs": [], 774 | "source": [ 775 | "class SoftDiceLoss(Module): \n", 776 | " \"Soft dice loss based on a measure of overlap between prediction and ground truth\"\n", 777 | " def __init__(self, epsilon=1e-6, c=c):\n", 778 | " super().__init__()\n", 779 | " self.epsilon = epsilon\n", 780 | " self.c = c\n", 781 | " \n", 782 | " def forward(self, x:Tensor, y:Tensor):\n", 783 | " intersection = 2 * ( (x*y).sum() )\n", 784 | " union = (x**2).sum() + (y**2).sum() \n", 785 | " return 1 - ( ( intersection / (union + self.epsilon) ) / self.c )" 786 | ] 787 | }, 788 | { 789 | "cell_type": "code", 790 | "execution_count": null, 791 | "metadata": { 792 | "code_folding": [ 793 | 0 794 | ] 795 | }, 796 | "outputs": [], 797 | "source": [ 798 | "####### Sanity-check ############\n", 799 | "loss = " 800 | ] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "execution_count": 16, 805 | "metadata": { 806 | "code_folding": [], 807 | "colab": {}, 808 | "colab_type": "code", 809 | "id": "kOjrJ44uhOuK" 810 | }, 811 | "outputs": [], 812 | "source": [ 813 | "class KLDivergence(Module): \n", 814 | " \"KL divergence between the estimated normal distribution and a prior distribution\"\n", 815 | " N = H * W * D #hyperparameter check\n", 816 | "\n", 817 | " def __init__(self):\n", 818 | " super().__init__()\n", 819 | " \n", 820 | " def forward(self, z_mean:Tensor, z_log_var:Tensor):\n", 821 | " z_var = z_log_var.exp()\n", 822 | " return (1/self.N) * ( (z_mean**2 + z_var**2 - z_log_var**2 - 1).sum() )" 823 | ] 824 | }, 825 | { 826 | "cell_type": "code", 827 | "execution_count": null, 828 | "metadata": { 829 | "code_folding": [] 830 | }, 831 | "outputs": [], 832 | "source": [ 833 | "####### Sanity-check ############\n", 834 | "loss2 = KLDivergence()(z_mean=hooks.stored[3], z_log_var=hooks.stored[4])\n", 835 | "print(loss2)\n", 836 | "loss2.backward()" 837 | ] 838 | }, 839 | { 840 | "cell_type": "code", 841 | "execution_count": 18, 842 | "metadata": { 843 | "code_folding": [ 844 | 0 845 | ], 846 | "colab": {}, 847 | "colab_type": "code", 848 | "id": "HycYhLrohOuM" 849 | }, 850 | "outputs": [], 851 | "source": [ 852 | "class L2Loss(Module): \n", 853 | " \"Measuring the `Euclidian distance` between prediction and ground truh using `L2 Norm`\"\n", 854 | " def __init__(self):\n", 855 | " super().__init__()\n", 856 | " \n", 857 | " def forward(self, x:Tensor, y:Tensor):\n", 858 | " return ( (x - y)**2 ).sum() " 859 | ] 860 | }, 861 | { 862 | "cell_type": "code", 863 | "execution_count": null, 864 | "metadata": { 865 | "code_folding": [ 866 | 0 867 | ] 868 | }, 869 | "outputs": [], 870 | "source": [ 871 | "####### Sanity-check ############\n", 872 | "loss3 = L2Loss()(bottom_res=output[1], orig=input)\n", 873 | "print(loss3)\n", 874 | "loss3.backward()" 875 | ] 876 | }, 877 | { 878 | "cell_type": "markdown", 879 | "metadata": { 880 | "colab_type": "text", 881 | "id": "MsP_HOw2_6Jd" 882 | }, 883 | "source": [ 884 | "## Optimizer" 885 | ] 886 | }, 887 | { 888 | "cell_type": "code", 889 | "execution_count": 0, 890 | "metadata": { 891 | "colab": {}, 892 | "colab_type": "code", 893 | "id": "XYaFQ6nQ_8O4" 894 | }, 895 | "outputs": [], 896 | "source": [ 897 | "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)" 898 | ] 899 | }, 900 | { 901 | "cell_type": "markdown", 902 | "metadata": { 903 | "colab_type": "text", 904 | "id": "GPK9Qfc0_tGL" 905 | }, 906 | "source": [ 907 | "## Training" 908 | ] 909 | }, 910 | { 911 | "cell_type": "code", 912 | "execution_count": 0, 913 | "metadata": { 914 | "code_folding": [], 915 | "colab": {}, 916 | "colab_type": "code", 917 | "id": "3YkqxURk_w8K" 918 | }, 919 | "outputs": [], 920 | "source": [ 921 | "for epoch in range(epochs):\n", 922 | " \n", 923 | " model.train()\n", 924 | " for xb,yb in train_dl:\n", 925 | " top_res, bottom_res = model(xb)\n", 926 | " top_y, bottom_y = train_seg, input\n", 927 | " z_mean, z_log_var = hooks.stored[4], hooks.stored[5] \n", 928 | " loss = SoftDiceLoss()(top_res, top_y) + \\\n", 929 | " (0.1 * KLDivergence()(z_mean, z_log_var)) + \\\n", 930 | " (0.1 * L2Loss()(bottom_res, bottom_y))\n", 931 | " loss.backward()\n", 932 | " optimizer.step()\n", 933 | " optimizer.zero_grad()\n", 934 | "\n", 935 | " model.eval()\n", 936 | " with torch.no_grad():\n", 937 | " tot_loss, tot_acc = 0., 0.\n", 938 | " for xb, yb in valid_dl: \n", 939 | " top_res, bottom_res = model(xb)\n", 940 | " top_y, bottom_y = valid_seg, input\n", 941 | " z_mean, z_log_var = hooks.stored[4], hooks.stored[5]\n", 942 | " loss = SoftDiceLoss()(top_res, top_y) + \\\n", 943 | " (0.1 * KLDivergence()(z_mean, z_log_var)) + \\\n", 944 | " (0.1 * L2Loss()(bottom_res, bottom_y)) \n", 945 | " tot_loss += loss\n", 946 | " tot_acc += dice_coeff\n", 947 | "\n", 948 | " nv = len(valid_dl)\n", 949 | " return tot_loss/nv, tot_acc/nv" 950 | ] 951 | }, 952 | { 953 | "cell_type": "markdown", 954 | "metadata": { 955 | "colab_type": "text", 956 | "heading_collapsed": true, 957 | "id": "GXaVq0m5hUbO" 958 | }, 959 | "source": [ 960 | "## Memory-check" 961 | ] 962 | }, 963 | { 964 | "cell_type": "code", 965 | "execution_count": 21, 966 | "metadata": { 967 | "colab": {}, 968 | "colab_type": "code", 969 | "hidden": true, 970 | "id": "Xuy-W1NFhUbR", 971 | "outputId": "f9a8cc29-2291-488f-ea8f-44b63cb8bd29" 972 | }, 973 | "outputs": [ 974 | { 975 | "data": { 976 | "text/plain": [ 977 | "9884946432" 978 | ] 979 | }, 980 | "execution_count": 21, 981 | "metadata": {}, 982 | "output_type": "execute_result" 983 | } 984 | ], 985 | "source": [ 986 | "# Memory ocuupied by Pytorch `Tensors`\n", 987 | "torch.cuda.memory_allocated(device=None)" 988 | ] 989 | }, 990 | { 991 | "cell_type": "code", 992 | "execution_count": 22, 993 | "metadata": { 994 | "colab": { 995 | "base_uri": "https://localhost:8080/", 996 | "height": 697 997 | }, 998 | "colab_type": "code", 999 | "hidden": true, 1000 | "id": "-F1mbF44hUbO", 1001 | "outputId": "6b4ff5a9-766a-48d0-fc2c-bd0675e303e8", 1002 | "scrolled": true 1003 | }, 1004 | "outputs": [ 1005 | { 1006 | "name": "stdout", 1007 | "output_type": "stream", 1008 | "text": [ 1009 | "|===========================================================================|\n", 1010 | "| PyTorch CUDA memory summary, device ID 1 |\n", 1011 | "|---------------------------------------------------------------------------|\n", 1012 | "| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n", 1013 | "|===========================================================================|\n", 1014 | "| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n", 1015 | "|---------------------------------------------------------------------------|\n", 1016 | "| Allocated memory | 9427 MB | 9855 MB | 10859 MB | 1432 MB |\n", 1017 | "| from large pool | 9423 MB | 9851 MB | 10847 MB | 1424 MB |\n", 1018 | "| from small pool | 3 MB | 3 MB | 11 MB | 8 MB |\n", 1019 | "|---------------------------------------------------------------------------|\n", 1020 | "| Active memory | 9427 MB | 9855 MB | 10859 MB | 1432 MB |\n", 1021 | "| from large pool | 9423 MB | 9851 MB | 10847 MB | 1424 MB |\n", 1022 | "| from small pool | 3 MB | 3 MB | 11 MB | 8 MB |\n", 1023 | "|---------------------------------------------------------------------------|\n", 1024 | "| GPU reserved memory | 9482 MB | 10012 MB | 10012 MB | 542720 KB |\n", 1025 | "| from large pool | 9478 MB | 10008 MB | 10008 MB | 542720 KB |\n", 1026 | "| from small pool | 4 MB | 4 MB | 4 MB | 0 KB |\n", 1027 | "|---------------------------------------------------------------------------|\n", 1028 | "| Non-releasable memory | 56300 KB | 168416 KB | 981 MB | 926 MB |\n", 1029 | "| from large pool | 55744 KB | 167872 KB | 970 MB | 915 MB |\n", 1030 | "| from small pool | 556 KB | 2034 KB | 11 MB | 11 MB |\n", 1031 | "|---------------------------------------------------------------------------|\n", 1032 | "| Allocations | 193 | 265 | 837 | 644 |\n", 1033 | "| from large pool | 76 | 124 | 145 | 69 |\n", 1034 | "| from small pool | 117 | 142 | 692 | 575 |\n", 1035 | "|---------------------------------------------------------------------------|\n", 1036 | "| Active allocs | 193 | 265 | 837 | 644 |\n", 1037 | "| from large pool | 76 | 124 | 145 | 69 |\n", 1038 | "| from small pool | 117 | 142 | 692 | 575 |\n", 1039 | "|---------------------------------------------------------------------------|\n", 1040 | "| GPU reserved segments | 66 | 91 | 91 | 25 |\n", 1041 | "| from large pool | 64 | 89 | 89 | 25 |\n", 1042 | "| from small pool | 2 | 2 | 2 | 0 |\n", 1043 | "|---------------------------------------------------------------------------|\n", 1044 | "| Non-releasable allocs | 10 | 31 | 100 | 90 |\n", 1045 | "| from large pool | 8 | 29 | 42 | 34 |\n", 1046 | "| from small pool | 2 | 5 | 58 | 56 |\n", 1047 | "|===========================================================================|\n", 1048 | "\n" 1049 | ] 1050 | } 1051 | ], 1052 | "source": [ 1053 | "# Memory status\n", 1054 | "print(torch.cuda.memory_summary(device=None, abbreviated=False))" 1055 | ] 1056 | } 1057 | ], 1058 | "metadata": { 1059 | "accelerator": "GPU", 1060 | "colab": { 1061 | "collapsed_sections": [ 1062 | "b6KgSfpahUZ1", 1063 | "C77ffh0whUZ7", 1064 | "btBiurUPxx7t", 1065 | "MY131WWbx3nN", 1066 | "BKmlf1qY74Fx", 1067 | "Sq9kLEFbx8sF", 1068 | "dtLzCKAOEn6c", 1069 | "ZSPf7atqhOuG", 1070 | "MsP_HOw2_6Jd", 1071 | "GPK9Qfc0_tGL", 1072 | "GXaVq0m5hUbO", 1073 | "v-ODWm3ehUbG" 1074 | ], 1075 | "name": "model_prototype_1.ipynb", 1076 | "provenance": [] 1077 | }, 1078 | "kernelspec": { 1079 | "display_name": "Python 3", 1080 | "language": "python", 1081 | "name": "python3" 1082 | }, 1083 | "language_info": { 1084 | "codemirror_mode": { 1085 | "name": "ipython", 1086 | "version": 3 1087 | }, 1088 | "file_extension": ".py", 1089 | "mimetype": "text/x-python", 1090 | "name": "python", 1091 | "nbconvert_exporter": "python", 1092 | "pygments_lexer": "ipython3", 1093 | "version": "3.6.5" 1094 | } 1095 | }, 1096 | "nbformat": 4, 1097 | "nbformat_minor": 1 1098 | } 1099 | -------------------------------------------------------------------------------- /ndv/modules/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch, fastai, sys, os 2 | from fastai.vision import * 3 | from fastai.vision.data import SegmentationProcessor 4 | import ants 5 | from ants.core.ants_image import ANTsImage 6 | from jupyterthemes import jtplot 7 | sys.path.insert(0, './exp') 8 | jtplot.style(theme='gruvboxd') 9 | 10 | # Set a root directory 11 | path = Path('/home/ubuntu/MultiCampus/MICCAI_BraTS_2019_Data_Training') 12 | 13 | def is_mod(fn:str, mod:str)->bool: 14 | "Check if file path contains a specified name of modality used for MRI" 15 | import re 16 | r = re.compile('.*' + mod, re.IGNORECASE) 17 | return True if r.match(fn) else False 18 | 19 | def is_mods(fn:str, mods:Collection[str])->bool: 20 | "Check if file path contains specified names of modality used for MRI" 21 | import re 22 | return any([is_mod(fn, mod) for mod in mods]) 23 | 24 | def _path_to_same_str(p_fn): 25 | "path -> str, but same on nt+posix, for alpha-sort only" 26 | s_fn = str(p_fn) 27 | s_fn = s_fn.replace('\\','.') 28 | s_fn = s_fn.replace('/','.') 29 | return s_fn 30 | 31 | def _get_files(path, file, modality): 32 | """ 33 | Internal implementation for `get_files` to combine a parent directory with a file 34 | to make a full path to file(s) 35 | """ 36 | p = Path(path) 37 | res = [p/o for o in file if not o.startswith('.') and is_mods(o, modality)] 38 | assert len(res)==len(modality) #TODO: Assert message 39 | return res 40 | 41 | def get_files(path:PathOrStr, modality:Union[str, Collection[str]], 42 | presort:bool=False)->FilePathList: 43 | "Return a list of full file paths in `path` each of which contains modality in its name" 44 | file = [o.name for o in os.scandir(path) if o.is_file()] 45 | res = _get_files(path, file, modality) 46 | if presort: res = sorted(res, key=lambda p: _path_to_same_str(p), reverse=False) 47 | return res 48 | 49 | def _repr_antsimage(self): 50 | if self.dimension == 3: 51 | s = 'NiftiImage ({})\n'.format(self.orientation) 52 | else: 53 | s = 'NiftiImage\n' 54 | s = s +\ 55 | '\t {:<10} : {} ({})\n'.format('Pixel Type', self.pixeltype, self.dtype)+\ 56 | '\t {:<10} : {}{}\n'.format('Components', self.components, ' (RGB)' if 'RGB' in self._libsuffix else '')+\ 57 | '\t {:<10} : {}\n'.format('Dimensions', self.shape)+\ 58 | '\t {:<10} : {}\n'.format('Spacing', tuple([round(s,4) for s in self.spacing]))+\ 59 | '\t {:<10} : {}\n'.format('Origin', tuple([round(o,4) for o in self.origin]))+\ 60 | '\t {:<10} : {}\n'.format('Direction', np.round(self.direction.flatten(),4)) 61 | return s 62 | 63 | # Modify the representation of `ANTsImage` object 64 | ANTsImage.__repr__ = _repr_antsimage 65 | 66 | class NiftiImage(ItemBase): 67 | "Support handling NIfTI image format" 68 | #TODO: Extend the code so as to support various Python (medical) libraries that can read NIfTI format 69 | def __init__(self, data:Union[Tensor,np.array], obj:ANTsImage, path:str): 70 | self.data = data 71 | self.obj = obj 72 | self.path = path 73 | # Only works for a specific folder tree 74 | self.mod = self.path.split(".")[0].split("_")[-1] 75 | 76 | def __repr__(self): return str(self.obj) + '\t {:<10} : {}\n\n'.format('Modality', str(self.mod)) 77 | 78 | def __getattr__(self, k:str): 79 | func = getattr(self.obj, k) 80 | if isinstance(func, Callable): return func 81 | 82 | def __setattr__(self, k, v): 83 | if k == 'obj': 84 | self.data = torch.tensor(v.numpy()) 85 | return super().__setattr__(k, v) 86 | 87 | # This wraps ANTsPy's `plot` method to show NIfTI image 88 | def show(self, **kwargs): 89 | ants.plot(self.obj) 90 | 91 | # This wraps ANTsPy's `image_read` method to read NIfTI format 92 | @classmethod 93 | def create(cls, path:PathOrStr): 94 | nimg = ants.image_read(str(path)) 95 | t = torch.tensor(nimg.numpy()) 96 | return cls(t, nimg, path) 97 | 98 | def apply_tfms(self, tfms:List[Transform], *args, order='order', **kwargs): 99 | key = lambda o : getattr(o, order, 0) 100 | for tfm in sorted(listify(tfms), key=key): self = tfm(self, *args, **kwargs) #ascending order eg. [3,2,1] -> [1,2,3] 101 | return self 102 | 103 | class MultiNiftiImage(ItemBase): 104 | "Support handling multi-channel NIfTI images" 105 | def __init__(self, obj:Tuple[NiftiImage]): 106 | self.obj = obj # type annotation violated when `subregionify` is used. Should be fixed. 107 | self.data = None 108 | 109 | def __repr__(self): 110 | return f"Inside {self.__class__.__name__}:\n {[self.obj[i] for i in range(len(self.obj))]}" 111 | 112 | def __getitem__(self, i): 113 | return self.obj[i] 114 | 115 | @classmethod 116 | def create(cls, paths:FilePathList): 117 | obj = tuple([NiftiImage.create(str(path)) for path in paths]) 118 | return cls(obj) 119 | 120 | def apply_tfms(self, tfms:List[Transform], *args, order='order', **kwargs): 121 | self.obj = tuple([self.obj[i].apply_tfms(tfms, order, *args, **kwargs) for i in range(len(self.obj))]) 122 | self.data = torch.stack([nft.data for nft in self.obj], dim=0) 123 | return self 124 | 125 | @property 126 | def data(self): 127 | return self._data 128 | 129 | @data.setter 130 | def data(self, _): 131 | self._data = ( torch.stack([nft.data for nft in self.obj], dim=0) 132 | if hasattr(self.obj[0], "data") 133 | else torch.stack([torch.tensor(nft.numpy()) for nft in self.obj], dim=0) ) 134 | 135 | class NiftiImageList(ItemList): 136 | 137 | def __repr__(self)->str: 138 | return '{} ({} items)\n{}\nPath: {}'.format(self.__class__.__name__, 139 | len(self.items), show_some(self.items, n_max=4, sep="\n"), 140 | self.path) 141 | def get(self, i)->NiftiImage: 142 | fn = str(self.items[i]) 143 | return NiftiImage.create(fn) 144 | 145 | class MultiNiftiImageList(ItemList): 146 | 147 | def __repr__(self)->str: 148 | return '{} ({} items)\n{}\nPath: {}'.format(self.__class__.__name__, 149 | len(self.items), show_some(self.items, n_max=4, sep="\n"), 150 | self.path) 151 | def get(self, i)->MultiNiftiImage: 152 | filepaths = [str(self.items[i][x]) for x in range(len(self.items[i]))] 153 | return MultiNiftiImage.create(filepaths) 154 | 155 | @classmethod 156 | def from_folder(cls, folderpaths:FilePathList, modality:Union[str, Collection[str]], 157 | presort:bool=False, **kwargs): 158 | """ 159 | This method assumes a list of full paths to the desired files's parent folders 160 | and returns NiftiImageTupleList whose item is a nested list with each sublist 161 | belonging to its parent folder 162 | ------------------------------------------------------------------------- 163 | Test: 164 | assert len(filepaths) == len(path) 165 | 166 | """ 167 | filepaths=[] 168 | for fp in folderpaths: 169 | filepath = get_files(fp, modality=modality, presort=True) 170 | filepaths.append(filepath) 171 | 172 | return cls(items=filepaths, path=path, **kwargs) 173 | 174 | hgg_subdirs = (path/'HGG').ls() 175 | lgg_subdirs = (path/'LGG').ls() 176 | parent_folders = hgg_subdirs + lgg_subdirs 177 | 178 | def get_parents(path:Path, pname:str, shuffle:bool=True, pct=0.2): 179 | "List a certain percent of items under a specified parent directory randomly or not" 180 | from random import shuffle 181 | ps = [d[i] for r,d,_ in os.walk(path) for i in range(len(d)) if Path(r).name==pname] 182 | if shuffle: shuffle(ps) 183 | return ps[:round((pct*len(ps)))] 184 | 185 | def write_val_list(fname:str='valid.txt', vals:List[str]=None): 186 | "Write a list of names into `fname` to be used for train/validation split" 187 | val_list = vals 188 | with open(fname, 'w') as f: 189 | f.write('\n'.join(val_list)) 190 | print("{} items written into {}.".format(len(val_list), fname)) 191 | 192 | val_list = get_parents(path, 'HGG', pct=0.15) + get_parents(path, 'LGG', pct=0.1) 193 | write_val_list('valid.txt', val_list) 194 | 195 | def split_by_parents(self, valid_names:'ItemList')->'ItemLists': 196 | "Split the data by using the parent names in `valid_names` for validation." 197 | return self.split_by_valid_func(lambda o: o.parent.name in valid_names) 198 | 199 | def split_by_pname_file(self, fname:PathOrStr, path:PathOrStr=None)->'ItemLists': 200 | "Split the data by using the parent names in `fname` for the validation set. `path` will override `self.path`." 201 | path = Path(ifnone(path, self.path)) 202 | valid_names = loadtxt_str(path/fname) 203 | return self.split_by_parents(valid_names) 204 | 205 | def split_by_valid_func(self, func:Callable)->'ItemLists': 206 | "Split the data by result of `func` (which returns `True` for validation set)." 207 | valid_idx = [i for i,o in enumerate(self.items) if func(o[0])] 208 | return self.split_by_idx(valid_idx) 209 | 210 | def _repr_labellist(self)->str: 211 | items = [self[i] for i in range(min(1,len(self.items)))] 212 | res = f'{self.__class__.__name__} ({len(self.items)} items)\n' 213 | res += f'x: {self.x.__class__.__name__}\n{show_some([i[0] for i in items], n_max=1)}\n' 214 | res += f'y: {self.y.__class__.__name__}\n{show_some([i[1] for i in items], n_max=1)}\n' 215 | return res + f'Path: {self.path}' 216 | 217 | # Modify the methods of `MultiNiftiImageList` object 218 | MultiNiftiImageList.split_by_parents = split_by_parents 219 | MultiNiftiImageList.split_by_pname_file = split_by_pname_file 220 | MultiNiftiImageList.split_by_valid_func = split_by_valid_func 221 | 222 | # Modify the representation of `LabelList` object 223 | LabelList.__repr__ = _repr_labellist 224 | 225 | class NiftiSegmentationLabelList(NiftiImageList): 226 | "`ItemList` for NIfTI segmentatoin masks" 227 | _processor=SegmentationProcessor 228 | 229 | def __init__(self, items:Iterator, classes:Collection=None, **kwargs): 230 | super().__init__(items, **kwargs) 231 | self.copy_new.append('classes') 232 | self.classes,self.loss_func = classes,None 233 | 234 | def reconstruct(self, t:Tensor): 235 | obj = ants.from_numpy(t.numpy()) 236 | path = self.path 237 | return NiftiImage(t, obj, path) 238 | 239 | get_y_fn = lambda x: x[0].parent/Path(x[0].as_posix().split(os.sep)[-2]+'_seg.nii.gz') 240 | 241 | subregion = np.array(['WT', 'TC', 'ET']) 242 | 243 | def crop_3d(item:NiftiImage, do_resolve=False, *args, lowerind:Tuple, upperind:Tuple, **kwargs): 244 | "Crop 3-dimensional NIfTI image by slicing indices from lower to upper indices per image axis" 245 | cropped_item = item.obj.crop_indices(lowerind, upperind) 246 | item.obj = cropped_item 247 | return item 248 | 249 | def standardize(item:NiftiImage, do_resolve=False, *args, **kwargs): 250 | "Standardize our custom itembase `NiftiImage` to have zero mean and unit std based on non-zero voxels only" 251 | arr = item.obj.numpy() 252 | arr_nonzero = arr[arr!=0] 253 | arr_nonzero = (arr_nonzero - arr_nonzero.mean()) / arr_nonzero.std() 254 | arr[arr!=0] = arr_nonzero / arr_nonzero.max() 255 | item.obj = ants.from_numpy(arr) 256 | return item 257 | 258 | def subregionify(item:NiftiImage, do_resolve=False, *args, **kwargs): 259 | "Combine the three annotations into 3 nested subregions: Whole Tumor(WT), Tumor Core(TC), Enhancing Tumor(ET)" 260 | arr = item.obj.numpy() 261 | wt_arr = arr.copy() 262 | wt_arr[wt_arr==1.] = 1.; wt_arr[wt_arr==2.] = 1.; wt_arr[wt_arr==4.] = 1. 263 | tc_arr = arr.copy() 264 | tc_arr[tc_arr==1.] = 1.; tc_arr[tc_arr==2.] = 0.; tc_arr[tc_arr==4.] = 1. 265 | et_arr = arr.copy() 266 | et_arr[et_arr==1.] = 0.; et_arr[et_arr==2.] = 0.; et_arr[et_arr==4.] = 1. 267 | return MultiNiftiImage([ants.from_numpy(arr) for arr in [wt_arr, tc_arr, et_arr]]) 268 | 269 | crop_3d = Transform(crop_3d, order=0) # Applied to 'x' first then `y` for a implementation detail with overwrite 270 | standardize = Transform(standardize, order=1) # Only applied to 'x' 271 | subregionify = Transform(subregionify, order=1) # Only applied to 'y' 272 | 273 | x_transform = [crop_3d, standardize] 274 | y_transform = [crop_3d, subregionify] 275 | 276 | data = (MultiNiftiImageList.from_folder(parent_folders, modality=['Flair', 'T1', 'T2', 'T1ce']) 277 | .split_by_pname_file(fname='valid.txt', path=Path('.')) 278 | .label_from_func(get_y_fn, classes=subregion, label_cls=NiftiSegmentationLabelList) 279 | .transform((x_transform, x_transform), tfm_y=False, lowerind=(40,28,10), upperind=(200,220,138)) 280 | .transform_y((y_transform, y_transform), lowerind=(40,28,10), upperind=(200,220,138)) 281 | .databunch(bs=1, collate_fn=data_collate, num_workers=0)) -------------------------------------------------------------------------------- /ndv/modules/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """01_model.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1OWXPL8K-jKC4KgGmYXXkeBWOL2V1biuf 8 | 9 | # Setup 10 | """ 11 | 12 | import torch 13 | from fastai.callbacks import * 14 | from fastai.vision import * 15 | 16 | H = 160 17 | W = 192 18 | D = 128 19 | 20 | def conv_block(c_in, c_out, ks, num_groups=None, **conv_kwargs): 21 | "A sequence of modules composed of Group Norm, ReLU and Conv3d in order" 22 | if not num_groups : num_groups = int(c_in/2) if c_in%2 == 0 else None 23 | return nn.Sequential(nn.GroupNorm(num_groups, c_in), 24 | nn.ReLU(), 25 | nn.Conv3d(c_in, c_out, ks, **conv_kwargs)) 26 | 27 | def reslike_block(nf, num_groups=None, bottle_neck:bool=False, **conv_kwargs): 28 | "A ResNet-like block with the GroupNorm normalization providing optional bottle-neck functionality" 29 | nf_inner = nf / 2 if bottle_neck else nf 30 | return SequentialEx(conv_block(num_groups=num_groups, c_in=nf, c_out=nf_inner, ks=3, stride=1, padding=1, **conv_kwargs), 31 | conv_block(num_groups=num_groups, c_in=nf_inner, c_out=nf, ks=3, stride=1, padding=1, **conv_kwargs), 32 | MergeLayer()) 33 | 34 | def upsize(c_in, c_out, ks=1, scale=2): 35 | "Reduce the number of features by 2 using Conv with kernel size 1x1x1 and double the spatial dimension using 3D trilinear upsampling" 36 | return nn.Sequential(nn.Conv3d(c_in, c_out, ks), 37 | nn.Upsample(scale_factor=scale, mode='trilinear', align_corners=True)) 38 | 39 | def hook_debug(module, input, output): 40 | """ 41 | Print out what's been hooked usually for debugging purpose 42 | ---------------------------------------------------------- 43 | Example: 44 | Hooks(ms, hook_debug, is_forward=True, detach=False) 45 | 46 | """ 47 | print('Hooking ' + module.__class__.__name__) 48 | print('output size:', output.data.size()) 49 | return output 50 | 51 | class Encoder(nn.Module): 52 | "Encoder part" 53 | def __init__(self): 54 | super().__init__() 55 | self.conv1 = nn.Conv3d(4, 32, 3, stride=1, padding=1) 56 | self.res_block1 = reslike_block(32, num_groups=8) 57 | self.conv_block1 = conv_block(32, 64, 3, num_groups=8, stride=2, padding=1) 58 | self.res_block2 = reslike_block(64, num_groups=8) 59 | self.conv_block2 = conv_block(64, 64, 3, num_groups=8, stride=1, padding=1) 60 | self.res_block3 = reslike_block(64, num_groups=8) 61 | self.conv_block3 = conv_block(64, 128, 3, num_groups=8, stride=2, padding=1) 62 | self.res_block4 = reslike_block(128, num_groups=8) 63 | self.conv_block4 = conv_block(128, 128, 3, num_groups=8, stride=1, padding=1) 64 | self.res_block5 = reslike_block(128, num_groups=8) 65 | self.conv_block5 = conv_block(128, 256, 3, num_groups=8, stride=2, padding=1) 66 | self.res_block6 = reslike_block(256, num_groups=8) 67 | self.conv_block6 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1) 68 | self.res_block7 = reslike_block(256, num_groups=8) 69 | self.conv_block7 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1) 70 | self.res_block8 = reslike_block(256, num_groups=8) 71 | self.conv_block8 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1) 72 | self.res_block9 = reslike_block(256, num_groups=8) 73 | 74 | def forward(self, x): 75 | x = self.conv1(x) # Output size: (1, 32, 160, 192, 128) 76 | x = self.res_block1(x) # Output size: (1, 32, 160, 192, 128) 77 | x = self.conv_block1(x) # Output size: (1, 64, 80, 96, 64) 78 | x = self.res_block2(x) # Output size: (1, 64, 80, 96, 64) 79 | x = self.conv_block2(x) # Output size: (1, 64, 80, 96, 64) 80 | x = self.res_block3(x) # Output size: (1, 64, 80, 96, 64) 81 | x = self.conv_block3(x) # Output size: (1, 128, 40, 48, 32) 82 | x = self.res_block4(x) # Output size: (1, 128, 40, 48, 32) 83 | x = self.conv_block4(x) # Output size: (1, 128, 40, 48, 32) 84 | x = self.res_block5(x) # Output size: (1, 128, 40, 48, 32) 85 | x = self.conv_block5(x) # Output size: (1, 256, 20, 24, 16) 86 | x = self.res_block6(x) # Output size: (1, 256, 20, 24, 16) 87 | x = self.conv_block6(x) # Output size: (1, 256, 20, 24, 16) 88 | x = self.res_block7(x) # Output size: (1, 256, 20, 24, 16) 89 | x = self.conv_block7(x) # Output size: (1, 256, 20, 24, 16) 90 | x = self.res_block8(x) # Output size: (1, 256, 20, 24, 16) 91 | x = self.conv_block8(x) # Output size: (1, 256, 20, 24, 16) 92 | x = self.res_block9(x) # Output size: (1, 256, 20, 24, 16) 93 | return x 94 | 95 | class Decoder(nn.Module): 96 | "Decoder Part" 97 | def __init__(self): 98 | super().__init__() 99 | self.upsize1 = upsize(256, 128) 100 | self.reslike1 = reslike_block(128, num_groups=8) 101 | self.upsize2 = upsize(128, 64) 102 | self.reslike2 = reslike_block(64, num_groups=8) 103 | self.upsize3 = upsize(64, 32) 104 | self.reslike3 = reslike_block(32, num_groups=8) 105 | self.conv1 = nn.Conv3d(32, 3, 1) 106 | self.sigmoid1 = torch.nn.Sigmoid() 107 | 108 | def forward(self, x): 109 | x = self.upsize1(x) # Output size: (1, 128, 40, 48, 32) 110 | x = x + hooks.stored[2] # Output size: (1, 128, 40, 48, 32) 111 | x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32) 112 | x = self.upsize2(x) # Output size: (1, 64, 80, 96, 64) 113 | x = x + hooks.stored[1] # Output size: (1, 64, 80, 96, 64) 114 | x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64) 115 | x = self.upsize3(x) # Output size: (1, 32, 160, 192, 128) 116 | x = x + hooks.stored[0] # Output size: (1, 32, 160, 192, 128) 117 | x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128) 118 | x = self.conv1(x) # Output size: (1, 3, 160, 192, 128) 119 | x = self.sigmoid1(x) # Output size: (1, 3, 160, 192, 128) 120 | return x 121 | 122 | class VAEEncoder(nn.Module): 123 | "Variational auto-encoder encoder part" 124 | def __init__(self, latent_dim:int=128): 125 | super().__init__() 126 | self.latent_dim = latent_dim 127 | self.conv_block = conv_block(256, 16, 3, num_groups=8, stride=2, padding=1) 128 | self.linear1 = nn.Linear(60, 1) 129 | 130 | # Assumed latent variable's probability density function parameters 131 | self.z_mean = nn.Linear(256, latent_dim) 132 | self.z_log_var = nn.Linear(256, latent_dim) 133 | self.epsilon = nn.Parameter(torch.randn(1, latent_dim)) 134 | 135 | def forward(self, x): 136 | x = self.conv_block(x) # Output size: (1, 16, 10, 12, 8) 137 | x = x.view(256, -1) # Output size: (256, 60) 138 | x = self.linear1(x) # Output size: (256, 1) 139 | x = x.view(1, 256) # Output size: (1, 256) 140 | z_mean = self.z_mean(x) # Output size: (1, 128) 141 | z_var = self.z_log_var(x).exp() # Output size: (1, 128) 142 | return z_mean + z_var * self.epsilon # Output size: (1, 128) 143 | 144 | class VAEDecoder(nn.Module): 145 | "Variational auto-encoder decoder part" 146 | def __init__(self): 147 | super().__init__() 148 | self.linear1 = nn.Linear(128, 256*60) 149 | self.relu1 = nn.ReLU() 150 | self.upsize1 = upsize(16, 256) 151 | self.upsize2 = upsize(256, 128) 152 | self.reslike1 = reslike_block(128, num_groups=8) 153 | self.upsize3 = upsize(128, 64) 154 | self.reslike2 = reslike_block(64, num_groups=8) 155 | self.upsize4 = upsize(64, 32) 156 | self.reslike3 = reslike_block(32, num_groups=8) 157 | self.conv1 = nn.Conv3d(32, 4, 1) 158 | 159 | def forward(self, x): 160 | x = self.linear1(x) # Output size: (1, 256*60) 161 | x = self.relu1(x) # Output size: (1, 256*60) 162 | x = x.view(1, 16, 10, 12, 8) # Output size: (1, 16, 10, 12, 8) 163 | x = self.upsize1(x) # Output size: (1, 256, 20, 24, 16) 164 | x = self.upsize2(x) # Output size: (1, 128, 40, 48, 32) 165 | x = self.reslike1(x) # Output size: (1, 128, 40, 48, 32) 166 | x = self.upsize3(x) # Output size: (1, 64, 80, 96, 64) 167 | x = self.reslike2(x) # Output size: (1, 64, 80, 96, 64) 168 | x = self.upsize4(x) # Output size: (1, 32, 160, 192, 128) 169 | x = self.reslike3(x) # Output size: (1, 32, 160, 192, 128) 170 | x = self.conv1(x) # Output size: (1, 4, 160, 192, 128) 171 | return x 172 | 173 | class AutoUNet(nn.Module): 174 | "3D U-Net using autoencoder regularization" 175 | def __init__(self): 176 | super().__init__() 177 | self.encoder = Encoder() 178 | self.decoder = Decoder() 179 | self.vencoder = VAEEncoder(latent_dim=128) 180 | self.vdecoder = VAEDecoder() 181 | 182 | def forward(self, input): 183 | interm_res = self.encoder(input) 184 | top_res = self.decoder(interm_res) # Output size: (1, 3, 160, 192, 128) 185 | bottom_res = self.vdecoder(self.vencoder(interm_res)) # Output size: (1, 4, 160, 192, 128) 186 | return top_res, bottom_res 187 | 188 | class SoftDiceLoss(nn.Module): 189 | "Soft dice loss based on a measure of overlap between prediction and ground truth" 190 | def __init__(self, epsilon=1e-6, c=3): 191 | super().__init__() 192 | self.epsilon = epsilon 193 | self.c = 3 194 | 195 | def forward(self, x:Tensor, y:Tensor): 196 | intersection = 2 * ( (x*y).sum() ) 197 | union = (x**2).sum() + (y**2).sum() 198 | return 1 - ( ( intersection / (union + self.epsilon) ) / self.c ) 199 | 200 | class KLDivergence(nn.Module): 201 | "KL divergence between the estimated normal distribution and a prior distribution" 202 | N = H * W * D #hyperparameter check 203 | 204 | def __init__(self): 205 | super().__init__() 206 | 207 | def forward(self, z_mean:Tensor, z_log_var:Tensor): 208 | z_var = z_log_var.exp() 209 | return (1/self.N) * ( (z_mean**2 + z_var**2 - z_log_var**2 - 1).sum() ) 210 | 211 | class L2Loss(nn.Module): 212 | "Measuring the `Euclidian distance` between prediction and ground truh using `L2 Norm`" 213 | def __init__(self): 214 | super().__init__() 215 | 216 | def forward(self, x:Tensor, y:Tensor): 217 | return ( (x - y)**2 ).sum() 218 | 219 | autounet = AutoUNet() 220 | ms = [autounet.encoder.res_block1, 221 | autounet.encoder.res_block3, 222 | autounet.encoder.res_block5, 223 | autounet.vencoder.z_mean, 224 | autounet.vencoder.z_log_var] 225 | hooks = hook_outputs(ms, detach=False, grad=False) 226 | 227 | lr = 1e-4 228 | optimizer = optim.Adam(autounet.parameters(), lr) 229 | -------------------------------------------------------------------------------- /ndv/training-1cycle.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "47 items written into valid.txt.\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import torch, fastai, sys, os\n", 18 | "from fastai.vision import *\n", 19 | "import ants\n", 20 | "from ants.core.ants_image import ANTsImage\n", 21 | "from jupyterthemes import jtplot\n", 22 | "sys.path.insert(0, './exp')\n", 23 | "jtplot.style(theme='gruvboxd')\n", 24 | "\n", 25 | "import model\n", 26 | "from model import SoftDiceLoss, KLDivergence, L2Loss\n", 27 | "import dataloader \n", 28 | "from dataloader import data" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/plain": [ 39 | "1" 40 | ] 41 | }, 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "output_type": "execute_result" 45 | } 46 | ], 47 | "source": [ 48 | "torch.cuda.set_device(1)\n", 49 | "torch.cuda.current_device()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "autounet = model.autounet.cuda()\n", 59 | "sdl = SoftDiceLoss()\n", 60 | "kld = KLDivergence()\n", 61 | "l2l = L2Loss()" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "class AutoUNetCallback(LearnerCallback):\n", 71 | " \"Custom callback for implementing `AutoUNet` training loop\"\n", 72 | " _order=0\n", 73 | " \n", 74 | " def __init__(self, learn:Learner):\n", 75 | " super().__init__(learn)\n", 76 | " \n", 77 | " def on_batch_begin(self, last_input:Tensor, last_target:Tensor, **kwargs):\n", 78 | " \"Store the states to be later used to calculate the loss\"\n", 79 | " self.top_y, self.bottom_y = last_target.data, last_input.data\n", 80 | " \n", 81 | " def on_loss_begin(self, last_output:Tuple[Tensor,Tensor], **kwargs):\n", 82 | " \"Stroe the states to be later used to calculate the loss\"\n", 83 | " self.top_res, self.bottom_res = last_output\n", 84 | " self.z_mean, self.z_log_var = model.hooks.stored[3], model.hooks.stored[4]\n", 85 | " return {'last_output': (self.top_res, self.bottom_res,\n", 86 | " self.z_mean, self.z_log_var,\n", 87 | " self.top_y, self.bottom_y)}" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 5, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "class AutoUNetLoss(nn.Module):\n", 97 | " \"Combining all the loss functions defined for `AutoUNet`\"\n", 98 | " def __init__(self):\n", 99 | " super().__init__()\n", 100 | " \n", 101 | " def forward(self, top_res, bottom_res, z_mean, z_log_var, top_y, bottom_y):\n", 102 | " return sdl(top_res, top_y) + (0.1 * kld(z_mean, z_log_var)) + (0.1 * l2l(bottom_res, bottom_y))" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 6, 108 | "metadata": { 109 | "code_folding": [ 110 | 1 111 | ] 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "#monkey-patch\n", 116 | "def mp_loss_batch(model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None,\n", 117 | " cb_handler:Optional[CallbackHandler]=None)->Tuple[Union[Tensor,int,float,str]]:\n", 118 | " \"Calculate loss and metrics for a batch, call out to callbacks as necessary.\"\n", 119 | " cb_handler = ifnone(cb_handler, CallbackHandler())\n", 120 | " if not is_listy(xb): xb = [xb]\n", 121 | " if not is_listy(yb): yb = [yb]\n", 122 | " out = model(*xb)\n", 123 | " out = cb_handler.on_loss_begin(out)\n", 124 | "\n", 125 | " if not loss_func: return to_detach(out), to_detach(yb[0])\n", 126 | " loss = loss_func(*out) #modified\n", 127 | "\n", 128 | " if opt is not None:\n", 129 | " loss,skip_bwd = cb_handler.on_backward_begin(loss)\n", 130 | " if not skip_bwd: loss.backward()\n", 131 | " if not cb_handler.on_backward_end(): opt.step()\n", 132 | " if not cb_handler.on_step_end(): opt.zero_grad()\n", 133 | "\n", 134 | " return loss.detach().cpu()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 7, 140 | "metadata": { 141 | "code_folding": [ 142 | 1 143 | ] 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "#monkey-patch\n", 148 | "def mp_fit(epochs:int, learn:Learner, callbacks:Optional[CallbackList]=None, metrics:OptMetrics=None)->None:\n", 149 | " \"Fit the `model` on `data` and learn using `loss_func` and `opt`.\"\n", 150 | " assert len(learn.data.train_dl) != 0, f\"\"\"Your training dataloader is empty, can't train a model.\n", 151 | " Use a smaller batch size (batch size={learn.data.train_dl.batch_size} for {len(learn.data.train_dl.dataset)} elements).\"\"\"\n", 152 | " cb_handler = CallbackHandler(callbacks, metrics)\n", 153 | " pbar = master_bar(range(epochs))\n", 154 | " cb_handler.on_train_begin(epochs, pbar=pbar, metrics=metrics)\n", 155 | "\n", 156 | " exception=False\n", 157 | " try:\n", 158 | " for epoch in pbar:\n", 159 | " learn.model.train()\n", 160 | " cb_handler.set_dl(learn.data.train_dl)\n", 161 | " cb_handler.on_epoch_begin()\n", 162 | " for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):\n", 163 | " xb, yb = cb_handler.on_batch_begin(xb, yb)\n", 164 | " loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler) #modified\n", 165 | " if cb_handler.on_batch_end(loss): break\n", 166 | "\n", 167 | " if not cb_handler.skip_validate and not learn.data.empty_val:\n", 168 | " val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,\n", 169 | " cb_handler=cb_handler, pbar=pbar)\n", 170 | " else: val_loss=None\n", 171 | " if cb_handler.on_epoch_end(val_loss): break\n", 172 | " except Exception as e:\n", 173 | " exception = e\n", 174 | " raise\n", 175 | " finally: cb_handler.on_train_end(exception)\n" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 8, 181 | "metadata": { 182 | "code_folding": [ 183 | 1 184 | ] 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | " #monkey-patch\n", 189 | "def mp_learner_fit(self, epochs:int, lr:Union[Floats,slice]=defaults.lr,\n", 190 | " wd:Floats=None, callbacks:Collection[Callback]=None)->None:\n", 191 | " \"Fit the model on this learner with `lr` learning rate, `wd` weight decay for `epochs` with `callbacks`.\"\n", 192 | " lr = self.lr_range(lr)\n", 193 | " if wd is None: wd = self.wd\n", 194 | " if not getattr(self, 'opt', False): self.create_opt(lr, wd)\n", 195 | " else: self.opt.lr,self.opt.wd = lr,wd\n", 196 | " callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)\n", 197 | " fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 9, 203 | "metadata": { 204 | "code_folding": [ 205 | 1 206 | ] 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "#monkey-patch\n", 211 | "def mp_validate(model:nn.Module, dl:DataLoader, loss_func:OptLossFunc=None, cb_handler:Optional[CallbackHandler]=None,\n", 212 | " pbar:Optional[PBar]=None, average=True, n_batch:Optional[int]=None)->Iterator[Tuple[Union[Tensor,int],...]]:\n", 213 | " \"Calculate `loss_func` of `model` on `dl` in evaluation mode.\"\n", 214 | " model.eval()\n", 215 | " with torch.no_grad():\n", 216 | " val_losses,nums = [],[]\n", 217 | " if cb_handler: cb_handler.set_dl(dl)\n", 218 | " for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):\n", 219 | " if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)\n", 220 | " val_loss = loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler) #modified\n", 221 | " val_losses.append(val_loss)\n", 222 | " if not is_listy(yb): yb = [yb]\n", 223 | " nums.append(first_el(yb).shape[0])\n", 224 | " if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break\n", 225 | " if n_batch and (len(nums)>=n_batch): break\n", 226 | " nums = np.array(nums, dtype=np.float32)\n", 227 | " if average: return (to_np(torch.stack(val_losses)) * nums).sum() / nums.sum()\n", 228 | " else: return val_losses" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 10, 234 | "metadata": { 235 | "code_folding": [ 236 | 1 237 | ] 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "#monkey-patch\n", 242 | "def mp_learner_validate(self, dl=None, callbacks=None, metrics=None):\n", 243 | " \"Validate on `dl` with potential `callbacks` and `metrics`.\"\n", 244 | " dl = ifnone(dl, self.data.valid_dl)\n", 245 | " metrics = ifnone(metrics, self.metrics)\n", 246 | " cb_handler = CallbackHandler(self.callbacks + ifnone(callbacks, []), metrics)\n", 247 | " cb_handler.on_train_begin(1, None, metrics); cb_handler.on_epoch_begin()\n", 248 | " val_metrics = validate(self.model, dl, self.loss_func, cb_handler)\n", 249 | " cb_handler.on_epoch_end(val_metrics)\n", 250 | " return cb_handler.state_dict['last_metrics']" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 11, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "from fastai.basic_train import loss_batch, fit, validate" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 12, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "loss_batch = mp_loss_batch\n", 269 | "fit = mp_fit\n", 270 | "validate = mp_validate\n", 271 | "Learner.fit = mp_learner_fit\n", 272 | "Learner.validate = mp_learner_validate" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 13, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "def dice_coefficient(last_output:Tensor, last_target:Tensor):\n", 282 | " \"Metric based on dice coefficient\"\n", 283 | " pred, targ = last_output[0], last_target\n", 284 | " return 2 * (pred * targ).sum() / ((pred**2).sum() + (targ**2).sum())" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 14, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "auto_unet_loss = AutoUNetLoss()" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 15, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "learner = Learner(data, autounet, loss_func=auto_unet_loss)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 16, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "autounet_cb = AutoUNetCallback(learner)" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 17, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "learner.callbacks.append(autounet_cb)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 17, 326 | "metadata": {}, 327 | "outputs": [ 328 | { 329 | "data": { 330 | "text/html": [ 331 | "\n", 332 | "
\n", 333 | " \n", 345 | " \n", 346 | " 0.00% [0/1 00:00<00:00]\n", 347 | "
\n", 348 | " \n", 349 | "\n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | "
epochtrain_lossvalid_losstime

\n", 361 | "\n", 362 | "

\n", 363 | " \n", 375 | " \n", 376 | " 20.14% [58/288 05:56<23:34 176342.1875]\n", 377 | "
\n", 378 | " " 379 | ], 380 | "text/plain": [ 381 | "" 382 | ] 383 | }, 384 | "metadata": {}, 385 | "output_type": "display_data" 386 | }, 387 | { 388 | "name": "stdout", 389 | "output_type": "stream", 390 | "text": [ 391 | "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" 392 | ] 393 | } 394 | ], 395 | "source": [ 396 | "learner.lr_find()" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 18, 402 | "metadata": {}, 403 | "outputs": [ 404 | { 405 | "data": { 406 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaQAAAEMCAYAAACLA8K2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XmYFNW5+PFvb9PLMMywqCiCKIJJqSyiiJQLZru52ZOb3OQmsRLjFqNEgyGJSuISRNSIid7kl7jEpIxZ1GuMWTQuiVsBiiKIFKtsA8g6zML03l2/P6ownZEZGKarq7rn/TxPPdNTp7rr7UMzb59Tp84JWJaFEEII4bWg1wEIIYQQIAlJCCGET0hCEkII4QuSkIQQQviCJCQhhBC+EPY6gGplGnoAOBpo9zoWIYSoIgOBzYqqvWuItySkQ3c0sMnrIIQQogqNBJq77pSEdOjaAS6e+b+k0llPAggAgxob2NPWgdxNVn5Sv+6TOnaX3+o3Hqvj7tsuh256liQh9VEqnSWZynhy7gAQj0ZJpjK++LDVGqlf90kdu6va6lcGNQghhPAFSUhCCCF8QRKSEEIIX6joNSTT0C8DvgKMAxYqqjatpOxG4JOAAjyoqNpXuzz3SOAe4FygFfiRomp3VKpcCCGEuyrdQnobmAvs7w/9OuAa4LfdPPdBYCdwBPAx4FrT0D9RwXIhhBAuqmgLSVG1RwFMQx+5n7JfOWUfBJpKy0xDPw44BzhKUbW9wOumod8DXAA87nZ5ueuBUB0UvBkqLoQQflUtw77HAVsVVdtesm8x8IUKlXcr4Gy9EZtwCVY+RWbZr6GY6+Wz//3cpT9FeUn9uk/q2F1+q98DxVEtCakBaOuyr9XZX4nybg1qbCAejR7osH9T3PoE+fdcSN2Zs4isvI9AtrVXz+9qcNPAPj1f9Ezq131Sx+7yS/3GYnU9lldLQuoAGrvsa3L2V6K8W3vaOnp/Y2xrO+y8gdiEiymcPIP04p9S3LOmd6+B/W1jcNNAWlrbq+Kmt2oj9es+qWN3+a1+E/Gev7xXy7DvN4CjTEM/vGTfRGBZhcq7ZR3qlk+RevVOcs0vED99JuGR0w7pdfoUg2xSvz7YpI77V/32pNLDvsPOOcNA0DT0GFBUVC1rGnoECO3bnLKComo5RdXWmYb+PHCzaejfBMYAF2EPOsDtcvdYZFc/SrF9E7FxFxAceAyZ5b8Bq+DuaYUQwocq3UKaBaSA24CznMdPOWX3OL9fDnzZeXxPyXO/BAwDdgB/A+YoqvZ4Bctdk9/2KskFNxEeqpA4/TsE6vzR3yuEEJUUsKwDNaLE/piG3gi0njd9XvkmV43UE594KcHE4SQXzcPq3Nbj4QFgSNNAdvukf7jWSP26T+rYXX6r30Q8ygN3zQBoUlSt60CyqrmG1D/kOkktuoPCntUkzriaYOOxXkckhBAVIwnJb6wC6aX3kd9skDj9O4QOO9nriIQQoiIkIfmSRWblQ2TWPEZ80nTCw6d6HZAQQriuWu5D6pdy6/+OlWkjNu5rZOoGklv/pNchCSGEayQh+Vx+60JS2b3ET/kGwVgjmRUPceDR/EIIUX2ky64KFHa9SfLlWwkfNZXY+AshEPI6JCGEKDtJSFWi2LaB5MI5hAaPJT7pcghGvA5JCCHKShJSFbE6t5NcMIdA4nDip82AcNzrkIQQomwkIVUZK72H1MK5BMJR4qfPxArXex2SEEKUhSSkKmRlO0i+fBvkM+RO+iaB2CCvQxJCiD6ThFSt8ilSi+YRSO8ifsbVBOqP8DoiIYToE0lI1ayYI7zqPgota0hMuZrgwHetDC+EEFVDElKVC1hFMkvvJb/tVRKnzyTYOMrrkIQQ4pBIQqoJFpnlvyG3eT6Jyd8m2Hic1wEJIUSvSUKqIZkVvyPX/CKJyTMINo32OhwhhOgVSUg1JrPyD+Q2PU/iNElKQojqIgmpBmVWPUx20z9ITL6K0KAxXocjhBAHRRJSjcqu+j+yG54mftq3CA0a63U4QghxQJKQalh29R/Jrv+7nZQGv8frcIQQokeSkGpcds2fyK77G/HTvkV42GlehyOEEN2S9ZD6gezaP2Nl2olNuIjs6iFk18lCf0II/5GE1E/kmp+nmG4hPvFSAvHDyJgPglX0OiwhhHiHdNn1I4Wdy0guuJnwEROIT5oOoajXIQkhxDsq2kIyDf0y4CvAOGChomrTSsoGAD8HPgFkgPuAqxVVs/xQXiuKHc0k588mfuqVJKZ8j9SrP8HKtHodlhBCVLyF9DYwF7hjP2V3AocBxwCTgc8C031UXjOs9B6SC+diZTtITL2W4IDhXockhBCVTUiKqj2qqNqjwPbS/aahJ4AvAtcqqrZHUbX1wG3ABX4o70nA4+2QY8inSL/6Ewq7TBJnXE1o0FjP34sfNz/8G9f6JnXcv+q3J34Z1DAWqAOWlOxbDJxoGnrI63JF1QrdBT6osYF41NtrMYObBh7yc63m/6NAksTpVxFerRNqWVbGyGpDX+pXHBypY3f5pX5jsboey/2SkBqApKJq+ZJ9rUAIiPugfG93ge9p6yCZyvTirZZPAPuD1tLaTp8udLX+gUj7Tqz3foW9y39DvvmFMkVY3cpWv6JbUsfu8lv9JuI9f3n3yyi7DiBhGnppgmwCCkDKB+XdsjzeyhVDduM/SC+5h6jyJSKjP+b5+/LL5od/41rfpI77V/32xC8JaTWQBcaX7JsImE53mdfl/UJ+2yJSr/6YutEfIap8kQP3+AohRPlUNCGZhh42DT2G3VUYNA09Zhp6naJqSeC3wGzT0JtMQx8FfBu4F8Dr8v6ksHsFyYVzCR95GrEJl0DQL726QohaV+kW0izsLrDbgLOcx085ZVcAu4FNwKvAo8BdJc/1urzfKLZvIrngZkKNo4ifeqXcQCuEqIiAZR2oV0/sj2nojUDredPneTqoYUjTQHa7dMEyUDeQ+OSroJAhuegOyPd4Oa3muF2/QurYbX6r30Q8ygN3zQBoUlStrWu5X64hCR+ysu0kX74VAkESp3+HQF2D1yEJIWqYJCTRs1wnyZd/hJVPEz/9uwSiTV5HJISoUZKQxIEV0qQW3YGV2k3ijKsJxId6HZEQogZJQhIHp5gltfguCu0bSUz5HsH6YV5HJISoMZKQxMEr5km//nMKu1cSn/I9gg0jvI5ICFFDJCGJ3rGKpN+4j/y2V0lM+S7BptFeRySEqBGSkMQhsMgs/w3ZTc+RmPxtQkNP9DogIUQNkIQkDll21SNk1z5OfNI3CQ+b5HU4QogqJwlJ9El23RNkzN8Sm3AJkaPP8jocIUQVk4nKRJ/lmp/HyieJjb8IIgly6//udUhCiCokCUmURf7tRaTyaeKnfINAJEF29R+9DkkIUWWky06UTWHnMlKvzKPumPcTPfHLyPIVQojekIQkyqqwZw3Jl28lPGwSsYmXQjDidUhCiCohCUmUXbF9E8n5cwg2DCdx+kwCkQFehySEqAKSkIQrrNROkgvmYFlFElOvJZA43OuQhBA+JwlJuCfXSeqVH1Fo20jijGsJNh7ndURCCB+ThCTcVcyTXvIL8lteIjFlJuEjJnodkRDCpyQhiQqwyKx8mMzKh4lNvJTIMe/3OiAhhA/JfUiiYnIb/4GVaiE28RJCg0aTXv4g5Dq9DksI0RuBkP3TKpT9paWFJCoqv2MJyZduJJg4gvqzbpSJWYWoMuFhk6g/+yZXXlsSkqi4YufbJBfMIbfpeeKnXmHfRBuq8zosIcRBCEQSWPmkK68tCUl4wyqQXfs4yQVzCA15D/Vn3iBrKwlRBQKRBJZLXe2+uoZkGvpo4E7gDCAH/BK4VlG1omnoYeB24DzsRPoIcJmiahnnua6WC3cU2zaQfOkGoid8lsSU75Jd9yTZNX9ypX9aCNF3gXA9Vq7GW0imoYeAx4E3gSOBU4GPADOdQ64BzgFOAsYAJwI3l7yE2+XCLcUcmRW/I7VoHpHhZ5A442oC8cO8jkoIsT8utpB8k5CAE5ztB4qqZRRVawbuAL7ulF8IzFZUbauiajuB64HzTUMPVqhcuKyweyWdL11PMd1K/ZnXER52qtchCSG6CETcayH5qctuf1NDB4BRpqGPBEYAr5eULQaagBGmobe5WQ5s7Clor+a0DnT5WRNynWQW/y/FY95HbPxF5IcqZMzfQzFb8VBqsn59RurYXW7UbyCSgFznIb3mgZ7jp4S0CngL+KFp6NcBRwBXOGWW87Ot5PhW52cDUHS5vFuDGhuIR6M9HeK6wU0DPT2/K9pepbjsbYInfIW6s68jvOpXBFPbPQmlJuvXZ6SO3VXO+s3GGohFLEKH8JqxWM+jaX2TkBRVy5uG/nHgx9gtkhbgPmAc/0oYjcAu53GT87PD2dws79aetg6SKW/GPQSwP2gtre3vZOya0toO268jeuKXKY6bQWb5b8lvfrFip6/5+vUBqWN3uVG/iWCMVNtu8q3tvX9uvOcv7766PqKo2mpF1T6iqNrhiqq9B0gCixRV2wI0AxNKDp+I3YppVlSt1c3ynmK2PN78EIOrWyFD+o37SL/5AFHlf4ic8Fmp3xrbpI6rq34DkQTFfGef4umOb1pIAKahjwPWAWngXGAWoDnF9wLXmoa+AHtI+PXA/YqqFStULjyU3zKfZOd2EpO/Dbm9ZNc96XVIQvRDAQjH+8WgBoDPAd8AYsBK4EJF1Z52yuYAQ4Hl2C27h7GHalOhcuGxYutbpBb/L/FJ38TKJck1v+B1SEL0L+E4gUDQtWHfAcs6UCNK7I9p6I1A63nT53l6DWlI00B297P+9/Cw04hNuMhe1mLba66dp7/WbyVJHbur3PUbiB/GgHNvoePp6Yc0MXIiHuWBu2YANCmq1ta13FfXkIQ4GPlti8gsf5DY+IsJDVW8DkeIfiMQSdgPan2mBiF6I9f8PNk1fyJ+ynSCTbISrRCV8K+bYt1pz0pCElUru+5v5Db+g8SpVxIcMNzrcISoefbEqu60jkASkqhymVUPk9u+mPjkqwgkDvc6HCFqmt1Ccm9RTUlIoupllv2aQssq6qfOIjT4BK/DEaJmBSLuDfkGSUiiJlikl9xNduOzxCdfRWTEOV4HJERtitRj5d1rIfntPiQhDpFFds2fKO7dSmzcBQQbhpNZ8Xuw5L5mIcrF7WtIkpBETcm/vYhkcifxSdMJ1g8j9frPwaXlloXobwLheqz0btdeX7rsRM0ptm0gafyQQCRB/dRZBOqP8DokIWqCjLIT4hBYmVaSC2+h0Lae+qmzCB95mtchCVH13FycDyQhiVpWzJFeeg+ZtX8hdvL5JGQUnhB9IteQhOij3Pq/k9+ygLrjP0Z88lUUdplkVj1CsWOz16EJUVXkPiQhysDKtpMxf0vnC7Ow8ikSZ15PbNzXCMQGex2aENVDriEJUT5WcgfpJb8gOX82gdhg6s+5mbrjPuJ1WEL4XzhmLz0h9yEJUV7Ftg2kXvkRocPHEx9/MYH4YDLLH8StSSOFqHaBcL39wM8tJNPQB5QjECG8UNixlOTLtxAeNonYxEshKN/RhNiffUtP+KbLzjT0q0xD/3zJ7zrQZhr6etPQ31v26ISogGL7JpLz5xAaOIL4aTMgHPc6JCF8JxCpx8qnwSq4do7etpAuBbYBmIZ+FvAZ4IvAIuDW8oYmROVYqZ0kF8whEIqSmPI9AtEmr0MSwlfcHvINvU9IRwHrnccfBR5RVO0PwA3AlHIGJkSlWdkOki/fhpVpI3HGNTLDgxAl3B7yDb1PSJ1Ao/N4GvBP53EKSJQpJiG8U0iTevUnFPasITHlGooNshqtEIDrQ76h96PsngduNw39JWAi8KSz/wSguZyBCeEZq0B66b1ET/gvAiddTuyIZWTW/Ili2wavIxPCM3aXnb9aSFcAaexrR5coqrbd2f8R4JlyBiaEtyyyqx4h8vpcrFwniamziJ96BcHGY70OTAhPBCL1rs+c36sWkqJqW4BP7Gf/9HIEYxr6UcBdwDlAAHgJuExRtc2moYeB24HzsBPpI05Zxnmuq+Wifwqmd5BZei+ZtX8mOvrjJKZeS2HncjJr/0SxdZ3X4QlRMYFIAivT5uo5ejvsO2oaerTk9+GmoV9mGvq0MsXzMyACHAuMwL5mdY9Tdg12ojoJGAOcCNxc8ly3y0U/ZnVuJ/3GvXS+cC1Wtp3ElKuJTfwGhGNehyZERQTC7s70Db3vsnsMuATeuSH2FWA28LRp6F8tQzzHAX9QVK1DUbUk8CBwslN2ITBbUbWtiqrtBK4HzjcNPVihciGcxHQfnS/OIlh/uD0aLz7U67CEcJ0fh31PAp5zHn8K6ACOwE5SM8oQzzzgv01DbzINvQG7++yvpqE3YbeYXi85djHQBIxwu7yngAMeb36IoZa37uqXzu2kFtyM1bmD+qnfJzRojOexVuvWXR3L5rP6jdRDrrMs8XSnt6PsBgItzuP3A48pqpY1Df0Z4M5evtb+GMAFzjksYKlznganvLQDs9X52QAUXS7v1qDGBuLRaE+HuG5w00BPz1/reqpfa90DFEb+J4nTZxJe9zChHS9XMLLaIZ9hd5WjfjPRegZEIdSH14rF6nos721C2gKMMw39beBDwPnO/iagTxf/na6xp4GHgP9wdl8PPAF82Pm9EdhVck6wW2kdLpd3a09bB8mUN+MeAtgftJbWdpkS1AUHXb+tvye8awPRk88nGRxMduVDyCStB0c+w+4qZ/3Wh+K079lJsbX9kF8jEe/5y3tvu+zuA34HLMdOQPtujJ0MrOxtcF0MBo4B7lRULelcQ7oLOB07cTYDE0qOn4jdimlWVK3VzfKegrY83vwQQy1vB1u/ua0LSS68hfDwKcROvQIrHPc89mrZDraOZfOwfkN1BIJhrFyyLPF0p7fDvueYhr4SGAk8pKhazikqAj/qzWvt57V3mYa+FrjMNPQbnN3Tgc1O2b3AtaahLwBy2K2n+xVV29fd5na5ED0qtq0jadxIfNI3SUydRfr1n1PskPvFRfULROylJ9y+MbbXc+0rqvbofvb9sjzh8EngDuyuwQCwhH/d9zQHGIrdOgsCD2MP1aZC5UIckJXeQ3LhzUTf+z8kps4is+oRchue4cDfDYXwr3eWnnD5xtiAZfXuP4pp6GOB72Dfp2Nh/wG/VVG1NeUPz79MQ28EWs+bPs/Ta0hDmgayW/rfXdHX+g0Pm0TspK9QaFtPeul9WNlD73uvVfIZdle56jc0aCzxyTPY+/ev9ymeRDzKA3fNAGhSVO1dd9n29sbYDwLLsK+vLMS+D+kUYJlp6O/vU6RC1Jj8ttfofOk6CEZInHUDocNOPvCThPChStyDBL3vspsD/D9F1a4s3Wka+k+wZzWYXK7AhKgFVnoPqZdvo270R4lPmk5u4z/JrHoYinmvQxPi4FVg6QnofUI6Cfjyfvb/DLi47+EIUYsssm/9hfxuk/j4iwkNPZHs2sfJv70IubYkqkGlWki9Hfbdwf5nLjgGkA5yIXpQbF1Hp3E9+W2vETvpPOrPmUPk6LMg2OuxRUJUVCCSAB+2kP4I3G0a+qXAi86+s7FbSO8afSeE6CKfJrvmj2TXP0HdyGnUnfAZ6sZ8kuz6v5Nrfh4KWa8jFOJd7NVi/XcN6SrgfuzZE0r7Gh4GZpYrKCFqXj5Ndt2TZDc8S+ToM6k77sPUHf8xcuufJrv+SbnGJHwlEElgZX3WQlJUbS/wOdPQRwOKs3s59pIRr5TsE0IcjGKO3KZ/kmt+gfCRk4mO/RSRo6eSXvYrCi2rvY5OCMBuIRWTO10/zyF1Xiuq9hbw1r7fTUMfj72MuRDiUFgF8lsXkN/+GtExnyI+eSa5zS+RWfmw66t0CnEggbA/BzUIIdxUyJJZ+RDJ+bMJNY6i/uzZhIdN8joq0d9VaNi3JCQhfKjYvpHk/NnkNjxFbPxFxCZNJxAb5HVYop/y67BvIUSlWAWy656k88UfEAhFqT9rNuGjpngdleiH9i3O57aDuoZkGvpTBzhkQBliEULsh5XcQeqVHxEZOY3YyV8hf9hJpJf/BvJpr0MT/UEwTCAU8dWw7y0HccyqvgQihOhZbtNzFFpWE5twCfVn3kBqyS8otq7zOixR4yq19AQcZEJSVO38Ax8lhHBbce9WkvN/SPSEz5KY8j2yax4n+9ZfkSmIhFv+lZD800ISQvhFMU9mxe/J71pObNzXCA09kfTSu7HSe7yOTNSgQCSBVcxD0f1ZRGRQgxBVqrBzGckXr4NClvqzbiR02DivQxK1qEJDvkESkhBVzcq2k3r1x2Tf+ivxSZdTd/wnsJdlE6I8KnVTLEiXnRA1wCK77kkKbRuJTfg6ocZRpJbeA/mU14GJGlCpId8gLSQhakZh9wqSxg0Eoo3Uqz8gOGC41yGJGlCpm2JBEpIQNcVKt5BceDOFltUkps4ifORpXockqpydkCrTQpIuOyFqTTFPetn9RFrXERt3IbnGY8msegSsoteRiSpUyRaSJCQhalSu+XkKHc3ET7mcQGwQ6aX3SFISvRepr9gtBdJlJ0QNK7auI7lwLqFBY4hN+DoEQl6HJKpMv2whmYa+t8uuKLBCUbVxTnkYuB04DzuRPgJcpqhaphLlQlQrK7mD5MJbSEyZSWzipaSX/FxWpBUHLdAf70NSVG1A6QasAH5fcsg1wDnAScAY4ETg5gqWC1G1rNROkgtuITRwBPFTLoOgb76LCp8LhBPQn0fZmYY+GXs59F+V7L4QmK2o2lZF1XYC1wPnm4YerFD5fgU83vwQQy1vtVS/pHeTWjiXYP0w4pO+SSAY8TymWqtjP259rl/nPqRyxtMdv35NugB4QlG1rQCmoTcBI4DXS45ZDDQBI0xDb3OzHNjYXaCDGhuIR6OH8h7LZnDTQE/PX+tqq34LWOZPyZ14GZEpM4isvJdAMed1UDVWx/5zqPVrBUJkw1EGxgIEy/BvFIvV9Vjuu4RkGnoC+AKglexucH62lexrLSkrulzerT1tHSRT3lxmCmB/0Fpa22WuZxfUbv22E5h/M7HTZ5IfcwGpV38CBfkM16K+1m+groF6oLVlB1a6vc/xJOI9f3n3Y5fdfwNJ4K8l+zqcn40l+5pKytwu75bl8eaHGGp5q9X6LWbbSb18K0TqiZ8+EytSL3Vco1uf6tdZeqKYS5Y1nu74MSFdCPxaUbV3hgEpqtYKNAMTSo6biN2KaXa7vEzvSwhfsbIdJBfeAsUCiSnfIxBtOvCTRL9iLz1RgEJlVif2VZedaegnAFOBr+2n+F7gWtPQFwA57EEH9yuqVqxQuRC1J58kueh24hMvI3HGNSQX3Y7Vud3rqIRPBCL1WPnKjLAD/7WQLgBeVFRt9X7K5gAvAcuBtYCJPVS7UuVC1KZCltRrd1JofYvElKsJDhzpdUTCJwKRyg35BghY1oF69cT+mIbeCLSeN32ep4MahjQNZLdcEHZF/6vfANETv0TkqCmkXr2Twp79fS8s9xn7Wx1XVl/rN3LM+4gMn0py/uyyxJOIR3ngrhkATYqqtXUt91sLSQjhGYvM8t+Q3fAM8ckzCB0+3uuAhMcquTgfSEISQnSRXfMYmVWPED/lMsJHn+l1OMJDlZw2CHw2qEEI4Q+5Dc9gZTuInfw1svVHkF31KAcetCtqTgUnVgVJSEKIbuS3vkwqtZvYKdMJJo4gvfReKGa9DktUUCCSoNi5rWLnky47IUS3CnvWkpw/m2DDUSSmfJdAtPHATxI1wx5lV7kuO0lIQogeWamdJOffhJVPkpg6i2DDCK9DEhViX0OSQQ1CCD/Jp0gt+jH5nctInHE1ocPGeR2RqAB7cT4Z1CCE8BurQOZNnWLnNuKTLie34Vmy657AyvZ90k3hT4FwZVtIkpCEEL2SW/8UxfbNRN/zWeqPmUZu03Nk1z2JlXnXfY6imgWCBCJxaSEJIfytsNskadxI6LBxRMd8gvpp55Lb9LzdYsq0HvgFhP+F4wDSQhJCVIfCzjdI7nyD0NCTnMR0C7nmF8i+9VdJTFUu4Cw9IQlJCFFVCrveJLnrTUJDFaJjPkX92bPJrHyYXPMLyA211SkQSWBZRcinKnZOSUhCiLIp7DJJ7jKJHH0W0fd+nvBRk0kv+zVWcofXoYleCkTqnWRUuS8UMuxbCFF2uc0v0vnCLKxckvqzbqTuuA9DQP7cVJNAhacNAmkhCSFcYmVaSS/+KeFhk4ie+GXCR04m/cb9FDtkEeZqUOmJVUFaSEIIl+W3vUbnC7MotjeTUL9P3eiPeR2SOAjSQhJC1KZcJ+ll9xN6+xXiEy4h2HAU6Tfuh2LO68hEN+ybYqWFJISoUYVdy+lccBOhxlEkTp9JoG6g1yGJ7njQQpKEJISoKKtzO53zb8Iq5IirsygmjvQ6JLEf9kzfkpCEELUu10lq0TwKu0xyJ18hk7X6UCBSj5WXLjshRH9gFcgs+xWh5ieJTbqcyKgPeB2RKCGDGoQQ/U5463Ps3bWJ2ISLCQ0cSdr8LeTTXofV73kx7Nt3Cck09I8DNwBjgQ5gnqJqt5mGHgZuB87Dbtk9AlymqFrGeZ6r5UII9xR2LCG5YA6x8RdRf9YPSS+7n8Iu0+uw+jUvWki+6rIzDf3DwM+BmUATcALwhFN8DXAOcBIwBjgRuLnk6W6XCyFcVOzYTHL+D8ltnk/81CuJnqRBOOZ1WP1UAMKVXXoCfJaQgB8CsxVVe1ZRtbyiau2Kqr3plF3olG1VVG0ncD1wvmnowQqVCyHcVsyTXfNHkvNvIjToeOrPupHQUMXrqPqfcJxAINh/ryGZhl4PTAJ+bxq6CQwB5gNXAO3ACOD1kqcsxm5FjTANvc3NcmBjd3EHnM0LgS4/RXlJ/bqvuzq22jeSMm6k7viPEz/1SvKbXyKz8iG5ttRLh/oZDkQS9oNcsqyf/wO9lm8SEjAIO96vAB8BdgB3AP8HfMY5pnRJyn2LrTQARZfLuw+6sYF4NNrTIa4b3CQ3F7pJ6td93dbxjmcpdq4mePz/EJl2C6G3nyO0zSBQkMTUG739DBfrDycHDK4PE6B8n/9YrK7Hcj8lpA7n552Kqm0AMA39WmAnUHDKGoFdzuOmkud1uFzerT1tHSQK4K5jAAASLklEQVRT3ox7CGB/0Fpa22XFGRdI/brvoOq4dTlsvY7w8KnUHfef5I96P7lNz5Fb/xRWtr2C0VafQ/0Mh0IWsVyKltbyLrKYiPf85d0310cUVWvD7horrbfSx83AhJLfJ2K3YpoVVWt1s7ynuC2PNz/EUMub1K9P6tgqOEtaXGvPiTdUIXHubdSddB4kDvf8Pfh5O6j67fqcSAIr3+laPN3xUwsJ7BF2V5iG/hR2y+iHwKuKqm01Df1e4FrT0BcAOexBB/crqravu83tciGE5yzy214jv+01QkPeS93oj1B/zhzy2xaT2/wShV1vgiX/ZfvKvgepsgMawH8J6Vbsa0mLsVtvLwH/5ZTNAYYCy52yh7GHalOhciGEjxR2ryC1ewXBxlHUHfM+4hO/jlXIkN+6kNzm+bLuUh8EBwyn6MEqvwHLOlAjSuyPaeiNQOt50+d5eg1pSNNAdss1DldI/bqvrHUcqiN8xClEhk8lNFSh2LGZ3Jb55Lcs7LfXmg61fhPqdeS2GOQ2PFPWeBLxKA/cNQOgyblM82/81kISQohDU8iS37qQ/NaFBGKDiBx1BpERZxMd+1/kmp8n+9ZfsTLv+hsougrHCQ4cQeGNVZU/dcXPKIQQLrPSe8iu+xvZdX8jNEQhOvZT1E+7hdymf5J964l+22I6GKFBYyCfptixueLnloQkhKhphd0myQUmocNOJjrmU9RPm0Z247Pk1j2JldvrdXi+Exo8lsKeNRx4TFz5SUISQvQLhZ3LSO5cRujwCUTHfoq6Y95HdsPTZNc9ITNAlAgPHkt++2Jvzu3JWYUQwiOFHUtI7lhKeNgp1I39NJER55Bd8xi55hdkyHiojmDjKPLm7z05vSQkIUQ/5NzPtP11IiPOpm7Mp4iM+gCZlQ9T2LHU6+A8E2oaDcUCxfZup+90lSQkIUT/ZRXtaYi2LqTuuI8Qn3gphT1ryaz4Q7+8jyk0+AQKrWvBKhz4YBdIQhJCiHya7OpHyW16jugJnyFx5nXkt75MfvcKih1bKO7dAoWs11G6LjR4LIXdKzw7vyQkIYRwWOkW0kvvJbj+aeqO+w/qjv0wwfojCARDFDt3UNi7xU5Q7ZvIb3/ds5aEK4JhQk2jya55zLMQJCEJIUQXxfaNpJfcbf8SDBOsH0aw4WiCDcMJNRxNZOQ51I3+KOk37vPkfh03hBqPBaDQus6zGCQhCSFET4p5ih2b/z3xhBPElC+QUH9Adu1fyL7116pvLYUGj6XQtg6Kec9i8M3yE0IIUTXySdJv/JLUa3cRGXE2CfX7BBtGeB1Vn4QGn0ChZbWnMUhCEkKIQ1TYuYzOF79PoW0DCfX71I35JARCXofVe4EgoUHHe56QpMtOCCH6Ip8is+xX5N9eROzkrxI+4hQy5u8otKz0OrKDFhw4EoIRCnve8jYOT88uhBA1orBrud1aallFfPIM4qd9q2q68UKDT7Bvhi14O4WSJCQhhCiXfJqM+Vs6n78WK7uXxJnXERt/EYH4YV5H1qPQ4LEUWiq/3ERXkpCEEKLMrNRO0kvvIWncSKBuAPXn3ERU+SKBuoFeh7YfAcKDxpD3+PoRSEISQgjXFNs3kVp0B6lX5hFqOo76aXOJvufzBGJDvA7tHcGG4RBJUGhZ43UokpCEEMJthZaVJOfPJr3kboJNx1I/bS6xCV8n2Hic16ERGjzWvscqn/Q6FBllJ4QQlZLfsYT8jiUEG4+l7tgPkTjjaoqt68hueIr8tsV4sSieH+4/2kdaSEIIUWHFtvWkl/yCzue+S2HPWmInn0/9tLlERn0QwvGKxuKXAQ0gLSQhhPCMlW4hs+phMmv/TGTEmdQd836iYz9NbvNLZDc8g5Xc4er5A/XDCEYbfXH9CCQhCSGE9wppchueIbfhWUKHj6du1AeoP2cOhR1vkN3wtGtLQoQHj6Ww922sbLsrr99bvklIpqH/CvgiULroyLmKqi1yysPA7cB52F2NjwCXKaqWqUS5EEK4z6KwYwmpHUsINhxNZNQHiZ96JcXkdrLrnyK/ZUFZJ3H1U3cd+CghOX6mqNqV3ZRdA5wDnATkgMeBm4EZFSoXQoiKKXZsJrPsfrKrHiEy4hyiJ3yW6NhPk13/FLnm5yHf91kVQoPHkln1aBmiLY9qGtRwITBbUbWtiqrtBK4HzjcNPVih8v0KeLz5IYZa3qR+pY693sh2kHvrLyT/OZPs2j9TN/JcBpz7I6JjP0OwbuAh128wNoRgfCjFllUV//fujt9aSJpp6BrwNvBL4A5F1YqmoTcBI4DXS45dDDQBI0xDb3OzHNjYXcCDGhuIR6OH8l7LZnCTH+/+rh1Sv+6TOj5I7Yuxlr5Occh4gsPfT91x/0FwxyuEts8n0Lml2z/4XevXCtZRGP4+CundDI4VIFaZ+o/F6nos91NCuhOYCbQAk4E/AEXgDqDBOaat5PhW52eDc5yb5d3a09ZBMuXNZaYA9getpbXdg7sXap/Ur/ukjg9R6wvw1guEhihEjvswoXFXYWXaKexcRn7nGxR2mZBP/Vv9UjeQ0OHjCR8xgdDQE6GQJbvmMfa2Vm5AQyLe85d33yQkRdUWl/y6wDT0uYCGnZA6nP2NwC7ncZPzs6MC5d2y8OJWNv/FUMukft0ndXxo8rtN8rtNiNQTHnoS4cNPJnqiRiASp7BnjZ2gYjFi73kvwUGjsVK7yW9fQnbRjynsWVPxVW4P9G/s52tI+1otKKrWCjQDE0rKJ2K3YprdLi/XGxJCCFfkOsm//TLppffS+eyVJBfMpdCymvCwUykOnUB+15skX7qezue+S2aFs1aTD5dc900LyTT0/waexG6RTAK+B/y05JB7gWtNQ1+APQrueuB+RdWKFSoXQogqYFFsW0e2bR25NX9iSNNAOqqkS9Q3CQm4HLgbO6YtwM+w7wvaZw4wFFiO3bJ7GHuodqXKhRBCuChgWdWQN/3HNPRGoPW86fM8HdQwpGkgu6vk20+1kfp1n9Sxu/xWv4l4lAfumgHQpKhaW9dyP19DEkII0Y9IQhJCCOELkpCEEEL4giQkIYQQviAJSQghhC/4adh3VYofYG4mNwWw54ZKxKO+GEFTa6R+3Sd17C6/1e+B/l5KQjp0AwHuvu1yr+MQQohqM5B/nzsUkITUF5uBkYA/lloUQojqMBD77+e7yI2xQgghfEEGNQghhPAFSUhCCCF8QRKSEEIIX5CEJIQQwhdklJ2HTEO/DPgKMA5YqKjatD6+3seBHwEjgDeBixVVW1JSfjT2kh4fwr5FYYWiamf05Zx+Vsn6NQ19GvBPoLPkKfcqqnZlX87pd5X+DJccdzP2mmmfVlTtsb6c088q/Bn+GDAXGI69uOtrwAxF1Zb15Zy9IS0kb72N/QG4o68vZBr68cDvgJnAIOBR4C+moced8nrsP5grgFHYaz/V9B9LKli/jjZF1QaUbLVev1D5OsY09PHAJ5xz17pK1u9i4IOKqg0CDgf+AlQ02UsLyUOKqj0KYBr6yK5lpqEfht2a+QB2a+ZPwFWKqnV2PdahAc8pqva48/xbgenAh4E/Al8FWhRVu77kOS+X5Y34VIXrt1+qdB2bhh4C7sFe0PP+sr4ZH6pk/SqqtrXk2AB2K+kY09AjiqrlyvSWeiQtJB8yDT2A/c2kBRgDvBf7Jty5PTxtHPD6vl+cpdeXOPsBzgFWm4b+qGnou01DX2Ia+qfdiN/vXKpfgAGmoW81DX2zaegPmoY+vOzBVwkX6/hKwFRU7Z/ljrmauFW/pqGPNA29FUgDPwbmVCoZgSQkvzoV+wN2laJqnYqqtQLXA1/q4TkNvHsqjlZnP8Bg4IvAb4FhwHeBB53uj/7GjfpdCUzA7ps/Ffsb5p9NQ++v/8fKXsemoR8LfBP4dtmjrT5ufIZRVG2TompNQBNwBSUJrBKky86fRgGNwG7T0PftCwBR09AHYF+U/LKz/0VF1f4T6HCeU6oJWO087gAWKKr2iPP7301D/wfwUWCpG2/Cx0ZR5vpVVG0bsM3Zv8009Iux//OPxU5W/c0oyv8Z/gXwA0XVdrkYd7UYRfnr9x2KqnWYhv5TYJdp6EsUVVtf/rfwbpKQ/GkjsFNRtWHdlH/d2Uq9AUzc94vTpB8P3OfsWgq8v8xxVis36rer/j4nlxt1/EFgnGnotzi/Hwbcbxr6uYqqXVG2yKtDJT7DASCKnfwkIdU609DD2P8GYSBoGnoMKAKvAuudoa1zsSdwHQ5MVFTtz928nA7McIZuPsW/RtA94fz8NTDTNPRPAn8GzgXeB1xb9jfmE5WsX9PQzwU2ONsQYB6wHFhT9jfmIxX+DB/Z5fhFwPep4UElFf4Mfx57qPc67AlQZwNJ7NF3FdFf+7f9YhaQAm4DznIeP+VcbPwE9jfAN7G7fp4GTuzuhRRVW4t9jWiec/xngY8rqpZyytcDnwZuwv7w/hj4kqJqtdxdV7H6xf7m+QKwF1gGRICPKapWKP/b8pVKfoa3lW5AAWhVVO1dyxjUkEp+ho8FnsHu2luN3TL6YCXrV2b7FkII4QvSQhJCCOELkpCEEEL4giQkIYQQviAJSQghhC9IQhJCCOELkpCEEEL4gtwYK0SVMQ19FPad82cpqvaSx+EIUTaSkITowjT0XwFHK6r2Aa9j6UYz9qwFu90+UUny26cDWAvcoajaA718rVnAhYqqjSpbgKKmSJedED5hGnrdwRynqFrBma2gYssCAJ/EToKnYC/sppuG/qEKnl/0A9JCEqKXnPnFZmEvLX0k8BZwp6Jqvyg55grgfOB47OmEngO+paja2075NOwVfD8GXI29nMC3TUPfC9yLvX7VXcB7sOfE+7qiaq85zx1FSZddye+fB87DnkR3G3BdaSvGWb7hF8DZwA7gFuBzwFpF1S48wNtucabrAZhtGvq3gP/AnhNt30Sdd2PPj3gU9kqnvwduUFQtYxr6V4EfOsfumx7mBkXVrj+Y+hT9g7SQhOi9e4HPAJdgr0lzI3CLaegXdDnu28DJ2HMIjsT+A93V7cCtzuvsWy46CNyMvR7NKcAe4CHnD3dP5gIPYC+49hD2TNhj4J2E8Ufs5QfOxp4H7aOUzP58MExDD5mG/gXs9bWyJUUBYDv2XGnvxZ6483zgGqf8D9gJcDN20jkSe4kEOPj6FDVOWkhC9ILTytAARVG1fescrTcN/QTs5aDvA1BU7SclT1tvGvplwGLT0IcrqralpOymfUtKO68P9h/3KxVVW+zs+wGwABgNrOohvP9VVO0h5zmzsJf5fh/2jOMfwF5qYIwzySamoX8ZO0EcjKdMQy8CMSAE7MReShzn/RaxWzn7bDANfTTwDeyWWspp/RVKWloHXZ+if5CEJETv7FsN9tWShdHA/r/0zszeTpfc1YCCvQjavt6IY4DShPTKfs5h8e+LJu47/gh6TkhL9j1QVC1vGvp25zk4cezal4ycY1pMQ+/p9Uqdj700wbHYs0XfoKjautIDTEO/CLgQe5boepwlEw7wugdVn6J/kIQkRO/s+wM7FXutmFIWgGnoI4G/YXef3QjsAo7Gntq/68CFzv2co9hl2Yp911wO9Mc92+V3q8tz+jK1/xYnma11uuwWmob+5r5WjWnonwN+CnwPeB57iZPPYS930pMD1qfoPyQhCdE7rzk/Ryqq9pdujjkNiGN3u6UATEOfVIngemACh5mGfnxJl90g7CXWX+vxmV0oqrbcNPQ/Y6/R83Fn99nA64qqzXvnhPZgi1JZ7O6+UgdTn6KfkIQkxP4NMA19Qpd9aUXVVpqG/kvgHtPQv4N9bacemAQcpqjaLdjXbCzgKtPQH8S+dvODCsa+P89gdwPqzgjALHbrJc+htURuw74mpiqqZmB3JV7grEj8Jvbowc90ec56YJhp6Gdg11FSUbW1B1Gfop+QUXZC7N/pwOtdtn2j4C4G7sBe/t0EnsUesrwOQFG1N7AvyF/ilH+bfy0X7QlF1Szs0X6dwIvAX7CXrl4FpA/h9V7HTnJznV2/wO6ivB+7rk4Hru/ytMeAh4G/Yg+K+I6zv8f6FP2HrBgrRD9lGnoD9ii7WYqq3eV1PEJIl50Q/YRp6J/A7qJbARwOXIfdXfeQl3EJsY8kJCH6jwT2taxR2F13rwFnKqq23cughNhHuuyEEEL4ggxqEEII4QuSkIQQQviCJCQhhBC+IAlJCCGEL0hCEkII4QuSkIQQQvjC/wdV5HIuJYQVnQAAAABJRU5ErkJggg==\n", 407 | "text/plain": [ 408 | "
" 409 | ] 410 | }, 411 | "metadata": {}, 412 | "output_type": "display_data" 413 | } 414 | ], 415 | "source": [ 416 | "learner.recorder.plot()" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 20, 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "data": { 426 | "text/plain": [ 427 | "PosixPath('/home/ubuntu/MultiCampus/MICCAI_BraTS_2019_Data_Training/models/trained_model.pth')" 428 | ] 429 | }, 430 | "execution_count": 20, 431 | "metadata": {}, 432 | "output_type": "execute_result" 433 | } 434 | ], 435 | "source": [ 436 | "learner.save(\"trained_model\", return_path=True)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 20, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "learner = learner.load(\"trained_model\", device=1)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 20, 451 | "metadata": {}, 452 | "outputs": [], 453 | "source": [ 454 | "learner.metrics = [dice_coefficient]" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": 23, 460 | "metadata": {}, 461 | "outputs": [ 462 | { 463 | "data": { 464 | "text/html": [ 465 | "\n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | "
epochtrain_lossvalid_lossdice_coefficienttime
011968.47656211397.2783200.72851431:11
" 485 | ], 486 | "text/plain": [ 487 | "" 488 | ] 489 | }, 490 | "metadata": {}, 491 | "output_type": "display_data" 492 | } 493 | ], 494 | "source": [ 495 | "learner.fit_one_cycle(1, max_lr=3e-04)" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": 24, 501 | "metadata": {}, 502 | "outputs": [ 503 | { 504 | "data": { 505 | "text/html": [ 506 | "\n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | "
epochtrain_lossvalid_lossdice_coefficienttime
010191.50683611200.1494140.73753132:57
111503.08789111923.1884770.64409532:46
211965.60546911864.0839840.65834332:44
311030.50683611926.5810550.67901232:11
410746.77929711227.8847660.70608832:13
510609.99707011073.2656250.70922532:19
610128.07519510866.3945310.73071532:28
78688.51562510695.4970700.72107032:36
89033.50683610564.5283200.75518432:44
98694.73242210569.3300780.75543730:29
" 589 | ], 590 | "text/plain": [ 591 | "" 592 | ] 593 | }, 594 | "metadata": {}, 595 | "output_type": "display_data" 596 | } 597 | ], 598 | "source": [ 599 | "learner.fit_one_cycle(10, max_lr=3e-04)" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": 26, 605 | "metadata": {}, 606 | "outputs": [ 607 | { 608 | "data": { 609 | "text/plain": [ 610 | "PosixPath('/home/ubuntu/MultiCampus/MICCAI_BraTS_2019_Data_Training/models/trained_model_1cycle.pth')" 611 | ] 612 | }, 613 | "execution_count": 26, 614 | "metadata": {}, 615 | "output_type": "execute_result" 616 | } 617 | ], 618 | "source": [ 619 | "learner.save(\"trained_model_1cycle\", return_path=True)" 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "execution_count": 18, 625 | "metadata": {}, 626 | "outputs": [], 627 | "source": [ 628 | "learner = learner.load(\"trained_model_1cycle\", device=1)" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": null, 634 | "metadata": {}, 635 | "outputs": [ 636 | { 637 | "data": { 638 | "text/html": [ 639 | "\n", 640 | "
\n", 641 | " \n", 653 | " \n", 654 | " 80.00% [8/10 4:10:44<1:02:41]\n", 655 | "
\n", 656 | " \n", 657 | "\n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | "
epochtrain_lossvalid_lossdice_coefficienttime
010392.3994149580.9570310.73688031:50
110042.61035210256.2548830.70132030:34
210119.89843810275.0361330.68950131:14
310711.19921910155.2695310.58403531:59
49853.23535210081.0664060.67310931:24
59889.5058599921.9316410.68896431:13
68726.8320319781.7431640.70829031:18
78652.9892589700.1904300.73205431:08

\n", 726 | "\n", 727 | "

\n", 728 | " \n", 740 | " \n", 741 | " 76.39% [220/288 22:29<06:57 8392.4023]\n", 742 | "
\n", 743 | " " 744 | ], 745 | "text/plain": [ 746 | "" 747 | ] 748 | }, 749 | "metadata": {}, 750 | "output_type": "display_data" 751 | } 752 | ], 753 | "source": [ 754 | "learner.fit_one_cycle(10, max_lr=3e-04)" 755 | ] 756 | } 757 | ], 758 | "metadata": { 759 | "kernelspec": { 760 | "display_name": "Python 3", 761 | "language": "python", 762 | "name": "python3" 763 | }, 764 | "language_info": { 765 | "codemirror_mode": { 766 | "name": "ipython", 767 | "version": 3 768 | }, 769 | "file_extension": ".py", 770 | "mimetype": "text/x-python", 771 | "name": "python", 772 | "nbconvert_exporter": "python", 773 | "pygments_lexer": "ipython3", 774 | "version": "3.6.5" 775 | } 776 | }, 777 | "nbformat": 4, 778 | "nbformat_minor": 4 779 | } 780 | -------------------------------------------------------------------------------- /ndv/training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "47 items written into valid.txt.\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import torch, fastai, sys, os\n", 18 | "from fastai.vision import *\n", 19 | "import ants\n", 20 | "from ants.core.ants_image import ANTsImage\n", 21 | "from jupyterthemes import jtplot\n", 22 | "sys.path.insert(0, './exp')\n", 23 | "jtplot.style(theme='gruvboxd')\n", 24 | "\n", 25 | "import model\n", 26 | "from model import SoftDiceLoss, KLDivergence, L2Loss\n", 27 | "import dataloader \n", 28 | "from dataloader import data" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/plain": [ 39 | "2" 40 | ] 41 | }, 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "output_type": "execute_result" 45 | } 46 | ], 47 | "source": [ 48 | "torch.cuda.set_device(2)\n", 49 | "torch.cuda.current_device()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "autounet = model.autounet.cuda()\n", 59 | "sdl = SoftDiceLoss()\n", 60 | "kld = KLDivergence()\n", 61 | "l2l = L2Loss()" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "class AutoUNetCallback(LearnerCallback):\n", 71 | " \"Custom callback for implementing `AutoUNet` training loop\"\n", 72 | " _order=0\n", 73 | " \n", 74 | " def __init__(self, learn:Learner):\n", 75 | " super().__init__(learn)\n", 76 | " \n", 77 | " def on_batch_begin(self, last_input:Tensor, last_target:Tensor, **kwargs):\n", 78 | " \"Store the states to be later used to calculate the loss\"\n", 79 | " self.top_y, self.bottom_y = last_target.data, last_input.data\n", 80 | " \n", 81 | " def on_loss_begin(self, last_output:Tuple[Tensor,Tensor], **kwargs):\n", 82 | " \"Stroe the states to be later used to calculate the loss\"\n", 83 | " self.top_res, self.bottom_res = last_output\n", 84 | " self.z_mean, self.z_log_var = model.hooks.stored[3], model.hooks.stored[4]\n", 85 | " return {'last_output': (self.top_res, self.bottom_res,\n", 86 | " self.z_mean, self.z_log_var,\n", 87 | " self.top_y, self.bottom_y)}" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 5, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "class AutoUNetLoss(nn.Module):\n", 97 | " \"Combining all the loss functions defined for `AutoUNet`\"\n", 98 | " def __init__(self):\n", 99 | " super().__init__()\n", 100 | " \n", 101 | " def forward(self, top_res, bottom_res, z_mean, z_log_var, top_y, bottom_y):\n", 102 | " return sdl(top_res, top_y) + (0.1 * kld(z_mean, z_log_var)) + (0.1 * l2l(bottom_res, bottom_y))" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 6, 108 | "metadata": { 109 | "code_folding": [ 110 | 1 111 | ] 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "#monkey-patch\n", 116 | "def mp_loss_batch(model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None,\n", 117 | " cb_handler:Optional[CallbackHandler]=None)->Tuple[Union[Tensor,int,float,str]]:\n", 118 | " \"Calculate loss and metrics for a batch, call out to callbacks as necessary.\"\n", 119 | " cb_handler = ifnone(cb_handler, CallbackHandler())\n", 120 | " if not is_listy(xb): xb = [xb]\n", 121 | " if not is_listy(yb): yb = [yb]\n", 122 | " out = model(*xb)\n", 123 | " out = cb_handler.on_loss_begin(out)\n", 124 | "\n", 125 | " if not loss_func: return to_detach(out), to_detach(yb[0])\n", 126 | " loss = loss_func(*out) #modified\n", 127 | "\n", 128 | " if opt is not None:\n", 129 | " loss,skip_bwd = cb_handler.on_backward_begin(loss)\n", 130 | " if not skip_bwd: loss.backward()\n", 131 | " if not cb_handler.on_backward_end(): opt.step()\n", 132 | " if not cb_handler.on_step_end(): opt.zero_grad()\n", 133 | "\n", 134 | " return loss.detach().cpu()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 7, 140 | "metadata": { 141 | "code_folding": [ 142 | 1 143 | ] 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "#monkey-patch\n", 148 | "def mp_fit(epochs:int, learn:Learner, callbacks:Optional[CallbackList]=None, metrics:OptMetrics=None)->None:\n", 149 | " \"Fit the `model` on `data` and learn using `loss_func` and `opt`.\"\n", 150 | " assert len(learn.data.train_dl) != 0, f\"\"\"Your training dataloader is empty, can't train a model.\n", 151 | " Use a smaller batch size (batch size={learn.data.train_dl.batch_size} for {len(learn.data.train_dl.dataset)} elements).\"\"\"\n", 152 | " cb_handler = CallbackHandler(callbacks, metrics)\n", 153 | " pbar = master_bar(range(epochs))\n", 154 | " cb_handler.on_train_begin(epochs, pbar=pbar, metrics=metrics)\n", 155 | "\n", 156 | " exception=False\n", 157 | " try:\n", 158 | " for epoch in pbar:\n", 159 | " learn.model.train()\n", 160 | " cb_handler.set_dl(learn.data.train_dl)\n", 161 | " cb_handler.on_epoch_begin()\n", 162 | " for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):\n", 163 | " xb, yb = cb_handler.on_batch_begin(xb, yb)\n", 164 | " loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler) #modified\n", 165 | " if cb_handler.on_batch_end(loss): break\n", 166 | "\n", 167 | " if not cb_handler.skip_validate and not learn.data.empty_val:\n", 168 | " val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,\n", 169 | " cb_handler=cb_handler, pbar=pbar)\n", 170 | " else: val_loss=None\n", 171 | " if cb_handler.on_epoch_end(val_loss): break\n", 172 | " except Exception as e:\n", 173 | " exception = e\n", 174 | " raise\n", 175 | " finally: cb_handler.on_train_end(exception)\n" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 8, 181 | "metadata": { 182 | "code_folding": [ 183 | 1 184 | ] 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | " #monkey-patch\n", 189 | "def mp_learner_fit(self, epochs:int, lr:Union[Floats,slice]=defaults.lr,\n", 190 | " wd:Floats=None, callbacks:Collection[Callback]=None)->None:\n", 191 | " \"Fit the model on this learner with `lr` learning rate, `wd` weight decay for `epochs` with `callbacks`.\"\n", 192 | " lr = self.lr_range(lr)\n", 193 | " if wd is None: wd = self.wd\n", 194 | " if not getattr(self, 'opt', False): self.create_opt(lr, wd)\n", 195 | " else: self.opt.lr,self.opt.wd = lr,wd\n", 196 | " callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)\n", 197 | " fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 9, 203 | "metadata": { 204 | "code_folding": [ 205 | 1 206 | ] 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "#monkey-patch\n", 211 | "def mp_validate(model:nn.Module, dl:DataLoader, loss_func:OptLossFunc=None, cb_handler:Optional[CallbackHandler]=None,\n", 212 | " pbar:Optional[PBar]=None, average=True, n_batch:Optional[int]=None)->Iterator[Tuple[Union[Tensor,int],...]]:\n", 213 | " \"Calculate `loss_func` of `model` on `dl` in evaluation mode.\"\n", 214 | " model.eval()\n", 215 | " with torch.no_grad():\n", 216 | " val_losses,nums = [],[]\n", 217 | " if cb_handler: cb_handler.set_dl(dl)\n", 218 | " for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):\n", 219 | " if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)\n", 220 | " val_loss = loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler) #modified\n", 221 | " val_losses.append(val_loss)\n", 222 | " if not is_listy(yb): yb = [yb]\n", 223 | " nums.append(first_el(yb).shape[0])\n", 224 | " if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break\n", 225 | " if n_batch and (len(nums)>=n_batch): break\n", 226 | " nums = np.array(nums, dtype=np.float32)\n", 227 | " if average: return (to_np(torch.stack(val_losses)) * nums).sum() / nums.sum()\n", 228 | " else: return val_losses" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 10, 234 | "metadata": { 235 | "code_folding": [ 236 | 1 237 | ] 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "#monkey-patch\n", 242 | "def mp_learner_validate(self, dl=None, callbacks=None, metrics=None):\n", 243 | " \"Validate on `dl` with potential `callbacks` and `metrics`.\"\n", 244 | " dl = ifnone(dl, self.data.valid_dl)\n", 245 | " metrics = ifnone(metrics, self.metrics)\n", 246 | " cb_handler = CallbackHandler(self.callbacks + ifnone(callbacks, []), metrics)\n", 247 | " cb_handler.on_train_begin(1, None, metrics); cb_handler.on_epoch_begin()\n", 248 | " val_metrics = validate(self.model, dl, self.loss_func, cb_handler)\n", 249 | " cb_handler.on_epoch_end(val_metrics)\n", 250 | " return cb_handler.state_dict['last_metrics']" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 11, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "from fastai.basic_train import loss_batch, fit, validate" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 12, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "loss_batch = mp_loss_batch\n", 269 | "fit = mp_fit\n", 270 | "validate = mp_validate\n", 271 | "Learner.fit = mp_learner_fit\n", 272 | "Learner.validate = mp_learner_validate" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 13, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "def dice_coefficient(last_output:Tensor, last_target:Tensor):\n", 282 | " \"Metric based on dice coefficient\"\n", 283 | " pred, targ = last_output[0], last_target\n", 284 | " return 2 * (pred * targ).sum() / ((pred**2).sum() + (targ**2).sum())" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 14, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "auto_unet_loss = AutoUNetLoss()" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 15, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "learner = Learner(data, autounet, loss_func=auto_unet_loss)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 16, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "autounet_cb = AutoUNetCallback(learner)" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 17, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "learner.callbacks.append(autounet_cb)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 17, 326 | "metadata": {}, 327 | "outputs": [ 328 | { 329 | "data": { 330 | "text/html": [ 331 | "\n", 332 | "
\n", 333 | " \n", 345 | " \n", 346 | " 0.00% [0/1 00:00<00:00]\n", 347 | "
\n", 348 | " \n", 349 | "\n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | "
epochtrain_lossvalid_losstime

\n", 361 | "\n", 362 | "

\n", 363 | " \n", 375 | " \n", 376 | " 20.14% [58/288 05:56<23:34 176342.1875]\n", 377 | "
\n", 378 | " " 379 | ], 380 | "text/plain": [ 381 | "" 382 | ] 383 | }, 384 | "metadata": {}, 385 | "output_type": "display_data" 386 | }, 387 | { 388 | "name": "stdout", 389 | "output_type": "stream", 390 | "text": [ 391 | "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" 392 | ] 393 | } 394 | ], 395 | "source": [ 396 | "learner.lr_find()" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 18, 402 | "metadata": {}, 403 | "outputs": [ 404 | { 405 | "data": { 406 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaQAAAEMCAYAAACLA8K2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XmYFNW5+PFvb9PLMMywqCiCKIJJqSyiiJQLZru52ZOb3OQmsRLjFqNEgyGJSuISRNSIid7kl7jEpIxZ1GuMWTQuiVsBiiKIFKtsA8g6zML03l2/P6ownZEZGKarq7rn/TxPPdNTp7rr7UMzb59Tp84JWJaFEEII4bWg1wEIIYQQIAlJCCGET0hCEkII4QuSkIQQQviCJCQhhBC+EPY6gGplGnoAOBpo9zoWIYSoIgOBzYqqvWuItySkQ3c0sMnrIIQQogqNBJq77pSEdOjaAS6e+b+k0llPAggAgxob2NPWgdxNVn5Sv+6TOnaX3+o3Hqvj7tsuh256liQh9VEqnSWZynhy7gAQj0ZJpjK++LDVGqlf90kdu6va6lcGNQghhPAFSUhCCCF8QRKSEEIIX6joNSTT0C8DvgKMAxYqqjatpOxG4JOAAjyoqNpXuzz3SOAe4FygFfiRomp3VKpcCCGEuyrdQnobmAvs7w/9OuAa4LfdPPdBYCdwBPAx4FrT0D9RwXIhhBAuqmgLSVG1RwFMQx+5n7JfOWUfBJpKy0xDPw44BzhKUbW9wOumod8DXAA87nZ5ueuBUB0UvBkqLoQQflUtw77HAVsVVdtesm8x8IUKlXcr4Gy9EZtwCVY+RWbZr6GY6+Wz//3cpT9FeUn9uk/q2F1+q98DxVEtCakBaOuyr9XZX4nybg1qbCAejR7osH9T3PoE+fdcSN2Zs4isvI9AtrVXz+9qcNPAPj1f9Ezq131Sx+7yS/3GYnU9lldLQuoAGrvsa3L2V6K8W3vaOnp/Y2xrO+y8gdiEiymcPIP04p9S3LOmd6+B/W1jcNNAWlrbq+Kmt2oj9es+qWN3+a1+E/Gev7xXy7DvN4CjTEM/vGTfRGBZhcq7ZR3qlk+RevVOcs0vED99JuGR0w7pdfoUg2xSvz7YpI77V/32pNLDvsPOOcNA0DT0GFBUVC1rGnoECO3bnLKComo5RdXWmYb+PHCzaejfBMYAF2EPOsDtcvdYZFc/SrF9E7FxFxAceAyZ5b8Bq+DuaYUQwocq3UKaBaSA24CznMdPOWX3OL9fDnzZeXxPyXO/BAwDdgB/A+YoqvZ4Bctdk9/2KskFNxEeqpA4/TsE6vzR3yuEEJUUsKwDNaLE/piG3gi0njd9XvkmV43UE594KcHE4SQXzcPq3Nbj4QFgSNNAdvukf7jWSP26T+rYXX6r30Q8ygN3zQBoUlSt60CyqrmG1D/kOkktuoPCntUkzriaYOOxXkckhBAVIwnJb6wC6aX3kd9skDj9O4QOO9nriIQQoiIkIfmSRWblQ2TWPEZ80nTCw6d6HZAQQriuWu5D6pdy6/+OlWkjNu5rZOoGklv/pNchCSGEayQh+Vx+60JS2b3ET/kGwVgjmRUPceDR/EIIUX2ky64KFHa9SfLlWwkfNZXY+AshEPI6JCGEKDtJSFWi2LaB5MI5hAaPJT7pcghGvA5JCCHKShJSFbE6t5NcMIdA4nDip82AcNzrkIQQomwkIVUZK72H1MK5BMJR4qfPxArXex2SEEKUhSSkKmRlO0i+fBvkM+RO+iaB2CCvQxJCiD6ThFSt8ilSi+YRSO8ifsbVBOqP8DoiIYToE0lI1ayYI7zqPgota0hMuZrgwHetDC+EEFVDElKVC1hFMkvvJb/tVRKnzyTYOMrrkIQQ4pBIQqoJFpnlvyG3eT6Jyd8m2Hic1wEJIUSvSUKqIZkVvyPX/CKJyTMINo32OhwhhOgVSUg1JrPyD+Q2PU/iNElKQojqIgmpBmVWPUx20z9ITL6K0KAxXocjhBAHRRJSjcqu+j+yG54mftq3CA0a63U4QghxQJKQalh29R/Jrv+7nZQGv8frcIQQokeSkGpcds2fyK77G/HTvkV42GlehyOEEN2S9ZD6gezaP2Nl2olNuIjs6iFk18lCf0II/5GE1E/kmp+nmG4hPvFSAvHDyJgPglX0OiwhhHiHdNn1I4Wdy0guuJnwEROIT5oOoajXIQkhxDsq2kIyDf0y4CvAOGChomrTSsoGAD8HPgFkgPuAqxVVs/xQXiuKHc0k588mfuqVJKZ8j9SrP8HKtHodlhBCVLyF9DYwF7hjP2V3AocBxwCTgc8C031UXjOs9B6SC+diZTtITL2W4IDhXockhBCVTUiKqj2qqNqjwPbS/aahJ4AvAtcqqrZHUbX1wG3ABX4o70nA4+2QY8inSL/6Ewq7TBJnXE1o0FjP34sfNz/8G9f6JnXcv+q3J34Z1DAWqAOWlOxbDJxoGnrI63JF1QrdBT6osYF41NtrMYObBh7yc63m/6NAksTpVxFerRNqWVbGyGpDX+pXHBypY3f5pX5jsboey/2SkBqApKJq+ZJ9rUAIiPugfG93ge9p6yCZyvTirZZPAPuD1tLaTp8udLX+gUj7Tqz3foW9y39DvvmFMkVY3cpWv6JbUsfu8lv9JuI9f3n3yyi7DiBhGnppgmwCCkDKB+XdsjzeyhVDduM/SC+5h6jyJSKjP+b5+/LL5od/41rfpI77V/32xC8JaTWQBcaX7JsImE53mdfl/UJ+2yJSr/6YutEfIap8kQP3+AohRPlUNCGZhh42DT2G3VUYNA09Zhp6naJqSeC3wGzT0JtMQx8FfBu4F8Dr8v6ksHsFyYVzCR95GrEJl0DQL726QohaV+kW0izsLrDbgLOcx085ZVcAu4FNwKvAo8BdJc/1urzfKLZvIrngZkKNo4ifeqXcQCuEqIiAZR2oV0/sj2nojUDredPneTqoYUjTQHa7dMEyUDeQ+OSroJAhuegOyPd4Oa3muF2/QurYbX6r30Q8ygN3zQBoUlStrWu5X64hCR+ysu0kX74VAkESp3+HQF2D1yEJIWqYJCTRs1wnyZd/hJVPEz/9uwSiTV5HJISoUZKQxIEV0qQW3YGV2k3ijKsJxId6HZEQogZJQhIHp5gltfguCu0bSUz5HsH6YV5HJISoMZKQxMEr5km//nMKu1cSn/I9gg0jvI5ICFFDJCGJ3rGKpN+4j/y2V0lM+S7BptFeRySEqBGSkMQhsMgs/w3ZTc+RmPxtQkNP9DogIUQNkIQkDll21SNk1z5OfNI3CQ+b5HU4QogqJwlJ9El23RNkzN8Sm3AJkaPP8jocIUQVk4nKRJ/lmp/HyieJjb8IIgly6//udUhCiCokCUmURf7tRaTyaeKnfINAJEF29R+9DkkIUWWky06UTWHnMlKvzKPumPcTPfHLyPIVQojekIQkyqqwZw3Jl28lPGwSsYmXQjDidUhCiCohCUmUXbF9E8n5cwg2DCdx+kwCkQFehySEqAKSkIQrrNROkgvmYFlFElOvJZA43OuQhBA+JwlJuCfXSeqVH1Fo20jijGsJNh7ndURCCB+ThCTcVcyTXvIL8lteIjFlJuEjJnodkRDCpyQhiQqwyKx8mMzKh4lNvJTIMe/3OiAhhA/JfUiiYnIb/4GVaiE28RJCg0aTXv4g5Dq9DksI0RuBkP3TKpT9paWFJCoqv2MJyZduJJg4gvqzbpSJWYWoMuFhk6g/+yZXXlsSkqi4YufbJBfMIbfpeeKnXmHfRBuq8zosIcRBCEQSWPmkK68tCUl4wyqQXfs4yQVzCA15D/Vn3iBrKwlRBQKRBJZLXe2+uoZkGvpo4E7gDCAH/BK4VlG1omnoYeB24DzsRPoIcJmiahnnua6WC3cU2zaQfOkGoid8lsSU75Jd9yTZNX9ypX9aCNF3gXA9Vq7GW0imoYeAx4E3gSOBU4GPADOdQ64BzgFOAsYAJwI3l7yE2+XCLcUcmRW/I7VoHpHhZ5A442oC8cO8jkoIsT8utpB8k5CAE5ztB4qqZRRVawbuAL7ulF8IzFZUbauiajuB64HzTUMPVqhcuKyweyWdL11PMd1K/ZnXER52qtchCSG6CETcayH5qctuf1NDB4BRpqGPBEYAr5eULQaagBGmobe5WQ5s7Clor+a0DnT5WRNynWQW/y/FY95HbPxF5IcqZMzfQzFb8VBqsn59RurYXW7UbyCSgFznIb3mgZ7jp4S0CngL+KFp6NcBRwBXOGWW87Ot5PhW52cDUHS5vFuDGhuIR6M9HeK6wU0DPT2/K9pepbjsbYInfIW6s68jvOpXBFPbPQmlJuvXZ6SO3VXO+s3GGohFLEKH8JqxWM+jaX2TkBRVy5uG/nHgx9gtkhbgPmAc/0oYjcAu53GT87PD2dws79aetg6SKW/GPQSwP2gtre3vZOya0toO268jeuKXKY6bQWb5b8lvfrFip6/5+vUBqWN3uVG/iWCMVNtu8q3tvX9uvOcv7766PqKo2mpF1T6iqNrhiqq9B0gCixRV2wI0AxNKDp+I3YppVlSt1c3ynmK2PN78EIOrWyFD+o37SL/5AFHlf4ic8Fmp3xrbpI6rq34DkQTFfGef4umOb1pIAKahjwPWAWngXGAWoDnF9wLXmoa+AHtI+PXA/YqqFStULjyU3zKfZOd2EpO/Dbm9ZNc96XVIQvRDAQjH+8WgBoDPAd8AYsBK4EJF1Z52yuYAQ4Hl2C27h7GHalOhcuGxYutbpBb/L/FJ38TKJck1v+B1SEL0L+E4gUDQtWHfAcs6UCNK7I9p6I1A63nT53l6DWlI00B297P+9/Cw04hNuMhe1mLba66dp7/WbyVJHbur3PUbiB/GgHNvoePp6Yc0MXIiHuWBu2YANCmq1ta13FfXkIQ4GPlti8gsf5DY+IsJDVW8DkeIfiMQSdgPan2mBiF6I9f8PNk1fyJ+ynSCTbISrRCV8K+bYt1pz0pCElUru+5v5Db+g8SpVxIcMNzrcISoefbEqu60jkASkqhymVUPk9u+mPjkqwgkDvc6HCFqmt1Ccm9RTUlIoupllv2aQssq6qfOIjT4BK/DEaJmBSLuDfkGSUiiJlikl9xNduOzxCdfRWTEOV4HJERtitRj5d1rIfntPiQhDpFFds2fKO7dSmzcBQQbhpNZ8Xuw5L5mIcrF7WtIkpBETcm/vYhkcifxSdMJ1g8j9frPwaXlloXobwLheqz0btdeX7rsRM0ptm0gafyQQCRB/dRZBOqP8DokIWqCjLIT4hBYmVaSC2+h0Lae+qmzCB95mtchCVH13FycDyQhiVpWzJFeeg+ZtX8hdvL5JGQUnhB9IteQhOij3Pq/k9+ygLrjP0Z88lUUdplkVj1CsWOz16EJUVXkPiQhysDKtpMxf0vnC7Ow8ikSZ15PbNzXCMQGex2aENVDriEJUT5WcgfpJb8gOX82gdhg6s+5mbrjPuJ1WEL4XzhmLz0h9yEJUV7Ftg2kXvkRocPHEx9/MYH4YDLLH8StSSOFqHaBcL39wM8tJNPQB5QjECG8UNixlOTLtxAeNonYxEshKN/RhNiffUtP+KbLzjT0q0xD/3zJ7zrQZhr6etPQ31v26ISogGL7JpLz5xAaOIL4aTMgHPc6JCF8JxCpx8qnwSq4do7etpAuBbYBmIZ+FvAZ4IvAIuDW8oYmROVYqZ0kF8whEIqSmPI9AtEmr0MSwlfcHvINvU9IRwHrnccfBR5RVO0PwA3AlHIGJkSlWdkOki/fhpVpI3HGNTLDgxAl3B7yDb1PSJ1Ao/N4GvBP53EKSJQpJiG8U0iTevUnFPasITHlGooNshqtEIDrQ76h96PsngduNw39JWAi8KSz/wSguZyBCeEZq0B66b1ET/gvAiddTuyIZWTW/Ili2wavIxPCM3aXnb9aSFcAaexrR5coqrbd2f8R4JlyBiaEtyyyqx4h8vpcrFwniamziJ96BcHGY70OTAhPBCL1rs+c36sWkqJqW4BP7Gf/9HIEYxr6UcBdwDlAAHgJuExRtc2moYeB24HzsBPpI05Zxnmuq+Wifwqmd5BZei+ZtX8mOvrjJKZeS2HncjJr/0SxdZ3X4QlRMYFIAivT5uo5ejvsO2oaerTk9+GmoV9mGvq0MsXzMyACHAuMwL5mdY9Tdg12ojoJGAOcCNxc8ly3y0U/ZnVuJ/3GvXS+cC1Wtp3ElKuJTfwGhGNehyZERQTC7s70Db3vsnsMuATeuSH2FWA28LRp6F8tQzzHAX9QVK1DUbUk8CBwslN2ITBbUbWtiqrtBK4HzjcNPVihciGcxHQfnS/OIlh/uD0aLz7U67CEcJ0fh31PAp5zHn8K6ACOwE5SM8oQzzzgv01DbzINvQG7++yvpqE3YbeYXi85djHQBIxwu7yngAMeb36IoZa37uqXzu2kFtyM1bmD+qnfJzRojOexVuvWXR3L5rP6jdRDrrMs8XSnt6PsBgItzuP3A48pqpY1Df0Z4M5evtb+GMAFzjksYKlznganvLQDs9X52QAUXS7v1qDGBuLRaE+HuG5w00BPz1/reqpfa90DFEb+J4nTZxJe9zChHS9XMLLaIZ9hd5WjfjPRegZEIdSH14rF6nos721C2gKMMw39beBDwPnO/iagTxf/na6xp4GHgP9wdl8PPAF82Pm9EdhVck6wW2kdLpd3a09bB8mUN+MeAtgftJbWdpkS1AUHXb+tvye8awPRk88nGRxMduVDyCStB0c+w+4qZ/3Wh+K079lJsbX9kF8jEe/5y3tvu+zuA34HLMdOQPtujJ0MrOxtcF0MBo4B7lRULelcQ7oLOB07cTYDE0qOn4jdimlWVK3VzfKegrY83vwQQy1vB1u/ua0LSS68hfDwKcROvQIrHPc89mrZDraOZfOwfkN1BIJhrFyyLPF0p7fDvueYhr4SGAk8pKhazikqAj/qzWvt57V3mYa+FrjMNPQbnN3Tgc1O2b3AtaahLwBy2K2n+xVV29fd5na5ED0qtq0jadxIfNI3SUydRfr1n1PskPvFRfULROylJ9y+MbbXc+0rqvbofvb9sjzh8EngDuyuwQCwhH/d9zQHGIrdOgsCD2MP1aZC5UIckJXeQ3LhzUTf+z8kps4is+oRchue4cDfDYXwr3eWnnD5xtiAZfXuP4pp6GOB72Dfp2Nh/wG/VVG1NeUPz79MQ28EWs+bPs/Ta0hDmgayW/rfXdHX+g0Pm0TspK9QaFtPeul9WNlD73uvVfIZdle56jc0aCzxyTPY+/ev9ymeRDzKA3fNAGhSVO1dd9n29sbYDwLLsK+vLMS+D+kUYJlp6O/vU6RC1Jj8ttfofOk6CEZInHUDocNOPvCThPChStyDBL3vspsD/D9F1a4s3Wka+k+wZzWYXK7AhKgFVnoPqZdvo270R4lPmk5u4z/JrHoYinmvQxPi4FVg6QnofUI6Cfjyfvb/DLi47+EIUYsssm/9hfxuk/j4iwkNPZHs2sfJv70IubYkqkGlWki9Hfbdwf5nLjgGkA5yIXpQbF1Hp3E9+W2vETvpPOrPmUPk6LMg2OuxRUJUVCCSAB+2kP4I3G0a+qXAi86+s7FbSO8afSeE6CKfJrvmj2TXP0HdyGnUnfAZ6sZ8kuz6v5Nrfh4KWa8jFOJd7NVi/XcN6SrgfuzZE0r7Gh4GZpYrKCFqXj5Ndt2TZDc8S+ToM6k77sPUHf8xcuufJrv+SbnGJHwlEElgZX3WQlJUbS/wOdPQRwOKs3s59pIRr5TsE0IcjGKO3KZ/kmt+gfCRk4mO/RSRo6eSXvYrCi2rvY5OCMBuIRWTO10/zyF1Xiuq9hbw1r7fTUMfj72MuRDiUFgF8lsXkN/+GtExnyI+eSa5zS+RWfmw66t0CnEggbA/BzUIIdxUyJJZ+RDJ+bMJNY6i/uzZhIdN8joq0d9VaNi3JCQhfKjYvpHk/NnkNjxFbPxFxCZNJxAb5HVYop/y67BvIUSlWAWy656k88UfEAhFqT9rNuGjpngdleiH9i3O57aDuoZkGvpTBzhkQBliEULsh5XcQeqVHxEZOY3YyV8hf9hJpJf/BvJpr0MT/UEwTCAU8dWw7y0HccyqvgQihOhZbtNzFFpWE5twCfVn3kBqyS8otq7zOixR4yq19AQcZEJSVO38Ax8lhHBbce9WkvN/SPSEz5KY8j2yax4n+9ZfkSmIhFv+lZD800ISQvhFMU9mxe/J71pObNzXCA09kfTSu7HSe7yOTNSgQCSBVcxD0f1ZRGRQgxBVqrBzGckXr4NClvqzbiR02DivQxK1qEJDvkESkhBVzcq2k3r1x2Tf+ivxSZdTd/wnsJdlE6I8KnVTLEiXnRA1wCK77kkKbRuJTfg6ocZRpJbeA/mU14GJGlCpId8gLSQhakZh9wqSxg0Eoo3Uqz8gOGC41yGJGlCpm2JBEpIQNcVKt5BceDOFltUkps4ifORpXockqpydkCrTQpIuOyFqTTFPetn9RFrXERt3IbnGY8msegSsoteRiSpUyRaSJCQhalSu+XkKHc3ET7mcQGwQ6aX3SFISvRepr9gtBdJlJ0QNK7auI7lwLqFBY4hN+DoEQl6HJKpMv2whmYa+t8uuKLBCUbVxTnkYuB04DzuRPgJcpqhaphLlQlQrK7mD5MJbSEyZSWzipaSX/FxWpBUHLdAf70NSVG1A6QasAH5fcsg1wDnAScAY4ETg5gqWC1G1rNROkgtuITRwBPFTLoOgb76LCp8LhBPQn0fZmYY+GXs59F+V7L4QmK2o2lZF1XYC1wPnm4YerFD5fgU83vwQQy1vtVS/pHeTWjiXYP0w4pO+SSAY8TymWqtjP259rl/nPqRyxtMdv35NugB4QlG1rQCmoTcBI4DXS45ZDDQBI0xDb3OzHNjYXaCDGhuIR6OH8h7LZnDTQE/PX+tqq34LWOZPyZ14GZEpM4isvJdAMed1UDVWx/5zqPVrBUJkw1EGxgIEy/BvFIvV9Vjuu4RkGnoC+AKglexucH62lexrLSkrulzerT1tHSRT3lxmCmB/0Fpa22WuZxfUbv22E5h/M7HTZ5IfcwGpV38CBfkM16K+1m+groF6oLVlB1a6vc/xJOI9f3n3Y5fdfwNJ4K8l+zqcn40l+5pKytwu75bl8eaHGGp5q9X6LWbbSb18K0TqiZ8+EytSL3Vco1uf6tdZeqKYS5Y1nu74MSFdCPxaUbV3hgEpqtYKNAMTSo6biN2KaXa7vEzvSwhfsbIdJBfeAsUCiSnfIxBtOvCTRL9iLz1RgEJlVif2VZedaegnAFOBr+2n+F7gWtPQFwA57EEH9yuqVqxQuRC1J58kueh24hMvI3HGNSQX3Y7Vud3rqIRPBCL1WPnKjLAD/7WQLgBeVFRt9X7K5gAvAcuBtYCJPVS7UuVC1KZCltRrd1JofYvElKsJDhzpdUTCJwKRyg35BghY1oF69cT+mIbeCLSeN32ep4MahjQNZLdcEHZF/6vfANETv0TkqCmkXr2Twp79fS8s9xn7Wx1XVl/rN3LM+4gMn0py/uyyxJOIR3ngrhkATYqqtXUt91sLSQjhGYvM8t+Q3fAM8ckzCB0+3uuAhMcquTgfSEISQnSRXfMYmVWPED/lMsJHn+l1OMJDlZw2CHw2qEEI4Q+5Dc9gZTuInfw1svVHkF31KAcetCtqTgUnVgVJSEKIbuS3vkwqtZvYKdMJJo4gvfReKGa9DktUUCCSoNi5rWLnky47IUS3CnvWkpw/m2DDUSSmfJdAtPHATxI1wx5lV7kuO0lIQogeWamdJOffhJVPkpg6i2DDCK9DEhViX0OSQQ1CCD/Jp0gt+jH5nctInHE1ocPGeR2RqAB7cT4Z1CCE8BurQOZNnWLnNuKTLie34Vmy657AyvZ90k3hT4FwZVtIkpCEEL2SW/8UxfbNRN/zWeqPmUZu03Nk1z2JlXnXfY6imgWCBCJxaSEJIfytsNskadxI6LBxRMd8gvpp55Lb9LzdYsq0HvgFhP+F4wDSQhJCVIfCzjdI7nyD0NCTnMR0C7nmF8i+9VdJTFUu4Cw9IQlJCFFVCrveJLnrTUJDFaJjPkX92bPJrHyYXPMLyA211SkQSWBZRcinKnZOSUhCiLIp7DJJ7jKJHH0W0fd+nvBRk0kv+zVWcofXoYleCkTqnWRUuS8UMuxbCFF2uc0v0vnCLKxckvqzbqTuuA9DQP7cVJNAhacNAmkhCSFcYmVaSS/+KeFhk4ie+GXCR04m/cb9FDtkEeZqUOmJVUFaSEIIl+W3vUbnC7MotjeTUL9P3eiPeR2SOAjSQhJC1KZcJ+ll9xN6+xXiEy4h2HAU6Tfuh2LO68hEN+ybYqWFJISoUYVdy+lccBOhxlEkTp9JoG6g1yGJ7njQQpKEJISoKKtzO53zb8Iq5IirsygmjvQ6JLEf9kzfkpCEELUu10lq0TwKu0xyJ18hk7X6UCBSj5WXLjshRH9gFcgs+xWh5ieJTbqcyKgPeB2RKCGDGoQQ/U5463Ps3bWJ2ISLCQ0cSdr8LeTTXofV73kx7Nt3Cck09I8DNwBjgQ5gnqJqt5mGHgZuB87Dbtk9AlymqFrGeZ6r5UII9xR2LCG5YA6x8RdRf9YPSS+7n8Iu0+uw+jUvWki+6rIzDf3DwM+BmUATcALwhFN8DXAOcBIwBjgRuLnk6W6XCyFcVOzYTHL+D8ltnk/81CuJnqRBOOZ1WP1UAMKVXXoCfJaQgB8CsxVVe1ZRtbyiau2Kqr3plF3olG1VVG0ncD1wvmnowQqVCyHcVsyTXfNHkvNvIjToeOrPupHQUMXrqPqfcJxAINh/ryGZhl4PTAJ+bxq6CQwB5gNXAO3ACOD1kqcsxm5FjTANvc3NcmBjd3EHnM0LgS4/RXlJ/bqvuzq22jeSMm6k7viPEz/1SvKbXyKz8iG5ttRLh/oZDkQS9oNcsqyf/wO9lm8SEjAIO96vAB8BdgB3AP8HfMY5pnRJyn2LrTQARZfLuw+6sYF4NNrTIa4b3CQ3F7pJ6td93dbxjmcpdq4mePz/EJl2C6G3nyO0zSBQkMTUG739DBfrDycHDK4PE6B8n/9YrK7Hcj8lpA7n552Kqm0AMA39WmAnUHDKGoFdzuOmkud1uFzerT1tHSQK4K5jAAASLklEQVRT3ox7CGB/0Fpa22XFGRdI/brvoOq4dTlsvY7w8KnUHfef5I96P7lNz5Fb/xRWtr2C0VafQ/0Mh0IWsVyKltbyLrKYiPf85d0310cUVWvD7horrbfSx83AhJLfJ2K3YpoVVWt1s7ynuC2PNz/EUMub1K9P6tgqOEtaXGvPiTdUIXHubdSddB4kDvf8Pfh5O6j67fqcSAIr3+laPN3xUwsJ7BF2V5iG/hR2y+iHwKuKqm01Df1e4FrT0BcAOexBB/crqravu83tciGE5yzy214jv+01QkPeS93oj1B/zhzy2xaT2/wShV1vgiX/ZfvKvgepsgMawH8J6Vbsa0mLsVtvLwH/5ZTNAYYCy52yh7GHalOhciGEjxR2ryC1ewXBxlHUHfM+4hO/jlXIkN+6kNzm+bLuUh8EBwyn6MEqvwHLOlAjSuyPaeiNQOt50+d5eg1pSNNAdss1DldI/bqvrHUcqiN8xClEhk8lNFSh2LGZ3Jb55Lcs7LfXmg61fhPqdeS2GOQ2PFPWeBLxKA/cNQOgyblM82/81kISQohDU8iS37qQ/NaFBGKDiBx1BpERZxMd+1/kmp8n+9ZfsTLv+hsougrHCQ4cQeGNVZU/dcXPKIQQLrPSe8iu+xvZdX8jNEQhOvZT1E+7hdymf5J964l+22I6GKFBYyCfptixueLnloQkhKhphd0myQUmocNOJjrmU9RPm0Z247Pk1j2JldvrdXi+Exo8lsKeNRx4TFz5SUISQvQLhZ3LSO5cRujwCUTHfoq6Y95HdsPTZNc9ITNAlAgPHkt++2Jvzu3JWYUQwiOFHUtI7lhKeNgp1I39NJER55Bd8xi55hdkyHiojmDjKPLm7z05vSQkIUQ/5NzPtP11IiPOpm7Mp4iM+gCZlQ9T2LHU6+A8E2oaDcUCxfZup+90lSQkIUT/ZRXtaYi2LqTuuI8Qn3gphT1ryaz4Q7+8jyk0+AQKrWvBKhz4YBdIQhJCiHya7OpHyW16jugJnyFx5nXkt75MfvcKih1bKO7dAoWs11G6LjR4LIXdKzw7vyQkIYRwWOkW0kvvJbj+aeqO+w/qjv0wwfojCARDFDt3UNi7xU5Q7ZvIb3/ds5aEK4JhQk2jya55zLMQJCEJIUQXxfaNpJfcbf8SDBOsH0aw4WiCDcMJNRxNZOQ51I3+KOk37vPkfh03hBqPBaDQus6zGCQhCSFET4p5ih2b/z3xhBPElC+QUH9Adu1fyL7116pvLYUGj6XQtg6Kec9i8M3yE0IIUTXySdJv/JLUa3cRGXE2CfX7BBtGeB1Vn4QGn0ChZbWnMUhCEkKIQ1TYuYzOF79PoW0DCfX71I35JARCXofVe4EgoUHHe56QpMtOCCH6Ip8is+xX5N9eROzkrxI+4hQy5u8otKz0OrKDFhw4EoIRCnve8jYOT88uhBA1orBrud1aallFfPIM4qd9q2q68UKDT7Bvhi14O4WSJCQhhCiXfJqM+Vs6n78WK7uXxJnXERt/EYH4YV5H1qPQ4LEUWiq/3ERXkpCEEKLMrNRO0kvvIWncSKBuAPXn3ERU+SKBuoFeh7YfAcKDxpD3+PoRSEISQgjXFNs3kVp0B6lX5hFqOo76aXOJvufzBGJDvA7tHcGG4RBJUGhZ43UokpCEEMJthZaVJOfPJr3kboJNx1I/bS6xCV8n2Hic16ERGjzWvscqn/Q6FBllJ4QQlZLfsYT8jiUEG4+l7tgPkTjjaoqt68hueIr8tsV4sSieH+4/2kdaSEIIUWHFtvWkl/yCzue+S2HPWmInn0/9tLlERn0QwvGKxuKXAQ0gLSQhhPCMlW4hs+phMmv/TGTEmdQd836iYz9NbvNLZDc8g5Xc4er5A/XDCEYbfXH9CCQhCSGE9wppchueIbfhWUKHj6du1AeoP2cOhR1vkN3wtGtLQoQHj6Ww922sbLsrr99bvklIpqH/CvgiULroyLmKqi1yysPA7cB52F2NjwCXKaqWqUS5EEK4z6KwYwmpHUsINhxNZNQHiZ96JcXkdrLrnyK/ZUFZJ3H1U3cd+CghOX6mqNqV3ZRdA5wDnATkgMeBm4EZFSoXQoiKKXZsJrPsfrKrHiEy4hyiJ3yW6NhPk13/FLnm5yHf91kVQoPHkln1aBmiLY9qGtRwITBbUbWtiqrtBK4HzjcNPVih8v0KeLz5IYZa3qR+pY693sh2kHvrLyT/OZPs2j9TN/JcBpz7I6JjP0OwbuAh128wNoRgfCjFllUV//fujt9aSJpp6BrwNvBL4A5F1YqmoTcBI4DXS45dDDQBI0xDb3OzHNjYXcCDGhuIR6OH8l7LZnCTH+/+rh1Sv+6TOj5I7Yuxlr5Occh4gsPfT91x/0FwxyuEts8n0Lml2z/4XevXCtZRGP4+CundDI4VIFaZ+o/F6nos91NCuhOYCbQAk4E/AEXgDqDBOaat5PhW52eDc5yb5d3a09ZBMuXNZaYA9getpbXdg7sXap/Ur/ukjg9R6wvw1guEhihEjvswoXFXYWXaKexcRn7nGxR2mZBP/Vv9UjeQ0OHjCR8xgdDQE6GQJbvmMfa2Vm5AQyLe85d33yQkRdUWl/y6wDT0uYCGnZA6nP2NwC7ncZPzs6MC5d2y8OJWNv/FUMukft0ndXxo8rtN8rtNiNQTHnoS4cNPJnqiRiASp7BnjZ2gYjFi73kvwUGjsVK7yW9fQnbRjynsWVPxVW4P9G/s52tI+1otKKrWCjQDE0rKJ2K3YprdLi/XGxJCCFfkOsm//TLppffS+eyVJBfMpdCymvCwUykOnUB+15skX7qezue+S2aFs1aTD5dc900LyTT0/waexG6RTAK+B/y05JB7gWtNQ1+APQrueuB+RdWKFSoXQogqYFFsW0e2bR25NX9iSNNAOqqkS9Q3CQm4HLgbO6YtwM+w7wvaZw4wFFiO3bJ7GHuodqXKhRBCuChgWdWQN/3HNPRGoPW86fM8HdQwpGkgu6vk20+1kfp1n9Sxu/xWv4l4lAfumgHQpKhaW9dyP19DEkII0Y9IQhJCCOELkpCEEEL4giQkIYQQviAJSQghhC/4adh3VYofYG4mNwWw54ZKxKO+GEFTa6R+3Sd17C6/1e+B/l5KQjp0AwHuvu1yr+MQQohqM5B/nzsUkITUF5uBkYA/lloUQojqMBD77+e7yI2xQgghfEEGNQghhPAFSUhCCCF8QRKSEEIIX5CEJIQQwhdklJ2HTEO/DPgKMA5YqKjatD6+3seBHwEjgDeBixVVW1JSfjT2kh4fwr5FYYWiamf05Zx+Vsn6NQ19GvBPoLPkKfcqqnZlX87pd5X+DJccdzP2mmmfVlTtsb6c088q/Bn+GDAXGI69uOtrwAxF1Zb15Zy9IS0kb72N/QG4o68vZBr68cDvgJnAIOBR4C+moced8nrsP5grgFHYaz/V9B9LKli/jjZF1QaUbLVev1D5OsY09PHAJ5xz17pK1u9i4IOKqg0CDgf+AlQ02UsLyUOKqj0KYBr6yK5lpqEfht2a+QB2a+ZPwFWKqnV2PdahAc8pqva48/xbgenAh4E/Al8FWhRVu77kOS+X5Y34VIXrt1+qdB2bhh4C7sFe0PP+sr4ZH6pk/SqqtrXk2AB2K+kY09AjiqrlyvSWeiQtJB8yDT2A/c2kBRgDvBf7Jty5PTxtHPD6vl+cpdeXOPsBzgFWm4b+qGnou01DX2Ia+qfdiN/vXKpfgAGmoW81DX2zaegPmoY+vOzBVwkX6/hKwFRU7Z/ljrmauFW/pqGPNA29FUgDPwbmVCoZgSQkvzoV+wN2laJqnYqqtQLXA1/q4TkNvHsqjlZnP8Bg4IvAb4FhwHeBB53uj/7GjfpdCUzA7ps/Ffsb5p9NQ++v/8fKXsemoR8LfBP4dtmjrT5ufIZRVG2TompNQBNwBSUJrBKky86fRgGNwG7T0PftCwBR09AHYF+U/LKz/0VF1f4T6HCeU6oJWO087gAWKKr2iPP7301D/wfwUWCpG2/Cx0ZR5vpVVG0bsM3Zv8009Iux//OPxU5W/c0oyv8Z/gXwA0XVdrkYd7UYRfnr9x2KqnWYhv5TYJdp6EsUVVtf/rfwbpKQ/GkjsFNRtWHdlH/d2Uq9AUzc94vTpB8P3OfsWgq8v8xxVis36rer/j4nlxt1/EFgnGnotzi/Hwbcbxr6uYqqXVG2yKtDJT7DASCKnfwkIdU609DD2P8GYSBoGnoMKAKvAuudoa1zsSdwHQ5MVFTtz928nA7McIZuPsW/RtA94fz8NTDTNPRPAn8GzgXeB1xb9jfmE5WsX9PQzwU2ONsQYB6wHFhT9jfmIxX+DB/Z5fhFwPep4UElFf4Mfx57qPc67AlQZwNJ7NF3FdFf+7f9YhaQAm4DznIeP+VcbPwE9jfAN7G7fp4GTuzuhRRVW4t9jWiec/xngY8rqpZyytcDnwZuwv7w/hj4kqJqtdxdV7H6xf7m+QKwF1gGRICPKapWKP/b8pVKfoa3lW5AAWhVVO1dyxjUkEp+ho8FnsHu2luN3TL6YCXrV2b7FkII4QvSQhJCCOELkpCEEEL4giQkIYQQviAJSQghhC9IQhJCCOELkpCEEEL4gtwYK0SVMQ19FPad82cpqvaSx+EIUTaSkITowjT0XwFHK6r2Aa9j6UYz9qwFu90+UUny26cDWAvcoajaA718rVnAhYqqjSpbgKKmSJedED5hGnrdwRynqFrBma2gYssCAJ/EToKnYC/sppuG/qEKnl/0A9JCEqKXnPnFZmEvLX0k8BZwp6Jqvyg55grgfOB47OmEngO+paja2075NOwVfD8GXI29nMC3TUPfC9yLvX7VXcB7sOfE+7qiaq85zx1FSZddye+fB87DnkR3G3BdaSvGWb7hF8DZwA7gFuBzwFpF1S48wNtucabrAZhtGvq3gP/AnhNt30Sdd2PPj3gU9kqnvwduUFQtYxr6V4EfOsfumx7mBkXVrj+Y+hT9g7SQhOi9e4HPAJdgr0lzI3CLaegXdDnu28DJ2HMIjsT+A93V7cCtzuvsWy46CNyMvR7NKcAe4CHnD3dP5gIPYC+49hD2TNhj4J2E8Ufs5QfOxp4H7aOUzP58MExDD5mG/gXs9bWyJUUBYDv2XGnvxZ6483zgGqf8D9gJcDN20jkSe4kEOPj6FDVOWkhC9ILTytAARVG1fescrTcN/QTs5aDvA1BU7SclT1tvGvplwGLT0IcrqralpOymfUtKO68P9h/3KxVVW+zs+wGwABgNrOohvP9VVO0h5zmzsJf5fh/2jOMfwF5qYIwzySamoX8ZO0EcjKdMQy8CMSAE7MReShzn/RaxWzn7bDANfTTwDeyWWspp/RVKWloHXZ+if5CEJETv7FsN9tWShdHA/r/0zszeTpfc1YCCvQjavt6IY4DShPTKfs5h8e+LJu47/gh6TkhL9j1QVC1vGvp25zk4cezal4ycY1pMQ+/p9Uqdj700wbHYs0XfoKjautIDTEO/CLgQe5boepwlEw7wugdVn6J/kIQkRO/s+wM7FXutmFIWgGnoI4G/YXef3QjsAo7Gntq/68CFzv2co9hl2Yp911wO9Mc92+V3q8tz+jK1/xYnma11uuwWmob+5r5WjWnonwN+CnwPeB57iZPPYS930pMD1qfoPyQhCdE7rzk/Ryqq9pdujjkNiGN3u6UATEOfVIngemACh5mGfnxJl90g7CXWX+vxmV0oqrbcNPQ/Y6/R83Fn99nA64qqzXvnhPZgi1JZ7O6+UgdTn6KfkIQkxP4NMA19Qpd9aUXVVpqG/kvgHtPQv4N9bacemAQcpqjaLdjXbCzgKtPQH8S+dvODCsa+P89gdwPqzgjALHbrJc+htURuw74mpiqqZmB3JV7grEj8Jvbowc90ec56YJhp6Gdg11FSUbW1B1Gfop+QUXZC7N/pwOtdtn2j4C4G7sBe/t0EnsUesrwOQFG1N7AvyF/ilH+bfy0X7QlF1Szs0X6dwIvAX7CXrl4FpA/h9V7HTnJznV2/wO6ivB+7rk4Hru/ytMeAh4G/Yg+K+I6zv8f6FP2HrBgrRD9lGnoD9ii7WYqq3eV1PEJIl50Q/YRp6J/A7qJbARwOXIfdXfeQl3EJsY8kJCH6jwT2taxR2F13rwFnKqq23cughNhHuuyEEEL4ggxqEEII4QuSkIQQQviCJCQhhBC+IAlJCCGEL0hCEkII4QuSkIQQQvjC/wdV5HIuJYQVnQAAAABJRU5ErkJggg==\n", 407 | "text/plain": [ 408 | "
" 409 | ] 410 | }, 411 | "metadata": {}, 412 | "output_type": "display_data" 413 | } 414 | ], 415 | "source": [ 416 | "learner.recorder.plot()" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 20, 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "data": { 426 | "text/plain": [ 427 | "PosixPath('/home/ubuntu/MultiCampus/MICCAI_BraTS_2019_Data_Training/models/trained_model_fit.pth')" 428 | ] 429 | }, 430 | "execution_count": 20, 431 | "metadata": {}, 432 | "output_type": "execute_result" 433 | } 434 | ], 435 | "source": [ 436 | "learner.save(\"trained_model_fit\", return_path=True)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "learner.load(\"trained_model_fit\", device=2)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 19, 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "data": { 455 | "text/html": [], 456 | "text/plain": [ 457 | "" 458 | ] 459 | }, 460 | "metadata": {}, 461 | "output_type": "display_data" 462 | }, 463 | { 464 | "data": { 465 | "text/plain": [ 466 | "[10493.069, tensor(0.6972)]" 467 | ] 468 | }, 469 | "execution_count": 19, 470 | "metadata": {}, 471 | "output_type": "execute_result" 472 | } 473 | ], 474 | "source": [ 475 | "learner.validate(metrics=[dice_coefficient])" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 20, 481 | "metadata": {}, 482 | "outputs": [], 483 | "source": [ 484 | "learner.metrics = [dice_coefficient]" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 21, 490 | "metadata": {}, 491 | "outputs": [ 492 | { 493 | "data": { 494 | "text/html": [ 495 | "\n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | "
epochtrain_lossvalid_lossdice_coefficienttime
010249.20117210427.6757810.69943831:40
19863.96484410486.2783200.74160331:53
29876.77441410650.1230470.71229531:48
310385.08203110679.8906250.73540431:32
49652.38085910485.9931640.66872231:21
59648.62109410375.0625000.73108631:28
69282.78515610530.2363280.75298231:29
79316.89550810471.4150390.75543831:32
89555.70312510649.2832030.71314431:32
99621.86230510841.5605470.73729331:39
" 578 | ], 579 | "text/plain": [ 580 | "" 581 | ] 582 | }, 583 | "metadata": {}, 584 | "output_type": "display_data" 585 | } 586 | ], 587 | "source": [ 588 | "learner.fit(epochs=10, lr=1e-04)" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": 22, 594 | "metadata": {}, 595 | "outputs": [ 596 | { 597 | "data": { 598 | "text/plain": [ 599 | "PosixPath('/home/ubuntu/MultiCampus/MICCAI_BraTS_2019_Data_Training/models/trained_model_fit_2.pth')" 600 | ] 601 | }, 602 | "execution_count": 22, 603 | "metadata": {}, 604 | "output_type": "execute_result" 605 | } 606 | ], 607 | "source": [ 608 | "learner.save(\"trained_model_fit_2\", return_path=True)" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": 23, 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "learner = learner.load(\"trained_model_fit_2\", device=2)" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 24, 623 | "metadata": {}, 624 | "outputs": [ 625 | { 626 | "data": { 627 | "text/html": [ 628 | "\n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | "
epochtrain_lossvalid_lossdice_coefficienttime
08968.47363310620.6484380.66888431:43
18851.03906210587.3740230.72292831:56
29223.56640610720.0888670.74954931:20
39645.30175810704.7441410.75402331:26
48731.02734410556.8691410.74759531:17
" 676 | ], 677 | "text/plain": [ 678 | "" 679 | ] 680 | }, 681 | "metadata": {}, 682 | "output_type": "display_data" 683 | } 684 | ], 685 | "source": [ 686 | "learner.fit(epochs=5, lr=1e-04)" 687 | ] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "execution_count": 25, 692 | "metadata": {}, 693 | "outputs": [ 694 | { 695 | "data": { 696 | "text/html": [ 697 | "\n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | "
epochtrain_lossvalid_lossdice_coefficienttime
08663.43750010722.7539060.74898631:28
18925.67089810696.8740230.73125631:27
29030.23437510776.8261720.75866631:04
38400.86425810698.2070310.72944831:09
48817.80273410819.5732420.76219731:09
" 745 | ], 746 | "text/plain": [ 747 | "" 748 | ] 749 | }, 750 | "metadata": {}, 751 | "output_type": "display_data" 752 | } 753 | ], 754 | "source": [ 755 | "learner.fit(epochs=5, lr=1e-04)" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": 26, 761 | "metadata": {}, 762 | "outputs": [ 763 | { 764 | "data": { 765 | "text/plain": [ 766 | "PosixPath('/home/ubuntu/MultiCampus/MICCAI_BraTS_2019_Data_Training/models/trained_model_fit_3.pth')" 767 | ] 768 | }, 769 | "execution_count": 26, 770 | "metadata": {}, 771 | "output_type": "execute_result" 772 | } 773 | ], 774 | "source": [ 775 | "learner.save(\"trained_model_fit_3\", return_path=True)" 776 | ] 777 | }, 778 | { 779 | "cell_type": "code", 780 | "execution_count": 27, 781 | "metadata": {}, 782 | "outputs": [], 783 | "source": [ 784 | "learner = learner.load(\"trained_model_fit_3\", device=2)" 785 | ] 786 | }, 787 | { 788 | "cell_type": "code", 789 | "execution_count": null, 790 | "metadata": {}, 791 | "outputs": [ 792 | { 793 | "data": { 794 | "text/html": [ 795 | "\n", 796 | "
\n", 797 | " \n", 809 | " \n", 810 | " 20.00% [1/5 31:07<2:04:30]\n", 811 | "
\n", 812 | " \n", 813 | "\n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | "
epochtrain_lossvalid_lossdice_coefficienttime
08293.03613310747.5761720.76589331:07

\n", 833 | "\n", 834 | "

\n", 835 | " \n", 847 | " \n", 848 | " 26.74% [77/288 07:52<21:33 8336.4150]\n", 849 | "
\n", 850 | " " 851 | ], 852 | "text/plain": [ 853 | "" 854 | ] 855 | }, 856 | "metadata": {}, 857 | "output_type": "display_data" 858 | } 859 | ], 860 | "source": [ 861 | "learner.fit(epochs=5, lr=1e-04)" 862 | ] 863 | } 864 | ], 865 | "metadata": { 866 | "kernelspec": { 867 | "display_name": "Python 3", 868 | "language": "python", 869 | "name": "python3" 870 | }, 871 | "language_info": { 872 | "codemirror_mode": { 873 | "name": "ipython", 874 | "version": 3 875 | }, 876 | "file_extension": ".py", 877 | "mimetype": "text/x-python", 878 | "name": "python", 879 | "nbconvert_exporter": "python", 880 | "pygments_lexer": "ipython3", 881 | "version": "3.6.5" 882 | } 883 | }, 884 | "nbformat": 4, 885 | "nbformat_minor": 4 886 | } 887 | -------------------------------------------------------------------------------- /ndv/utils/notebook2script.py: -------------------------------------------------------------------------------- 1 | 2 | import json, fire, re 3 | from pathlib import Path 4 | 5 | def is_export(cell): 6 | if cell['cell_type'] != 'code': return False 7 | src = cell['source'] 8 | if len(src) == 0 or len(src[0]) < 7: return False 9 | return re.match(r'^\s*#\s*export\s*$', src[0], re.IGNORECASE) is not None 10 | 11 | def notebook2script(fname): 12 | fname = Path(fname) 13 | fname_out = f'nb_{fname.stem.split("_")[0]}.py' 14 | main_dic = json.load(open(fname, 'r')) 15 | code_cells = [c for c in main_dic['cells'] if is_export(c)] 16 | module = f''' 17 | ################################################# 18 | ### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ### 19 | ################################################# 20 | # file to edit: dev_nb/{fname.name} 21 | ''' 22 | for cell in code_cells: module += ''.join(cell['source'][1:]) + '\n\n' 23 | # remove trailing spaces 24 | module = re.sub(r' +$', '', module, flags=re.MULTILINE) 25 | open(fname.parent/'exp'/fname_out,'w').write(module[:-2]) 26 | print(f"Converted {fname} to {fname_out}") 27 | 28 | if __name__ == '__main__': fire.Fire(notebook2script) 29 | 30 | 31 | -------------------------------------------------------------------------------- /ndv/utils/run_notebook.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import nbformat, fire 4 | from nbconvert.preprocessors import ExecutePreprocessor 5 | 6 | def run_notebook(path): 7 | nb = nbformat.read(open(path), as_version=nbformat.NO_CONVERT) 8 | ExecutePreprocessor(timeout=600).preprocessor(nb, {}) 9 | 10 | if __name__ == '__main__': fire.Fire(run_notebook) -------------------------------------------------------------------------------- /ndv/valid.txt: -------------------------------------------------------------------------------- 1 | BraTS19_TCIA01_221_1 2 | BraTS19_CBICA_AWI_1 3 | BraTS19_TCIA02_370_1 4 | BraTS19_TCIA03_375_1 5 | BraTS19_CBICA_ATV_1 6 | BraTS19_CBICA_ATF_1 7 | BraTS19_CBICA_AME_1 8 | BraTS19_CBICA_BHZ_1 9 | BraTS19_TCIA01_186_1 10 | BraTS19_CBICA_ABN_1 11 | BraTS19_CBICA_AVV_1 12 | BraTS19_TCIA08_406_1 13 | BraTS19_TCIA04_343_1 14 | BraTS19_TCIA02_274_1 15 | BraTS19_CBICA_BHK_1 16 | BraTS19_CBICA_ANV_1 17 | BraTS19_TMC_21360_1 18 | BraTS19_TMC_27374_1 19 | BraTS19_TCIA04_437_1 20 | BraTS19_TCIA02_471_1 21 | BraTS19_TMC_30014_1 22 | BraTS19_TCIA06_165_1 23 | BraTS19_TCIA04_479_1 24 | BraTS19_TCIA02_118_1 25 | BraTS19_CBICA_BAX_1 26 | BraTS19_CBICA_AVB_1 27 | BraTS19_CBICA_AXJ_1 28 | BraTS19_TCIA02_208_1 29 | BraTS19_CBICA_ATX_1 30 | BraTS19_CBICA_ATN_1 31 | BraTS19_TCIA02_606_1 32 | BraTS19_2013_26_1 33 | BraTS19_CBICA_ASG_1 34 | BraTS19_CBICA_AWG_1 35 | BraTS19_TCIA08_218_1 36 | BraTS19_CBICA_AUR_1 37 | BraTS19_CBICA_AVG_1 38 | BraTS19_CBICA_AXO_1 39 | BraTS19_TCIA02_331_1 40 | BraTS19_TCIA09_462_1 41 | BraTS19_TCIA10_299_1 42 | BraTS19_TCIA13_653_1 43 | BraTS19_TCIA10_130_1 44 | BraTS19_2013_15_1 45 | BraTS19_TCIA10_442_1 46 | BraTS19_TCIA10_387_1 47 | BraTS19_2013_9_1 --------------------------------------------------------------------------------