├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── baseline.sh ├── configs ├── abide_schaefer100 │ └── TUs_graph_classification_ContrastPool_abide_schaefer100_100k.json ├── adni_schaefer100 │ └── TUs_graph_classification_ContrastPool_adni_schaefer100_100k.json ├── neurocon_schaefer100 │ └── TUs_graph_classification_ContrastPool_neurocon_schaefer100_100k.json ├── ppmi_schaefer100 │ └── TUs_graph_classification_ContrastPool_ppmi_schaefer100_100k.json └── taowu_schaefer100 │ └── TUs_graph_classification_ContrastPool_taowu_schaefer100_100k.json ├── contrast_subgraph.py ├── data ├── BrainNet.py ├── abide_schaefer100 │ ├── test.index │ ├── train.index │ └── val.index ├── adni_schaefer100 │ ├── test.index │ ├── train.index │ └── val.index ├── data.py ├── generate_data_from_mat.py ├── neurocon_schaefer100 │ ├── test.index │ ├── train.index │ └── val.index ├── ppmi_schaefer100 │ ├── test.index │ ├── train.index │ └── val.index └── taowu_schaefer100 │ ├── test.index │ ├── train.index │ └── val.index ├── figs └── framework.png ├── layers ├── attention_layer.py ├── contrastpool_layer.py ├── diffpool_layer.py └── graphsage_layer.py ├── main.py ├── metrics.py ├── nets ├── contrastpool_net.py └── load_net.py └── train_TUs_graph_classification.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | # mypy 113 | .DS_Store 114 | .idea 115 | *.bak 116 | *.pkl 117 | save 118 | log 119 | log.test 120 | log.txt 121 | outputs 122 | out 123 | tmp 124 | tmp1.sh 125 | tmp2.sh 126 | tmp3.sh 127 | tmp4.sh 128 | result/ 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ContrastPool 2 | This is the official PyTorch implementation of ContrastPool from the paper 3 | *"Contrastive Graph Pooling for Explainable Classification of Brain Networks"* published in IEEE Transactions on Medical Imaging (TMI) 2024. 4 | 5 | Link: [Arxiv](https://arxiv.org/abs/2307.11133). 6 | 7 | Model 8 | 9 | 10 | ## Data 11 | All Preprocessed data used in this paper are published in [this paper](https://proceedings.neurips.cc/paper_files/paper/2023/file/44e3a3115ca26e5127851acd0cedd0d9-Paper-Datasets_and_Benchmarks.pdf). 12 | Data splits and configurations are stored in `./data/` and `./configs/`. If you want to process your own data, please check the dataloader script `./data/BrainNet.py`. 13 | 14 | ## Usage 15 | 16 | Please check `baseline.sh` on how to run the project. 17 | 18 | ## Citation 19 | 20 | If you find this code useful, please consider citing our paper: 21 | 22 | ``` 23 | @ARTICLE{10508252, 24 | author={Xu, Jiaxing and Bian, Qingtian and Li, Xinhang and Zhang, Aihu and Ke, Yiping and Qiao, Miao and Zhang, Wei and Sim, Wei Khang Jeremy and Gulyás, Balázs}, 25 | journal={IEEE Transactions on Medical Imaging}, 26 | title={Contrastive Graph Pooling for Explainable Classification of Brain Networks}, 27 | year={2024}, 28 | volume={}, 29 | number={}, 30 | pages={1-1}, 31 | keywords={Functional magnetic resonance imaging;Feature extraction;Task analysis;Data mining;Alzheimer's disease;Message passing;Brain modeling;Brain Network;Deep Learning for Neuroimaging;fMRI Biomarker;Graph Classification;Graph Neural Network}, 32 | doi={10.1109/TMI.2024.3392988}} 33 | ``` 34 | 35 | ## Contact 36 | 37 | If you have any questions, please feel free to reach out at `jiaxing003@e.ntu.edu.sg`. 38 | -------------------------------------------------------------------------------- /baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | model="configs/abide_schaefer100/TUs_graph_classification_ContrastPool_abide_schaefer100_100k.json" 4 | echo ${model} 5 | python main.py --config $model --gpu_id 0 --node_feat_transform pearson --max_time 60 --init_lr 1e-2 --threshold 0.0 --batch_size 20 --dropout 0.0 --contrast --pool_ratio 0.5 --lambda1 1e-3 --L 2 6 | -------------------------------------------------------------------------------- /configs/abide_schaefer100/TUs_graph_classification_ContrastPool_abide_schaefer100_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "ContrastPool", 8 | "dataset": "abide_schaefer100", 9 | 10 | "out_dir": "out/braindata_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-2, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 2, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/adni_schaefer100/TUs_graph_classification_ContrastPool_adni_schaefer100_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "ContrastPool", 8 | "dataset": "adni_schaefer100", 9 | 10 | "out_dir": "out/braindata_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-2, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 2, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/neurocon_schaefer100/TUs_graph_classification_ContrastPool_neurocon_schaefer100_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "ContrastPool", 8 | "dataset": "neurocon_schaefer100", 9 | 10 | "out_dir": "out/braindata_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 4, 16 | "init_lr": 1e-2, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 2, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/ppmi_schaefer100/TUs_graph_classification_ContrastPool_ppmi_schaefer100_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "ContrastPool", 8 | "dataset": "ppmi_schaefer100", 9 | 10 | "out_dir": "out/braindata_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 4, 16 | "init_lr": 1e-2, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 2, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/taowu_schaefer100/TUs_graph_classification_ContrastPool_taowu_schaefer100_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "ContrastPool", 8 | "dataset": "taowu_schaefer100", 9 | 10 | "out_dir": "out/braindata_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 4, 16 | "init_lr": 1e-2, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 2, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /contrast_subgraph.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import math 3 | import numpy as np 4 | import torch 5 | import dgl 6 | from dgl.data.utils import load_graphs 7 | from copy import deepcopy 8 | from tqdm import tqdm 9 | 10 | 11 | def get_summary_tensor(G_dataset, Labels, device, merge_classes=False): 12 | num_G = len(G_dataset) 13 | Labels = Labels.tolist() 14 | node_num = G_dataset[0].ndata['feat'].shape[0] 15 | adj_dict = {} 16 | nodes_dict = {} 17 | final_adj_dict = {} 18 | final_nodes_dict = {} 19 | for i in range(num_G): 20 | if Labels[i] not in adj_dict.keys(): 21 | adj_dict[Labels[i]] = [] 22 | nodes_dict[Labels[i]] = [] 23 | adj_dict[Labels[i]].append(G_dataset[i].edata['feat'].squeeze().view(node_num, -1).tolist()) 24 | nodes_dict[Labels[i]].append(G_dataset[i].ndata['feat'].tolist()) 25 | 26 | for i in adj_dict.keys(): 27 | final_adj_dict[i] = torch.tensor(adj_dict[i]).to(device) 28 | final_nodes_dict[i] = torch.tensor(nodes_dict[i]).to(device) 29 | return final_adj_dict, final_nodes_dict 30 | -------------------------------------------------------------------------------- /data/BrainNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import time 4 | import os 5 | import numpy as np 6 | import csv 7 | import dgl 8 | from dgl.data.utils import load_graphs 9 | import networkx as nx 10 | from tqdm import tqdm 11 | import random 12 | random.seed(42) 13 | from sklearn.model_selection import StratifiedKFold, train_test_split 14 | 15 | 16 | class DGLFormDataset(torch.utils.data.Dataset): 17 | """ 18 | DGLFormDataset wrapping graph list and label list as per pytorch Dataset. 19 | *lists (list): lists of 'graphs' and 'labels' with same len(). 20 | """ 21 | def __init__(self, *lists): 22 | assert all(len(lists[0]) == len(li) for li in lists) 23 | self.lists = lists 24 | self.graph_lists = lists[0] 25 | self.graph_labels = lists[1] 26 | 27 | def __getitem__(self, index): 28 | return tuple(li[index] for li in self.lists) 29 | 30 | def __len__(self): 31 | return len(self.lists[0]) 32 | 33 | 34 | def self_loop(g): 35 | """ 36 | Utility function only, to be used only when necessary as per user self_loop flag 37 | : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat'] 38 | 39 | 40 | This function is called inside a function in TUsDataset class. 41 | """ 42 | new_g = dgl.DGLGraph() 43 | new_g.add_nodes(g.number_of_nodes()) 44 | new_g.ndata['feat'] = g.ndata['feat'] 45 | 46 | src, dst = g.all_edges(order="eid") 47 | src = dgl.backend.zerocopy_to_numpy(src) 48 | dst = dgl.backend.zerocopy_to_numpy(dst) 49 | non_self_edges_idx = src != dst 50 | nodes = np.arange(g.number_of_nodes()) 51 | new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx]) 52 | new_g.add_edges(nodes, nodes) 53 | 54 | # This new edata is not used since this function gets called only for GCN, GAT 55 | # However, we need this for the generic requirement of ndata and edata 56 | new_g.edata['feat'] = torch.zeros(new_g.number_of_edges()) 57 | return new_g 58 | 59 | name2path = { 60 | 'abide_AAL116': '/path/to/data/abide_AAL116.bin', 61 | 'abide_harvard48': '/path/to/data/abide_harvard48.bin', 62 | 'abide_kmeans100': '/path/to/data/abide_kmeans100.bin', 63 | 'abide_schaefer100': '/path/to/data/abide_schaefer100.bin', 64 | 'abide_ward100': '/path/to/data/abide_ward100.bin', 65 | 66 | 'adni_AAL116': '/path/to/data/adni_AAL116.bin', 67 | 'adni_harvard48': '/path/to/data/adni_harvard48.bin', 68 | 'adni_kmeans100': '/path/to/data/adni_kmeans100.bin', 69 | 'adni_schaefer100': '/path/to/data/adni_schaefer100.bin', 70 | 'adni_ward100': '/path/to/data/adni_ward100.bin', 71 | 72 | 'neurocon_AAL116': '/path/to/data/neurocon_AAL116.bin', 73 | 'neurocon_harvard48': '/path/to/data/neurocon_harvard48.bin', 74 | 'neurocon_kmeans100': '/path/to/data/neurocon_kmeans100.bin', 75 | 'neurocon_schaefer100': '/path/to/data/neurocon_schaefer100.bin', 76 | 'neurocon_ward100': '/path/to/data/neurocon_ward100.bin', 77 | 78 | 'ppmi_AAL116': '/path/to/data/ppmi_AAL116.bin', 79 | 'ppmi_harvard48': '/path/to/data/ppmi_harvard48.bin', 80 | 'ppmi_kmeans100': '/path/to/data/ppmi_kmeans100.bin', 81 | 'ppmi_schaefer100': '/path/to/data/ppmi_schaefer100.bin', 82 | 'ppmi_ward100': '/path/to/data/ppmi_ward100.bin', 83 | 84 | 'taowu_AAL116': '/path/to/data/taowu_AAL116.bin', 85 | 'taowu_harvard48': '/path/to/data/taowu_harvard48.bin', 86 | 'taowu_kmeans100': '/path/to/data/taowu_kmeans100.bin', 87 | 'taowu_schaefer100': '/path/to/data/taowu_schaefer100.bin', 88 | 'taowu_ward100': '/path/to/data/taowu_ward100.bin', 89 | } 90 | 91 | 92 | class BrainDataset(torch.utils.data.Dataset): 93 | def __init__(self, name, threshold=0.3, edge_ratio=0, node_feat_transform='original'): 94 | t0 = time.time() 95 | self.name = name 96 | 97 | G_dataset, Labels = load_graphs(name2path[self.name]) 98 | 99 | self.node_num = G_dataset[0].ndata['N_features'].size(0) 100 | 101 | print("[!] Dataset: ", self.name) 102 | 103 | # transfer DGLHeteroGraph to DGLFormDataset 104 | data = [] 105 | error_case = [] 106 | for i in range(len(G_dataset)): 107 | if len(((G_dataset[i].ndata['N_features'] != 0).sum(dim=-1) == 0).nonzero()) > 0: 108 | error_case.append(i) 109 | print(error_case) 110 | G_dataset = [n for i, n in enumerate(G_dataset) if i not in error_case] 111 | 112 | for i in tqdm(range(len(G_dataset))): 113 | if edge_ratio: 114 | threshold_idx = int(len(G_dataset[i].edata['E_features']) * (1 - edge_ratio)) 115 | threshold = sorted(G_dataset[i].edata['E_features'].tolist())[threshold_idx] 116 | 117 | G_dataset[i].remove_edges(torch.squeeze((torch.abs(G_dataset[i].edata['E_features']) < float(threshold)).nonzero())) 118 | G_dataset[i].edata['feat'] = G_dataset[i].edata['E_features'].unsqueeze(-1).clone() 119 | 120 | if name[:-7] == 'pearson' or node_feat_transform == 'original': 121 | G_dataset[i].ndata['feat'] = G_dataset[i].ndata['N_features'].clone() 122 | elif node_feat_transform == 'one_hot': 123 | G_dataset[i].ndata['feat'] = torch.eye(self.node_num).clone() 124 | elif node_feat_transform == 'pearson': 125 | G_dataset[i].ndata['feat'] = torch.from_numpy(np.corrcoef(G_dataset[i].ndata['N_features'].numpy())).clone() 126 | elif node_feat_transform == 'degree': 127 | G_dataset[i].ndata['feat'] = G_dataset[i].in_degrees().unsqueeze(dim=1).clone() 128 | elif node_feat_transform == 'adj_matrix': 129 | G_dataset[i].ndata['feat'] = G_dataset[i].adj().to_dense().clone() 130 | elif node_feat_transform == 'mean_std': 131 | G_dataset[i].ndata['feat'] = torch.stack(torch.std_mean(G_dataset[i].ndata['N_features'], dim=-1)).T.flip(dims=[1]).clone() 132 | else: 133 | raise NotImplementedError 134 | 135 | G_dataset[i].ndata.pop('N_features') 136 | G_dataset[i].edata.pop('E_features') 137 | data.append([G_dataset[i], Labels['glabel'].tolist()[i]]) 138 | 139 | dataset = self.format_dataset(data) 140 | # this function splits data into train/val/test and returns the indices 141 | self.all_idx = self.get_all_split_idx(dataset) 142 | 143 | self.all = dataset 144 | self.train = [self.format_dataset([dataset[idx] for idx in self.all_idx['train'][split_num]]) for split_num in range(10)] 145 | self.val = [self.format_dataset([dataset[idx] for idx in self.all_idx['val'][split_num]]) for split_num in range(10)] 146 | self.test = [self.format_dataset([dataset[idx] for idx in self.all_idx['test'][split_num]]) for split_num in range(10)] 147 | 148 | print("Time taken: {:.4f}s".format(time.time()-t0)) 149 | 150 | def get_all_split_idx(self, dataset): 151 | """ 152 | - Split total number of graphs into 3 (train, val and test) in 80:10:10 153 | - Stratified split proportionate to original distribution of data with respect to classes 154 | - Using sklearn to perform the split and then save the indexes 155 | - Preparing 10 such combinations of indexes split to be used in Graph NNs 156 | - As with KFold, each of the 10 fold have unique test set. 157 | """ 158 | root_idx_dir = './data/{}/'.format(self.name) 159 | if not os.path.exists(root_idx_dir): 160 | os.makedirs(root_idx_dir) 161 | all_idx = {} 162 | 163 | # If there are no idx files, do the split and store the files 164 | if not (os.path.exists(root_idx_dir + 'train.index')): 165 | print("[!] Splitting the data into train/val/test ...") 166 | 167 | # Using 10-fold cross val to compare with benchmark papers 168 | k_splits = 10 169 | 170 | cross_val_fold = StratifiedKFold(n_splits=k_splits, shuffle=True) 171 | k_data_splits = [] 172 | 173 | # this is a temporary index assignment, to be used below for val splitting 174 | for i in range(len(dataset.graph_lists)): 175 | dataset[i][0].a = lambda: None 176 | setattr(dataset[i][0].a, 'index', i) 177 | 178 | for indexes in cross_val_fold.split(dataset.graph_lists, dataset.graph_labels): 179 | remain_index, test_index = indexes[0], indexes[1] 180 | 181 | remain_set = self.format_dataset([dataset[index] for index in remain_index]) 182 | 183 | # Gets final 'train' and 'val' 184 | train, val, _, __ = train_test_split(remain_set, 185 | range(len(remain_set.graph_lists)), 186 | test_size=0.111, 187 | stratify=remain_set.graph_labels) 188 | 189 | train, val = self.format_dataset(train), self.format_dataset(val) 190 | test = self.format_dataset([dataset[index] for index in test_index]) 191 | 192 | # Extracting only idx 193 | idx_train = [item[0].a.index for item in train] 194 | idx_val = [item[0].a.index for item in val] 195 | idx_test = [item[0].a.index for item in test] 196 | 197 | f_train_w = csv.writer(open(root_idx_dir + 'train.index', 'a+')) 198 | f_val_w = csv.writer(open(root_idx_dir + 'val.index', 'a+')) 199 | f_test_w = csv.writer(open(root_idx_dir + 'test.index', 'a+')) 200 | 201 | f_train_w.writerow(idx_train) 202 | f_val_w.writerow(idx_val) 203 | f_test_w.writerow(idx_test) 204 | 205 | print("[!] Splitting done!") 206 | 207 | # reading idx from the files 208 | for section in ['train', 'val', 'test']: 209 | with open(root_idx_dir + section + '.index', 'r') as f: 210 | reader = csv.reader(f) 211 | all_idx[section] = [list(map(int, idx)) for idx in reader] 212 | return all_idx 213 | 214 | def format_dataset(self, dataset): 215 | """ 216 | Utility function to recover data, 217 | INTO-> dgl/pytorch compatible format 218 | """ 219 | graphs = [data[0] for data in dataset] 220 | labels = [data[1] for data in dataset] 221 | 222 | for graph in graphs: 223 | graph.ndata['feat'] = graph.ndata['feat'].float() # dgl 4.0 224 | # adding edge features for Residual Gated ConvNet, if not there 225 | if 'feat' not in graph.edata.keys(): 226 | edge_feat_dim = graph.ndata['feat'].shape[1] # dim same as node feature dim 227 | graph.edata['feat'] = torch.ones(graph.number_of_edges(), edge_feat_dim) 228 | 229 | return DGLFormDataset(graphs, labels) 230 | 231 | # form a mini batch from a given list of samples = [(graph, label) pairs] 232 | def collate(self, samples): 233 | # The input samples is a list of pairs (graph, label). 234 | graphs, labels = map(list, zip(*samples)) 235 | labels = torch.tensor(np.array(labels)) 236 | batched_graph = dgl.batch(graphs) 237 | 238 | return batched_graph, labels 239 | 240 | # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN 241 | def collate_dense_gnn(self, samples): 242 | # The input samples is a list of pairs (graph, label). 243 | graphs, labels = map(list, zip(*samples)) 244 | labels = torch.tensor(np.array(labels)) 245 | 246 | g = graphs[0] 247 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense()) 248 | """ 249 | Adapted from https://github.com/leichen2018/Ring-GNN/ 250 | Assigning node and edge feats:: 251 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}. 252 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix. 253 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i. 254 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j). 255 | """ 256 | 257 | zero_adj = torch.zeros_like(adj) 258 | 259 | in_dim = g.ndata['feat'].shape[1] 260 | 261 | # use node feats to prepare adj 262 | adj_node_feat = torch.stack([zero_adj for j in range(in_dim)]) 263 | adj_node_feat = torch.cat([adj.unsqueeze(0), adj_node_feat], dim=0) 264 | 265 | for node, node_feat in enumerate(g.ndata['feat']): 266 | adj_node_feat[1:, node, node] = node_feat 267 | 268 | x_node_feat = adj_node_feat.unsqueeze(0) 269 | 270 | return x_node_feat, labels 271 | 272 | def _sym_normalize_adj(self, adj): 273 | deg = torch.sum(adj, dim=0) 274 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size())) 275 | deg_inv = torch.diag(deg_inv) 276 | return torch.mm(deg_inv, torch.mm(adj, deg_inv)) 277 | 278 | def _add_self_loops(self): 279 | 280 | # function for adding self loops 281 | # this function will be called only if self_loop flag is True 282 | for split_num in range(10): 283 | self.train[split_num].graph_lists = [self_loop(g) for g in self.train[split_num].graph_lists] 284 | self.val[split_num].graph_lists = [self_loop(g) for g in self.val[split_num].graph_lists] 285 | self.test[split_num].graph_lists = [self_loop(g) for g in self.test[split_num].graph_lists] 286 | 287 | for split_num in range(10): 288 | self.train[split_num] = DGLFormDataset(self.train[split_num].graph_lists, self.train[split_num].graph_labels) 289 | self.val[split_num] = DGLFormDataset(self.val[split_num].graph_lists, self.val[split_num].graph_labels) 290 | self.test[split_num] = DGLFormDataset(self.test[split_num].graph_lists, self.test[split_num].graph_labels) 291 | -------------------------------------------------------------------------------- /data/abide_schaefer100/test.index: -------------------------------------------------------------------------------- 1 | 2,14,18,45,61,99,119,138,144,153,161,165,177,179,190,192,196,201,203,209,258,265,272,305,308,309,313,326,348,364,373,374,387,388,395,399,403,438,439,444,458,477,482,491,492,498,502,503,504,508,509,512,519,526,547,550,571,579,617,620,625,643,650,671,675,686,688,693,694,717,724,726,747,757,759,760,778,791,800,809,816,820,837,853,870,872,878,882,889,891,899,900,902,908,941,960,967,982,983 2 | 0,1,3,7,9,19,26,39,40,52,70,75,83,87,93,97,129,141,162,173,174,181,202,207,213,247,261,264,269,270,279,325,333,338,350,358,362,366,370,394,396,401,410,418,434,440,442,456,461,468,473,485,500,529,543,548,551,554,555,575,592,595,597,626,631,662,664,668,670,673,714,720,729,734,739,754,761,765,774,799,819,824,844,848,871,876,883,886,890,895,897,918,931,934,943,952,953,965,988 3 | 15,22,25,37,38,59,74,82,107,110,126,131,134,163,166,175,188,204,210,223,232,234,251,276,291,299,327,330,339,345,347,351,357,360,379,386,409,411,413,422,425,436,449,454,455,457,462,464,486,496,505,513,517,532,535,540,556,557,558,567,593,600,604,624,636,637,644,666,687,692,699,701,703,756,764,768,797,798,831,834,838,841,857,862,875,887,911,925,937,944,946,947,948,961,969,971,978,985,987 4 | 32,33,34,50,53,55,56,69,72,88,90,92,111,122,135,142,148,152,158,168,193,197,214,218,224,226,229,235,249,252,260,267,278,281,287,302,303,343,359,369,371,372,389,402,405,435,466,478,493,506,507,525,527,533,536,538,545,549,562,563,572,587,612,629,635,638,642,648,651,681,690,696,697,708,723,727,749,753,762,770,794,803,808,822,827,843,849,856,858,863,865,866,868,888,905,909,910,958,977 5 | 13,16,20,23,24,42,58,62,81,84,85,94,95,98,101,123,124,151,164,169,170,176,184,194,206,230,231,236,241,248,253,257,271,284,292,301,307,319,320,323,344,352,354,381,392,429,447,467,472,487,521,524,530,542,552,553,573,577,601,607,618,630,639,645,654,678,682,683,698,700,705,707,738,741,744,746,779,785,788,789,793,802,805,807,818,828,829,873,885,893,894,898,903,912,935,945,955,956,962 6 | 4,11,63,77,79,96,108,112,113,120,121,128,133,136,139,171,185,189,208,212,215,216,217,227,233,242,245,246,259,268,280,304,310,321,322,328,346,353,367,377,378,385,390,419,431,459,463,474,481,499,514,515,516,534,541,559,560,581,602,608,616,627,647,655,674,677,706,710,711,712,716,719,725,735,743,763,777,780,795,814,825,833,846,850,852,860,861,869,901,914,916,923,927,933,940,949,954,975,979 7 | 8,12,30,35,41,60,71,73,80,102,115,125,127,140,147,154,155,167,178,183,222,238,240,250,255,263,266,274,283,289,311,312,314,324,355,376,382,398,417,420,423,426,428,441,443,460,469,479,488,497,501,510,518,537,544,566,574,584,585,588,598,628,634,658,660,665,672,684,704,713,718,721,728,740,745,758,766,769,775,783,796,801,804,817,830,867,880,904,913,920,926,929,930,942,950,964,970,981,984 8 | 6,17,29,36,43,47,64,65,66,78,86,100,104,109,117,146,149,150,156,172,180,186,187,191,198,219,228,244,254,275,296,297,300,317,331,334,363,368,375,384,391,400,408,412,427,445,446,471,476,489,490,522,523,561,569,589,591,596,603,613,614,619,633,640,641,649,676,685,689,695,715,731,736,771,787,810,811,812,821,826,832,835,836,854,859,864,877,881,884,892,924,932,938,951,957,963,973,980,986 9 | 5,10,21,28,46,48,49,67,68,89,103,105,106,118,145,159,199,200,205,211,221,237,239,243,256,262,273,306,329,332,336,337,340,342,349,356,361,365,380,393,404,406,424,430,437,448,451,453,483,484,494,495,511,539,546,568,578,580,582,583,586,594,605,606,609,611,615,621,646,652,653,656,657,661,663,669,733,742,750,752,755,776,784,786,813,823,842,879,896,906,907,919,922,928,939,966,972,974,976 10 | 27,31,44,51,54,57,76,91,114,116,130,132,137,143,157,160,182,195,220,225,277,282,285,286,288,290,293,294,295,298,315,316,318,335,341,383,397,407,414,415,416,421,432,433,450,452,465,470,475,480,520,528,531,564,565,570,576,590,599,610,622,623,632,659,667,679,680,691,702,709,722,730,732,737,748,751,767,772,773,781,782,790,792,806,815,839,840,845,847,851,855,874,915,917,921,936,959,968 11 | -------------------------------------------------------------------------------- /data/abide_schaefer100/train.index: -------------------------------------------------------------------------------- 1 | 148,561,678,919,731,517,752,486,523,495,116,985,219,574,879,548,52,708,88,522,841,306,836,465,777,567,430,790,783,743,281,866,185,813,467,277,235,507,173,663,737,611,565,679,154,80,647,380,255,335,788,11,584,947,917,125,664,47,969,323,706,473,600,802,425,469,171,283,418,775,232,110,168,354,661,933,525,166,819,432,954,338,103,582,269,909,839,407,782,368,151,181,57,856,132,366,566,106,552,768,978,43,136,665,713,963,248,776,599,98,280,55,211,806,127,46,261,904,41,601,799,898,670,692,369,150,964,770,867,725,372,74,21,424,356,363,972,468,538,630,224,384,117,573,158,798,5,834,377,871,957,645,466,383,873,94,346,918,876,923,296,69,894,10,711,564,524,628,394,827,421,307,883,682,475,785,370,20,580,90,691,101,734,93,988,155,615,687,471,563,131,358,604,34,359,897,906,583,422,944,545,593,291,780,888,199,118,497,241,869,435,543,848,242,193,732,903,28,884,191,385,936,750,632,735,102,353,160,463,513,318,411,433,189,389,887,922,929,16,516,496,797,920,122,221,470,147,641,44,596,164,123,961,629,959,861,378,794,570,393,927,449,240,357,631,431,771,113,608,478,748,489,500,818,840,324,0,984,642,454,981,169,204,253,91,598,64,334,312,850,767,223,915,766,609,744,718,17,415,48,701,428,260,278,965,499,410,656,56,303,427,738,180,352,342,817,696,980,284,720,597,578,658,557,655,300,924,42,58,107,874,808,448,459,627,446,855,673,928,533,245,975,697,739,271,4,480,129,451,205,76,187,603,932,623,754,912,644,940,27,392,7,71,695,594,262,246,453,684,152,569,914,842,862,401,481,97,558,81,950,749,141,703,263,488,634,96,753,336,361,413,329,54,830,506,274,685,330,847,236,108,973,660,607,243,188,367,587,208,270,474,715,73,846,676,527,938,450,654,536,610,773,976,476,514,779,447,222,111,213,31,589,729,885,6,544,121,217,824,945,426,805,109,618,341,457,487,325,301,811,49,619,321,437,390,602,592,669,460,266,895,115,68,930,381,556,539,396,350,231,37,194,362,528,958,87,254,79,646,810,843,733,721,741,892,112,762,907,935,949,814,971,934,649,137,78,716,977,761,1,130,807,614,709,518,613,845,651,595,322,417,142,247,745,758,916,546,145,279,755,751,638,835,62,591,534,124,250,815,327,890,256,812,275,968,531,293,252,455,126,328,286,787,677,436,803,423,233,143,186,854,314,635,535,838,315,237,680,398,637,8,826,239,572,210,319,304,966,986,85,490,828,666,157,769,553,501,25,576,311,483,822,756,198,510,970,195,22,714,875,559,30,925,881,51,910,376,792,653,825,505,59,931,953,351,227,149,332,681,452,621,339,652,333,32,657,445,285,35,702,162,292,355,360,33,690,343,409,9,939,175,288,857,434,633,456,962,639,40,220,712,371,860,829,13,218,479,225,226,337,104,77,946,414,606,297,585,12,316,400,197,167,302,206,781,943,849,905,100,229,877,844,267,575,214,146,540,730,859,184,408,746,704,26,19,230,537,23,340,133,375,801,156,736,577,821,134,60,345,956,70,83,344,521,659,289,234,114,441,832,549,39,310,636,379,494,793,207,120,172,865,622,462,299,683,140,420,386,590,484,605,796,443,626,863,926,530,700,50,763,251,419,529,786,182,238,554,772,75,382,485,667,264,464,128,560,722,616,727,937,674,951,216,662,294,765,955,723,244,586,290,317,163,913,868,287,36,15,24,268,555,202,795,53,257,273,699,331,89,581,472 2 | 349,813,252,571,667,320,833,945,690,352,72,437,180,178,946,105,121,635,703,843,967,21,404,568,391,210,306,63,16,127,368,863,469,89,425,920,617,205,481,846,367,130,806,949,910,419,962,402,815,514,955,248,629,342,332,818,684,177,619,448,919,816,738,208,634,648,671,907,926,790,961,603,191,399,692,836,505,879,81,569,636,665,917,389,104,561,143,15,382,457,238,282,621,300,748,64,984,175,520,851,432,216,125,301,398,284,862,537,726,947,450,678,479,795,695,944,355,612,316,195,880,164,705,522,882,217,560,118,706,860,466,778,299,713,812,214,709,691,618,861,53,439,504,441,8,92,455,318,834,101,803,885,170,585,702,853,533,222,280,924,723,974,94,811,681,272,490,826,977,740,91,239,135,341,808,219,951,159,936,55,535,109,620,606,730,821,724,802,652,232,268,591,132,235,196,447,319,875,13,655,959,111,400,313,487,420,403,831,142,61,117,392,751,688,513,200,666,25,659,777,625,901,250,153,34,643,66,889,898,825,158,154,540,563,867,76,807,852,206,884,698,499,426,644,172,797,287,262,903,930,502,411,155,954,221,199,646,576,18,900,929,149,868,878,365,54,847,719,921,710,489,580,329,814,506,519,712,841,637,464,42,559,562,859,46,809,611,906,209,88,582,2,228,346,384,335,877,891,374,462,865,856,517,395,460,581,163,421,134,638,685,90,744,157,226,870,150,733,624,187,412,787,59,151,86,480,493,339,428,423,796,970,184,6,298,689,887,725,492,436,574,244,197,849,607,321,328,50,731,275,647,167,230,444,627,881,454,717,801,715,459,57,771,324,742,417,672,388,223,827,622,536,842,131,858,755,385,657,140,596,940,817,711,240,344,303,769,5,315,747,41,586,564,477,37,721,950,113,330,291,278,987,925,518,701,530,766,383,752,397,463,286,283,699,27,229,632,794,347,854,4,694,531,99,233,326,708,735,979,538,905,203,49,532,458,527,211,767,780,285,17,452,146,523,722,915,916,311,271,371,47,359,386,491,56,565,616,964,982,497,973,700,855,414,152,192,260,798,682,122,194,570,972,429,255,598,732,115,107,495,828,745,850,176,669,183,524,114,511,539,68,609,869,406,899,743,35,663,67,120,613,942,71,508,658,29,296,594,408,376,549,488,310,11,969,963,573,476,922,465,380,276,515,290,471,263,494,331,242,590,165,902,243,431,36,677,486,78,793,23,941,100,983,645,544,593,79,650,759,893,33,288,601,294,633,541,686,351,258,58,547,728,556,241,966,185,251,20,904,948,785,381,182,503,112,584,976,416,119,823,138,656,103,449,985,198,583,307,909,507,789,866,337,323,253,630,602,108,932,913,445,96,550,190,845,704,679,874,478,265,293,422,372,387,139,676,30,661,69,409,218,136,687,312,98,212,975,707,38,106,336,363,73,133,896,978,914,696,393,642,123,379,48,567,348,453,927,525,137,528,546,343,498,470,512,501,756,939,369,935,43,857,839,614,776,760,156,509,430,354,289,832,639,259,545,24,753,267,608,126,727,587,908,610,757,84,435,295,169,144,986,304,224,189,378,186,357,604,31,600,628,390,683,737,266,415,168,364,693,605,124,937,215,356,912,424,589,467,837,245,246,148,553,758,201,305,933,14,179,322,483,957,557,768,788,829,188,516,472,193,750,74,281,314,309,256,792,231,784,145,835,443,749,928,302,800,340,373,65,718,838,651,649,110,542,526,911,830,292,804,171,674,892,736,257,274,521,971,220,375,615,773,12,660,762,623,781,62,772,872,273 3 | 890,333,62,208,200,757,438,466,70,633,832,565,498,669,778,329,440,352,684,336,202,269,90,606,340,19,32,226,663,640,102,39,487,609,958,168,410,796,897,668,921,193,896,96,586,460,949,278,951,288,401,737,54,105,632,415,628,406,452,344,384,929,260,843,13,405,970,790,287,389,268,142,566,613,572,229,154,334,433,86,388,795,711,237,816,75,423,303,777,451,103,760,616,839,94,876,859,959,869,738,292,502,429,955,355,247,806,7,595,597,713,846,403,736,473,424,671,530,132,582,657,235,121,320,612,780,71,83,625,159,646,390,215,608,144,0,702,611,228,520,444,479,882,127,650,468,975,962,972,407,361,225,771,963,252,672,943,734,552,605,952,860,578,673,697,518,106,783,500,680,976,52,913,621,190,945,108,164,280,203,162,263,820,374,854,981,3,509,246,98,983,709,456,534,57,56,915,844,481,272,879,89,805,259,767,735,218,546,220,550,174,367,364,69,784,751,620,559,315,404,892,63,920,21,847,282,139,974,655,926,523,156,706,681,432,115,957,42,746,428,439,719,426,309,867,607,645,335,61,358,68,710,830,236,548,442,782,812,849,18,814,914,799,551,589,630,677,661,480,76,419,222,808,91,393,183,179,964,917,477,956,341,254,138,871,968,79,253,365,772,143,35,617,53,328,9,728,266,450,575,789,11,903,686,323,714,716,895,588,465,391,265,301,763,469,24,447,418,837,840,599,745,833,441,708,239,290,399,47,654,427,197,881,72,97,506,562,64,245,196,982,856,44,529,904,192,744,49,51,695,827,725,811,802,819,109,398,704,383,382,95,116,729,779,128,537,141,250,788,312,524,17,741,171,922,29,906,255,140,885,786,294,36,165,205,580,392,209,123,776,417,682,318,31,306,727,749,178,815,331,172,122,93,187,55,544,402,753,864,773,596,715,446,471,561,346,277,508,898,298,573,338,659,117,626,167,935,43,161,227,870,381,238,810,472,453,85,555,12,289,793,868,216,73,185,155,261,803,927,322,533,836,594,136,781,801,553,189,92,670,262,368,516,314,528,285,111,397,199,376,219,912,372,475,674,591,286,950,622,894,207,414,319,910,979,761,512,928,660,649,387,698,765,526,377,84,828,243,40,787,936,939,295,256,16,583,521,118,694,732,965,587,359,536,267,569,792,642,794,362,980,923,568,543,707,241,373,217,221,28,302,754,861,324,176,408,279,180,541,863,967,461,33,158,845,667,213,385,954,177,81,305,87,809,4,638,960,135,494,104,119,602,310,938,905,113,866,585,514,889,909,157,817,742,942,717,878,614,705,369,41,264,700,888,966,88,696,484,579,822,153,986,242,50,688,224,731,891,137,283,821,448,850,257,769,300,858,718,80,396,214,629,721,125,348,902,800,186,326,120,690,835,211,683,775,45,570,497,395,321,739,149,610,510,485,653,488,930,720,10,656,67,313,206,181,170,825,77,733,883,114,598,307,627,23,430,194,275,652,304,349,770,308,750,752,872,940,14,474,515,478,590,366,270,437,865,371,619,560,662,66,420,26,99,20,634,675,907,785,504,350,148,233,495,908,354,807,522,932,603,400,743,293,483,34,934,941,547,435,919,212,691,899,65,501,539,665,412,356,648,394,643,824,431,564,924,342,5,880,740,641,527,375,953,146,748,525,658,855,195,554,492,353,759,273,618,984,664,631,574,685,678,689,248,182,240,325,482,848,874,755,151,563,101,184,284,2,886,693,852,48,712,813,380,490,133,30,145,470,317,842,78,8,191,818,747,129,647,152,493,901 4 | 163,349,23,268,543,984,382,345,76,123,306,180,667,776,136,198,981,571,94,817,270,106,146,89,657,243,513,969,246,79,645,528,974,37,391,509,970,134,603,810,874,337,834,664,363,904,534,236,620,6,85,40,882,552,441,386,422,292,713,101,582,949,932,622,935,703,965,715,748,87,980,376,972,754,215,544,443,558,86,105,242,15,437,380,275,599,623,560,307,438,440,297,258,793,44,850,179,255,631,963,82,487,568,124,333,166,902,495,327,125,464,695,823,979,70,462,481,108,561,916,374,928,238,780,976,454,504,886,898,643,567,8,365,355,673,540,269,434,173,265,415,254,245,31,295,426,65,130,570,97,987,777,322,779,700,200,456,705,912,983,760,263,480,52,465,248,887,48,574,686,352,840,542,833,859,447,739,564,951,583,311,277,626,358,946,332,351,424,666,826,746,191,3,291,160,47,181,379,757,283,811,171,765,548,400,425,433,787,162,720,17,971,677,790,964,233,914,290,324,873,615,147,182,659,51,541,325,202,217,199,954,711,110,178,854,107,216,656,601,155,789,812,2,139,330,296,361,41,220,4,366,861,829,896,884,701,231,743,732,126,771,520,728,144,679,576,460,871,948,143,398,503,684,539,189,573,978,385,192,955,807,206,289,228,940,877,831,16,945,891,120,149,738,194,669,54,170,225,682,516,515,988,806,368,502,973,300,318,482,702,335,150,929,906,676,165,499,961,419,707,486,917,903,279,706,691,617,472,604,895,394,430,452,788,628,221,535,530,982,864,710,378,319,609,234,399,390,565,986,600,346,196,815,455,288,241,396,880,339,658,934,98,730,484,137,878,799,805,409,329,494,35,232,924,261,207,660,479,773,474,406,68,747,508,758,427,608,14,514,960,842,632,566,588,644,59,423,735,606,862,42,475,766,293,802,448,761,726,417,844,907,985,239,519,517,718,320,913,752,870,353,496,649,586,488,809,650,383,957,687,294,759,672,28,744,869,27,921,38,240,717,219,523,847,364,613,75,510,313,796,751,45,553,737,678,356,340,721,625,276,694,908,395,84,885,671,360,943,633,156,733,852,526,725,39,420,347,897,167,210,377,336,373,590,592,922,554,118,26,641,652,93,266,188,491,461,786,315,781,876,272,611,835,203,18,832,894,598,693,205,5,731,699,524,942,314,256,237,212,688,637,792,838,785,157,0,104,489,323,46,183,818,901,675,13,393,846,81,522,133,685,595,418,25,634,30,96,436,128,154,745,655,469,253,458,211,804,416,164,920,129,222,550,230,145,813,639,647,569,22,867,80,195,326,630,589,692,830,428,7,498,213,674,127,892,627,153,975,783,74,429,497,966,930,470,19,893,814,341,73,177,362,141,404,185,459,729,190,490,918,9,797,962,640,115,10,851,375,505,704,63,764,947,900,176,584,784,532,555,500,722,605,114,186,941,140,117,367,559,317,956,616,273,29,683,411,121,936,446,12,585,49,132,250,899,384,821,646,828,841,151,457,421,259,78,915,201,937,529,301,889,1,741,575,551,860,953,911,310,284,102,64,546,20,410,716,857,680,512,138,483,939,401,350,756,450,453,581,839,798,66,308,131,923,357,445,321,594,511,251,778,161,596,477,736,11,439,848,879,621,689,819,925,772,264,119,944,62,800,816,36,463,331,926,348,742,71,618,933,247,91,257,791,209,853,312,388,99,309,801,468,875,967,959,670,169,444,824,881,769,883,67,112,451,342,614,408,610,280,174,431,413,709,665,763,719,282,938,442,338,103,845,577,432,837,381,24,77,597,476,414 5 | 39,782,787,203,650,970,550,41,563,346,466,589,275,972,773,140,608,250,647,670,402,924,493,132,446,274,327,835,221,222,528,440,546,43,477,46,863,651,533,959,287,579,946,799,131,637,14,845,632,389,702,950,382,815,65,658,267,332,926,391,739,742,964,104,322,473,513,383,158,341,936,527,980,822,172,760,481,593,315,441,663,495,674,703,453,91,456,628,625,44,729,224,679,831,87,173,137,534,879,545,99,304,609,673,183,600,463,982,568,424,640,806,840,67,340,262,623,374,339,877,92,624,51,266,298,490,387,457,596,451,329,40,22,778,360,515,398,631,464,695,377,167,666,752,517,384,428,866,719,512,985,944,247,379,951,557,233,66,857,371,273,103,80,363,313,878,291,580,478,106,820,410,677,244,461,584,480,868,324,918,270,372,712,543,930,343,549,720,144,599,154,649,111,146,771,401,328,403,538,251,19,590,425,433,112,288,358,213,356,981,119,691,12,598,468,547,506,497,200,605,21,492,904,210,458,825,733,644,715,0,774,859,11,615,934,196,880,142,730,916,740,412,342,479,895,283,675,963,225,347,668,431,34,303,507,745,141,133,781,830,56,899,834,526,960,149,701,790,204,765,603,18,421,976,77,846,286,984,833,25,749,659,26,238,757,444,968,115,216,498,821,430,30,681,978,380,881,243,17,966,160,646,860,438,159,939,988,876,235,977,500,882,554,711,716,767,378,748,889,714,586,751,919,927,215,922,15,484,811,819,420,496,162,591,892,417,572,690,86,842,201,93,239,501,891,849,302,110,181,523,753,629,486,150,489,185,643,240,627,509,452,794,548,706,174,175,717,214,127,406,269,914,48,232,696,300,209,592,510,692,874,499,212,911,667,295,331,249,884,365,482,669,370,450,72,459,734,947,883,409,933,180,743,809,202,306,520,913,390,505,844,460,5,321,294,858,564,416,847,660,823,942,595,454,54,620,755,443,907,583,282,465,205,330,263,810,708,617,455,407,166,797,800,4,867,796,229,843,388,55,783,393,76,986,832,350,684,305,108,948,404,887,587,764,128,165,541,187,411,474,336,508,400,318,722,279,855,865,694,928,784,153,792,551,157,870,64,685,179,856,504,252,436,923,278,171,394,89,369,126,824,483,522,540,938,442,397,813,107,770,758,636,414,518,234,83,974,710,754,277,147,642,297,537,569,612,2,198,125,135,308,544,138,621,633,816,921,585,786,74,953,353,606,426,177,79,555,906,189,419,795,687,476,70,69,852,808,697,338,348,559,731,736,281,558,405,652,841,152,117,648,396,634,726,961,725,71,535,427,491,854,280,525,423,309,853,851,952,186,432,929,987,761,1,875,118,246,672,191,937,264,359,226,861,97,839,871,941,979,581,485,435,574,227,656,35,567,900,561,532,671,68,265,334,597,602,848,817,689,662,439,777,896,82,317,272,178,139,219,814,886,872,969,6,775,728,136,971,488,756,862,293,268,470,32,686,120,905,289,958,897,61,594,565,611,255,299,556,31,915,208,211,727,368,657,762,143,957,197,105,156,349,259,261,759,570,386,445,3,604,129,688,218,195,614,791,28,562,373,653,285,223,36,680,256,850,578,207,836,638,366,763,33,812,983,63,514,502,114,494,345,116,102,469,503,693,217,560,973,335,78,575,134,47,925,713,869,616,188,355,965,9,576,588,750,539,38,333,768,367,130,49,376,228,531,57,864,399,88,53,199,163,949,661,516,772,145,10,351,357,723,908,511,804,582,59,258,888,718,122,37,220,310,193,613,75,536,622,50,362,100 6 | 160,857,33,983,587,842,379,475,822,456,786,974,64,23,154,584,915,74,275,879,760,426,98,414,963,789,460,680,967,382,206,632,773,768,736,37,18,561,14,141,331,876,599,793,173,729,358,640,155,689,551,111,375,779,90,840,641,25,236,843,805,449,520,582,406,434,697,620,776,106,199,849,177,699,862,965,453,344,28,791,838,740,307,22,859,921,211,152,937,610,32,987,543,309,13,753,467,754,505,540,870,248,169,48,5,867,423,219,536,803,676,898,277,355,519,772,704,903,603,973,305,462,24,167,221,798,393,851,764,56,252,594,897,107,630,125,609,162,204,553,702,46,943,270,116,384,101,176,76,461,102,686,400,944,535,402,809,964,562,480,118,428,841,595,586,357,547,698,976,477,507,82,567,523,70,26,509,192,178,957,417,365,479,713,183,715,149,134,92,875,778,728,554,945,407,722,517,174,484,800,184,142,981,681,928,738,356,147,670,218,685,131,469,596,572,197,369,237,437,951,234,751,66,831,291,688,97,140,972,806,664,294,614,73,749,692,810,900,241,888,401,422,40,500,424,261,550,272,598,464,569,12,153,188,287,717,784,617,314,759,661,442,180,552,143,231,373,579,615,415,694,832,969,868,159,213,372,946,421,457,970,94,671,856,525,910,0,980,105,129,158,765,19,864,71,622,148,52,429,487,796,483,53,207,240,894,959,313,470,376,99,504,513,15,636,408,471,913,820,374,194,274,657,947,85,968,451,508,17,518,909,619,660,889,81,468,835,256,435,168,658,545,448,315,433,172,368,391,38,884,748,672,243,257,691,349,812,542,135,43,527,633,583,117,506,49,288,271,239,251,187,682,911,853,611,381,737,289,495,877,269,31,476,802,854,203,223,3,298,797,226,497,486,416,485,744,549,334,312,8,733,601,899,93,568,58,902,397,190,78,621,198,591,432,444,755,938,783,345,881,354,829,804,124,210,669,266,907,336,488,196,339,585,84,170,555,343,839,845,283,465,144,823,67,175,701,75,273,279,342,301,929,606,146,978,89,771,244,530,491,592,137,844,801,230,42,295,329,335,191,575,29,607,363,695,387,675,303,821,326,123,95,558,880,181,895,883,741,577,774,533,781,597,195,267,395,364,643,906,338,478,528,109,498,756,646,324,925,284,700,347,578,380,161,104,319,263,574,935,966,817,7,604,956,649,960,281,590,667,59,813,635,302,360,62,745,687,727,318,593,30,886,848,327,858,163,634,286,988,758,410,703,47,878,232,262,300,436,330,890,27,492,182,110,229,361,220,971,165,166,785,72,502,750,538,6,847,214,775,570,493,441,731,905,932,205,819,366,908,132,787,684,642,529,546,260,961,746,922,44,399,466,790,673,952,564,39,311,68,1,761,942,100,9,962,887,645,341,836,412,726,808,625,663,950,83,403,589,445,69,388,130,420,739,539,201,705,119,333,413,494,510,443,668,308,637,580,425,60,752,383,544,912,612,920,629,571,811,157,628,766,418,50,985,103,917,235,618,624,447,707,693,524,891,690,164,882,757,939,254,709,54,936,86,285,639,828,455,654,348,732,45,648,816,977,489,708,934,296,88,21,863,918,320,179,276,651,332,225,576,948,834,36,652,127,371,452,202,893,769,362,721,80,930,588,613,904,51,35,548,394,398,156,563,290,325,830,450,531,65,662,659,511,984,490,264,557,316,340,696,683,873,351,982,370,826,209,10,623,16,678,872,404,20,566,770,454,482,565,679,138,605,865,473,522,986,87,405,222,720,61,794,396,532,512,323,924,34,747,638,446 7 | 74,948,562,708,849,291,546,141,871,252,32,458,715,877,886,494,797,632,348,427,339,129,714,480,940,579,693,760,787,826,535,199,577,495,523,631,616,360,48,319,695,808,807,392,119,268,515,575,862,980,898,304,189,432,452,0,813,571,681,384,59,834,14,230,517,356,442,86,664,555,270,627,846,682,223,919,422,686,697,408,45,874,855,320,236,184,956,97,111,586,150,937,245,349,716,869,162,833,866,260,67,814,83,550,438,445,674,547,101,108,337,973,482,444,174,903,679,794,249,221,10,130,133,789,931,971,215,958,572,16,675,287,399,278,1,383,484,582,493,487,727,654,778,114,194,831,895,943,24,619,663,455,962,662,656,190,717,424,644,225,220,98,531,146,779,75,896,905,798,208,564,689,413,29,978,363,500,906,543,601,829,128,709,752,354,406,608,309,272,385,509,739,176,592,620,264,892,334,667,22,369,878,156,921,212,271,91,173,415,568,657,439,788,581,205,269,477,848,386,680,983,387,105,322,69,611,917,110,161,153,843,58,232,397,914,451,729,179,511,621,837,393,972,214,610,76,357,87,888,13,434,233,338,604,722,719,894,953,750,617,761,759,793,330,666,365,925,134,165,175,841,882,180,454,700,828,467,710,569,404,669,301,55,277,836,171,596,795,118,685,558,590,542,844,706,402,211,23,131,158,563,603,839,595,450,241,248,374,551,213,781,560,492,859,735,607,557,744,820,88,186,159,315,712,691,375,157,724,416,299,325,720,952,31,40,78,966,891,244,933,335,453,113,436,777,625,653,768,641,670,884,987,811,974,690,254,149,104,870,456,624,534,861,875,676,297,4,231,226,350,50,946,927,93,42,419,965,431,613,502,587,711,379,678,825,806,570,556,407,457,329,845,351,246,358,265,288,418,730,217,967,589,401,139,200,923,207,747,331,803,893,195,228,872,915,885,907,191,378,472,673,466,755,554,99,132,827,321,908,782,889,503,359,545,668,390,955,824,396,364,100,541,197,548,835,106,979,553,361,945,661,699,36,725,934,633,353,166,748,172,447,414,448,683,857,446,897,652,538,463,822,593,216,96,512,963,203,26,251,25,968,591,307,924,333,121,863,326,922,18,784,539,294,302,237,33,411,381,136,77,79,650,247,57,865,56,478,262,809,540,732,328,526,959,864,805,34,757,514,84,15,985,605,552,224,975,696,573,258,284,868,181,910,19,949,82,112,486,3,229,840,561,821,988,403,932,313,612,529,107,853,306,879,688,295,430,273,635,395,109,51,851,753,774,800,743,935,636,954,177,471,576,816,227,606,771,957,815,960,39,818,505,734,790,53,142,521,117,332,819,647,124,705,21,852,345,847,762,116,951,810,754,20,838,5,429,137,483,11,292,765,947,198,185,138,887,327,513,459,303,143,145,151,281,832,352,201,671,465,256,342,692,900,257,193,38,388,168,280,377,485,600,791,305,630,481,936,767,300,527,890,776,371,470,638,202,54,68,763,583,911,182,394,646,285,476,27,373,961,746,66,400,72,370,204,733,43,687,764,298,123,609,614,602,267,362,253,405,912,389,928,516,643,854,7,89,261,81,282,786,135,90,519,433,599,17,969,642,850,65,380,651,372,196,731,122,169,567,210,293,618,645,701,37,883,336,148,290,126,209,770,860,243,28,597,982,944,938,276,812,737,92,421,741,308,160,9,506,410,347,578,310,475,580,559,192,163,391,187,802,648,738,368,473,316,881,626,188,62,504,323,47,412,468,296,902,409,773,461,464,63,977,425,52,742,490,367,858,219,2,95 8 | 81,463,69,396,599,642,233,961,663,634,58,353,551,260,195,82,562,855,485,853,360,269,983,919,175,506,521,923,354,158,22,220,322,540,216,793,770,984,167,876,937,756,177,803,75,278,606,758,616,10,18,947,291,820,886,221,426,138,248,604,817,56,15,597,724,572,790,541,856,683,110,706,507,830,913,255,215,323,105,870,80,52,730,968,838,279,869,28,23,505,656,493,575,669,967,675,247,845,652,173,496,238,129,63,807,559,900,809,126,113,653,508,487,784,733,798,459,135,341,635,948,946,704,909,209,414,539,754,945,543,470,34,395,585,133,866,627,148,773,299,305,592,73,449,181,637,495,465,511,458,342,620,101,964,962,277,646,584,8,533,514,188,952,643,304,740,436,393,159,922,885,681,170,908,536,645,607,595,40,677,140,586,199,718,610,262,251,168,483,288,437,340,499,929,482,970,699,223,298,965,253,608,842,442,960,136,501,702,609,861,558,212,601,544,979,644,873,583,509,473,872,797,750,416,42,235,265,739,791,621,958,332,943,687,696,287,934,723,50,822,560,115,549,379,746,602,674,662,684,329,818,682,630,226,935,819,274,690,612,475,24,185,404,348,127,927,914,2,319,236,343,217,528,432,365,804,9,477,860,44,673,320,301,556,639,62,441,182,959,921,207,227,418,57,390,816,766,261,834,241,728,550,515,448,691,844,907,398,709,795,374,478,377,401,211,453,525,538,858,53,76,0,618,915,234,203,977,41,655,650,920,166,389,92,846,785,258,857,123,975,898,944,232,657,155,840,88,3,89,5,264,794,25,423,27,440,893,678,729,312,350,152,906,524,357,988,708,474,878,918,328,688,814,242,888,871,293,839,460,171,933,430,742,974,204,894,178,444,717,567,547,292,213,936,443,406,716,796,399,598,98,16,579,164,557,828,665,887,891,548,457,214,966,555,615,530,738,953,587,273,950,12,137,658,874,500,245,666,327,769,503,201,128,624,518,249,134,747,59,256,760,31,732,813,355,711,7,239,899,843,779,11,805,315,428,112,897,421,801,194,765,578,14,438,361,517,382,118,106,802,71,789,276,510,311,165,623,455,450,705,985,638,660,452,321,926,202,420,270,823,380,847,941,531,48,862,875,344,827,371,356,447,96,931,890,230,956,153,151,849,763,224,169,381,883,648,225,762,468,512,573,841,553,825,268,33,108,72,472,912,196,205,491,566,93,545,600,338,972,904,237,903,761,250,176,976,70,780,670,231,628,336,581,667,281,119,103,114,768,313,484,582,316,174,546,700,703,537,535,697,852,302,568,629,346,781,611,565,532,879,513,35,925,792,759,289,280,605,777,564,120,13,333,451,588,60,370,143,593,694,229,189,594,197,467,774,359,308,307,366,415,554,49,727,713,266,240,351,394,955,456,516,84,978,534,303,310,734,497,306,552,111,131,145,751,208,939,362,383,454,737,788,911,916,749,422,337,193,778,851,263,863,376,417,905,74,661,124,519,91,748,435,680,142,413,206,162,132,625,462,940,403,97,372,85,144,901,87,672,651,710,752,433,160,318,712,488,21,571,352,326,121,94,294,577,285,865,764,192,349,184,981,949,679,654,367,895,504,190,309,743,385,464,429,725,125,30,179,808,896,4,942,325,210,590,388,721,971,954,529,387,183,90,358,345,339,461,867,286,1,419,282,741,526,498,55,397,707,720,425,330,410,480,889,479,290,102,38,902,527,295,222,37,753,987,668,68,880,494,431,257,20,130,837,806,32,39,647,910,407,51,469,868,917,636,671,686,829,79,157,800 9 | 469,674,116,39,947,425,727,331,43,251,638,371,801,294,509,320,138,308,924,24,559,386,438,818,220,595,716,366,743,82,870,779,602,53,395,409,362,126,459,420,182,643,977,788,945,223,715,731,789,238,310,639,571,728,608,553,492,526,965,254,504,811,632,983,97,207,441,289,777,475,673,577,903,222,458,986,180,477,445,768,723,343,396,821,698,981,299,839,31,244,44,228,322,806,515,19,480,186,95,491,408,670,470,468,130,961,898,121,374,971,17,265,41,71,13,970,432,558,877,350,820,55,123,358,78,869,618,415,460,25,978,172,908,496,443,861,832,233,874,304,710,90,599,809,108,259,554,490,617,953,512,3,814,111,426,858,671,176,695,659,481,954,372,882,860,557,143,878,694,96,645,206,835,500,739,107,54,534,146,758,446,263,704,419,636,161,59,689,514,655,137,174,321,455,279,843,163,353,584,212,188,744,120,921,26,802,170,281,938,339,278,804,862,696,736,757,378,540,790,916,680,590,681,439,958,91,783,377,683,407,912,421,883,131,236,248,828,887,541,264,405,561,774,401,902,253,885,218,531,837,360,45,796,934,389,313,452,735,155,631,382,660,732,604,77,20,62,247,677,962,387,517,890,119,628,622,964,1,431,411,297,314,428,753,849,725,654,190,516,691,498,734,585,316,83,150,104,778,623,687,209,9,846,822,442,888,662,946,197,943,502,824,63,464,816,38,551,418,375,564,359,892,246,397,863,367,524,227,926,164,747,871,139,647,162,429,950,782,301,416,948,598,454,901,109,184,668,791,920,194,722,213,463,770,102,955,730,591,471,581,573,918,345,319,185,549,503,967,276,479,740,478,607,771,417,341,537,335,560,153,692,76,765,171,641,853,738,351,848,625,11,140,177,982,444,973,635,675,857,73,787,157,65,567,193,376,923,257,851,317,904,255,203,168,797,762,202,936,610,745,913,933,942,447,940,33,277,252,147,93,250,899,886,241,620,204,287,895,693,751,493,815,100,697,169,726,70,838,312,672,414,50,290,914,149,667,388,352,156,527,344,840,749,507,328,347,99,985,798,467,23,980,592,794,545,240,721,949,600,400,489,81,434,593,714,327,506,80,292,909,98,685,226,688,30,235,746,298,113,956,379,807,385,927,402,293,729,394,601,61,754,229,282,720,413,900,167,836,7,92,181,844,699,759,543,719,208,72,817,300,897,412,565,175,485,960,189,597,566,482,929,589,52,318,232,173,94,544,772,110,369,963,708,854,456,915,881,261,260,665,748,4,165,859,280,115,855,303,286,87,550,497,214,268,187,291,780,513,364,984,216,160,64,231,718,548,158,579,271,827,326,529,270,525,803,547,333,530,775,284,501,686,219,893,724,399,795,16,125,799,285,519,427,969,702,713,69,664,931,354,391,195,649,392,35,596,532,613,676,198,269,334,894,315,975,27,633,242,42,709,2,472,384,283,968,825,142,129,373,644,75,510,141,910,905,461,133,127,684,381,957,166,626,18,555,74,880,538,616,296,761,486,658,678,101,826,505,183,764,128,741,867,523,959,66,309,865,576,148,847,234,487,355,363,122,773,800,196,440,422,737,630,266,423,701,435,717,612,215,812,410,8,792,348,14,642,499,634,679,767,756,988,79,624,330,769,60,225,562,136,522,682,766,224,834,86,563,872,151,449,706,932,324,619,305,191,841,810,569,781,808,58,450,614,935,36,528,84,521,845,703,258,144,124,152,535,230,850,135,56,520,785,0,22,833,917,357,288,760,987,707,937,29,572,925,952,249,178,307,575,951 10 | 640,852,346,824,641,706,949,626,427,628,464,382,727,633,899,394,972,289,256,896,12,259,582,37,364,40,644,473,436,609,657,297,363,672,512,663,217,909,272,594,546,864,48,604,719,670,608,261,518,558,126,776,914,498,975,122,352,568,635,189,861,674,292,788,257,723,491,426,337,653,869,203,850,985,343,446,374,18,419,50,522,974,278,511,105,270,651,348,982,186,336,606,514,214,402,787,82,237,58,944,442,555,895,492,424,398,676,78,638,248,879,703,384,935,541,437,59,878,441,45,744,192,267,686,269,349,619,675,856,486,508,525,958,882,799,598,833,49,304,804,213,829,652,654,456,831,593,9,183,300,505,515,766,6,756,334,509,553,19,347,516,617,159,761,649,521,970,759,451,428,418,129,73,131,801,650,340,948,714,988,94,634,563,210,877,218,368,377,871,355,168,953,378,870,938,52,239,692,973,230,250,785,620,557,420,92,924,808,246,30,104,147,187,317,61,922,227,886,134,409,361,43,846,38,897,826,120,234,536,479,396,769,87,208,784,770,627,718,962,739,550,403,669,123,254,255,765,817,666,5,602,645,497,198,322,138,264,931,410,190,556,435,757,830,329,493,868,222,53,881,499,84,690,890,572,258,392,534,573,894,616,283,211,103,946,490,658,88,390,324,885,178,417,549,697,678,717,698,445,439,927,393,89,713,249,754,715,121,950,333,263,825,22,90,835,746,152,513,127,146,580,911,354,142,41,395,613,474,209,139,331,212,109,597,276,618,86,711,903,893,320,312,976,411,677,811,399,141,495,600,729,980,67,794,483,665,734,284,194,113,987,170,524,721,440,977,796,919,726,575,934,888,389,155,101,942,372,821,818,704,561,455,422,873,244,967,453,69,501,332,326,768,226,725,642,163,971,408,306,966,458,118,920,742,174,8,963,696,380,404,574,339,367,793,268,42,466,74,904,28,260,167,449,430,350,822,434,504,743,859,517,85,447,64,947,15,232,485,460,344,376,36,538,369,636,571,21,319,154,898,802,117,583,391,200,241,951,901,150,577,251,388,185,307,646,172,177,79,179,1,7,97,47,841,682,752,779,96,119,603,551,161,11,202,660,216,803,584,986,313,448,365,969,559,716,848,204,519,591,221,247,63,523,469,578,836,135,566,431,529,596,0,872,820,849,510,892,629,918,70,912,252,733,310,905,816,542,266,184,468,908,271,844,265,955,814,643,173,196,535,133,537,280,231,783,281,797,810,308,964,39,489,862,775,668,891,253,930,55,342,360,287,916,929,813,46,83,108,507,695,345,13,707,699,34,373,476,587,805,106,60,661,75,664,303,279,673,330,745,941,467,933,749,488,615,637,112,301,481,370,205,328,14,26,356,945,162,148,502,755,910,548,567,701,812,625,981,66,188,62,375,165,274,362,655,23,4,900,387,197,961,764,630,457,589,500,791,20,978,965,913,867,171,240,624,880,353,760,710,477,771,875,778,158,876,543,273,153,379,724,35,381,233,671,80,700,181,2,762,68,503,647,687,544,323,786,774,595,979,413,795,827,601,800,487,532,16,243,569,371,612,984,17,25,932,201,180,149,865,484,807,621,423,242,586,429,296,65,438,100,309,275,720,842,552,843,747,863,357,140,359,683,866,819,527,338,191,884,954,960,857,902,923,854,494,136,24,291,684,688,858,789,631,614,454,401,708,56,93,750,199,110,592,482,736,219,837,3,738,581,740,305,459,98,506,705,325,685,32,983,758,299,425,952,496,777,245,385,834,207,412,526,175,166,907,302,545,223,314,662,562,712 11 | -------------------------------------------------------------------------------- /data/abide_schaefer100/val.index: -------------------------------------------------------------------------------- 1 | 276,429,82,568,705,200,391,105,95,282,668,987,215,831,764,416,397,66,29,893,689,176,672,541,72,858,249,719,170,63,640,784,833,139,38,974,551,159,442,178,742,896,864,698,92,183,349,789,298,3,365,520,135,402,710,295,562,174,612,886,851,852,774,86,740,320,588,67,511,493,347,728,880,259,440,532,648,921,542,901,942,823,624,212,461,405,406,228,707,65,84,911,515,948,979,404,952,804,412 2 | 572,433,32,786,236,577,770,654,805,60,558,578,566,746,958,640,680,249,653,360,297,482,675,95,82,102,166,716,377,980,484,85,317,763,160,327,783,764,204,438,451,822,938,128,80,791,254,227,51,475,888,697,361,308,840,775,741,779,413,981,474,820,407,277,956,405,237,427,345,960,28,864,234,968,810,353,641,334,782,588,116,496,534,77,510,599,894,446,161,225,579,873,45,10,22,552,147,44,923 3 | 458,933,297,973,343,173,877,130,853,601,160,491,918,58,900,730,884,676,893,873,549,615,726,489,378,584,762,545,519,723,851,826,147,679,499,281,363,503,112,988,337,201,100,332,829,46,977,60,571,27,311,724,124,774,507,758,577,370,1,434,316,249,445,538,635,244,274,169,722,916,258,271,459,230,531,296,823,416,592,511,421,476,6,623,150,443,791,651,463,467,542,231,804,576,639,198,581,766,931 4 | 661,579,109,184,298,740,113,734,872,116,61,950,518,724,392,407,663,159,95,767,21,244,227,467,57,60,397,768,387,624,501,492,698,208,593,274,471,370,485,547,556,299,175,223,602,354,187,537,412,607,100,619,271,750,668,521,931,83,890,919,262,927,714,557,591,172,654,755,820,662,473,580,304,636,795,836,952,286,403,531,775,204,653,968,825,344,712,305,328,774,449,285,334,855,578,782,58,43,316 5 | 954,311,902,519,610,242,168,418,60,626,190,161,316,932,910,296,665,943,641,766,113,475,375,529,448,826,326,709,385,314,408,325,364,704,635,395,940,148,8,254,192,73,967,312,52,182,664,290,7,96,769,735,699,890,29,27,245,237,449,724,838,437,415,827,155,975,361,676,803,909,413,109,462,931,260,471,422,920,721,801,732,434,45,571,90,776,780,655,276,566,798,837,901,737,747,917,337,121,619 6 | 503,122,955,573,644,919,653,317,389,718,631,297,931,224,91,537,896,392,253,650,714,200,926,941,762,386,824,438,114,145,782,792,55,41,2,151,827,656,458,126,666,892,186,874,427,350,247,626,818,282,439,600,501,885,150,359,238,337,472,278,409,352,57,730,255,250,430,871,193,723,742,496,767,958,788,526,258,724,815,855,292,293,265,837,440,799,807,299,411,115,556,953,306,521,228,249,866,665,734 7 | 525,44,341,530,640,524,170,520,64,491,655,275,622,242,474,549,508,823,346,235,85,94,103,916,785,565,976,941,726,120,440,939,522,234,623,437,799,856,694,629,792,536,239,489,899,615,49,435,496,6,449,723,462,286,218,317,61,343,340,279,639,901,637,876,366,259,152,659,164,909,532,499,533,144,507,344,206,649,703,749,707,677,698,873,498,736,756,46,842,594,780,918,318,751,702,528,772,70,986 8 | 580,719,833,722,848,783,659,284,411,283,831,324,632,776,631,218,757,726,775,386,782,520,563,45,767,314,364,617,576,424,570,272,799,693,246,154,502,574,701,46,626,141,83,824,200,99,252,434,259,139,19,378,698,405,107,122,622,402,116,744,95,928,486,147,409,692,982,369,735,815,271,67,347,26,850,542,466,745,969,373,54,930,882,786,392,335,492,61,755,664,481,243,267,439,161,77,163,714,772 9 | 117,640,941,272,154,85,368,875,891,911,179,830,457,325,793,705,588,275,112,574,629,700,979,763,32,873,302,398,12,132,51,245,274,805,711,866,556,114,536,436,930,323,40,57,37,868,889,383,267,637,34,712,884,831,433,508,666,474,864,533,311,201,210,627,603,403,690,346,390,217,570,465,466,88,6,370,134,295,587,542,856,192,518,47,338,819,473,488,462,552,829,852,944,15,476,651,648,650,876 10 | 928,156,10,940,889,151,860,193,925,956,29,639,763,585,957,735,462,838,689,169,71,823,554,681,229,579,753,828,478,386,472,943,215,206,694,463,406,321,780,107,228,832,77,656,648,443,358,883,224,530,102,405,471,798,560,327,164,607,72,351,366,99,906,937,741,809,611,262,728,125,547,311,731,939,853,693,144,111,235,81,588,540,461,236,887,444,400,605,124,539,926,115,33,145,128,95,238,176,533 11 | -------------------------------------------------------------------------------- /data/adni_schaefer100/test.index: -------------------------------------------------------------------------------- 1 | 4,26,53,57,71,76,77,79,110,120,126,131,137,146,151,162,183,188,199,204,227,232,239,248,249,258,269,285,287,300,301,314,317,324,329,359,382,383,385,400,420,427,431,447,452,469,471,472,473,474,475,484,498,512,518,544,550,557,565,566,571,584,594,600,604,606,609,611,612,634,638,651,656,688,695,696,708,713,714,732,740,756,826,834,841,848,854,858,878,896,897,909,912,919,923,926,937,953,968,980,992,1015,1023,1029,1039,1043,1045,1052,1054,1059,1071,1096,1107,1125,1138,1147,1161,1175,1178,1208,1210,1213,1231,1234,1239,1241,1262,1268,1282,1291,1299,1307,1319 2 | 6,10,16,17,25,27,42,49,69,80,89,101,122,138,144,157,169,174,180,210,212,242,253,266,280,288,306,318,322,332,339,341,350,356,361,365,368,374,377,384,387,392,404,407,411,426,428,448,451,457,504,553,568,569,572,574,579,580,610,622,632,642,645,647,652,660,673,676,685,718,738,739,744,749,754,762,764,765,769,771,778,809,819,836,840,846,849,865,873,884,893,934,950,963,979,990,991,995,997,1003,1006,1022,1027,1037,1055,1064,1075,1082,1088,1094,1136,1143,1148,1165,1166,1176,1182,1184,1188,1193,1211,1219,1233,1240,1243,1245,1264,1274,1285,1297,1298,1310,1323 3 | 0,28,35,36,59,62,78,88,104,113,116,141,164,178,179,181,185,187,190,202,203,208,213,226,230,236,246,251,259,264,277,284,298,299,309,321,323,334,347,357,380,395,417,435,442,446,459,465,483,485,486,495,509,516,517,520,523,527,533,540,543,552,561,613,627,633,641,654,659,667,686,687,690,693,701,717,720,723,731,736,774,813,828,832,837,850,853,860,879,890,891,942,948,949,951,952,959,966,972,976,986,1010,1016,1017,1030,1034,1056,1062,1063,1109,1126,1137,1140,1150,1160,1167,1174,1179,1185,1209,1217,1224,1230,1237,1242,1247,1257,1266,1277,1279,1313,1314,1324 4 | 3,8,14,33,34,37,41,46,61,66,68,86,90,96,98,112,121,133,139,145,153,159,161,166,184,191,193,195,198,207,215,216,220,223,235,238,244,255,261,267,278,282,305,307,327,333,344,360,379,390,402,405,415,423,424,430,436,458,481,482,493,555,567,575,583,589,601,602,625,628,630,631,648,658,664,668,674,712,724,761,767,783,820,823,842,844,868,876,888,898,902,915,925,928,935,956,969,977,982,983,994,1019,1020,1042,1051,1068,1072,1098,1101,1108,1115,1123,1129,1132,1135,1139,1144,1159,1170,1198,1199,1203,1212,1216,1238,1253,1260,1270,1278,1290,1306,1315,1318 5 | 11,22,23,39,45,65,75,83,115,117,123,134,148,152,155,171,172,194,219,225,229,237,247,270,272,276,295,302,330,342,343,355,363,370,372,386,388,419,421,440,450,460,462,467,476,494,503,513,532,536,542,548,549,576,590,607,614,615,620,626,629,637,666,675,678,684,689,694,698,706,715,722,729,782,784,785,788,789,793,801,804,807,825,830,843,847,852,882,892,901,903,913,924,929,933,938,941,962,988,998,1008,1026,1036,1069,1070,1078,1081,1083,1089,1099,1104,1105,1110,1112,1118,1119,1122,1158,1180,1189,1200,1214,1228,1229,1251,1254,1281,1283,1288,1292,1293,1300,1303 6 | 5,7,9,15,30,43,63,74,84,93,97,103,130,140,143,147,168,186,197,200,201,206,209,231,260,262,263,268,271,286,291,296,304,313,315,345,346,348,366,394,397,418,422,425,429,453,456,463,470,477,478,492,496,497,524,526,528,531,539,546,556,570,593,603,617,639,644,646,655,682,716,725,735,743,753,755,770,787,792,794,795,817,833,838,851,857,866,870,877,899,907,910,920,932,943,964,970,981,985,989,999,1014,1021,1024,1028,1048,1050,1058,1061,1079,1086,1087,1093,1111,1128,1151,1162,1169,1187,1191,1194,1221,1225,1226,1250,1259,1269,1272,1276,1312,1317,1320,1325 7 | 18,44,55,64,72,82,85,91,109,125,129,132,135,142,211,218,224,228,279,303,319,325,353,358,362,364,367,369,371,391,396,408,412,414,416,432,437,438,441,444,449,479,488,489,499,502,508,525,530,537,554,563,564,591,592,595,598,599,618,640,661,662,671,679,705,719,727,728,733,734,746,747,758,760,763,766,772,777,791,802,810,811,829,856,862,863,867,886,889,894,905,917,918,930,954,973,978,987,1001,1002,1004,1011,1033,1040,1041,1049,1067,1092,1095,1114,1117,1120,1121,1131,1146,1168,1172,1181,1186,1190,1196,1197,1218,1223,1235,1263,1265,1271,1287,1295,1308,1322 8 | 1,20,21,24,50,54,56,95,102,106,114,124,127,165,167,170,175,177,182,192,252,274,297,308,310,316,331,338,351,375,381,389,393,403,409,413,433,434,468,487,490,491,500,501,506,510,514,521,538,547,560,562,573,578,581,582,596,597,621,650,653,657,677,691,692,699,700,702,707,710,721,730,737,741,751,773,775,798,803,805,814,816,821,824,827,839,861,871,887,895,900,916,922,945,946,955,960,961,965,1005,1009,1038,1046,1047,1053,1057,1066,1076,1080,1085,1103,1106,1130,1134,1149,1153,1164,1171,1173,1205,1220,1227,1232,1236,1244,1256,1258,1273,1289,1294,1304,1321 9 | 2,12,13,19,32,38,47,48,58,60,67,70,87,92,94,99,107,108,111,118,119,128,156,160,163,176,189,196,205,221,222,250,254,256,265,273,275,283,289,290,311,312,320,326,336,337,352,373,399,406,439,443,445,454,455,480,511,535,541,551,585,587,616,619,623,635,636,669,672,681,683,697,711,742,748,750,752,799,800,806,812,818,822,835,845,869,872,880,881,904,908,911,914,939,940,957,958,974,975,984,993,1012,1013,1018,1031,1032,1060,1077,1084,1090,1091,1097,1100,1102,1133,1142,1156,1163,1177,1195,1202,1206,1246,1248,1249,1261,1267,1280,1286,1301,1302,1309 10 | 29,31,40,51,52,73,81,100,105,136,149,150,154,158,173,214,217,233,234,240,241,243,245,257,281,292,293,294,328,335,340,349,354,376,378,398,401,410,461,464,466,505,507,515,519,522,529,534,545,558,559,577,586,588,605,608,624,643,649,663,665,670,680,703,704,709,726,745,757,759,768,776,779,780,781,786,790,796,797,808,815,831,855,859,864,874,875,883,885,906,921,927,931,936,944,947,967,971,996,1000,1007,1025,1035,1044,1065,1073,1074,1113,1116,1124,1127,1141,1145,1152,1154,1155,1157,1183,1192,1201,1204,1207,1215,1222,1252,1255,1275,1284,1296,1305,1311,1316 11 | -------------------------------------------------------------------------------- /data/adni_schaefer100/val.index: -------------------------------------------------------------------------------- 1 | 185,363,1184,96,41,560,1136,1038,450,798,831,526,783,358,1004,1060,1,458,1152,462,1130,334,1308,857,459,1137,435,642,985,1271,265,50,231,1085,357,588,1242,321,936,620,37,576,545,1114,791,628,1191,647,415,351,1112,422,1237,832,187,952,68,1013,1325,531,977,661,337,561,907,994,795,8,1276,81,729,75,820,284,553,1215,1324,1000,1105,786,944,139,707,416,448,666,801,625,1264,203,984,42,215,1221,36,1093,456,543,1082,404,372,333,864,156,847,1088,563,1098,343,461,5,965,122,46,504,541,636,1127,633,1240,583,860,736,895,1193,617,1116,485,654,1272,686,806,174 2 | 540,198,931,164,1284,625,443,410,607,1068,1278,554,51,1035,419,55,700,629,942,1212,880,501,145,902,65,251,525,1072,552,1047,602,1092,1098,706,0,261,968,265,760,192,889,551,282,628,120,1280,658,1286,479,18,305,476,168,166,244,870,281,338,608,816,290,160,94,630,1271,1218,1135,581,947,1031,1058,1206,336,598,998,959,740,231,1106,248,621,90,279,1063,536,441,264,444,992,564,1107,867,1239,1229,903,842,510,634,982,814,317,31,620,662,821,1260,1170,667,1121,1043,1249,1119,314,1008,1127,292,333,465,651,1149,1207,1317,163,961,233,717,692,825,773,289,548,458,28 3 | 1043,348,814,1301,549,47,682,911,849,759,283,455,646,276,461,383,1275,428,441,665,643,1308,326,1054,1013,1082,797,794,647,129,630,715,163,146,935,940,924,1012,513,1221,770,303,109,73,15,765,680,306,975,482,1295,883,535,1239,234,55,1212,24,507,733,1072,1248,992,225,266,590,1261,1125,856,1229,338,93,669,882,183,855,399,824,1299,274,473,664,871,695,186,521,337,1088,670,977,367,1166,1049,261,964,537,1068,111,571,1041,1057,608,118,27,907,124,210,25,466,340,1270,970,1120,679,1178,278,583,877,1218,1046,222,985,384,494,544,1243,43,1134,13,229,1118,800,677 4 | 446,380,62,326,1083,75,697,167,957,590,409,1166,1100,1150,1057,371,824,22,26,515,684,689,670,1094,5,952,938,1210,736,92,502,510,200,108,862,1130,678,407,1114,394,1237,1119,538,1005,1208,199,604,1136,468,475,136,277,342,651,421,368,30,1227,1294,187,453,351,989,751,709,1084,770,358,899,420,831,1126,944,961,406,1316,1149,288,470,147,694,789,160,1053,999,1034,826,971,1273,395,1324,556,489,1026,1200,1321,1250,896,335,447,905,263,457,189,598,854,181,385,1096,138,960,126,833,611,70,1295,486,1177,25,532,1224,239,706,334,325,699,354,1312,225,228,376,640,299 5 | 770,133,750,797,586,1175,257,1272,926,671,879,1182,859,1117,702,1259,74,915,486,781,602,708,582,136,710,682,1048,808,899,1252,46,977,364,1203,1168,1302,728,1017,530,777,246,374,744,93,1082,1240,528,358,1265,454,260,1262,507,1004,352,13,1244,53,481,173,1035,332,1237,105,153,904,524,1051,1155,673,1195,220,971,851,180,212,127,1170,565,968,499,241,506,131,531,221,1223,63,888,97,1095,983,1183,465,119,831,1023,251,707,283,73,672,946,721,141,597,12,949,410,817,588,187,990,886,900,1113,457,27,128,1021,3,264,116,1116,1315,742,1319,62,664,1030,57,662,56 6 | 228,846,1081,70,991,11,941,923,122,577,252,715,1042,265,285,1182,821,255,278,790,543,574,129,997,12,509,373,1155,740,342,1136,1122,335,327,559,1138,750,10,534,1026,1036,1181,1017,1248,1288,611,1121,884,764,280,956,774,840,1291,141,827,202,1106,924,371,1313,823,731,1205,612,876,1154,495,1260,2,1173,40,1301,550,1323,1120,219,938,266,843,155,214,338,117,159,945,454,488,1279,942,590,1045,29,1219,215,46,785,239,614,804,59,832,3,665,685,739,1148,230,580,728,1204,979,216,390,586,573,479,419,1241,353,328,729,473,678,23,672,620,1193,466,754,791,382,1238 7 | 446,633,435,118,25,1035,491,106,381,1103,994,1006,801,0,819,605,697,260,239,451,79,576,1032,575,1221,402,1231,1149,861,604,1109,670,372,909,799,518,515,343,596,677,1156,504,220,276,538,102,1253,1185,1277,1258,956,207,189,1207,832,225,468,904,1136,480,1303,631,921,219,1089,1116,624,731,1007,817,1272,695,1026,194,482,1198,828,821,666,988,494,985,465,300,664,1255,929,88,543,1125,481,327,307,943,534,178,1314,609,242,675,1106,654,648,1183,725,213,837,556,375,510,286,637,339,74,316,131,471,409,348,1208,442,1020,680,1113,186,868,1279,873,1229,1060,521,158,1052 8 | 416,341,1008,877,520,494,583,940,1313,1225,605,367,19,212,1167,387,1012,720,1229,92,543,198,845,1126,678,145,12,322,399,213,1190,84,1050,48,1072,306,1271,228,931,286,1209,993,941,493,754,701,412,1197,153,964,140,355,1265,227,155,899,117,574,185,624,422,862,134,687,1147,1306,298,195,304,275,143,1006,1048,637,1092,995,697,156,1122,142,144,1044,216,832,1094,666,255,1262,287,302,1013,1117,439,673,558,266,540,1084,857,806,796,944,400,786,1299,764,97,1160,425,458,248,1204,336,11,865,268,474,766,1185,735,896,319,130,31,566,1098,34,1248,1259,180,957,835,1131 9 | 899,319,393,371,991,408,733,1178,211,931,478,949,258,972,1157,934,181,360,388,218,1075,441,567,282,41,1251,1303,1167,600,106,304,558,51,479,379,306,988,1104,1135,18,774,281,322,871,503,701,267,970,699,834,540,836,1095,745,35,675,449,246,284,1068,487,76,342,398,85,1052,594,694,1203,1009,559,93,537,639,1034,1270,368,796,666,1228,1264,566,1217,1273,739,179,315,102,1001,831,1193,700,1155,542,55,16,707,561,802,476,435,877,719,1002,117,650,1132,162,1057,885,795,169,1105,768,1256,358,1219,1214,1071,409,1041,1292,573,1291,848,240,64,852,1058,301,7,765,1146 10 | 35,675,694,654,77,1126,579,103,20,508,363,470,222,71,1224,274,1120,517,1133,205,925,923,1178,1199,299,890,765,610,1175,296,364,1208,1235,1076,411,344,89,1230,611,1078,999,1280,1245,365,1031,677,1214,79,969,676,752,562,53,734,1085,485,289,408,600,499,817,727,840,196,812,805,380,1259,445,824,1179,323,1069,564,762,1098,412,63,854,625,527,963,1301,748,879,770,117,232,1292,754,829,871,1297,531,207,1068,437,1312,760,295,182,556,1020,422,761,617,528,1016,9,1105,355,57,740,932,458,246,902,1011,778,989,688,995,1299,432,104,216,914,359,208,1114,839,928,638 11 | -------------------------------------------------------------------------------- /data/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | File to load dataset based on user control from main file 3 | """ 4 | from data.BrainNet import BrainDataset 5 | 6 | 7 | def LoadData(DATASET_NAME, threshold=0, edge_ratio=0, node_feat_transform='original'): 8 | """ 9 | This function is called in the main.py file 10 | returns: 11 | ; dataset object 12 | """ 13 | 14 | return BrainDataset(DATASET_NAME, threshold=threshold, edge_ratio=edge_ratio, node_feat_transform=node_feat_transform) 15 | -------------------------------------------------------------------------------- /data/generate_data_from_mat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pandas as pd 3 | import numpy as np 4 | import networkx as nx 5 | import os # To create directories 6 | import shutil 7 | import scipy.io 8 | import dgl 9 | import torch 10 | import glob 11 | import csv 12 | import re 13 | import json 14 | from tqdm import tqdm 15 | from dgl.data.utils import save_graphs 16 | from sklearn.model_selection import StratifiedKFold, train_test_split 17 | 18 | 19 | def _load_matrix_subject_with_files(files, remove_negative=False): 20 | subjects = [] 21 | for file in files: 22 | mat = scipy.io.loadmat(file) 23 | mat = mat["data"] 24 | np.fill_diagonal(mat, 0) 25 | if remove_negative: 26 | mat[mat < 0] = 0 27 | subjects.append(mat) 28 | return np.array(subjects) 29 | 30 | def construct_dataset(data_name): 31 | feat_dir = 'data/to/connectivity_matrices_schaefer/' + data_name + '/' 32 | 33 | G_dataset = [] 34 | Labels = [] 35 | group2idx = {} 36 | paths = glob.glob(feat_dir + '/*/' + '*_features_timeseries.mat', recursive=True) 37 | feats = _load_matrix_subject_with_files(paths) 38 | 39 | print('Processing ' + data_name + '...') 40 | 41 | for j in tqdm(range(len(feats))): 42 | name = paths[j].split('/')[-1] 43 | group = re.findall('sub-([^\d]+)', name)[0] 44 | if group not in group2idx.keys(): 45 | group2idx[group] = len(group2idx.keys()) 46 | i = group2idx[group] 47 | 48 | G = nx.DiGraph(np.ones([feats[j].shape[0], feats[j].shape[0]])) 49 | graph_dgl = dgl.from_networkx(G) 50 | 51 | graph_dgl.ndata['N_features'] = torch.from_numpy(feats[j]) 52 | # Include edge features 53 | weights = [] 54 | for u, v, w in G.edges.data('weight'): 55 | # if w is not None: 56 | weights.append(w) 57 | graph_dgl.edata['E_features'] = torch.Tensor(weights) 58 | 59 | G_dataset.append(graph_dgl) 60 | Labels.append(i) 61 | 62 | print('Finish process ' + data_name + '. ' + str(len(feats)) + ' subjects in total.') 63 | 64 | Labels = torch.LongTensor(Labels) 65 | graph_labels = {"glabel": Labels} 66 | if not os.path.exists('./bin_dataset/'): 67 | os.mkdir('./bin_dataset/') 68 | print(Labels.shape) 69 | print(len(G_dataset)) 70 | save_graphs("./bin_dataset/" + data_name + ".bin", G_dataset, graph_labels) 71 | 72 | 73 | def move_files(data_name): 74 | feat_dir = '/data/jiaxing/brain/connectivity_matrices_schaefer/' + data_name + '/' 75 | paths = glob.glob(feat_dir + '/*/*', recursive=True) 76 | for path in paths: 77 | if path[-4:] == '.mat': 78 | if 'schashaefer' in path: 79 | new_path = re.sub('schashaefer', 'schaefer', path) 80 | os.rename(path, new_path) 81 | continue 82 | else: 83 | parcellation = data_name.split('_')[-1] 84 | os.rename(path, path + '_' + parcellation + '_correlation_matrix.mat') 85 | 86 | 87 | if __name__ == '__main__': 88 | error_name = [] 89 | # file_name_list = os.listdir('./correlation_datasets/') 90 | file_name_list = ['adni_schaefer100'] 91 | 92 | for data_name in file_name_list: 93 | move_files(data_name) 94 | # construct_dataset(data_name) 95 | # try: 96 | # construct_dataset(data_name) 97 | # except: 98 | # print('[ERROR]: ' + data_name) 99 | # error_name.append(data_name) 100 | print(error_name) 101 | print('Done!') 102 | -------------------------------------------------------------------------------- /data/neurocon_schaefer100/test.index: -------------------------------------------------------------------------------- 1 | 10,12,15,22,29 2 | 6,9,20,30 3 | 3,14,36,37 4 | 0,1,27,39 5 | 7,11,16,25 6 | 4,31,33,34 7 | 8,32,35,38 8 | 13,17,19,21 9 | 5,23,26,28 10 | 2,18,24,40 11 | -------------------------------------------------------------------------------- /data/neurocon_schaefer100/train.index: -------------------------------------------------------------------------------- 1 | 0,31,23,28,30,26,2,9,40,18,20,19,35,33,27,25,36,4,38,7,24,17,8,21,13,5,3,1,6,39,16,14 2 | 35,21,10,2,14,4,3,39,23,29,11,28,26,13,38,27,15,18,32,12,16,5,1,24,34,36,33,31,25,8,40,37 3 | 2,8,12,20,11,27,0,32,29,33,31,7,4,18,24,23,40,28,17,5,1,19,30,9,35,38,10,25,22,21,26,34 4 | 20,15,19,26,34,10,16,11,6,23,38,37,7,35,9,18,5,30,32,31,24,8,13,33,29,28,12,40,14,25,36,2 5 | 13,26,2,30,40,21,19,38,32,31,35,14,17,24,15,36,29,1,28,33,34,0,3,6,27,5,37,12,23,39,9,10 6 | 26,18,1,3,10,25,32,36,9,20,28,17,14,27,19,15,22,39,24,8,13,29,23,30,5,6,37,7,35,12,11,40 7 | 4,13,3,23,12,15,10,30,17,19,27,0,11,20,29,22,31,34,9,14,18,24,36,16,21,33,7,26,37,40,6,5 8 | 22,25,30,29,3,11,16,5,28,14,34,36,2,0,32,39,7,8,15,4,10,38,18,1,37,24,26,27,9,35,40,23 9 | 2,10,0,32,15,11,36,27,17,1,13,21,37,20,34,39,38,4,7,22,16,9,3,25,18,8,33,12,40,30,35,19 10 | 38,12,16,8,14,32,39,29,31,10,19,1,20,36,22,4,6,26,13,30,37,34,25,28,0,11,9,23,3,17,33,15 11 | -------------------------------------------------------------------------------- /data/neurocon_schaefer100/val.index: -------------------------------------------------------------------------------- 1 | 34,11,32,37 2 | 19,0,17,22,7 3 | 39,15,13,16,6 4 | 4,22,17,21,3 5 | 18,20,8,4,22 6 | 0,2,38,21,16 7 | 2,25,39,28,1 8 | 12,31,6,20,33 9 | 14,29,31,6,24 10 | 7,5,21,35,27 11 | -------------------------------------------------------------------------------- /data/ppmi_schaefer100/test.index: -------------------------------------------------------------------------------- 1 | 0,14,22,30,37,45,53,63,68,76,90,93,112,128,135,155,157,164,171,181,199 2 | 3,5,41,44,55,65,66,77,81,84,94,105,126,134,139,141,148,159,160,169,205 3 | 8,12,34,35,47,48,58,78,80,98,100,101,113,132,133,136,144,149,173,182,206 4 | 4,6,31,52,56,59,60,71,72,82,119,120,123,130,165,168,172,176,180,193,202 5 | 11,13,20,29,32,46,57,83,106,111,115,118,124,129,142,150,162,166,177,191,203 6 | 10,17,25,38,39,43,74,79,85,88,95,96,102,151,154,156,158,178,185,200,204 7 | 1,15,19,33,61,67,73,104,107,110,117,121,122,138,140,145,146,153,188,201,207 8 | 9,18,27,36,50,54,64,70,75,86,89,108,125,137,152,161,167,170,183,195,196 9 | 7,21,24,26,40,42,49,92,97,103,114,116,131,147,184,186,187,190,194,198,208 10 | 2,16,23,28,51,62,69,87,91,99,109,127,143,163,174,175,179,189,192,197 11 | -------------------------------------------------------------------------------- /data/ppmi_schaefer100/train.index: -------------------------------------------------------------------------------- 1 | 31,15,169,204,167,24,99,129,180,100,29,203,133,40,60,56,61,80,118,173,33,36,191,150,18,144,102,186,91,2,183,127,175,122,146,103,27,32,192,156,57,177,123,69,116,74,47,41,158,83,137,88,109,111,16,11,77,107,8,184,48,59,198,163,172,208,179,73,108,131,194,95,121,105,120,21,201,72,1,25,161,9,79,139,114,3,170,26,205,75,71,98,126,35,67,17,162,132,94,34,206,197,174,124,38,202,141,87,110,52,188,154,125,130,51,113,7,168,166,148,42,86,196,65,28,142,49,66,182,193,44,143,10,84,187,96,153,12,101,152,62,13,159,70,19,134,178,81,50,106,20,97,89,85,6,145,5,78,136,119,165,185,190,207,92,140,43 2 | 110,71,32,80,83,31,186,46,158,92,4,74,51,136,119,155,88,146,156,26,153,48,37,53,78,20,98,135,208,13,97,58,64,23,33,12,124,150,95,29,123,43,72,149,61,89,196,198,57,75,137,127,68,11,21,175,200,164,163,201,115,128,204,197,117,144,207,90,111,25,109,143,104,151,70,176,54,120,73,206,180,165,82,138,177,60,39,157,194,93,122,188,174,162,178,199,27,47,30,99,125,145,19,147,112,50,91,59,189,181,17,15,22,191,85,106,168,2,101,183,100,185,202,40,108,96,35,69,67,187,42,131,161,170,38,167,121,133,114,113,103,7,86,130,24,14,49,152,0,193,182,8,173,132,34,195,28,179,1,16,6,87,116,192,45,184,129 3 | 72,38,203,200,156,88,19,104,14,111,207,86,90,126,117,51,155,75,183,143,181,63,124,167,55,151,9,1,119,79,23,197,74,165,36,153,56,81,193,103,166,192,20,141,159,184,114,4,91,106,205,37,150,160,45,121,169,176,178,15,65,198,130,7,107,186,99,17,64,112,32,95,42,129,102,128,185,16,161,195,177,135,96,66,158,163,73,175,142,199,84,40,24,194,31,148,22,108,77,168,92,76,139,85,93,69,171,140,27,116,3,105,50,191,115,49,162,26,57,70,13,30,28,145,10,204,201,154,152,39,89,172,120,188,157,123,196,94,110,67,179,170,131,118,44,180,60,33,43,0,41,52,208,2,127,125,146,97,54,109,5,137,61,62,82,46,174 4 | 30,115,33,146,127,95,145,178,184,143,77,188,87,45,117,208,22,50,139,32,141,183,27,68,51,66,122,153,43,65,207,3,101,192,13,155,53,67,134,113,23,135,17,62,16,159,104,151,41,92,181,170,38,189,136,204,110,84,103,142,112,14,175,88,173,21,156,42,11,140,83,182,86,108,169,203,24,44,40,80,37,15,0,20,19,150,160,109,177,58,121,5,64,54,190,99,137,26,28,179,70,205,93,129,152,79,200,76,74,57,144,131,36,100,157,1,29,158,194,48,149,167,75,46,186,118,97,78,69,7,201,171,197,133,89,116,34,105,114,81,196,10,63,195,107,96,199,187,191,132,174,154,39,8,163,102,126,12,111,55,164,85,2,47,106,128,94 5 | 192,91,70,190,131,183,25,148,157,68,147,65,41,188,151,143,200,145,44,34,154,125,105,120,71,185,155,159,161,149,108,127,85,56,128,45,69,33,78,14,130,103,50,26,9,1,175,60,12,97,202,102,180,123,163,152,146,176,22,58,164,167,156,140,7,182,196,74,207,35,92,137,168,59,141,8,169,165,39,113,15,153,24,40,179,36,17,79,53,119,187,5,66,114,95,201,138,48,139,63,81,178,16,109,101,88,6,96,99,195,160,37,67,204,55,172,184,30,170,126,86,174,199,47,10,206,62,42,43,76,110,73,72,121,205,194,52,122,135,87,117,198,61,104,208,90,80,21,134,100,3,144,173,23,51,89,112,82,158,27,38,64,31,18,93,0,193 6 | 208,143,83,15,37,6,2,77,11,73,23,160,145,150,122,125,65,194,162,5,36,78,91,60,50,51,199,18,169,42,180,170,45,168,109,172,8,98,206,176,133,148,53,163,94,157,186,142,159,52,183,131,130,135,32,177,34,87,29,110,182,203,81,202,196,61,153,69,205,100,190,33,89,114,71,13,12,191,92,116,75,207,44,108,3,9,80,121,174,171,195,118,97,57,103,47,193,86,139,136,146,165,105,137,16,124,106,66,62,84,181,30,198,0,126,54,166,22,27,28,128,192,173,46,127,70,175,129,201,26,187,101,59,132,117,184,55,107,104,141,120,134,58,188,64,113,140,76,115,68,138,14,21,49,144,48,167,63,152,35,67,119,40,31,7,111,19 7 | 52,100,11,206,148,150,116,81,165,59,111,96,27,76,46,125,105,187,161,114,56,178,30,44,54,134,78,99,135,185,37,194,143,179,43,186,123,13,120,80,98,64,149,139,127,163,48,199,17,92,112,58,86,101,51,189,164,173,183,45,62,118,180,50,75,130,38,32,172,159,24,41,4,176,22,167,16,141,124,175,191,126,18,162,3,94,6,90,197,113,72,40,20,157,21,177,144,68,108,36,119,85,83,77,93,66,8,181,2,88,0,49,53,174,128,195,200,23,203,151,87,7,205,57,106,160,65,142,89,147,182,166,132,79,95,5,190,71,82,137,198,14,133,34,204,152,70,103,91,154,28,9,171,97,63,169,131,156,31,102,129,25,136,115,208,196,29 8 | 7,82,128,83,120,140,22,100,153,21,123,191,177,42,4,201,39,188,49,85,34,122,117,158,112,118,124,131,15,180,163,166,88,102,26,35,135,55,40,68,107,142,139,155,94,38,187,72,162,151,206,20,58,165,44,56,133,134,156,115,46,190,78,169,8,45,98,179,148,185,25,198,28,76,24,197,104,208,109,136,121,73,143,31,65,3,14,144,113,111,96,77,51,48,63,192,129,41,10,99,101,93,74,47,87,171,81,1,29,30,6,71,189,66,17,116,11,168,202,90,157,52,80,207,159,164,205,141,150,145,69,95,92,199,91,103,5,62,194,33,174,130,110,186,182,59,184,114,154,12,57,23,37,203,147,127,43,13,181,146,126,204,175,105,132,138,67 9 | 170,70,129,96,110,192,160,155,84,154,74,175,135,146,158,145,54,162,35,169,17,3,44,200,2,185,27,68,152,71,85,28,173,125,91,93,207,201,13,121,204,206,73,183,195,181,43,122,25,196,22,138,11,171,159,182,6,9,107,5,112,20,67,50,79,23,143,41,53,88,99,77,52,124,134,32,69,83,55,58,19,36,203,0,100,136,56,1,178,76,18,189,14,118,80,46,15,65,109,63,197,117,127,172,105,111,151,115,60,38,47,113,30,153,48,102,174,16,156,75,101,89,168,157,142,64,150,202,163,132,164,167,8,140,144,137,191,12,205,176,106,166,45,33,78,94,165,72,128,82,139,126,87,123,148,86,31,29,108,37,149,104,39,177,57,61,180 10 | 4,73,57,146,3,121,119,5,133,15,60,141,76,114,145,142,188,67,182,123,169,90,205,30,86,77,100,137,94,116,71,108,45,151,103,178,208,200,107,6,81,155,117,204,44,32,122,26,97,49,193,24,177,46,165,134,36,63,170,66,39,125,157,172,12,186,41,61,33,161,154,106,40,70,93,42,59,111,160,153,89,185,183,92,194,104,25,78,101,129,74,158,79,14,8,128,191,159,150,18,95,113,147,140,190,27,166,164,50,124,55,203,110,98,105,20,207,58,88,35,202,43,75,9,167,29,180,181,22,21,34,47,118,10,130,1,195,131,84,138,85,7,96,187,199,52,13,19,184,112,144,54,162,83,198,173,115,120,56,201,135,82,136,31,196,53,139,149 11 | -------------------------------------------------------------------------------- /data/ppmi_schaefer100/val.index: -------------------------------------------------------------------------------- 1 | 82,55,138,149,117,176,46,39,147,4,58,151,115,200,104,160,189,195,23,64,54 2 | 102,107,62,76,203,142,154,166,63,18,140,172,10,9,79,190,118,56,36,171,52 3 | 202,29,18,11,147,122,190,164,25,189,71,187,87,83,134,21,59,6,138,68,53 4 | 90,138,166,206,198,185,161,73,91,98,61,125,49,25,148,147,9,162,18,35,124 5 | 171,28,19,77,4,107,54,132,181,116,2,186,136,197,98,189,133,84,49,94,75 6 | 112,56,197,4,179,24,155,93,147,99,20,82,1,189,161,41,72,90,149,164,123 7 | 184,10,26,74,170,168,202,60,109,12,193,192,35,158,47,69,42,84,39,55,155 8 | 16,176,149,53,172,200,0,19,97,79,173,60,160,119,2,106,32,61,84,178,193 9 | 90,161,133,120,179,199,62,4,59,34,10,98,130,119,66,95,141,51,81,188,193 10 | 0,126,148,48,65,11,17,206,152,156,72,37,64,168,102,38,132,171,68,80,176 11 | -------------------------------------------------------------------------------- /data/taowu_schaefer100/test.index: -------------------------------------------------------------------------------- 1 | 0,16,26,31 2 | 1,9,23,27 3 | 10,14,22,24 4 | 15,18,25,30 5 | 3,4,21,38 6 | 6,11,34,36 7 | 17,19,28,35 8 | 5,8,33,37 9 | 2,7,20,39 10 | 12,13,29,32 11 | -------------------------------------------------------------------------------- /data/taowu_schaefer100/train.index: -------------------------------------------------------------------------------- 1 | 3,19,20,34,36,5,32,17,38,21,13,2,9,22,28,29,24,12,14,1,35,11,18,37,39,30,27,4,8,15,23,7 2 | 12,5,8,35,11,21,3,28,22,29,30,15,18,39,17,26,38,0,14,16,2,31,34,4,36,7,10,33,37,20,32,13 3 | 19,20,13,17,33,30,29,34,26,37,6,2,32,28,35,8,15,18,12,38,36,21,27,16,25,1,5,0,11,39,7,4 4 | 5,14,4,39,17,16,29,9,26,31,36,21,11,38,22,3,10,34,2,37,12,7,19,8,23,35,27,0,20,33,1,32 5 | 13,1,37,15,22,35,39,14,0,26,11,8,36,23,20,17,2,34,33,24,16,10,30,5,27,6,31,18,25,28,7,9 6 | 18,38,13,37,10,26,3,16,19,28,15,7,27,5,32,1,17,39,29,9,25,4,0,20,14,33,12,21,35,24,30,22 7 | 22,14,37,11,8,13,20,25,29,34,39,7,30,12,3,36,4,16,9,33,1,23,2,6,31,21,5,27,18,15,38,26 8 | 4,32,6,38,1,13,36,12,24,0,9,25,35,31,22,20,28,30,19,2,23,15,10,3,18,7,34,16,26,11,39,21 9 | 0,12,21,36,17,14,9,11,27,28,25,4,5,29,1,19,31,8,26,22,34,23,32,10,15,24,16,18,35,30,37,6 10 | 7,9,26,10,27,23,6,1,36,31,17,4,3,28,22,25,5,34,0,33,39,16,11,24,20,30,2,15,18,35,19,37 11 | -------------------------------------------------------------------------------- /data/taowu_schaefer100/val.index: -------------------------------------------------------------------------------- 1 | 33,6,10,25 2 | 25,6,19,24 3 | 23,31,3,9 4 | 24,13,28,6 5 | 29,19,12,32 6 | 8,31,2,23 7 | 32,24,10,0 8 | 29,14,27,17 9 | 3,38,13,33 10 | 14,38,8,21 11 | -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AngusMonroe/ContrastPool/f3df45d5fc1573b3b6a05a5c214a9552c34a37d7/figs/framework.png -------------------------------------------------------------------------------- /layers/attention_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import csv 5 | import numpy as np 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | def __init__(self, hid_dim, n_heads, pf_dim, dropout, device, feat_dim, learnable_q=False, pos_enc=None): 10 | super().__init__() 11 | 12 | self.learnable_q = learnable_q 13 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim) 14 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) 15 | self.dropout = nn.Dropout(dropout) 16 | self.q = torch.nn.Parameter(torch.ones([pf_dim, feat_dim, hid_dim])) if self.learnable_q else None 17 | 18 | def forward(self, src, src_mask=None): 19 | if self.learnable_q: 20 | _src, _ = self.self_attention(self.q, src, src, src_mask) 21 | else: 22 | _src, _ = self.self_attention(src, src, src, src_mask) 23 | src = self.self_attn_layer_norm(src + self.dropout(_src)) 24 | # src = [batch size, src len, hid dim] 25 | return src 26 | 27 | 28 | class MultiHeadAttentionLayer(nn.Module): 29 | def __init__(self, hid_dim, n_heads, dropout, device): 30 | super().__init__() 31 | 32 | self.hid_dim = hid_dim 33 | self.n_heads = n_heads 34 | 35 | assert hid_dim % n_heads == 0 36 | 37 | self.w_q = nn.Linear(hid_dim, hid_dim) 38 | self.w_k = nn.Linear(hid_dim, hid_dim) 39 | self.w_v = nn.Linear(hid_dim, hid_dim) 40 | 41 | self.fc = nn.Linear(hid_dim, hid_dim) 42 | 43 | self.dropout = nn.Dropout(dropout) 44 | 45 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device) 46 | 47 | def forward(self, query, key, value, mask=None): 48 | 49 | bsz = query.shape[0] 50 | 51 | Q = self.w_q(query) 52 | K = self.w_k(key) 53 | V = self.w_v(value) 54 | 55 | Q = Q.view(bsz, -1, self.n_heads, self.hid_dim // 56 | self.n_heads).permute(0, 2, 1, 3) 57 | K = K.view(bsz, -1, self.n_heads, self.hid_dim // 58 | self.n_heads).permute(0, 2, 1, 3) 59 | V = V.view(bsz, -1, self.n_heads, self.hid_dim // 60 | self.n_heads).permute(0, 2, 1, 3) 61 | 62 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale 63 | 64 | if mask is not None: 65 | energy = energy.masked_fill(mask == 0, -1e10) 66 | 67 | attention = self.dropout(torch.softmax(energy, dim=-1)) 68 | 69 | 70 | x = torch.matmul(attention, V) 71 | 72 | x = x.permute(0, 2, 1, 3).contiguous() 73 | 74 | x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads)) 75 | 76 | x = self.fc(x) 77 | 78 | return x, attention.squeeze() 79 | 80 | 81 | class PositionwiseFeedforwardLayer(nn.Module): 82 | def __init__(self, hid_dim, pf_dim, dropout): 83 | super().__init__() 84 | 85 | self.fc_1 = nn.Linear(hid_dim, pf_dim) 86 | self.fc_2 = nn.Linear(pf_dim, hid_dim) 87 | 88 | self.dropout = nn.Dropout(dropout) 89 | 90 | def forward(self, x): 91 | # x = [batch size, seq len, hid dim] 92 | x = self.dropout(torch.relu(self.fc_1(x))) 93 | # x = [batch size, seq len, pf dim] 94 | x = self.fc_2(x) 95 | # x = [batch size, seq len, hid dim] 96 | 97 | return x 98 | 99 | 100 | class PositionalEncoding(nn.Module): 101 | "Implement the PE function." 102 | def __init__(self, d_model, dropout, max_len=5000): 103 | super(PositionalEncoding, self).__init__() 104 | self.dropout = nn.Dropout(p=dropout) 105 | 106 | # Compute the positional encodings once in log space. 107 | pe = torch.zeros(max_len, d_model) 108 | position = torch.arange(0., max_len).unsqueeze(1) 109 | div_term = torch.exp(torch.arange(0., d_model, 2) * 110 | -(math.log(10000.0) / d_model)) 111 | pe[:, 0::2] = torch.sin(position * div_term) 112 | pe[:, 1::2] = torch.cos(position * div_term) 113 | pe = pe.unsqueeze(0) 114 | self.register_buffer('pe', pe) 115 | 116 | def forward(self, x): 117 | x = x + self.pe[:, :x.size(1)] 118 | return self.dropout(x) 119 | -------------------------------------------------------------------------------- /layers/contrastpool_layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from scipy.linalg import block_diag 7 | from torch.autograd import Function 8 | from layers.graphsage_layer import GraphSageLayer, DenseGraphSage 9 | 10 | 11 | def masked_softmax(matrix, mask, dim=-1, memory_efficient=True, 12 | mask_fill_value=-1e32): 13 | ''' 14 | masked_softmax for dgl batch graph 15 | code snippet contributed by AllenNLP (https://github.com/allenai/allennlp) 16 | ''' 17 | if mask is None: 18 | result = torch.nn.functional.softmax(matrix, dim=dim) 19 | else: 20 | mask = mask.float() 21 | while mask.dim() < matrix.dim(): 22 | mask = mask.unsqueeze(1) 23 | if not memory_efficient: 24 | result = torch.nn.functional.softmax(matrix * mask, dim=dim) 25 | result = result * mask 26 | result = result / (result.sum(dim=dim, keepdim=True) + 1e-13) 27 | else: 28 | masked_matrix = matrix.masked_fill((1 - mask).byte(), 29 | mask_fill_value) 30 | result = torch.nn.functional.softmax(masked_matrix, dim=dim) 31 | return result 32 | 33 | 34 | class EntropyLoss(nn.Module): 35 | # Return Scalar 36 | # loss used in diffpool 37 | def forward(self, adj, anext, s_l): 38 | entropy = (torch.distributions.Categorical( 39 | probs=s_l).entropy()).sum(-1).mean(-1) 40 | assert not torch.isnan(entropy) 41 | return entropy 42 | 43 | 44 | class ContrastPoolLayer(nn.Module): 45 | 46 | def __init__(self, input_dim, assign_dim, output_feat_dim, 47 | activation, dropout, aggregator_type, link_pred, batch_norm, pool_assign='GraphSage', max_node_num=0): 48 | super().__init__() 49 | self.embedding_dim = input_dim 50 | self.assign_dim = assign_dim 51 | self.hidden_dim = output_feat_dim 52 | self.link_pred = link_pred 53 | self.feat_gc = GraphSageLayer( 54 | input_dim, 55 | output_feat_dim, 56 | activation, 57 | dropout, 58 | aggregator_type, 59 | batch_norm) 60 | if pool_assign == 'GraphSage': 61 | self.pool_gc = GraphSageLayer( 62 | input_dim, 63 | assign_dim, 64 | activation, 65 | dropout, 66 | aggregator_type, 67 | batch_norm) 68 | else: 69 | pass 70 | self.reg_loss = nn.ModuleList([]) 71 | self.loss_log = {} 72 | self.reg_loss.append(EntropyLoss()) 73 | 74 | # cs 75 | self.weight = nn.Parameter(torch.Tensor(max_node_num, assign_dim)) 76 | self.bias = nn.Parameter(torch.Tensor(1, assign_dim)) 77 | stdv = 1. / math.sqrt(self.weight.size(1)) 78 | self.weight.data.uniform_(-stdv, stdv) 79 | self.bias.data.uniform_(-stdv, stdv) 80 | 81 | def forward(self, g, h, diff_h=None, adj=None, e=None): 82 | # h: [1000, 86] 83 | batch_size = len(g.batch_num_nodes()) 84 | feat, e = self.feat_gc(g, h, e) 85 | device = feat.device 86 | # GCN 87 | if diff_h is not None: 88 | # print(diff_h.shape) 89 | # print(self.weight.shape) 90 | support = torch.matmul(diff_h, self.weight) 91 | if adj is not None: 92 | output = torch.matmul(adj.to(device), support) 93 | else: 94 | output = torch.matmul(g.adj().to_dense().clone().to(device), support.repeat(batch_size, 1)) 95 | assign_tensor = output + self.bias 96 | else: 97 | assign_tensor, e = self.pool_gc(g, h, e) 98 | # assign_tensor: [2000, 50] 99 | # print(assign_tensor.shape) 100 | 101 | assign_tensor_masks = [] 102 | assign_size = int(assign_tensor.size()[1]) if adj is not None else int(assign_tensor.size()[1] / batch_size) 103 | for g_n_nodes in g.batch_num_nodes(): 104 | mask = torch.ones((g_n_nodes, assign_size)) 105 | assign_tensor_masks.append(mask) 106 | 107 | """ 108 | The first pooling layer is computed on batched graph. 109 | We first take the adjacency matrix of the batched graph, which is block-wise diagonal. 110 | We then compute the assignment matrix for the whole batch graph, which will also be block diagonal 111 | """ 112 | mask = torch.FloatTensor( 113 | block_diag( 114 | * 115 | assign_tensor_masks)).to( 116 | device=device) 117 | if adj is not None: 118 | assign_tensor = assign_tensor.repeat(batch_size, batch_size) 119 | 120 | assign_tensor = masked_softmax(assign_tensor, mask, memory_efficient=False) 121 | h = torch.matmul(torch.t(assign_tensor), feat) # equation (3) of DIFFPOOL paper 122 | adj = g.adjacency_matrix(ctx=device) 123 | 124 | adj_new = torch.sparse.mm(adj, assign_tensor) 125 | adj_new = torch.mm(torch.t(assign_tensor), adj_new) # equation (4) of DIFFPOOL paper 126 | 127 | if self.link_pred: 128 | current_lp_loss = torch.norm(adj.to_dense() - 129 | torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(), 2) 130 | self.loss_log['LinkPredLoss'] = current_lp_loss 131 | 132 | for loss_layer in self.reg_loss: 133 | loss_name = str(type(loss_layer).__name__) 134 | 135 | self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor) 136 | return adj_new, h 137 | 138 | 139 | class LinkPredLoss(nn.Module): 140 | # loss used in diffpool 141 | def forward(self, adj, anext, s_l): 142 | link_pred_loss = ( 143 | adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2)) 144 | link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2)) 145 | return link_pred_loss.mean() 146 | 147 | 148 | class DenseDiffPool(nn.Module): 149 | def __init__(self, nfeat, nnext, nhid, link_pred=False, entropy=True): 150 | super().__init__() 151 | self.link_pred = link_pred 152 | self.log = {} 153 | self.link_pred_layer = LinkPredLoss() 154 | self.embed = DenseGraphSage(nfeat, nhid, use_bn=True) 155 | self.assign = DiffPoolAssignment(nfeat, nnext) 156 | self.reg_loss = nn.ModuleList([]) 157 | self.loss_log = {} 158 | if link_pred: 159 | self.reg_loss.append(LinkPredLoss()) 160 | if entropy: 161 | self.reg_loss.append(EntropyLoss()) 162 | 163 | def forward(self, x, adj, log=False): 164 | z_l = self.embed(x, adj) 165 | s_l = self.assign(x, adj) 166 | if log: 167 | self.log['s'] = s_l.cpu().numpy() 168 | xnext = torch.matmul(s_l.transpose(-1, -2), z_l) 169 | anext = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l) 170 | 171 | for loss_layer in self.reg_loss: 172 | loss_name = str(type(loss_layer).__name__) 173 | self.loss_log[loss_name] = loss_layer(adj, anext, s_l) 174 | if log: 175 | self.log['a'] = anext.cpu().numpy() 176 | return xnext, anext 177 | 178 | 179 | class DiffPoolAssignment(nn.Module): 180 | def __init__(self, nfeat, nnext): 181 | super().__init__() 182 | self.assign_mat = DenseGraphSage(nfeat, nnext, use_bn=True) 183 | 184 | def forward(self, x, adj, log=False): 185 | s_l_init = self.assign_mat(x, adj) 186 | s_l = F.softmax(s_l_init, dim=-1) 187 | return s_l 188 | -------------------------------------------------------------------------------- /layers/diffpool_layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from scipy.linalg import block_diag 7 | 8 | from torch.autograd import Function 9 | 10 | """ 11 | DIFFPOOL: 12 | Z. Ying, J. You, C. Morris, X. Ren, W. Hamilton, and J. Leskovec, 13 | Hierarchical graph representation learning with differentiable pooling (NeurIPS 2018) 14 | https://arxiv.org/pdf/1806.08804.pdf 15 | 16 | ! code started from dgl diffpool examples dir 17 | """ 18 | 19 | from layers.graphsage_layer import GraphSageLayer, DenseGraphSage 20 | 21 | 22 | def masked_softmax(matrix, mask, dim=-1, memory_efficient=True, 23 | mask_fill_value=-1e32): 24 | ''' 25 | masked_softmax for dgl batch graph 26 | code snippet contributed by AllenNLP (https://github.com/allenai/allennlp) 27 | ''' 28 | if mask is None: 29 | result = torch.nn.functional.softmax(matrix, dim=dim) 30 | else: 31 | mask = mask.float() 32 | while mask.dim() < matrix.dim(): 33 | mask = mask.unsqueeze(1) 34 | if not memory_efficient: 35 | result = torch.nn.functional.softmax(matrix * mask, dim=dim) 36 | result = result * mask 37 | result = result / (result.sum(dim=dim, keepdim=True) + 1e-13) 38 | else: 39 | masked_matrix = matrix.masked_fill((1 - mask).byte(), 40 | mask_fill_value) 41 | result = torch.nn.functional.softmax(masked_matrix, dim=dim) 42 | return result 43 | 44 | 45 | class EntropyLoss(nn.Module): 46 | # Return Scalar 47 | # loss used in diffpool 48 | def forward(self, adj, anext, s_l): 49 | entropy = (torch.distributions.Categorical( 50 | probs=s_l).entropy()).sum(-1).mean(-1) 51 | assert not torch.isnan(entropy) 52 | return entropy 53 | 54 | 55 | class DiffPoolLayer(nn.Module): 56 | 57 | def __init__(self, input_dim, assign_dim, output_feat_dim, 58 | activation, dropout, aggregator_type, link_pred, batch_norm, pool_assign='GraphSage'): 59 | super().__init__() 60 | self.embedding_dim = input_dim 61 | self.assign_dim = assign_dim 62 | self.hidden_dim = output_feat_dim 63 | self.link_pred = link_pred 64 | self.feat_gc = GraphSageLayer( 65 | input_dim, 66 | output_feat_dim, 67 | activation, 68 | dropout, 69 | aggregator_type, 70 | batch_norm) 71 | if pool_assign == 'GraphSage': 72 | self.pool_gc = GraphSageLayer( 73 | input_dim, 74 | assign_dim, 75 | activation, 76 | dropout, 77 | aggregator_type, 78 | batch_norm) 79 | else: 80 | pass 81 | self.reg_loss = nn.ModuleList([]) 82 | self.loss_log = {} 83 | self.reg_loss.append(EntropyLoss()) 84 | 85 | def forward(self, g, h, e=None): 86 | # h: [1000, 86] 87 | feat, e = self.feat_gc(g, h, e) 88 | device = feat.device 89 | assign_tensor, e = self.pool_gc(g, h, e) 90 | 91 | assign_tensor_masks = [] 92 | batch_size = len(g.batch_num_nodes()) 93 | for g_n_nodes in g.batch_num_nodes(): 94 | mask = torch.ones((g_n_nodes, 95 | int(assign_tensor.size()[1] / batch_size))) 96 | assign_tensor_masks.append(mask) 97 | """ 98 | The first pooling layer is computed on batched graph. 99 | We first take the adjacency matrix of the batched graph, which is block-wise diagonal. 100 | We then compute the assignment matrix for the whole batch graph, which will also be block diagonal 101 | """ 102 | mask = torch.FloatTensor( 103 | block_diag( 104 | * 105 | assign_tensor_masks)).to( 106 | device=device) 107 | 108 | assign_tensor = masked_softmax(assign_tensor, mask, 109 | memory_efficient=False) 110 | # print(assign_tensor.shape) 111 | h = torch.matmul(torch.t(assign_tensor), feat) # equation (3) of DIFFPOOL paper 112 | adj = g.adjacency_matrix(ctx=device) 113 | 114 | adj_new = torch.sparse.mm(adj, assign_tensor) 115 | adj_new = torch.mm(torch.t(assign_tensor), adj_new) # equation (4) of DIFFPOOL paper 116 | 117 | if self.link_pred: 118 | current_lp_loss = torch.norm(adj.to_dense() - 119 | torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(), 2) 120 | self.loss_log['LinkPredLoss'] = current_lp_loss 121 | 122 | for loss_layer in self.reg_loss: 123 | loss_name = str(type(loss_layer).__name__) 124 | 125 | self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor) 126 | return adj_new, h 127 | 128 | 129 | class LinkPredLoss(nn.Module): 130 | # loss used in diffpool 131 | def forward(self, adj, anext, s_l): 132 | link_pred_loss = ( 133 | adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2)) 134 | link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2)) 135 | return link_pred_loss.mean() 136 | 137 | 138 | class DenseDiffPool(nn.Module): 139 | def __init__(self, nfeat, nnext, nhid, link_pred=False, entropy=True): 140 | super().__init__() 141 | self.link_pred = link_pred 142 | self.log = {} 143 | self.link_pred_layer = self.LinkPredLoss() 144 | self.embed = DenseGraphSage(nfeat, nhid, use_bn=True) 145 | self.assign = DiffPoolAssignment(nfeat, nnext) 146 | self.reg_loss = nn.ModuleList([]) 147 | self.loss_log = {} 148 | if link_pred: 149 | self.reg_loss.append(LinkPredLoss()) 150 | if entropy: 151 | self.reg_loss.append(EntropyLoss()) 152 | 153 | def forward(self, x, adj, log=False): 154 | z_l = self.embed(x, adj) 155 | s_l = self.assign(x, adj) 156 | if log: 157 | self.log['s'] = s_l.cpu().numpy() 158 | xnext = torch.matmul(s_l.transpose(-1, -2), z_l) 159 | anext = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l) 160 | 161 | for loss_layer in self.reg_loss: 162 | loss_name = str(type(loss_layer).__name__) 163 | self.loss_log[loss_name] = loss_layer(adj, anext, s_l) 164 | if log: 165 | self.log['a'] = anext.cpu().numpy() 166 | return xnext, anext 167 | 168 | 169 | class DiffPoolAssignment(nn.Module): 170 | def __init__(self, nfeat, nnext): 171 | super().__init__() 172 | self.assign_mat = DenseGraphSage(nfeat, nnext, use_bn=True) 173 | 174 | def forward(self, x, adj, log=False): 175 | s_l_init = self.assign_mat(x, adj) 176 | s_l = F.softmax(s_l_init, dim=-1) 177 | return s_l 178 | -------------------------------------------------------------------------------- /layers/graphsage_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl.function as fn 6 | from dgl.nn.pytorch import SAGEConv 7 | 8 | """ 9 | GraphSAGE: 10 | William L. Hamilton, Rex Ying, Jure Leskovec, Inductive Representation Learning on Large Graphs (NeurIPS 2017) 11 | https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf 12 | """ 13 | 14 | class GraphSageLayer(nn.Module): 15 | 16 | def __init__(self, in_feats, out_feats, activation, dropout, 17 | aggregator_type, batch_norm, residual=False, 18 | bias=True, dgl_builtin=False, e_feat=False): 19 | super().__init__() 20 | self.in_channels = in_feats 21 | self.out_channels = out_feats 22 | self.aggregator_type = aggregator_type 23 | self.batch_norm = batch_norm 24 | self.residual = residual 25 | self.dgl_builtin = dgl_builtin 26 | 27 | if in_feats != out_feats: 28 | self.residual = False 29 | 30 | self.dropout = nn.Dropout(p=dropout) 31 | 32 | self.message_func = fn.copy_src(src='h', out='m') if not e_feat else fn.u_mul_e('h', 'e', 'm') 33 | 34 | if dgl_builtin == False: 35 | self.nodeapply = NodeApply(in_feats, out_feats, activation, dropout, 36 | bias=bias) 37 | if aggregator_type == "maxpool": 38 | self.aggregator = MaxPoolAggregator(in_feats, in_feats, 39 | activation, bias) 40 | elif aggregator_type == "lstm": 41 | self.aggregator = LSTMAggregator(in_feats, in_feats) 42 | else: 43 | self.aggregator = MeanAggregator() 44 | else: 45 | self.sageconv = SAGEConv(in_feats, out_feats, aggregator_type, 46 | dropout, activation=activation) 47 | 48 | if self.batch_norm: 49 | self.batchnorm_h = nn.BatchNorm1d(out_feats) 50 | self.batchnorm_e = nn.BatchNorm1d(out_feats) 51 | 52 | def forward(self, g, h, e=None): 53 | h_in = h # for residual connection 54 | # e_in = e 55 | 56 | if self.dgl_builtin == False: 57 | h = self.dropout(h) 58 | # e = self.dropout(e) 59 | g.ndata['h'] = h 60 | # g.edata['e'] = e 61 | g.update_all(fn.copy_src(src='h', out='m'), 62 | self.aggregator, 63 | self.nodeapply) 64 | 65 | h = g.ndata['h'] 66 | else: 67 | h = self.sageconv(g, h) 68 | 69 | if self.batch_norm: 70 | h = self.batchnorm_h(g, h) 71 | 72 | if self.residual: 73 | h = h_in + h # residual connection 74 | 75 | return h, e 76 | 77 | def __repr__(self): 78 | return '{}(in_channels={}, out_channels={}, aggregator={}, residual={})'.format(self.__class__.__name__, 79 | self.in_channels, 80 | self.out_channels, self.aggregator_type, self.residual) 81 | 82 | 83 | 84 | """ 85 | Aggregators for GraphSage 86 | """ 87 | class Aggregator(nn.Module): 88 | """ 89 | Base Aggregator class. 90 | """ 91 | 92 | def __init__(self): 93 | super().__init__() 94 | 95 | def forward(self, node): 96 | neighbour = node.mailbox['m'] 97 | c = self.aggre(neighbour) 98 | return {"c": c} 99 | 100 | def aggre(self, neighbour): 101 | # N x F 102 | raise NotImplementedError 103 | 104 | 105 | class MeanAggregator(Aggregator): 106 | """ 107 | Mean Aggregator for graphsage 108 | """ 109 | 110 | def __init__(self): 111 | super().__init__() 112 | 113 | def aggre(self, neighbour): 114 | mean_neighbour = torch.mean(neighbour, dim=1) 115 | return mean_neighbour 116 | 117 | 118 | class MaxPoolAggregator(Aggregator): 119 | """ 120 | Maxpooling aggregator for graphsage 121 | """ 122 | 123 | def __init__(self, in_feats, out_feats, activation, bias): 124 | super().__init__() 125 | self.linear = nn.Linear(in_feats, out_feats, bias=bias) 126 | self.activation = activation 127 | 128 | def aggre(self, neighbour): 129 | neighbour = self.linear(neighbour) 130 | if self.activation: 131 | neighbour = self.activation(neighbour) 132 | maxpool_neighbour = torch.max(neighbour, dim=1)[0] 133 | return maxpool_neighbour 134 | 135 | 136 | class LSTMAggregator(Aggregator): 137 | """ 138 | LSTM aggregator for graphsage 139 | """ 140 | 141 | def __init__(self, in_feats, hidden_feats): 142 | super().__init__() 143 | self.lstm = nn.LSTM(in_feats, hidden_feats, batch_first=True) 144 | self.hidden_dim = hidden_feats 145 | self.hidden = self.init_hidden() 146 | 147 | nn.init.xavier_uniform_(self.lstm.weight, 148 | gain=nn.init.calculate_gain('relu')) 149 | 150 | def init_hidden(self): 151 | """ 152 | Defaulted to initialite all zero 153 | """ 154 | return (torch.zeros(1, 1, self.hidden_dim), 155 | torch.zeros(1, 1, self.hidden_dim)) 156 | 157 | def aggre(self, neighbours): 158 | """ 159 | aggregation function 160 | """ 161 | # N X F 162 | rand_order = torch.randperm(neighbours.size()[1]) 163 | neighbours = neighbours[:, rand_order, :] 164 | 165 | (lstm_out, self.hidden) = self.lstm(neighbours.view(neighbours.size()[0], neighbours.size()[1], -1)) 166 | return lstm_out[:, -1, :] 167 | 168 | def forward(self, node): 169 | neighbour = node.mailbox['m'] 170 | c = self.aggre(neighbour) 171 | return {"c": c} 172 | 173 | 174 | class NodeApply(nn.Module): 175 | """ 176 | Works -> the node_apply function in DGL paradigm 177 | """ 178 | 179 | def __init__(self, in_feats, out_feats, activation, dropout, bias=True): 180 | super().__init__() 181 | self.dropout = nn.Dropout(p=dropout) 182 | self.linear = nn.Linear(in_feats * 2, out_feats, bias) 183 | self.activation = activation 184 | 185 | def concat(self, h, aggre_result): 186 | bundle = torch.cat((h, aggre_result), 1) 187 | bundle = self.linear(bundle) 188 | return bundle 189 | 190 | def forward(self, node): 191 | h = node.data['h'] 192 | c = node.data['c'] 193 | bundle = self.concat(h, c) 194 | bundle = F.normalize(bundle, p=2, dim=1) 195 | if self.activation: 196 | bundle = self.activation(bundle) 197 | return {"h": bundle} 198 | 199 | 200 | class GraphSageLayerEdgeFeat(nn.Module): 201 | 202 | def __init__(self, in_feats, out_feats, activation, dropout, 203 | aggregator_type, batch_norm, residual=False, 204 | bias=True, dgl_builtin=False): 205 | super().__init__() 206 | self.in_channels = in_feats 207 | self.out_channels = out_feats 208 | self.batch_norm = batch_norm 209 | self.residual = residual 210 | 211 | if in_feats != out_feats: 212 | self.residual = False 213 | 214 | self.dropout = nn.Dropout(p=dropout) 215 | 216 | self.activation = activation 217 | 218 | self.A = nn.Linear(in_feats, out_feats, bias=bias) 219 | self.B = nn.Linear(in_feats, out_feats, bias=bias) 220 | 221 | self.nodeapply = NodeApply(in_feats, out_feats, activation, dropout, bias=bias) 222 | 223 | if self.batch_norm: 224 | self.batchnorm_h = nn.BatchNorm1d(out_feats) 225 | 226 | def message_func(self, edges): 227 | Ah_j = edges.src['Ah'] 228 | e_ij = edges.src['Bh'] + edges.dst['Bh'] # e_ij = Bhi + Bhj 229 | edges.data['e'] = e_ij 230 | return {'Ah_j' : Ah_j, 'e_ij' : e_ij} 231 | 232 | def reduce_func(self, nodes): 233 | # Anisotropic MaxPool aggregation 234 | 235 | Ah_j = nodes.mailbox['Ah_j'] 236 | e = nodes.mailbox['e_ij'] 237 | sigma_ij = torch.sigmoid(e) # sigma_ij = sigmoid(e_ij) 238 | 239 | Ah_j = sigma_ij * Ah_j 240 | if self.activation: 241 | Ah_j = self.activation(Ah_j) 242 | 243 | c = torch.max(Ah_j, dim=1)[0] 244 | return {'c' : c} 245 | 246 | def forward(self, g, h): 247 | h_in = h # for residual connection 248 | h = self.dropout(h) 249 | 250 | g.ndata['h'] = h 251 | g.ndata['Ah'] = self.A(h) 252 | g.ndata['Bh'] = self.B(h) 253 | g.update_all(self.message_func, 254 | self.reduce_func, 255 | self.nodeapply) 256 | h = g.ndata['h'] 257 | 258 | if self.batch_norm: 259 | h = self.batchnorm_h(h) 260 | 261 | if self.residual: 262 | h = h_in + h # residual connection 263 | 264 | return h 265 | 266 | def __repr__(self): 267 | return '{}(in_channels={}, out_channels={}, residual={})'.format( 268 | self.__class__.__name__, 269 | self.in_channels, 270 | self.out_channels, 271 | self.residual) 272 | 273 | 274 | ############################################################## 275 | 276 | 277 | class GraphSageLayerEdgeReprFeat(nn.Module): 278 | 279 | def __init__(self, in_feats, out_feats, activation, dropout, 280 | aggregator_type, batch_norm, residual=False, 281 | bias=True, dgl_builtin=False): 282 | super().__init__() 283 | self.in_channels = in_feats 284 | self.out_channels = out_feats 285 | self.batch_norm = batch_norm 286 | self.residual = residual 287 | 288 | if in_feats != out_feats: 289 | self.residual = False 290 | 291 | self.dropout = nn.Dropout(p=dropout) 292 | 293 | self.activation = activation 294 | 295 | self.A = nn.Linear(in_feats, out_feats, bias=bias) 296 | self.B = nn.Linear(in_feats, out_feats, bias=bias) 297 | self.C = nn.Linear(in_feats, out_feats, bias=bias) 298 | 299 | self.nodeapply = NodeApply(in_feats, out_feats, activation, dropout, bias=bias) 300 | 301 | if self.batch_norm: 302 | self.batchnorm_h = nn.BatchNorm1d(out_feats) 303 | self.batchnorm_e = nn.BatchNorm1d(out_feats) 304 | 305 | def message_func(self, edges): 306 | Ah_j = edges.src['Ah'] 307 | e_ij = edges.data['Ce'] + edges.src['Bh'] + edges.dst['Bh'] # e_ij = Ce_ij + Bhi + Bhj 308 | edges.data['e'] = e_ij 309 | return {'Ah_j' : Ah_j, 'e_ij' : e_ij} 310 | 311 | def reduce_func(self, nodes): 312 | # Anisotropic MaxPool aggregation 313 | 314 | Ah_j = nodes.mailbox['Ah_j'] 315 | e = nodes.mailbox['e_ij'] 316 | sigma_ij = torch.sigmoid(e) # sigma_ij = sigmoid(e_ij) 317 | 318 | Ah_j = sigma_ij * Ah_j 319 | if self.activation: 320 | Ah_j = self.activation(Ah_j) 321 | 322 | c = torch.max(Ah_j, dim=1)[0] 323 | return {'c' : c} 324 | 325 | def forward(self, g, h, e): 326 | h_in = h # for residual connection 327 | e_in = e 328 | h = self.dropout(h) 329 | 330 | g.ndata['h'] = h 331 | g.ndata['Ah'] = self.A(h) 332 | g.ndata['Bh'] = self.B(h) 333 | g.edata['e'] = e 334 | g.edata['Ce'] = self.C(e) 335 | g.update_all(self.message_func, 336 | self.reduce_func, 337 | self.nodeapply) 338 | h = g.ndata['h'] 339 | e = g.edata['e'] 340 | 341 | if self.activation: 342 | e = self.activation(e) # non-linear activation 343 | 344 | if self.batch_norm: 345 | h = self.batchnorm_h(h) 346 | e = self.batchnorm_e(e) 347 | 348 | if self.residual: 349 | h = h_in + h # residual connection 350 | e = e_in + e # residual connection 351 | 352 | return h, e 353 | 354 | def __repr__(self): 355 | return '{}(in_channels={}, out_channels={}, residual={})'.format( 356 | self.__class__.__name__, 357 | self.in_channels, 358 | self.out_channels, 359 | self.residual) 360 | 361 | 362 | class DenseGraphSage(nn.Module): 363 | def __init__(self, infeat, outfeat, residual=False, use_bn=True, 364 | mean=False, add_self=False): 365 | super().__init__() 366 | self.add_self = add_self 367 | self.use_bn = use_bn 368 | self.mean = mean 369 | self.residual = residual 370 | 371 | if infeat != outfeat: 372 | self.residual = False 373 | 374 | self.W = nn.Linear(infeat, outfeat, bias=True) 375 | 376 | nn.init.xavier_uniform_( 377 | self.W.weight, 378 | gain=nn.init.calculate_gain('relu')) 379 | 380 | def forward(self, x, adj): 381 | h_in = x # for residual connection 382 | 383 | if self.use_bn and not hasattr(self, 'bn'): 384 | self.bn = nn.BatchNorm1d(adj.size(1)).to(adj.device) 385 | 386 | if self.add_self: 387 | adj = adj + torch.eye(adj.size(0)).to(adj.device) 388 | 389 | if self.mean: 390 | adj = adj / adj.sum(1, keepdim=True) 391 | 392 | h_k_N = torch.matmul(adj, x) 393 | h_k = self.W(h_k_N) 394 | h_k = F.normalize(h_k, dim=2, p=2) 395 | h_k = F.relu(h_k) 396 | 397 | if self.residual: 398 | h_k = h_in + h_k # residual connection 399 | 400 | if self.use_bn: 401 | h_k = self.bn(h_k) 402 | return h_k 403 | 404 | def __repr__(self): 405 | if self.use_bn: 406 | return 'BN' + super(DenseGraphSage, self).__repr__() 407 | else: 408 | return super(DenseGraphSage, self).__repr__() 409 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import socket 4 | import time 5 | import random 6 | import glob 7 | import argparse, json 8 | import dgl 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | from tqdm import tqdm 16 | from nets.load_net import gnn_model # import GNNs 17 | from data.data import LoadData # import dataset 18 | 19 | 20 | def gpu_setup(use_gpu, gpu_id): 21 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 22 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 23 | 24 | if torch.cuda.is_available() and use_gpu: 25 | print('cuda available with GPU:', torch.cuda.get_device_name(0)) 26 | device = torch.device("cuda") 27 | else: 28 | print('cuda not available') 29 | device = torch.device("cpu") 30 | return device 31 | 32 | 33 | def view_model_param(MODEL_NAME, net_params): 34 | model = gnn_model(MODEL_NAME, net_params) 35 | total_param = 0 36 | print("MODEL DETAILS:\n") 37 | for param in model.parameters(): 38 | total_param += np.prod(list(param.data.size())) 39 | print('MODEL/Total parameters:', MODEL_NAME, total_param) 40 | return total_param 41 | 42 | 43 | def train_val_pipeline(MODEL_NAME, DATASET_NAME, params, net_params, dirs): 44 | avg_test_acc = [] 45 | avg_train_acc = [] 46 | avg_convergence_epochs = [] 47 | 48 | t0 = time.time() 49 | per_epoch_time = [] 50 | 51 | dataset = LoadData(DATASET_NAME, threshold=params['threshold'], node_feat_transform=params['node_feat_transform']) 52 | 53 | trainset, valset, testset = dataset.train, dataset.val, dataset.test 54 | 55 | root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs 56 | device = net_params['device'] 57 | 58 | # Write the network and optimization hyper-parameters in folder config/ 59 | with open(write_config_file + '.txt', 'w') as f: 60 | f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""".format(DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param'])) 61 | 62 | # At any point you can hit Ctrl + C to break out of training early. 63 | try: 64 | for split_number in range(10): 65 | t0_split = time.time() 66 | log_dir = os.path.join(root_log_dir, "RUN_" + str(split_number)) 67 | writer = SummaryWriter(log_dir=log_dir) 68 | 69 | # setting seeds 70 | random.seed(params['seed']) 71 | np.random.seed(params['seed']) 72 | torch.manual_seed(params['seed']) 73 | if device.type == 'cuda': 74 | torch.cuda.manual_seed(params['seed']) 75 | 76 | print("RUN NUMBER: ", split_number) 77 | trainset, valset, testset = dataset.train[split_number], dataset.val[split_number], dataset.test[split_number] 78 | print("Training Graphs: ", len(trainset)) 79 | print("Validation Graphs: ", len(valset)) 80 | print("Test Graphs: ", len(testset)) 81 | print("Number of Classes: ", net_params['n_classes']) 82 | 83 | model = gnn_model(MODEL_NAME, net_params) 84 | model = model.to(device) 85 | if net_params['contrast'] and MODEL_NAME in ['ContrastPool']: 86 | model.cal_contrast(trainset, device) 87 | optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay']) 88 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 89 | factor=params['lr_reduce_factor'], 90 | patience=params['lr_schedule_patience'], 91 | verbose=True) 92 | 93 | epoch_train_losses, epoch_val_losses = [], [] 94 | epoch_train_accs, epoch_val_accs = [], [] 95 | 96 | # batching exception for Diffpool 97 | drop_last = True if MODEL_NAME in ['DiffPool', 'ContrastPool'] else False 98 | 99 | from train_TUs_graph_classification import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network 100 | 101 | train_loader = DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, drop_last=drop_last, collate_fn=dataset.collate) 102 | val_loader = DataLoader(valset, batch_size=params['batch_size'], shuffle=False, drop_last=drop_last, collate_fn=dataset.collate) 103 | test_loader = DataLoader(testset, batch_size=params['batch_size'], shuffle=False, drop_last=drop_last, collate_fn=dataset.collate) 104 | 105 | with tqdm(range(params['epochs'])) as t: 106 | for epoch in t: 107 | 108 | t.set_description('Epoch %d' % epoch) 109 | 110 | start = time.time() 111 | 112 | epoch_train_loss, epoch_train_acc, optimizer = train_epoch(model, optimizer, device, train_loader, epoch) 113 | 114 | epoch_val_loss, epoch_val_acc = evaluate_network(model, device, val_loader, epoch) 115 | _, epoch_test_acc = evaluate_network(model, device, test_loader, epoch) 116 | 117 | epoch_train_losses.append(epoch_train_loss) 118 | epoch_val_losses.append(epoch_val_loss) 119 | epoch_train_accs.append(epoch_train_acc) 120 | epoch_val_accs.append(epoch_val_acc) 121 | 122 | writer.add_scalar('train/_loss', epoch_train_loss, epoch) 123 | writer.add_scalar('val/_loss', epoch_val_loss, epoch) 124 | writer.add_scalar('train/_acc', epoch_train_acc, epoch) 125 | writer.add_scalar('val/_acc', epoch_val_acc, epoch) 126 | writer.add_scalar('test/_acc', epoch_test_acc, epoch) 127 | writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) 128 | 129 | _, epoch_test_acc = evaluate_network(model, device, test_loader, epoch) 130 | t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'], 131 | train_loss=epoch_train_loss, val_loss=epoch_val_loss, 132 | train_acc=epoch_train_acc, val_acc=epoch_val_acc, 133 | test_acc=epoch_test_acc) 134 | 135 | per_epoch_time.append(time.time()-start) 136 | 137 | # Saving checkpoint 138 | ckpt_dir = os.path.join(root_ckpt_dir, "RUN_" + str(split_number)) 139 | if not os.path.exists(ckpt_dir): 140 | os.makedirs(ckpt_dir) 141 | 142 | torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch))) 143 | if MODEL_NAME in ['ContrastPool']: 144 | torch.save(model.ad_adj, '{}.pkl'.format(log_dir + "/epoch_" + str(epoch))) 145 | 146 | adj_files = glob.glob(log_dir + '/*.pkl') 147 | for adj_file in adj_files: 148 | epoch_nb = adj_file.split('_')[-1] 149 | epoch_nb = int(epoch_nb.split('.')[0]) 150 | if epoch_nb < epoch - 1: 151 | os.remove(adj_file) 152 | 153 | files = glob.glob(ckpt_dir + '/*.pkl') 154 | for file in files: 155 | epoch_nb = file.split('_')[-1] 156 | epoch_nb = int(epoch_nb.split('.')[0]) 157 | if epoch_nb < epoch-1: 158 | os.remove(file) 159 | 160 | scheduler.step(epoch_val_loss) 161 | 162 | if optimizer.param_groups[0]['lr'] < params['min_lr']: 163 | print("\n!! LR EQUAL TO MIN LR SET.") 164 | break 165 | 166 | # Stop training after params['max_time'] hours 167 | if time.time()-t0_split > params['max_time']*3600/10: # Dividing max_time by 10, since there are 10 runs in TUs 168 | print('-' * 89) 169 | print("Max_time for one train-val-test split experiment elapsed {:.3f} hours, so stopping".format(params['max_time']/10)) 170 | break 171 | 172 | _, test_acc = evaluate_network(model, device, test_loader, epoch) 173 | _, train_acc = evaluate_network(model, device, train_loader, epoch) 174 | avg_test_acc.append(test_acc) 175 | avg_train_acc.append(train_acc) 176 | avg_convergence_epochs.append(epoch) 177 | 178 | print("Test Accuracy [LAST EPOCH]: {:.4f}".format(test_acc)) 179 | print("Train Accuracy [LAST EPOCH]: {:.4f}".format(train_acc)) 180 | print("Convergence Time (Epochs): {:.4f}".format(epoch)) 181 | 182 | except KeyboardInterrupt: 183 | print('-' * 89) 184 | print('Exiting from training early because of KeyboardInterrupt') 185 | 186 | 187 | print("TOTAL TIME TAKEN: {:.4f}hrs".format((time.time()-t0)/3600)) 188 | print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time))) 189 | print("AVG CONVERGENCE Time (Epochs): {:.4f}".format(np.mean(np.array(avg_convergence_epochs)))) 190 | # Final test accuracy value averaged over 10-fold 191 | print("""\n\n\nFINAL RESULTS\n\nTEST ACCURACY averaged: {:.4f} with s.d. {:.4f}""".format(np.mean(np.array(avg_test_acc))*100, np.std(avg_test_acc)*100)) 192 | print("\nAll splits Test Accuracies:\n", avg_test_acc) 193 | print("""\n\n\nFINAL RESULTS\n\nTRAIN ACCURACY averaged: {:.4f} with s.d. {:.4f}""".format(np.mean(np.array(avg_train_acc))*100, np.std(avg_train_acc)*100)) 194 | print("\nAll splits Train Accuracies:\n", avg_train_acc) 195 | 196 | writer.close() 197 | 198 | """ 199 | Write the results in out/results folder 200 | """ 201 | with open(write_file_name + '.txt', 'w') as f: 202 | f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n edge_num: {}\n\n 203 | FINAL RESULTS\nTEST ACCURACY averaged: {:.4f} with s.d. {:.4f}\nTRAIN ACCURACY averaged: {:.4f} with s.d. {:.4f}\n\n 204 | Average Convergence Time (Epochs): {:.4f} with s.d. {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\nAll Splits Test Accuracies: {}""" \ 205 | .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'], len(trainset[0][0].edata['feat']), 206 | np.mean(np.array(avg_test_acc))*100, np.std(avg_test_acc)*100, 207 | np.mean(np.array(avg_train_acc))*100, np.std(avg_train_acc)*100, 208 | np.mean(avg_convergence_epochs), np.std(avg_convergence_epochs), 209 | (time.time()-t0)/3600, np.mean(per_epoch_time), avg_test_acc)) 210 | 211 | 212 | def main(): 213 | """ 214 | USER CONTROLS 215 | """ 216 | parser = argparse.ArgumentParser() 217 | parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details") 218 | parser.add_argument('--gpu_id', help="Please give a value for gpu id") 219 | parser.add_argument('--model', help="Please give a value for model name") 220 | parser.add_argument('--dataset', help="Please give a value for dataset name") 221 | parser.add_argument('--out_dir', help="Please give a value for out_dir") 222 | parser.add_argument('--seed', help="Please give a value for seed") 223 | parser.add_argument('--epochs', help="Please give a value for epochs") 224 | parser.add_argument('--batch_size', help="Please give a value for batch_size") 225 | parser.add_argument('--init_lr', help="Please give a value for init_lr") 226 | parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor") 227 | parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience") 228 | parser.add_argument('--min_lr', help="Please give a value for min_lr") 229 | parser.add_argument('--weight_decay', help="Please give a value for weight_decay") 230 | parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval") 231 | parser.add_argument('--L', help="Please give a value for L") 232 | parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim") 233 | parser.add_argument('--out_dim', help="Please give a value for out_dim") 234 | parser.add_argument('--residual', help="Please give a value for residual") 235 | parser.add_argument('--edge_feat', help="Please give a value for edge_feat") 236 | parser.add_argument('--readout', help="Please give a value for readout") 237 | parser.add_argument('--kernel', help="Please give a value for kernel") 238 | parser.add_argument('--n_heads', help="Please give a value for n_heads") 239 | parser.add_argument('--gated', help="Please give a value for gated") 240 | parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout") 241 | parser.add_argument('--dropout', help="Please give a value for dropout") 242 | parser.add_argument('--layer_norm', help="Please give a value for layer_norm") 243 | parser.add_argument('--batch_norm', help="Please give a value for batch_norm") 244 | parser.add_argument('--sage_aggregator', help="Please give a value for sage_aggregator") 245 | parser.add_argument('--data_mode', help="Please give a value for data_mode") 246 | parser.add_argument('--num_pool', help="Please give a value for num_pool") 247 | parser.add_argument('--gnn_per_block', help="Please give a value for gnn_per_block") 248 | parser.add_argument('--embedding_dim', help="Please give a value for embedding_dim") 249 | parser.add_argument('--pool_ratio', help="Please give a value for pool_ratio") 250 | parser.add_argument('--linkpred', help="Please give a value for linkpred") 251 | parser.add_argument('--cat', help="Please give a value for cat") 252 | parser.add_argument('--self_loop', help="Please give a value for self_loop") 253 | parser.add_argument('--max_time', help="Please give a value for max_time") 254 | parser.add_argument('--threshold', type=float, help="Please give a threshold to drop edge", default=0.3) 255 | parser.add_argument('--edge_ratio', type=float, help="Please give a ratio to drop edge", default=0) 256 | parser.add_argument('--node_feat_transform', help="Please give a value for node feature transform", default='original') 257 | parser.add_argument('--contrast', default=False, action='store_true') 258 | parser.add_argument('--pooling', type=float, default=0.5) 259 | parser.add_argument('--lambda1', type=float, default=0.001) 260 | parser.add_argument('--learnable_q', default=False, action='store_true') 261 | args = parser.parse_args() 262 | with open(args.config) as f: 263 | config = json.load(f) 264 | 265 | # device 266 | if args.gpu_id is not None and config['gpu']['use']: 267 | config['gpu']['id'] = int(args.gpu_id) 268 | config['gpu']['use'] = True 269 | device = gpu_setup(config['gpu']['use'], config['gpu']['id']) 270 | else: 271 | config['gpu']['id'] = 0 272 | device = torch.device('cpu') 273 | # model, dataset, out_dir 274 | if args.model is not None: 275 | MODEL_NAME = args.model 276 | else: 277 | MODEL_NAME = config['model'] 278 | if args.dataset is not None: 279 | DATASET_NAME = args.dataset 280 | else: 281 | DATASET_NAME = config['dataset'] 282 | dataset = LoadData(DATASET_NAME, args.threshold, args.edge_ratio, args.node_feat_transform) 283 | if args.out_dir is not None: 284 | out_dir = args.out_dir 285 | else: 286 | out_dir = config['out_dir'] 287 | # parameters 288 | params = config['params'] 289 | if args.seed is not None: 290 | params['seed'] = int(args.seed) 291 | if args.epochs is not None: 292 | params['epochs'] = int(args.epochs) 293 | if args.batch_size is not None: 294 | params['batch_size'] = int(args.batch_size) 295 | if args.init_lr is not None: 296 | params['init_lr'] = float(args.init_lr) 297 | if args.lr_reduce_factor is not None: 298 | params['lr_reduce_factor'] = float(args.lr_reduce_factor) 299 | if args.lr_schedule_patience is not None: 300 | params['lr_schedule_patience'] = int(args.lr_schedule_patience) 301 | if args.min_lr is not None: 302 | params['min_lr'] = float(args.min_lr) 303 | if args.weight_decay is not None: 304 | params['weight_decay'] = float(args.weight_decay) 305 | if args.print_epoch_interval is not None: 306 | params['print_epoch_interval'] = int(args.print_epoch_interval) 307 | if args.max_time is not None: 308 | params['max_time'] = float(args.max_time) 309 | if args.threshold is not None: 310 | params['threshold'] = float(args.threshold) 311 | if args.edge_ratio is not None: 312 | params['edge_ratio'] = float(args.edge_ratio) 313 | if args.node_feat_transform is not None: 314 | params['node_feat_transform'] = args.node_feat_transform 315 | # network parameters 316 | net_params = config['net_params'] 317 | if 'node_num' in dir(dataset): 318 | net_params['node_num'] = int(dataset.node_num) 319 | net_params['device'] = device 320 | net_params['gpu_id'] = config['gpu']['id'] 321 | net_params['batch_size'] = params['batch_size'] 322 | if args.L is not None: 323 | net_params['L'] = int(args.L) 324 | if args.hidden_dim is not None: 325 | net_params['hidden_dim'] = int(args.hidden_dim) 326 | if args.out_dim is not None: 327 | net_params['out_dim'] = int(args.out_dim) 328 | if args.residual is not None: 329 | net_params['residual'] = True if args.residual=='True' else False 330 | if args.edge_feat is not None: 331 | net_params['edge_feat'] = True if args.edge_feat=='True' else False 332 | if args.readout is not None: 333 | net_params['readout'] = args.readout 334 | if args.kernel is not None: 335 | net_params['kernel'] = int(args.kernel) 336 | if args.n_heads is not None: 337 | net_params['n_heads'] = int(args.n_heads) 338 | if args.gated is not None: 339 | net_params['gated'] = True if args.gated=='True' else False 340 | if args.in_feat_dropout is not None: 341 | net_params['in_feat_dropout'] = float(args.in_feat_dropout) 342 | if args.dropout is not None: 343 | net_params['dropout'] = float(args.dropout) 344 | if args.layer_norm is not None: 345 | net_params['layer_norm'] = True if args.layer_norm=='True' else False 346 | if args.batch_norm is not None: 347 | net_params['batch_norm'] = True if args.batch_norm=='True' else False 348 | if args.sage_aggregator is not None: 349 | net_params['sage_aggregator'] = args.sage_aggregator 350 | if args.data_mode is not None: 351 | net_params['data_mode'] = args.data_mode 352 | if args.num_pool is not None: 353 | net_params['num_pool'] = int(args.num_pool) 354 | if args.gnn_per_block is not None: 355 | net_params['gnn_per_block'] = int(args.gnn_per_block) 356 | if args.embedding_dim is not None: 357 | net_params['embedding_dim'] = int(args.embedding_dim) 358 | if args.pool_ratio is not None: 359 | net_params['pool_ratio'] = float(args.pool_ratio) 360 | if args.linkpred is not None: 361 | net_params['linkpred'] = True if args.linkpred=='True' else False 362 | if args.cat is not None: 363 | net_params['cat'] = True if args.cat=='True' else False 364 | if args.self_loop is not None: 365 | net_params['self_loop'] = True if args.self_loop=='True' else False 366 | if args.contrast is not None: 367 | net_params['contrast'] = args.contrast 368 | if args.pooling is not None: 369 | net_params['pooling'] = float(args.pooling) 370 | if args.lambda1 is not None: 371 | net_params['lambda1'] = float(args.lambda1) 372 | if args.learnable_q is not None: 373 | net_params['learnable_q'] = args.learnable_q 374 | 375 | # TUs 376 | net_params['in_dim'] = dataset.all.graph_lists[0].ndata['feat'].shape[1] 377 | net_params['edge_dim'] = dataset.all.graph_lists[0].edata['feat'][0].shape[0] \ 378 | if 'feat' in dataset.all.graph_lists[0].edata else None 379 | num_classes = len(np.unique(dataset.all.graph_labels)) 380 | net_params['n_classes'] = num_classes 381 | 382 | if MODEL_NAME in ['DiffPool', 'ContrastPool']: 383 | net_params['max_num_node'] = dataset.node_num 384 | # calculate assignment dimension: pool_ratio * largest graph's maximum 385 | # number of nodes in the dataset 386 | num_nodes = [dataset.all[i][0].number_of_nodes() for i in range(len(dataset.all))] 387 | max_num_node = max(num_nodes) 388 | net_params['assign_dim'] = int(max_num_node * net_params['pool_ratio']) * net_params['batch_size'] 389 | 390 | root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') 391 | root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') 392 | write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') 393 | write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') 394 | dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file 395 | 396 | if not os.path.exists(out_dir + 'results'): 397 | os.makedirs(out_dir + 'results') 398 | 399 | if not os.path.exists(out_dir + 'configs'): 400 | os.makedirs(out_dir + 'configs') 401 | 402 | net_params['total_param'] = view_model_param(MODEL_NAME, net_params) 403 | train_val_pipeline(MODEL_NAME, DATASET_NAME, params, net_params, dirs) 404 | 405 | 406 | main() 407 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from sklearn.metrics import confusion_matrix 6 | from sklearn.metrics import f1_score 7 | import numpy as np 8 | 9 | 10 | def accuracy_TU(scores, targets): 11 | scores = scores.detach().argmax(dim=1) 12 | acc = (scores==targets).float().sum().item() 13 | return acc 14 | -------------------------------------------------------------------------------- /nets/contrastpool_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from layers.attention_layer import EncoderLayer 6 | import time 7 | import numpy as np 8 | from scipy.linalg import block_diag 9 | import dgl 10 | 11 | from layers.graphsage_layer import GraphSageLayer, DenseGraphSage 12 | from layers.contrastpool_layer import ContrastPoolLayer, DenseDiffPool 13 | 14 | 15 | class ContrastPoolNet(nn.Module): 16 | """ 17 | DiffPool Fuse with GNN layers and pooling layers in sequence 18 | """ 19 | 20 | def __init__(self, net_params, pool_ratio=0.5): 21 | 22 | super().__init__() 23 | input_dim = net_params['in_dim'] 24 | self.hidden_dim = net_params['hidden_dim'] 25 | embedding_dim = net_params['hidden_dim'] 26 | out_dim = net_params['hidden_dim'] 27 | self.n_classes = net_params['n_classes'] 28 | activation = F.relu 29 | n_layers = net_params['L'] 30 | dropout = net_params['dropout'] 31 | self.batch_norm = net_params['batch_norm'] 32 | self.residual = net_params['residual'] 33 | aggregator_type = net_params['sage_aggregator'] 34 | self.lambda1 = net_params['lambda1'] 35 | self.learnable_q = net_params['learnable_q'] 36 | 37 | self.device = net_params['device'] 38 | self.link_pred = True 39 | self.concat = False 40 | self.n_pooling = 1 41 | self.batch_size = net_params['batch_size'] 42 | if 'pool_ratio' in net_params.keys(): 43 | pool_ratio = net_params['pool_ratio'] 44 | self.e_feat = net_params['edge_feat'] 45 | self.link_pred_loss = [] 46 | self.entropy_loss = [] 47 | 48 | self.embedding_h = nn.Linear(input_dim, self.hidden_dim) 49 | 50 | # list of GNN modules before the first diffpool operation 51 | self.gc_before_pool = nn.ModuleList() 52 | 53 | self.assign_dim = int(net_params['max_num_node'] * pool_ratio) 54 | self.bn = True 55 | self.num_aggs = 1 56 | 57 | # constructing layers 58 | # layers before diffpool 59 | assert n_layers >= 2, "n_layers too few" 60 | self.gc_before_pool.append(GraphSageLayer(self.hidden_dim, self.hidden_dim, activation, 61 | dropout, aggregator_type, self.residual, self.bn, e_feat=self.e_feat)) 62 | 63 | for _ in range(n_layers - 2): 64 | self.gc_before_pool.append(GraphSageLayer(self.hidden_dim, self.hidden_dim, activation, 65 | dropout, aggregator_type, self.residual, self.bn, e_feat=self.e_feat)) 66 | 67 | self.gc_before_pool.append(GraphSageLayer(self.hidden_dim, embedding_dim, None, dropout, aggregator_type, self.residual, e_feat=self.e_feat)) 68 | 69 | 70 | assign_dims = [] 71 | assign_dims.append(self.assign_dim) 72 | if self.concat: 73 | # diffpool layer receive pool_emedding_dim node feature tensor 74 | # and return pool_embedding_dim node embedding 75 | pool_embedding_dim = self.hidden_dim * (n_layers - 1) + embedding_dim 76 | else: 77 | 78 | pool_embedding_dim = embedding_dim 79 | 80 | self.first_diffpool_layer = ContrastPoolLayer(pool_embedding_dim, self.assign_dim, self.hidden_dim, activation, 81 | dropout, aggregator_type, self.link_pred, self.batch_norm, 82 | max_node_num=net_params['max_num_node']) 83 | gc_after_per_pool = nn.ModuleList() 84 | 85 | # list of list of GNN modules, each list after one diffpool operation 86 | self.gc_after_pool = nn.ModuleList() 87 | 88 | for _ in range(n_layers - 1): 89 | gc_after_per_pool.append(DenseGraphSage(self.hidden_dim, self.hidden_dim, self.residual)) 90 | gc_after_per_pool.append(DenseGraphSage(self.hidden_dim, embedding_dim, self.residual)) 91 | self.gc_after_pool.append(gc_after_per_pool) 92 | 93 | self.assign_dim = int(self.assign_dim * pool_ratio) 94 | 95 | self.diffpool_layers = nn.ModuleList() 96 | # each pooling module 97 | for _ in range(self.n_pooling - 1): 98 | self.diffpool_layers.append(DenseDiffPool(pool_embedding_dim, self.assign_dim, self.hidden_dim, self.link_pred)) 99 | 100 | gc_after_per_pool = nn.ModuleList() 101 | 102 | for _ in range(n_layers - 1): 103 | gc_after_per_pool.append(DenseGraphSage(self.hidden_dim, self.hidden_dim, self.residual)) 104 | gc_after_per_pool.append(DenseGraphSage(self.hidden_dim, embedding_dim, self.residual)) 105 | self.gc_after_pool.append(gc_after_per_pool) 106 | 107 | assign_dims.append(self.assign_dim) 108 | self.assign_dim = int(self.assign_dim * pool_ratio) 109 | 110 | # predicting layer 111 | if self.concat: 112 | self.pred_input_dim = pool_embedding_dim * \ 113 | self.num_aggs * (self.n_pooling + 1) 114 | else: 115 | self.pred_input_dim = embedding_dim * self.num_aggs 116 | self.pred_layer = nn.Linear(self.pred_input_dim, self.n_classes) 117 | 118 | # weight initialization 119 | for m in self.modules(): 120 | if isinstance(m, nn.Linear): 121 | m.weight.data = init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu')) 122 | if m.bias is not None: 123 | m.bias.data = init.constant_(m.bias.data, 0.0) 124 | 125 | self.contrast_adj = None 126 | self.adj_dict = None 127 | self.nodes_dict = None 128 | self.nodes1 = None 129 | self.nodes2 = None 130 | self.encoder1 = None 131 | self.encoder2 = None 132 | self.encoder1_node = None 133 | self.encoder2_node = None 134 | self.num_A = None 135 | self.num_B = None 136 | self.node_num = None 137 | self.diff_h = None 138 | self.attn_loss = None 139 | self.ad_adj = None 140 | self.softmax = nn.Softmax(dim=-1) 141 | # self.sim = nn.CosineSimilarity(dim=-1, eps=1e-08) 142 | 143 | def cal_attn_loss(self, attn): 144 | entropy = (torch.distributions.Categorical(logits=attn).entropy()).mean() 145 | assert not torch.isnan(entropy) 146 | return entropy 147 | 148 | def cal_contrast(self, trainset, device, merge_classes=True): 149 | from contrast_subgraph import get_summary_tensor 150 | G_dataset = trainset[:][0] 151 | Labels = torch.tensor(trainset[:][1]) 152 | 153 | self.adj_dict, self.nodes_dict = get_summary_tensor(G_dataset, Labels, device, merge_classes=merge_classes) 154 | self.node_num = G_dataset[0].ndata['feat'].size(0) 155 | feat_dim = G_dataset[0].ndata['feat'].size(1) 156 | 157 | learnable_q = self.learnable_q 158 | n_head = 1 159 | self.encoder1 = EncoderLayer(self.node_num, n_head, self.node_num, 0.0, device, self.node_num, learnable_q, pos_enc='index').to(device) 160 | self.encoder2 = EncoderLayer(self.node_num, n_head, self.node_num, 0.0, device, self.node_num, learnable_q).to(device) 161 | self.encoder1_node = EncoderLayer(self.node_num, n_head, self.node_num, 0.0, device, feat_dim, learnable_q, pos_enc='index').to(device) 162 | self.encoder2_node = EncoderLayer(self.node_num, n_head, self.node_num, 0.0, device, feat_dim, learnable_q).to(device) 163 | 164 | def cal_contrast_adj(self, device): 165 | adj_list = [] 166 | nodes_list = [] 167 | for i in self.adj_dict.keys(): 168 | adj = self.encoder1(self.adj_dict[i]) 169 | adj = self.encoder2(adj.permute(1, 0, 2)) 170 | adj_list.append(adj.mean(1)) 171 | 172 | nodes_feat = self.encoder1_node(self.nodes_dict[i]) 173 | nodes_feat = self.encoder2_node(nodes_feat.permute(1, 0, 2)) 174 | nodes_list.append(nodes_feat.mean(1)) 175 | self.ad_adj = torch.stack(adj_list) 176 | adj_var = torch.std(torch.stack(adj_list).to(device), 0) 177 | nodes_var = torch.std(torch.stack(nodes_list).to(device), 0) 178 | 179 | self.contrast_adj = adj_var 180 | self.diff_h = nodes_var 181 | self.attn_loss = self.cal_attn_loss(self.contrast_adj) 182 | 183 | self.contrast_adj_trans = self.contrast_adj 184 | 185 | def gcn_forward(self, g, h, e, gc_layers, cat=False): 186 | """ 187 | Return gc_layer embedding cat. 188 | """ 189 | block_readout = [] 190 | for gc_layer in gc_layers[:-1]: 191 | h, e = gc_layer(g, h, e) 192 | block_readout.append(h) 193 | h, e = gc_layers[-1](g, h, e) 194 | block_readout.append(h) 195 | if cat: 196 | block = torch.cat(block_readout, dim=1) # N x F, F = F1 + F2 + ... 197 | else: 198 | block = h 199 | return block 200 | 201 | def gcn_forward_tensorized(self, h, adj, gc_layers, cat=False): 202 | block_readout = [] 203 | for gc_layer in gc_layers: 204 | h = gc_layer(h, adj) 205 | block_readout.append(h) 206 | if cat: 207 | block = torch.cat(block_readout, dim=2) # N x F, F = F1 + F2 + ... 208 | else: 209 | block = h 210 | return block 211 | 212 | def forward(self, g, h, e): 213 | self.link_pred_loss = [] 214 | self.entropy_loss = [] 215 | 216 | # node feature for assignment matrix computation is the same as the 217 | # original node feature 218 | h = self.embedding_h(h) 219 | 220 | out_all = [] 221 | 222 | # we use GCN blocks to get an embedding first 223 | g_embedding = self.gcn_forward(g, h, e, self.gc_before_pool, self.concat) 224 | 225 | g.ndata['h'] = g_embedding 226 | 227 | readout = dgl.sum_nodes(g, 'h') 228 | out_all.append(readout) 229 | if self.num_aggs == 2: 230 | readout = dgl.max_nodes(g, 'h') 231 | out_all.append(readout) 232 | 233 | self.cal_contrast_adj(device=h.device) 234 | adj, h = self.first_diffpool_layer(g, g_embedding, self.diff_h, self.contrast_adj_trans) 235 | node_per_pool_graph = int(adj.size()[0] / self.batch_size) 236 | 237 | h, adj = self.batch2tensor(adj, h, node_per_pool_graph) 238 | h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[0], self.concat) 239 | 240 | readout = torch.sum(h, dim=1) 241 | out_all.append(readout) 242 | if self.num_aggs == 2: 243 | readout, _ = torch.max(h, dim=1) 244 | out_all.append(readout) 245 | 246 | for i, diffpool_layer in enumerate(self.diffpool_layers): 247 | h, adj = diffpool_layer(h, adj) 248 | h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[i + 1], self.concat) 249 | 250 | readout = torch.sum(h, dim=1) 251 | out_all.append(readout) 252 | 253 | if self.num_aggs == 2: 254 | readout, _ = torch.max(h, dim=1) 255 | out_all.append(readout) 256 | 257 | if self.concat or self.num_aggs > 1: 258 | hg = torch.cat(out_all, dim=1) 259 | else: 260 | hg = readout 261 | 262 | ypred = self.pred_layer(hg) 263 | return ypred 264 | 265 | def batch2tensor(self, batch_adj, batch_feat, node_per_pool_graph): 266 | """ 267 | transform a batched graph to batched adjacency tensor and node feature tensor 268 | """ 269 | batch_size = int(batch_adj.size()[0] / node_per_pool_graph) 270 | adj_list = [] 271 | feat_list = [] 272 | 273 | for i in range(batch_size): 274 | start = i * node_per_pool_graph 275 | end = (i + 1) * node_per_pool_graph 276 | 277 | # 1/sqrt(V) normalization 278 | snorm_n = torch.FloatTensor(node_per_pool_graph, 1).fill_(1./float(node_per_pool_graph)).sqrt().to(self.device) 279 | 280 | adj_list.append(batch_adj[start:end, start:end]) 281 | feat_list.append((batch_feat[start:end, :])*snorm_n) 282 | adj_list = list(map(lambda x: torch.unsqueeze(x, 0), adj_list)) 283 | feat_list = list(map(lambda x: torch.unsqueeze(x, 0), feat_list)) 284 | adj = torch.cat(adj_list, dim=0) 285 | feat = torch.cat(feat_list, dim=0) 286 | 287 | return feat, adj 288 | 289 | def loss(self, pred, label): 290 | ''' 291 | loss function 292 | ''' 293 | #softmax + CE 294 | criterion = nn.CrossEntropyLoss() 295 | loss = criterion(pred, label) 296 | e1_loss = 0.0 297 | for diffpool_layer in self.diffpool_layers: 298 | for key, value in diffpool_layer.loss_log.items(): 299 | e1_loss += value 300 | loss += e1_loss + self.lambda1 * self.attn_loss 301 | return loss 302 | -------------------------------------------------------------------------------- /nets/load_net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility file to select GraphNN model as 3 | selected by the user 4 | """ 5 | 6 | from nets.contrastpool_net import ContrastPoolNet 7 | 8 | 9 | def ContrastPool(net_params): 10 | return ContrastPoolNet(net_params) 11 | 12 | 13 | def gnn_model(MODEL_NAME, net_params): 14 | models = { 15 | "ContrastPool": ContrastPool 16 | } 17 | model = models[MODEL_NAME](net_params) 18 | model.name = MODEL_NAME 19 | 20 | return model 21 | -------------------------------------------------------------------------------- /train_TUs_graph_classification.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for training one epoch 3 | and evaluating one epoch 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import math 8 | 9 | from metrics import accuracy_TU as accuracy 10 | 11 | """ 12 | For GCNs 13 | """ 14 | def train_epoch_sparse(model, optimizer, device, data_loader, epoch): 15 | model.train() 16 | epoch_loss = 0 17 | epoch_train_acc = 0 18 | nb_data = 0 19 | for iter, (batch_graphs, batch_labels) in enumerate(data_loader): 20 | batch_graphs = batch_graphs.to(device) 21 | batch_x = batch_graphs.ndata['feat'].to(device) # num x feat 22 | batch_e = batch_graphs.edata['feat'].to(device) 23 | batch_labels = batch_labels.to(device) 24 | optimizer.zero_grad() 25 | if model.name in ["PRGNN", "LINet"]: 26 | batch_scores, score1, score2 = model.forward(batch_graphs, batch_x, batch_e) 27 | loss = model.loss(batch_scores, batch_labels, score1, score2) 28 | else: 29 | batch_scores = model.forward(batch_graphs, batch_x, batch_e) 30 | loss = model.loss(batch_scores, batch_labels) 31 | loss.backward() 32 | optimizer.step() 33 | epoch_loss += loss.detach().item() 34 | epoch_train_acc += accuracy(batch_scores, batch_labels) 35 | nb_data += batch_labels.size(0) 36 | epoch_loss /= (iter + 1) 37 | epoch_train_acc /= nb_data 38 | 39 | return epoch_loss, epoch_train_acc, optimizer 40 | 41 | 42 | def evaluate_network_sparse(model, device, data_loader, epoch): 43 | model.eval() 44 | epoch_test_loss = 0 45 | epoch_test_acc = 0 46 | nb_data = 0 47 | with torch.no_grad(): 48 | for iter, (batch_graphs, batch_labels) in enumerate(data_loader): 49 | batch_graphs = batch_graphs.to(device) 50 | batch_x = batch_graphs.ndata['feat'].to(device) 51 | batch_e = batch_graphs.edata['feat'].to(device) 52 | batch_labels = batch_labels.to(device) 53 | if model.name in ["PRGNN", "LINet"]: 54 | batch_scores, score1, score2 = model.forward(batch_graphs, batch_x, batch_e) 55 | loss = model.loss(batch_scores, batch_labels, score1, score2) 56 | else: 57 | batch_scores = model.forward(batch_graphs, batch_x, batch_e) 58 | loss = model.loss(batch_scores, batch_labels) 59 | epoch_test_loss += loss.detach().item() 60 | epoch_test_acc += accuracy(batch_scores, batch_labels) 61 | nb_data += batch_labels.size(0) 62 | epoch_test_loss /= (iter + 1) 63 | epoch_test_acc /= nb_data 64 | 65 | return epoch_test_loss, epoch_test_acc 66 | 67 | def check_patience(all_losses, best_loss, best_epoch, curr_loss, curr_epoch, counter): 68 | if curr_loss < best_loss: 69 | counter = 0 70 | best_loss = curr_loss 71 | best_epoch = curr_epoch 72 | else: 73 | counter += 1 74 | return best_loss, best_epoch, counter 75 | --------------------------------------------------------------------------------