├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── baseline.sh
├── configs
├── abide_schaefer100
│ └── TUs_graph_classification_ContrastPool_abide_schaefer100_100k.json
├── adni_schaefer100
│ └── TUs_graph_classification_ContrastPool_adni_schaefer100_100k.json
├── neurocon_schaefer100
│ └── TUs_graph_classification_ContrastPool_neurocon_schaefer100_100k.json
├── ppmi_schaefer100
│ └── TUs_graph_classification_ContrastPool_ppmi_schaefer100_100k.json
└── taowu_schaefer100
│ └── TUs_graph_classification_ContrastPool_taowu_schaefer100_100k.json
├── contrast_subgraph.py
├── data
├── BrainNet.py
├── abide_schaefer100
│ ├── test.index
│ ├── train.index
│ └── val.index
├── adni_schaefer100
│ ├── test.index
│ ├── train.index
│ └── val.index
├── data.py
├── generate_data_from_mat.py
├── neurocon_schaefer100
│ ├── test.index
│ ├── train.index
│ └── val.index
├── ppmi_schaefer100
│ ├── test.index
│ ├── train.index
│ └── val.index
└── taowu_schaefer100
│ ├── test.index
│ ├── train.index
│ └── val.index
├── figs
└── framework.png
├── layers
├── attention_layer.py
├── contrastpool_layer.py
├── diffpool_layer.py
└── graphsage_layer.py
├── main.py
├── metrics.py
├── nets
├── contrastpool_net.py
└── load_net.py
└── train_TUs_graph_classification.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .nox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # IPython
77 | profile_default/
78 | ipython_config.py
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 | .dmypy.json
111 | dmypy.json
112 | # mypy
113 | .DS_Store
114 | .idea
115 | *.bak
116 | *.pkl
117 | save
118 | log
119 | log.test
120 | log.txt
121 | outputs
122 | out
123 | tmp
124 | tmp1.sh
125 | tmp2.sh
126 | tmp3.sh
127 | tmp4.sh
128 | result/
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 |
134 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ContrastPool
2 | This is the official PyTorch implementation of ContrastPool from the paper
3 | *"Contrastive Graph Pooling for Explainable Classification of Brain Networks"* published in IEEE Transactions on Medical Imaging (TMI) 2024.
4 |
5 | Link: [Arxiv](https://arxiv.org/abs/2307.11133).
6 |
7 |
8 |
9 |
10 | ## Data
11 | All Preprocessed data used in this paper are published in [this paper](https://proceedings.neurips.cc/paper_files/paper/2023/file/44e3a3115ca26e5127851acd0cedd0d9-Paper-Datasets_and_Benchmarks.pdf).
12 | Data splits and configurations are stored in `./data/` and `./configs/`. If you want to process your own data, please check the dataloader script `./data/BrainNet.py`.
13 |
14 | ## Usage
15 |
16 | Please check `baseline.sh` on how to run the project.
17 |
18 | ## Citation
19 |
20 | If you find this code useful, please consider citing our paper:
21 |
22 | ```
23 | @ARTICLE{10508252,
24 | author={Xu, Jiaxing and Bian, Qingtian and Li, Xinhang and Zhang, Aihu and Ke, Yiping and Qiao, Miao and Zhang, Wei and Sim, Wei Khang Jeremy and Gulyás, Balázs},
25 | journal={IEEE Transactions on Medical Imaging},
26 | title={Contrastive Graph Pooling for Explainable Classification of Brain Networks},
27 | year={2024},
28 | volume={},
29 | number={},
30 | pages={1-1},
31 | keywords={Functional magnetic resonance imaging;Feature extraction;Task analysis;Data mining;Alzheimer's disease;Message passing;Brain modeling;Brain Network;Deep Learning for Neuroimaging;fMRI Biomarker;Graph Classification;Graph Neural Network},
32 | doi={10.1109/TMI.2024.3392988}}
33 | ```
34 |
35 | ## Contact
36 |
37 | If you have any questions, please feel free to reach out at `jiaxing003@e.ntu.edu.sg`.
38 |
--------------------------------------------------------------------------------
/baseline.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | model="configs/abide_schaefer100/TUs_graph_classification_ContrastPool_abide_schaefer100_100k.json"
4 | echo ${model}
5 | python main.py --config $model --gpu_id 0 --node_feat_transform pearson --max_time 60 --init_lr 1e-2 --threshold 0.0 --batch_size 20 --dropout 0.0 --contrast --pool_ratio 0.5 --lambda1 1e-3 --L 2
6 |
--------------------------------------------------------------------------------
/configs/abide_schaefer100/TUs_graph_classification_ContrastPool_abide_schaefer100_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "ContrastPool",
8 | "dataset": "abide_schaefer100",
9 |
10 | "out_dir": "out/braindata_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-2,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 2,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/adni_schaefer100/TUs_graph_classification_ContrastPool_adni_schaefer100_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "ContrastPool",
8 | "dataset": "adni_schaefer100",
9 |
10 | "out_dir": "out/braindata_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-2,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 2,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/neurocon_schaefer100/TUs_graph_classification_ContrastPool_neurocon_schaefer100_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "ContrastPool",
8 | "dataset": "neurocon_schaefer100",
9 |
10 | "out_dir": "out/braindata_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 4,
16 | "init_lr": 1e-2,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 2,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/ppmi_schaefer100/TUs_graph_classification_ContrastPool_ppmi_schaefer100_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "ContrastPool",
8 | "dataset": "ppmi_schaefer100",
9 |
10 | "out_dir": "out/braindata_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 4,
16 | "init_lr": 1e-2,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 2,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/taowu_schaefer100/TUs_graph_classification_ContrastPool_taowu_schaefer100_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "ContrastPool",
8 | "dataset": "taowu_schaefer100",
9 |
10 | "out_dir": "out/braindata_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 4,
16 | "init_lr": 1e-2,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 2,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/contrast_subgraph.py:
--------------------------------------------------------------------------------
1 | import heapq
2 | import math
3 | import numpy as np
4 | import torch
5 | import dgl
6 | from dgl.data.utils import load_graphs
7 | from copy import deepcopy
8 | from tqdm import tqdm
9 |
10 |
11 | def get_summary_tensor(G_dataset, Labels, device, merge_classes=False):
12 | num_G = len(G_dataset)
13 | Labels = Labels.tolist()
14 | node_num = G_dataset[0].ndata['feat'].shape[0]
15 | adj_dict = {}
16 | nodes_dict = {}
17 | final_adj_dict = {}
18 | final_nodes_dict = {}
19 | for i in range(num_G):
20 | if Labels[i] not in adj_dict.keys():
21 | adj_dict[Labels[i]] = []
22 | nodes_dict[Labels[i]] = []
23 | adj_dict[Labels[i]].append(G_dataset[i].edata['feat'].squeeze().view(node_num, -1).tolist())
24 | nodes_dict[Labels[i]].append(G_dataset[i].ndata['feat'].tolist())
25 |
26 | for i in adj_dict.keys():
27 | final_adj_dict[i] = torch.tensor(adj_dict[i]).to(device)
28 | final_nodes_dict[i] = torch.tensor(nodes_dict[i]).to(device)
29 | return final_adj_dict, final_nodes_dict
30 |
--------------------------------------------------------------------------------
/data/BrainNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import time
4 | import os
5 | import numpy as np
6 | import csv
7 | import dgl
8 | from dgl.data.utils import load_graphs
9 | import networkx as nx
10 | from tqdm import tqdm
11 | import random
12 | random.seed(42)
13 | from sklearn.model_selection import StratifiedKFold, train_test_split
14 |
15 |
16 | class DGLFormDataset(torch.utils.data.Dataset):
17 | """
18 | DGLFormDataset wrapping graph list and label list as per pytorch Dataset.
19 | *lists (list): lists of 'graphs' and 'labels' with same len().
20 | """
21 | def __init__(self, *lists):
22 | assert all(len(lists[0]) == len(li) for li in lists)
23 | self.lists = lists
24 | self.graph_lists = lists[0]
25 | self.graph_labels = lists[1]
26 |
27 | def __getitem__(self, index):
28 | return tuple(li[index] for li in self.lists)
29 |
30 | def __len__(self):
31 | return len(self.lists[0])
32 |
33 |
34 | def self_loop(g):
35 | """
36 | Utility function only, to be used only when necessary as per user self_loop flag
37 | : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat']
38 |
39 |
40 | This function is called inside a function in TUsDataset class.
41 | """
42 | new_g = dgl.DGLGraph()
43 | new_g.add_nodes(g.number_of_nodes())
44 | new_g.ndata['feat'] = g.ndata['feat']
45 |
46 | src, dst = g.all_edges(order="eid")
47 | src = dgl.backend.zerocopy_to_numpy(src)
48 | dst = dgl.backend.zerocopy_to_numpy(dst)
49 | non_self_edges_idx = src != dst
50 | nodes = np.arange(g.number_of_nodes())
51 | new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx])
52 | new_g.add_edges(nodes, nodes)
53 |
54 | # This new edata is not used since this function gets called only for GCN, GAT
55 | # However, we need this for the generic requirement of ndata and edata
56 | new_g.edata['feat'] = torch.zeros(new_g.number_of_edges())
57 | return new_g
58 |
59 | name2path = {
60 | 'abide_AAL116': '/path/to/data/abide_AAL116.bin',
61 | 'abide_harvard48': '/path/to/data/abide_harvard48.bin',
62 | 'abide_kmeans100': '/path/to/data/abide_kmeans100.bin',
63 | 'abide_schaefer100': '/path/to/data/abide_schaefer100.bin',
64 | 'abide_ward100': '/path/to/data/abide_ward100.bin',
65 |
66 | 'adni_AAL116': '/path/to/data/adni_AAL116.bin',
67 | 'adni_harvard48': '/path/to/data/adni_harvard48.bin',
68 | 'adni_kmeans100': '/path/to/data/adni_kmeans100.bin',
69 | 'adni_schaefer100': '/path/to/data/adni_schaefer100.bin',
70 | 'adni_ward100': '/path/to/data/adni_ward100.bin',
71 |
72 | 'neurocon_AAL116': '/path/to/data/neurocon_AAL116.bin',
73 | 'neurocon_harvard48': '/path/to/data/neurocon_harvard48.bin',
74 | 'neurocon_kmeans100': '/path/to/data/neurocon_kmeans100.bin',
75 | 'neurocon_schaefer100': '/path/to/data/neurocon_schaefer100.bin',
76 | 'neurocon_ward100': '/path/to/data/neurocon_ward100.bin',
77 |
78 | 'ppmi_AAL116': '/path/to/data/ppmi_AAL116.bin',
79 | 'ppmi_harvard48': '/path/to/data/ppmi_harvard48.bin',
80 | 'ppmi_kmeans100': '/path/to/data/ppmi_kmeans100.bin',
81 | 'ppmi_schaefer100': '/path/to/data/ppmi_schaefer100.bin',
82 | 'ppmi_ward100': '/path/to/data/ppmi_ward100.bin',
83 |
84 | 'taowu_AAL116': '/path/to/data/taowu_AAL116.bin',
85 | 'taowu_harvard48': '/path/to/data/taowu_harvard48.bin',
86 | 'taowu_kmeans100': '/path/to/data/taowu_kmeans100.bin',
87 | 'taowu_schaefer100': '/path/to/data/taowu_schaefer100.bin',
88 | 'taowu_ward100': '/path/to/data/taowu_ward100.bin',
89 | }
90 |
91 |
92 | class BrainDataset(torch.utils.data.Dataset):
93 | def __init__(self, name, threshold=0.3, edge_ratio=0, node_feat_transform='original'):
94 | t0 = time.time()
95 | self.name = name
96 |
97 | G_dataset, Labels = load_graphs(name2path[self.name])
98 |
99 | self.node_num = G_dataset[0].ndata['N_features'].size(0)
100 |
101 | print("[!] Dataset: ", self.name)
102 |
103 | # transfer DGLHeteroGraph to DGLFormDataset
104 | data = []
105 | error_case = []
106 | for i in range(len(G_dataset)):
107 | if len(((G_dataset[i].ndata['N_features'] != 0).sum(dim=-1) == 0).nonzero()) > 0:
108 | error_case.append(i)
109 | print(error_case)
110 | G_dataset = [n for i, n in enumerate(G_dataset) if i not in error_case]
111 |
112 | for i in tqdm(range(len(G_dataset))):
113 | if edge_ratio:
114 | threshold_idx = int(len(G_dataset[i].edata['E_features']) * (1 - edge_ratio))
115 | threshold = sorted(G_dataset[i].edata['E_features'].tolist())[threshold_idx]
116 |
117 | G_dataset[i].remove_edges(torch.squeeze((torch.abs(G_dataset[i].edata['E_features']) < float(threshold)).nonzero()))
118 | G_dataset[i].edata['feat'] = G_dataset[i].edata['E_features'].unsqueeze(-1).clone()
119 |
120 | if name[:-7] == 'pearson' or node_feat_transform == 'original':
121 | G_dataset[i].ndata['feat'] = G_dataset[i].ndata['N_features'].clone()
122 | elif node_feat_transform == 'one_hot':
123 | G_dataset[i].ndata['feat'] = torch.eye(self.node_num).clone()
124 | elif node_feat_transform == 'pearson':
125 | G_dataset[i].ndata['feat'] = torch.from_numpy(np.corrcoef(G_dataset[i].ndata['N_features'].numpy())).clone()
126 | elif node_feat_transform == 'degree':
127 | G_dataset[i].ndata['feat'] = G_dataset[i].in_degrees().unsqueeze(dim=1).clone()
128 | elif node_feat_transform == 'adj_matrix':
129 | G_dataset[i].ndata['feat'] = G_dataset[i].adj().to_dense().clone()
130 | elif node_feat_transform == 'mean_std':
131 | G_dataset[i].ndata['feat'] = torch.stack(torch.std_mean(G_dataset[i].ndata['N_features'], dim=-1)).T.flip(dims=[1]).clone()
132 | else:
133 | raise NotImplementedError
134 |
135 | G_dataset[i].ndata.pop('N_features')
136 | G_dataset[i].edata.pop('E_features')
137 | data.append([G_dataset[i], Labels['glabel'].tolist()[i]])
138 |
139 | dataset = self.format_dataset(data)
140 | # this function splits data into train/val/test and returns the indices
141 | self.all_idx = self.get_all_split_idx(dataset)
142 |
143 | self.all = dataset
144 | self.train = [self.format_dataset([dataset[idx] for idx in self.all_idx['train'][split_num]]) for split_num in range(10)]
145 | self.val = [self.format_dataset([dataset[idx] for idx in self.all_idx['val'][split_num]]) for split_num in range(10)]
146 | self.test = [self.format_dataset([dataset[idx] for idx in self.all_idx['test'][split_num]]) for split_num in range(10)]
147 |
148 | print("Time taken: {:.4f}s".format(time.time()-t0))
149 |
150 | def get_all_split_idx(self, dataset):
151 | """
152 | - Split total number of graphs into 3 (train, val and test) in 80:10:10
153 | - Stratified split proportionate to original distribution of data with respect to classes
154 | - Using sklearn to perform the split and then save the indexes
155 | - Preparing 10 such combinations of indexes split to be used in Graph NNs
156 | - As with KFold, each of the 10 fold have unique test set.
157 | """
158 | root_idx_dir = './data/{}/'.format(self.name)
159 | if not os.path.exists(root_idx_dir):
160 | os.makedirs(root_idx_dir)
161 | all_idx = {}
162 |
163 | # If there are no idx files, do the split and store the files
164 | if not (os.path.exists(root_idx_dir + 'train.index')):
165 | print("[!] Splitting the data into train/val/test ...")
166 |
167 | # Using 10-fold cross val to compare with benchmark papers
168 | k_splits = 10
169 |
170 | cross_val_fold = StratifiedKFold(n_splits=k_splits, shuffle=True)
171 | k_data_splits = []
172 |
173 | # this is a temporary index assignment, to be used below for val splitting
174 | for i in range(len(dataset.graph_lists)):
175 | dataset[i][0].a = lambda: None
176 | setattr(dataset[i][0].a, 'index', i)
177 |
178 | for indexes in cross_val_fold.split(dataset.graph_lists, dataset.graph_labels):
179 | remain_index, test_index = indexes[0], indexes[1]
180 |
181 | remain_set = self.format_dataset([dataset[index] for index in remain_index])
182 |
183 | # Gets final 'train' and 'val'
184 | train, val, _, __ = train_test_split(remain_set,
185 | range(len(remain_set.graph_lists)),
186 | test_size=0.111,
187 | stratify=remain_set.graph_labels)
188 |
189 | train, val = self.format_dataset(train), self.format_dataset(val)
190 | test = self.format_dataset([dataset[index] for index in test_index])
191 |
192 | # Extracting only idx
193 | idx_train = [item[0].a.index for item in train]
194 | idx_val = [item[0].a.index for item in val]
195 | idx_test = [item[0].a.index for item in test]
196 |
197 | f_train_w = csv.writer(open(root_idx_dir + 'train.index', 'a+'))
198 | f_val_w = csv.writer(open(root_idx_dir + 'val.index', 'a+'))
199 | f_test_w = csv.writer(open(root_idx_dir + 'test.index', 'a+'))
200 |
201 | f_train_w.writerow(idx_train)
202 | f_val_w.writerow(idx_val)
203 | f_test_w.writerow(idx_test)
204 |
205 | print("[!] Splitting done!")
206 |
207 | # reading idx from the files
208 | for section in ['train', 'val', 'test']:
209 | with open(root_idx_dir + section + '.index', 'r') as f:
210 | reader = csv.reader(f)
211 | all_idx[section] = [list(map(int, idx)) for idx in reader]
212 | return all_idx
213 |
214 | def format_dataset(self, dataset):
215 | """
216 | Utility function to recover data,
217 | INTO-> dgl/pytorch compatible format
218 | """
219 | graphs = [data[0] for data in dataset]
220 | labels = [data[1] for data in dataset]
221 |
222 | for graph in graphs:
223 | graph.ndata['feat'] = graph.ndata['feat'].float() # dgl 4.0
224 | # adding edge features for Residual Gated ConvNet, if not there
225 | if 'feat' not in graph.edata.keys():
226 | edge_feat_dim = graph.ndata['feat'].shape[1] # dim same as node feature dim
227 | graph.edata['feat'] = torch.ones(graph.number_of_edges(), edge_feat_dim)
228 |
229 | return DGLFormDataset(graphs, labels)
230 |
231 | # form a mini batch from a given list of samples = [(graph, label) pairs]
232 | def collate(self, samples):
233 | # The input samples is a list of pairs (graph, label).
234 | graphs, labels = map(list, zip(*samples))
235 | labels = torch.tensor(np.array(labels))
236 | batched_graph = dgl.batch(graphs)
237 |
238 | return batched_graph, labels
239 |
240 | # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
241 | def collate_dense_gnn(self, samples):
242 | # The input samples is a list of pairs (graph, label).
243 | graphs, labels = map(list, zip(*samples))
244 | labels = torch.tensor(np.array(labels))
245 |
246 | g = graphs[0]
247 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense())
248 | """
249 | Adapted from https://github.com/leichen2018/Ring-GNN/
250 | Assigning node and edge feats::
251 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}.
252 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix.
253 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i.
254 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j).
255 | """
256 |
257 | zero_adj = torch.zeros_like(adj)
258 |
259 | in_dim = g.ndata['feat'].shape[1]
260 |
261 | # use node feats to prepare adj
262 | adj_node_feat = torch.stack([zero_adj for j in range(in_dim)])
263 | adj_node_feat = torch.cat([adj.unsqueeze(0), adj_node_feat], dim=0)
264 |
265 | for node, node_feat in enumerate(g.ndata['feat']):
266 | adj_node_feat[1:, node, node] = node_feat
267 |
268 | x_node_feat = adj_node_feat.unsqueeze(0)
269 |
270 | return x_node_feat, labels
271 |
272 | def _sym_normalize_adj(self, adj):
273 | deg = torch.sum(adj, dim=0)
274 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size()))
275 | deg_inv = torch.diag(deg_inv)
276 | return torch.mm(deg_inv, torch.mm(adj, deg_inv))
277 |
278 | def _add_self_loops(self):
279 |
280 | # function for adding self loops
281 | # this function will be called only if self_loop flag is True
282 | for split_num in range(10):
283 | self.train[split_num].graph_lists = [self_loop(g) for g in self.train[split_num].graph_lists]
284 | self.val[split_num].graph_lists = [self_loop(g) for g in self.val[split_num].graph_lists]
285 | self.test[split_num].graph_lists = [self_loop(g) for g in self.test[split_num].graph_lists]
286 |
287 | for split_num in range(10):
288 | self.train[split_num] = DGLFormDataset(self.train[split_num].graph_lists, self.train[split_num].graph_labels)
289 | self.val[split_num] = DGLFormDataset(self.val[split_num].graph_lists, self.val[split_num].graph_labels)
290 | self.test[split_num] = DGLFormDataset(self.test[split_num].graph_lists, self.test[split_num].graph_labels)
291 |
--------------------------------------------------------------------------------
/data/abide_schaefer100/test.index:
--------------------------------------------------------------------------------
1 | 2,14,18,45,61,99,119,138,144,153,161,165,177,179,190,192,196,201,203,209,258,265,272,305,308,309,313,326,348,364,373,374,387,388,395,399,403,438,439,444,458,477,482,491,492,498,502,503,504,508,509,512,519,526,547,550,571,579,617,620,625,643,650,671,675,686,688,693,694,717,724,726,747,757,759,760,778,791,800,809,816,820,837,853,870,872,878,882,889,891,899,900,902,908,941,960,967,982,983
2 | 0,1,3,7,9,19,26,39,40,52,70,75,83,87,93,97,129,141,162,173,174,181,202,207,213,247,261,264,269,270,279,325,333,338,350,358,362,366,370,394,396,401,410,418,434,440,442,456,461,468,473,485,500,529,543,548,551,554,555,575,592,595,597,626,631,662,664,668,670,673,714,720,729,734,739,754,761,765,774,799,819,824,844,848,871,876,883,886,890,895,897,918,931,934,943,952,953,965,988
3 | 15,22,25,37,38,59,74,82,107,110,126,131,134,163,166,175,188,204,210,223,232,234,251,276,291,299,327,330,339,345,347,351,357,360,379,386,409,411,413,422,425,436,449,454,455,457,462,464,486,496,505,513,517,532,535,540,556,557,558,567,593,600,604,624,636,637,644,666,687,692,699,701,703,756,764,768,797,798,831,834,838,841,857,862,875,887,911,925,937,944,946,947,948,961,969,971,978,985,987
4 | 32,33,34,50,53,55,56,69,72,88,90,92,111,122,135,142,148,152,158,168,193,197,214,218,224,226,229,235,249,252,260,267,278,281,287,302,303,343,359,369,371,372,389,402,405,435,466,478,493,506,507,525,527,533,536,538,545,549,562,563,572,587,612,629,635,638,642,648,651,681,690,696,697,708,723,727,749,753,762,770,794,803,808,822,827,843,849,856,858,863,865,866,868,888,905,909,910,958,977
5 | 13,16,20,23,24,42,58,62,81,84,85,94,95,98,101,123,124,151,164,169,170,176,184,194,206,230,231,236,241,248,253,257,271,284,292,301,307,319,320,323,344,352,354,381,392,429,447,467,472,487,521,524,530,542,552,553,573,577,601,607,618,630,639,645,654,678,682,683,698,700,705,707,738,741,744,746,779,785,788,789,793,802,805,807,818,828,829,873,885,893,894,898,903,912,935,945,955,956,962
6 | 4,11,63,77,79,96,108,112,113,120,121,128,133,136,139,171,185,189,208,212,215,216,217,227,233,242,245,246,259,268,280,304,310,321,322,328,346,353,367,377,378,385,390,419,431,459,463,474,481,499,514,515,516,534,541,559,560,581,602,608,616,627,647,655,674,677,706,710,711,712,716,719,725,735,743,763,777,780,795,814,825,833,846,850,852,860,861,869,901,914,916,923,927,933,940,949,954,975,979
7 | 8,12,30,35,41,60,71,73,80,102,115,125,127,140,147,154,155,167,178,183,222,238,240,250,255,263,266,274,283,289,311,312,314,324,355,376,382,398,417,420,423,426,428,441,443,460,469,479,488,497,501,510,518,537,544,566,574,584,585,588,598,628,634,658,660,665,672,684,704,713,718,721,728,740,745,758,766,769,775,783,796,801,804,817,830,867,880,904,913,920,926,929,930,942,950,964,970,981,984
8 | 6,17,29,36,43,47,64,65,66,78,86,100,104,109,117,146,149,150,156,172,180,186,187,191,198,219,228,244,254,275,296,297,300,317,331,334,363,368,375,384,391,400,408,412,427,445,446,471,476,489,490,522,523,561,569,589,591,596,603,613,614,619,633,640,641,649,676,685,689,695,715,731,736,771,787,810,811,812,821,826,832,835,836,854,859,864,877,881,884,892,924,932,938,951,957,963,973,980,986
9 | 5,10,21,28,46,48,49,67,68,89,103,105,106,118,145,159,199,200,205,211,221,237,239,243,256,262,273,306,329,332,336,337,340,342,349,356,361,365,380,393,404,406,424,430,437,448,451,453,483,484,494,495,511,539,546,568,578,580,582,583,586,594,605,606,609,611,615,621,646,652,653,656,657,661,663,669,733,742,750,752,755,776,784,786,813,823,842,879,896,906,907,919,922,928,939,966,972,974,976
10 | 27,31,44,51,54,57,76,91,114,116,130,132,137,143,157,160,182,195,220,225,277,282,285,286,288,290,293,294,295,298,315,316,318,335,341,383,397,407,414,415,416,421,432,433,450,452,465,470,475,480,520,528,531,564,565,570,576,590,599,610,622,623,632,659,667,679,680,691,702,709,722,730,732,737,748,751,767,772,773,781,782,790,792,806,815,839,840,845,847,851,855,874,915,917,921,936,959,968
11 |
--------------------------------------------------------------------------------
/data/abide_schaefer100/train.index:
--------------------------------------------------------------------------------
1 | 148,561,678,919,731,517,752,486,523,495,116,985,219,574,879,548,52,708,88,522,841,306,836,465,777,567,430,790,783,743,281,866,185,813,467,277,235,507,173,663,737,611,565,679,154,80,647,380,255,335,788,11,584,947,917,125,664,47,969,323,706,473,600,802,425,469,171,283,418,775,232,110,168,354,661,933,525,166,819,432,954,338,103,582,269,909,839,407,782,368,151,181,57,856,132,366,566,106,552,768,978,43,136,665,713,963,248,776,599,98,280,55,211,806,127,46,261,904,41,601,799,898,670,692,369,150,964,770,867,725,372,74,21,424,356,363,972,468,538,630,224,384,117,573,158,798,5,834,377,871,957,645,466,383,873,94,346,918,876,923,296,69,894,10,711,564,524,628,394,827,421,307,883,682,475,785,370,20,580,90,691,101,734,93,988,155,615,687,471,563,131,358,604,34,359,897,906,583,422,944,545,593,291,780,888,199,118,497,241,869,435,543,848,242,193,732,903,28,884,191,385,936,750,632,735,102,353,160,463,513,318,411,433,189,389,887,922,929,16,516,496,797,920,122,221,470,147,641,44,596,164,123,961,629,959,861,378,794,570,393,927,449,240,357,631,431,771,113,608,478,748,489,500,818,840,324,0,984,642,454,981,169,204,253,91,598,64,334,312,850,767,223,915,766,609,744,718,17,415,48,701,428,260,278,965,499,410,656,56,303,427,738,180,352,342,817,696,980,284,720,597,578,658,557,655,300,924,42,58,107,874,808,448,459,627,446,855,673,928,533,245,975,697,739,271,4,480,129,451,205,76,187,603,932,623,754,912,644,940,27,392,7,71,695,594,262,246,453,684,152,569,914,842,862,401,481,97,558,81,950,749,141,703,263,488,634,96,753,336,361,413,329,54,830,506,274,685,330,847,236,108,973,660,607,243,188,367,587,208,270,474,715,73,846,676,527,938,450,654,536,610,773,976,476,514,779,447,222,111,213,31,589,729,885,6,544,121,217,824,945,426,805,109,618,341,457,487,325,301,811,49,619,321,437,390,602,592,669,460,266,895,115,68,930,381,556,539,396,350,231,37,194,362,528,958,87,254,79,646,810,843,733,721,741,892,112,762,907,935,949,814,971,934,649,137,78,716,977,761,1,130,807,614,709,518,613,845,651,595,322,417,142,247,745,758,916,546,145,279,755,751,638,835,62,591,534,124,250,815,327,890,256,812,275,968,531,293,252,455,126,328,286,787,677,436,803,423,233,143,186,854,314,635,535,838,315,237,680,398,637,8,826,239,572,210,319,304,966,986,85,490,828,666,157,769,553,501,25,576,311,483,822,756,198,510,970,195,22,714,875,559,30,925,881,51,910,376,792,653,825,505,59,931,953,351,227,149,332,681,452,621,339,652,333,32,657,445,285,35,702,162,292,355,360,33,690,343,409,9,939,175,288,857,434,633,456,962,639,40,220,712,371,860,829,13,218,479,225,226,337,104,77,946,414,606,297,585,12,316,400,197,167,302,206,781,943,849,905,100,229,877,844,267,575,214,146,540,730,859,184,408,746,704,26,19,230,537,23,340,133,375,801,156,736,577,821,134,60,345,956,70,83,344,521,659,289,234,114,441,832,549,39,310,636,379,494,793,207,120,172,865,622,462,299,683,140,420,386,590,484,605,796,443,626,863,926,530,700,50,763,251,419,529,786,182,238,554,772,75,382,485,667,264,464,128,560,722,616,727,937,674,951,216,662,294,765,955,723,244,586,290,317,163,913,868,287,36,15,24,268,555,202,795,53,257,273,699,331,89,581,472
2 | 349,813,252,571,667,320,833,945,690,352,72,437,180,178,946,105,121,635,703,843,967,21,404,568,391,210,306,63,16,127,368,863,469,89,425,920,617,205,481,846,367,130,806,949,910,419,962,402,815,514,955,248,629,342,332,818,684,177,619,448,919,816,738,208,634,648,671,907,926,790,961,603,191,399,692,836,505,879,81,569,636,665,917,389,104,561,143,15,382,457,238,282,621,300,748,64,984,175,520,851,432,216,125,301,398,284,862,537,726,947,450,678,479,795,695,944,355,612,316,195,880,164,705,522,882,217,560,118,706,860,466,778,299,713,812,214,709,691,618,861,53,439,504,441,8,92,455,318,834,101,803,885,170,585,702,853,533,222,280,924,723,974,94,811,681,272,490,826,977,740,91,239,135,341,808,219,951,159,936,55,535,109,620,606,730,821,724,802,652,232,268,591,132,235,196,447,319,875,13,655,959,111,400,313,487,420,403,831,142,61,117,392,751,688,513,200,666,25,659,777,625,901,250,153,34,643,66,889,898,825,158,154,540,563,867,76,807,852,206,884,698,499,426,644,172,797,287,262,903,930,502,411,155,954,221,199,646,576,18,900,929,149,868,878,365,54,847,719,921,710,489,580,329,814,506,519,712,841,637,464,42,559,562,859,46,809,611,906,209,88,582,2,228,346,384,335,877,891,374,462,865,856,517,395,460,581,163,421,134,638,685,90,744,157,226,870,150,733,624,187,412,787,59,151,86,480,493,339,428,423,796,970,184,6,298,689,887,725,492,436,574,244,197,849,607,321,328,50,731,275,647,167,230,444,627,881,454,717,801,715,459,57,771,324,742,417,672,388,223,827,622,536,842,131,858,755,385,657,140,596,940,817,711,240,344,303,769,5,315,747,41,586,564,477,37,721,950,113,330,291,278,987,925,518,701,530,766,383,752,397,463,286,283,699,27,229,632,794,347,854,4,694,531,99,233,326,708,735,979,538,905,203,49,532,458,527,211,767,780,285,17,452,146,523,722,915,916,311,271,371,47,359,386,491,56,565,616,964,982,497,973,700,855,414,152,192,260,798,682,122,194,570,972,429,255,598,732,115,107,495,828,745,850,176,669,183,524,114,511,539,68,609,869,406,899,743,35,663,67,120,613,942,71,508,658,29,296,594,408,376,549,488,310,11,969,963,573,476,922,465,380,276,515,290,471,263,494,331,242,590,165,902,243,431,36,677,486,78,793,23,941,100,983,645,544,593,79,650,759,893,33,288,601,294,633,541,686,351,258,58,547,728,556,241,966,185,251,20,904,948,785,381,182,503,112,584,976,416,119,823,138,656,103,449,985,198,583,307,909,507,789,866,337,323,253,630,602,108,932,913,445,96,550,190,845,704,679,874,478,265,293,422,372,387,139,676,30,661,69,409,218,136,687,312,98,212,975,707,38,106,336,363,73,133,896,978,914,696,393,642,123,379,48,567,348,453,927,525,137,528,546,343,498,470,512,501,756,939,369,935,43,857,839,614,776,760,156,509,430,354,289,832,639,259,545,24,753,267,608,126,727,587,908,610,757,84,435,295,169,144,986,304,224,189,378,186,357,604,31,600,628,390,683,737,266,415,168,364,693,605,124,937,215,356,912,424,589,467,837,245,246,148,553,758,201,305,933,14,179,322,483,957,557,768,788,829,188,516,472,193,750,74,281,314,309,256,792,231,784,145,835,443,749,928,302,800,340,373,65,718,838,651,649,110,542,526,911,830,292,804,171,674,892,736,257,274,521,971,220,375,615,773,12,660,762,623,781,62,772,872,273
3 | 890,333,62,208,200,757,438,466,70,633,832,565,498,669,778,329,440,352,684,336,202,269,90,606,340,19,32,226,663,640,102,39,487,609,958,168,410,796,897,668,921,193,896,96,586,460,949,278,951,288,401,737,54,105,632,415,628,406,452,344,384,929,260,843,13,405,970,790,287,389,268,142,566,613,572,229,154,334,433,86,388,795,711,237,816,75,423,303,777,451,103,760,616,839,94,876,859,959,869,738,292,502,429,955,355,247,806,7,595,597,713,846,403,736,473,424,671,530,132,582,657,235,121,320,612,780,71,83,625,159,646,390,215,608,144,0,702,611,228,520,444,479,882,127,650,468,975,962,972,407,361,225,771,963,252,672,943,734,552,605,952,860,578,673,697,518,106,783,500,680,976,52,913,621,190,945,108,164,280,203,162,263,820,374,854,981,3,509,246,98,983,709,456,534,57,56,915,844,481,272,879,89,805,259,767,735,218,546,220,550,174,367,364,69,784,751,620,559,315,404,892,63,920,21,847,282,139,974,655,926,523,156,706,681,432,115,957,42,746,428,439,719,426,309,867,607,645,335,61,358,68,710,830,236,548,442,782,812,849,18,814,914,799,551,589,630,677,661,480,76,419,222,808,91,393,183,179,964,917,477,956,341,254,138,871,968,79,253,365,772,143,35,617,53,328,9,728,266,450,575,789,11,903,686,323,714,716,895,588,465,391,265,301,763,469,24,447,418,837,840,599,745,833,441,708,239,290,399,47,654,427,197,881,72,97,506,562,64,245,196,982,856,44,529,904,192,744,49,51,695,827,725,811,802,819,109,398,704,383,382,95,116,729,779,128,537,141,250,788,312,524,17,741,171,922,29,906,255,140,885,786,294,36,165,205,580,392,209,123,776,417,682,318,31,306,727,749,178,815,331,172,122,93,187,55,544,402,753,864,773,596,715,446,471,561,346,277,508,898,298,573,338,659,117,626,167,935,43,161,227,870,381,238,810,472,453,85,555,12,289,793,868,216,73,185,155,261,803,927,322,533,836,594,136,781,801,553,189,92,670,262,368,516,314,528,285,111,397,199,376,219,912,372,475,674,591,286,950,622,894,207,414,319,910,979,761,512,928,660,649,387,698,765,526,377,84,828,243,40,787,936,939,295,256,16,583,521,118,694,732,965,587,359,536,267,569,792,642,794,362,980,923,568,543,707,241,373,217,221,28,302,754,861,324,176,408,279,180,541,863,967,461,33,158,845,667,213,385,954,177,81,305,87,809,4,638,960,135,494,104,119,602,310,938,905,113,866,585,514,889,909,157,817,742,942,717,878,614,705,369,41,264,700,888,966,88,696,484,579,822,153,986,242,50,688,224,731,891,137,283,821,448,850,257,769,300,858,718,80,396,214,629,721,125,348,902,800,186,326,120,690,835,211,683,775,45,570,497,395,321,739,149,610,510,485,653,488,930,720,10,656,67,313,206,181,170,825,77,733,883,114,598,307,627,23,430,194,275,652,304,349,770,308,750,752,872,940,14,474,515,478,590,366,270,437,865,371,619,560,662,66,420,26,99,20,634,675,907,785,504,350,148,233,495,908,354,807,522,932,603,400,743,293,483,34,934,941,547,435,919,212,691,899,65,501,539,665,412,356,648,394,643,824,431,564,924,342,5,880,740,641,527,375,953,146,748,525,658,855,195,554,492,353,759,273,618,984,664,631,574,685,678,689,248,182,240,325,482,848,874,755,151,563,101,184,284,2,886,693,852,48,712,813,380,490,133,30,145,470,317,842,78,8,191,818,747,129,647,152,493,901
4 | 163,349,23,268,543,984,382,345,76,123,306,180,667,776,136,198,981,571,94,817,270,106,146,89,657,243,513,969,246,79,645,528,974,37,391,509,970,134,603,810,874,337,834,664,363,904,534,236,620,6,85,40,882,552,441,386,422,292,713,101,582,949,932,622,935,703,965,715,748,87,980,376,972,754,215,544,443,558,86,105,242,15,437,380,275,599,623,560,307,438,440,297,258,793,44,850,179,255,631,963,82,487,568,124,333,166,902,495,327,125,464,695,823,979,70,462,481,108,561,916,374,928,238,780,976,454,504,886,898,643,567,8,365,355,673,540,269,434,173,265,415,254,245,31,295,426,65,130,570,97,987,777,322,779,700,200,456,705,912,983,760,263,480,52,465,248,887,48,574,686,352,840,542,833,859,447,739,564,951,583,311,277,626,358,946,332,351,424,666,826,746,191,3,291,160,47,181,379,757,283,811,171,765,548,400,425,433,787,162,720,17,971,677,790,964,233,914,290,324,873,615,147,182,659,51,541,325,202,217,199,954,711,110,178,854,107,216,656,601,155,789,812,2,139,330,296,361,41,220,4,366,861,829,896,884,701,231,743,732,126,771,520,728,144,679,576,460,871,948,143,398,503,684,539,189,573,978,385,192,955,807,206,289,228,940,877,831,16,945,891,120,149,738,194,669,54,170,225,682,516,515,988,806,368,502,973,300,318,482,702,335,150,929,906,676,165,499,961,419,707,486,917,903,279,706,691,617,472,604,895,394,430,452,788,628,221,535,530,982,864,710,378,319,609,234,399,390,565,986,600,346,196,815,455,288,241,396,880,339,658,934,98,730,484,137,878,799,805,409,329,494,35,232,924,261,207,660,479,773,474,406,68,747,508,758,427,608,14,514,960,842,632,566,588,644,59,423,735,606,862,42,475,766,293,802,448,761,726,417,844,907,985,239,519,517,718,320,913,752,870,353,496,649,586,488,809,650,383,957,687,294,759,672,28,744,869,27,921,38,240,717,219,523,847,364,613,75,510,313,796,751,45,553,737,678,356,340,721,625,276,694,908,395,84,885,671,360,943,633,156,733,852,526,725,39,420,347,897,167,210,377,336,373,590,592,922,554,118,26,641,652,93,266,188,491,461,786,315,781,876,272,611,835,203,18,832,894,598,693,205,5,731,699,524,942,314,256,237,212,688,637,792,838,785,157,0,104,489,323,46,183,818,901,675,13,393,846,81,522,133,685,595,418,25,634,30,96,436,128,154,745,655,469,253,458,211,804,416,164,920,129,222,550,230,145,813,639,647,569,22,867,80,195,326,630,589,692,830,428,7,498,213,674,127,892,627,153,975,783,74,429,497,966,930,470,19,893,814,341,73,177,362,141,404,185,459,729,190,490,918,9,797,962,640,115,10,851,375,505,704,63,764,947,900,176,584,784,532,555,500,722,605,114,186,941,140,117,367,559,317,956,616,273,29,683,411,121,936,446,12,585,49,132,250,899,384,821,646,828,841,151,457,421,259,78,915,201,937,529,301,889,1,741,575,551,860,953,911,310,284,102,64,546,20,410,716,857,680,512,138,483,939,401,350,756,450,453,581,839,798,66,308,131,923,357,445,321,594,511,251,778,161,596,477,736,11,439,848,879,621,689,819,925,772,264,119,944,62,800,816,36,463,331,926,348,742,71,618,933,247,91,257,791,209,853,312,388,99,309,801,468,875,967,959,670,169,444,824,881,769,883,67,112,451,342,614,408,610,280,174,431,413,709,665,763,719,282,938,442,338,103,845,577,432,837,381,24,77,597,476,414
5 | 39,782,787,203,650,970,550,41,563,346,466,589,275,972,773,140,608,250,647,670,402,924,493,132,446,274,327,835,221,222,528,440,546,43,477,46,863,651,533,959,287,579,946,799,131,637,14,845,632,389,702,950,382,815,65,658,267,332,926,391,739,742,964,104,322,473,513,383,158,341,936,527,980,822,172,760,481,593,315,441,663,495,674,703,453,91,456,628,625,44,729,224,679,831,87,173,137,534,879,545,99,304,609,673,183,600,463,982,568,424,640,806,840,67,340,262,623,374,339,877,92,624,51,266,298,490,387,457,596,451,329,40,22,778,360,515,398,631,464,695,377,167,666,752,517,384,428,866,719,512,985,944,247,379,951,557,233,66,857,371,273,103,80,363,313,878,291,580,478,106,820,410,677,244,461,584,480,868,324,918,270,372,712,543,930,343,549,720,144,599,154,649,111,146,771,401,328,403,538,251,19,590,425,433,112,288,358,213,356,981,119,691,12,598,468,547,506,497,200,605,21,492,904,210,458,825,733,644,715,0,774,859,11,615,934,196,880,142,730,916,740,412,342,479,895,283,675,963,225,347,668,431,34,303,507,745,141,133,781,830,56,899,834,526,960,149,701,790,204,765,603,18,421,976,77,846,286,984,833,25,749,659,26,238,757,444,968,115,216,498,821,430,30,681,978,380,881,243,17,966,160,646,860,438,159,939,988,876,235,977,500,882,554,711,716,767,378,748,889,714,586,751,919,927,215,922,15,484,811,819,420,496,162,591,892,417,572,690,86,842,201,93,239,501,891,849,302,110,181,523,753,629,486,150,489,185,643,240,627,509,452,794,548,706,174,175,717,214,127,406,269,914,48,232,696,300,209,592,510,692,874,499,212,911,667,295,331,249,884,365,482,669,370,450,72,459,734,947,883,409,933,180,743,809,202,306,520,913,390,505,844,460,5,321,294,858,564,416,847,660,823,942,595,454,54,620,755,443,907,583,282,465,205,330,263,810,708,617,455,407,166,797,800,4,867,796,229,843,388,55,783,393,76,986,832,350,684,305,108,948,404,887,587,764,128,165,541,187,411,474,336,508,400,318,722,279,855,865,694,928,784,153,792,551,157,870,64,685,179,856,504,252,436,923,278,171,394,89,369,126,824,483,522,540,938,442,397,813,107,770,758,636,414,518,234,83,974,710,754,277,147,642,297,537,569,612,2,198,125,135,308,544,138,621,633,816,921,585,786,74,953,353,606,426,177,79,555,906,189,419,795,687,476,70,69,852,808,697,338,348,559,731,736,281,558,405,652,841,152,117,648,396,634,726,961,725,71,535,427,491,854,280,525,423,309,853,851,952,186,432,929,987,761,1,875,118,246,672,191,937,264,359,226,861,97,839,871,941,979,581,485,435,574,227,656,35,567,900,561,532,671,68,265,334,597,602,848,817,689,662,439,777,896,82,317,272,178,139,219,814,886,872,969,6,775,728,136,971,488,756,862,293,268,470,32,686,120,905,289,958,897,61,594,565,611,255,299,556,31,915,208,211,727,368,657,762,143,957,197,105,156,349,259,261,759,570,386,445,3,604,129,688,218,195,614,791,28,562,373,653,285,223,36,680,256,850,578,207,836,638,366,763,33,812,983,63,514,502,114,494,345,116,102,469,503,693,217,560,973,335,78,575,134,47,925,713,869,616,188,355,965,9,576,588,750,539,38,333,768,367,130,49,376,228,531,57,864,399,88,53,199,163,949,661,516,772,145,10,351,357,723,908,511,804,582,59,258,888,718,122,37,220,310,193,613,75,536,622,50,362,100
6 | 160,857,33,983,587,842,379,475,822,456,786,974,64,23,154,584,915,74,275,879,760,426,98,414,963,789,460,680,967,382,206,632,773,768,736,37,18,561,14,141,331,876,599,793,173,729,358,640,155,689,551,111,375,779,90,840,641,25,236,843,805,449,520,582,406,434,697,620,776,106,199,849,177,699,862,965,453,344,28,791,838,740,307,22,859,921,211,152,937,610,32,987,543,309,13,753,467,754,505,540,870,248,169,48,5,867,423,219,536,803,676,898,277,355,519,772,704,903,603,973,305,462,24,167,221,798,393,851,764,56,252,594,897,107,630,125,609,162,204,553,702,46,943,270,116,384,101,176,76,461,102,686,400,944,535,402,809,964,562,480,118,428,841,595,586,357,547,698,976,477,507,82,567,523,70,26,509,192,178,957,417,365,479,713,183,715,149,134,92,875,778,728,554,945,407,722,517,174,484,800,184,142,981,681,928,738,356,147,670,218,685,131,469,596,572,197,369,237,437,951,234,751,66,831,291,688,97,140,972,806,664,294,614,73,749,692,810,900,241,888,401,422,40,500,424,261,550,272,598,464,569,12,153,188,287,717,784,617,314,759,661,442,180,552,143,231,373,579,615,415,694,832,969,868,159,213,372,946,421,457,970,94,671,856,525,910,0,980,105,129,158,765,19,864,71,622,148,52,429,487,796,483,53,207,240,894,959,313,470,376,99,504,513,15,636,408,471,913,820,374,194,274,657,947,85,968,451,508,17,518,909,619,660,889,81,468,835,256,435,168,658,545,448,315,433,172,368,391,38,884,748,672,243,257,691,349,812,542,135,43,527,633,583,117,506,49,288,271,239,251,187,682,911,853,611,381,737,289,495,877,269,31,476,802,854,203,223,3,298,797,226,497,486,416,485,744,549,334,312,8,733,601,899,93,568,58,902,397,190,78,621,198,591,432,444,755,938,783,345,881,354,829,804,124,210,669,266,907,336,488,196,339,585,84,170,555,343,839,845,283,465,144,823,67,175,701,75,273,279,342,301,929,606,146,978,89,771,244,530,491,592,137,844,801,230,42,295,329,335,191,575,29,607,363,695,387,675,303,821,326,123,95,558,880,181,895,883,741,577,774,533,781,597,195,267,395,364,643,906,338,478,528,109,498,756,646,324,925,284,700,347,578,380,161,104,319,263,574,935,966,817,7,604,956,649,960,281,590,667,59,813,635,302,360,62,745,687,727,318,593,30,886,848,327,858,163,634,286,988,758,410,703,47,878,232,262,300,436,330,890,27,492,182,110,229,361,220,971,165,166,785,72,502,750,538,6,847,214,775,570,493,441,731,905,932,205,819,366,908,132,787,684,642,529,546,260,961,746,922,44,399,466,790,673,952,564,39,311,68,1,761,942,100,9,962,887,645,341,836,412,726,808,625,663,950,83,403,589,445,69,388,130,420,739,539,201,705,119,333,413,494,510,443,668,308,637,580,425,60,752,383,544,912,612,920,629,571,811,157,628,766,418,50,985,103,917,235,618,624,447,707,693,524,891,690,164,882,757,939,254,709,54,936,86,285,639,828,455,654,348,732,45,648,816,977,489,708,934,296,88,21,863,918,320,179,276,651,332,225,576,948,834,36,652,127,371,452,202,893,769,362,721,80,930,588,613,904,51,35,548,394,398,156,563,290,325,830,450,531,65,662,659,511,984,490,264,557,316,340,696,683,873,351,982,370,826,209,10,623,16,678,872,404,20,566,770,454,482,565,679,138,605,865,473,522,986,87,405,222,720,61,794,396,532,512,323,924,34,747,638,446
7 | 74,948,562,708,849,291,546,141,871,252,32,458,715,877,886,494,797,632,348,427,339,129,714,480,940,579,693,760,787,826,535,199,577,495,523,631,616,360,48,319,695,808,807,392,119,268,515,575,862,980,898,304,189,432,452,0,813,571,681,384,59,834,14,230,517,356,442,86,664,555,270,627,846,682,223,919,422,686,697,408,45,874,855,320,236,184,956,97,111,586,150,937,245,349,716,869,162,833,866,260,67,814,83,550,438,445,674,547,101,108,337,973,482,444,174,903,679,794,249,221,10,130,133,789,931,971,215,958,572,16,675,287,399,278,1,383,484,582,493,487,727,654,778,114,194,831,895,943,24,619,663,455,962,662,656,190,717,424,644,225,220,98,531,146,779,75,896,905,798,208,564,689,413,29,978,363,500,906,543,601,829,128,709,752,354,406,608,309,272,385,509,739,176,592,620,264,892,334,667,22,369,878,156,921,212,271,91,173,415,568,657,439,788,581,205,269,477,848,386,680,983,387,105,322,69,611,917,110,161,153,843,58,232,397,914,451,729,179,511,621,837,393,972,214,610,76,357,87,888,13,434,233,338,604,722,719,894,953,750,617,761,759,793,330,666,365,925,134,165,175,841,882,180,454,700,828,467,710,569,404,669,301,55,277,836,171,596,795,118,685,558,590,542,844,706,402,211,23,131,158,563,603,839,595,450,241,248,374,551,213,781,560,492,859,735,607,557,744,820,88,186,159,315,712,691,375,157,724,416,299,325,720,952,31,40,78,966,891,244,933,335,453,113,436,777,625,653,768,641,670,884,987,811,974,690,254,149,104,870,456,624,534,861,875,676,297,4,231,226,350,50,946,927,93,42,419,965,431,613,502,587,711,379,678,825,806,570,556,407,457,329,845,351,246,358,265,288,418,730,217,967,589,401,139,200,923,207,747,331,803,893,195,228,872,915,885,907,191,378,472,673,466,755,554,99,132,827,321,908,782,889,503,359,545,668,390,955,824,396,364,100,541,197,548,835,106,979,553,361,945,661,699,36,725,934,633,353,166,748,172,447,414,448,683,857,446,897,652,538,463,822,593,216,96,512,963,203,26,251,25,968,591,307,924,333,121,863,326,922,18,784,539,294,302,237,33,411,381,136,77,79,650,247,57,865,56,478,262,809,540,732,328,526,959,864,805,34,757,514,84,15,985,605,552,224,975,696,573,258,284,868,181,910,19,949,82,112,486,3,229,840,561,821,988,403,932,313,612,529,107,853,306,879,688,295,430,273,635,395,109,51,851,753,774,800,743,935,636,954,177,471,576,816,227,606,771,957,815,960,39,818,505,734,790,53,142,521,117,332,819,647,124,705,21,852,345,847,762,116,951,810,754,20,838,5,429,137,483,11,292,765,947,198,185,138,887,327,513,459,303,143,145,151,281,832,352,201,671,465,256,342,692,900,257,193,38,388,168,280,377,485,600,791,305,630,481,936,767,300,527,890,776,371,470,638,202,54,68,763,583,911,182,394,646,285,476,27,373,961,746,66,400,72,370,204,733,43,687,764,298,123,609,614,602,267,362,253,405,912,389,928,516,643,854,7,89,261,81,282,786,135,90,519,433,599,17,969,642,850,65,380,651,372,196,731,122,169,567,210,293,618,645,701,37,883,336,148,290,126,209,770,860,243,28,597,982,944,938,276,812,737,92,421,741,308,160,9,506,410,347,578,310,475,580,559,192,163,391,187,802,648,738,368,473,316,881,626,188,62,504,323,47,412,468,296,902,409,773,461,464,63,977,425,52,742,490,367,858,219,2,95
8 | 81,463,69,396,599,642,233,961,663,634,58,353,551,260,195,82,562,855,485,853,360,269,983,919,175,506,521,923,354,158,22,220,322,540,216,793,770,984,167,876,937,756,177,803,75,278,606,758,616,10,18,947,291,820,886,221,426,138,248,604,817,56,15,597,724,572,790,541,856,683,110,706,507,830,913,255,215,323,105,870,80,52,730,968,838,279,869,28,23,505,656,493,575,669,967,675,247,845,652,173,496,238,129,63,807,559,900,809,126,113,653,508,487,784,733,798,459,135,341,635,948,946,704,909,209,414,539,754,945,543,470,34,395,585,133,866,627,148,773,299,305,592,73,449,181,637,495,465,511,458,342,620,101,964,962,277,646,584,8,533,514,188,952,643,304,740,436,393,159,922,885,681,170,908,536,645,607,595,40,677,140,586,199,718,610,262,251,168,483,288,437,340,499,929,482,970,699,223,298,965,253,608,842,442,960,136,501,702,609,861,558,212,601,544,979,644,873,583,509,473,872,797,750,416,42,235,265,739,791,621,958,332,943,687,696,287,934,723,50,822,560,115,549,379,746,602,674,662,684,329,818,682,630,226,935,819,274,690,612,475,24,185,404,348,127,927,914,2,319,236,343,217,528,432,365,804,9,477,860,44,673,320,301,556,639,62,441,182,959,921,207,227,418,57,390,816,766,261,834,241,728,550,515,448,691,844,907,398,709,795,374,478,377,401,211,453,525,538,858,53,76,0,618,915,234,203,977,41,655,650,920,166,389,92,846,785,258,857,123,975,898,944,232,657,155,840,88,3,89,5,264,794,25,423,27,440,893,678,729,312,350,152,906,524,357,988,708,474,878,918,328,688,814,242,888,871,293,839,460,171,933,430,742,974,204,894,178,444,717,567,547,292,213,936,443,406,716,796,399,598,98,16,579,164,557,828,665,887,891,548,457,214,966,555,615,530,738,953,587,273,950,12,137,658,874,500,245,666,327,769,503,201,128,624,518,249,134,747,59,256,760,31,732,813,355,711,7,239,899,843,779,11,805,315,428,112,897,421,801,194,765,578,14,438,361,517,382,118,106,802,71,789,276,510,311,165,623,455,450,705,985,638,660,452,321,926,202,420,270,823,380,847,941,531,48,862,875,344,827,371,356,447,96,931,890,230,956,153,151,849,763,224,169,381,883,648,225,762,468,512,573,841,553,825,268,33,108,72,472,912,196,205,491,566,93,545,600,338,972,904,237,903,761,250,176,976,70,780,670,231,628,336,581,667,281,119,103,114,768,313,484,582,316,174,546,700,703,537,535,697,852,302,568,629,346,781,611,565,532,879,513,35,925,792,759,289,280,605,777,564,120,13,333,451,588,60,370,143,593,694,229,189,594,197,467,774,359,308,307,366,415,554,49,727,713,266,240,351,394,955,456,516,84,978,534,303,310,734,497,306,552,111,131,145,751,208,939,362,383,454,737,788,911,916,749,422,337,193,778,851,263,863,376,417,905,74,661,124,519,91,748,435,680,142,413,206,162,132,625,462,940,403,97,372,85,144,901,87,672,651,710,752,433,160,318,712,488,21,571,352,326,121,94,294,577,285,865,764,192,349,184,981,949,679,654,367,895,504,190,309,743,385,464,429,725,125,30,179,808,896,4,942,325,210,590,388,721,971,954,529,387,183,90,358,345,339,461,867,286,1,419,282,741,526,498,55,397,707,720,425,330,410,480,889,479,290,102,38,902,527,295,222,37,753,987,668,68,880,494,431,257,20,130,837,806,32,39,647,910,407,51,469,868,917,636,671,686,829,79,157,800
9 | 469,674,116,39,947,425,727,331,43,251,638,371,801,294,509,320,138,308,924,24,559,386,438,818,220,595,716,366,743,82,870,779,602,53,395,409,362,126,459,420,182,643,977,788,945,223,715,731,789,238,310,639,571,728,608,553,492,526,965,254,504,811,632,983,97,207,441,289,777,475,673,577,903,222,458,986,180,477,445,768,723,343,396,821,698,981,299,839,31,244,44,228,322,806,515,19,480,186,95,491,408,670,470,468,130,961,898,121,374,971,17,265,41,71,13,970,432,558,877,350,820,55,123,358,78,869,618,415,460,25,978,172,908,496,443,861,832,233,874,304,710,90,599,809,108,259,554,490,617,953,512,3,814,111,426,858,671,176,695,659,481,954,372,882,860,557,143,878,694,96,645,206,835,500,739,107,54,534,146,758,446,263,704,419,636,161,59,689,514,655,137,174,321,455,279,843,163,353,584,212,188,744,120,921,26,802,170,281,938,339,278,804,862,696,736,757,378,540,790,916,680,590,681,439,958,91,783,377,683,407,912,421,883,131,236,248,828,887,541,264,405,561,774,401,902,253,885,218,531,837,360,45,796,934,389,313,452,735,155,631,382,660,732,604,77,20,62,247,677,962,387,517,890,119,628,622,964,1,431,411,297,314,428,753,849,725,654,190,516,691,498,734,585,316,83,150,104,778,623,687,209,9,846,822,442,888,662,946,197,943,502,824,63,464,816,38,551,418,375,564,359,892,246,397,863,367,524,227,926,164,747,871,139,647,162,429,950,782,301,416,948,598,454,901,109,184,668,791,920,194,722,213,463,770,102,955,730,591,471,581,573,918,345,319,185,549,503,967,276,479,740,478,607,771,417,341,537,335,560,153,692,76,765,171,641,853,738,351,848,625,11,140,177,982,444,973,635,675,857,73,787,157,65,567,193,376,923,257,851,317,904,255,203,168,797,762,202,936,610,745,913,933,942,447,940,33,277,252,147,93,250,899,886,241,620,204,287,895,693,751,493,815,100,697,169,726,70,838,312,672,414,50,290,914,149,667,388,352,156,527,344,840,749,507,328,347,99,985,798,467,23,980,592,794,545,240,721,949,600,400,489,81,434,593,714,327,506,80,292,909,98,685,226,688,30,235,746,298,113,956,379,807,385,927,402,293,729,394,601,61,754,229,282,720,413,900,167,836,7,92,181,844,699,759,543,719,208,72,817,300,897,412,565,175,485,960,189,597,566,482,929,589,52,318,232,173,94,544,772,110,369,963,708,854,456,915,881,261,260,665,748,4,165,859,280,115,855,303,286,87,550,497,214,268,187,291,780,513,364,984,216,160,64,231,718,548,158,579,271,827,326,529,270,525,803,547,333,530,775,284,501,686,219,893,724,399,795,16,125,799,285,519,427,969,702,713,69,664,931,354,391,195,649,392,35,596,532,613,676,198,269,334,894,315,975,27,633,242,42,709,2,472,384,283,968,825,142,129,373,644,75,510,141,910,905,461,133,127,684,381,957,166,626,18,555,74,880,538,616,296,761,486,658,678,101,826,505,183,764,128,741,867,523,959,66,309,865,576,148,847,234,487,355,363,122,773,800,196,440,422,737,630,266,423,701,435,717,612,215,812,410,8,792,348,14,642,499,634,679,767,756,988,79,624,330,769,60,225,562,136,522,682,766,224,834,86,563,872,151,449,706,932,324,619,305,191,841,810,569,781,808,58,450,614,935,36,528,84,521,845,703,258,144,124,152,535,230,850,135,56,520,785,0,22,833,917,357,288,760,987,707,937,29,572,925,952,249,178,307,575,951
10 | 640,852,346,824,641,706,949,626,427,628,464,382,727,633,899,394,972,289,256,896,12,259,582,37,364,40,644,473,436,609,657,297,363,672,512,663,217,909,272,594,546,864,48,604,719,670,608,261,518,558,126,776,914,498,975,122,352,568,635,189,861,674,292,788,257,723,491,426,337,653,869,203,850,985,343,446,374,18,419,50,522,974,278,511,105,270,651,348,982,186,336,606,514,214,402,787,82,237,58,944,442,555,895,492,424,398,676,78,638,248,879,703,384,935,541,437,59,878,441,45,744,192,267,686,269,349,619,675,856,486,508,525,958,882,799,598,833,49,304,804,213,829,652,654,456,831,593,9,183,300,505,515,766,6,756,334,509,553,19,347,516,617,159,761,649,521,970,759,451,428,418,129,73,131,801,650,340,948,714,988,94,634,563,210,877,218,368,377,871,355,168,953,378,870,938,52,239,692,973,230,250,785,620,557,420,92,924,808,246,30,104,147,187,317,61,922,227,886,134,409,361,43,846,38,897,826,120,234,536,479,396,769,87,208,784,770,627,718,962,739,550,403,669,123,254,255,765,817,666,5,602,645,497,198,322,138,264,931,410,190,556,435,757,830,329,493,868,222,53,881,499,84,690,890,572,258,392,534,573,894,616,283,211,103,946,490,658,88,390,324,885,178,417,549,697,678,717,698,445,439,927,393,89,713,249,754,715,121,950,333,263,825,22,90,835,746,152,513,127,146,580,911,354,142,41,395,613,474,209,139,331,212,109,597,276,618,86,711,903,893,320,312,976,411,677,811,399,141,495,600,729,980,67,794,483,665,734,284,194,113,987,170,524,721,440,977,796,919,726,575,934,888,389,155,101,942,372,821,818,704,561,455,422,873,244,967,453,69,501,332,326,768,226,725,642,163,971,408,306,966,458,118,920,742,174,8,963,696,380,404,574,339,367,793,268,42,466,74,904,28,260,167,449,430,350,822,434,504,743,859,517,85,447,64,947,15,232,485,460,344,376,36,538,369,636,571,21,319,154,898,802,117,583,391,200,241,951,901,150,577,251,388,185,307,646,172,177,79,179,1,7,97,47,841,682,752,779,96,119,603,551,161,11,202,660,216,803,584,986,313,448,365,969,559,716,848,204,519,591,221,247,63,523,469,578,836,135,566,431,529,596,0,872,820,849,510,892,629,918,70,912,252,733,310,905,816,542,266,184,468,908,271,844,265,955,814,643,173,196,535,133,537,280,231,783,281,797,810,308,964,39,489,862,775,668,891,253,930,55,342,360,287,916,929,813,46,83,108,507,695,345,13,707,699,34,373,476,587,805,106,60,661,75,664,303,279,673,330,745,941,467,933,749,488,615,637,112,301,481,370,205,328,14,26,356,945,162,148,502,755,910,548,567,701,812,625,981,66,188,62,375,165,274,362,655,23,4,900,387,197,961,764,630,457,589,500,791,20,978,965,913,867,171,240,624,880,353,760,710,477,771,875,778,158,876,543,273,153,379,724,35,381,233,671,80,700,181,2,762,68,503,647,687,544,323,786,774,595,979,413,795,827,601,800,487,532,16,243,569,371,612,984,17,25,932,201,180,149,865,484,807,621,423,242,586,429,296,65,438,100,309,275,720,842,552,843,747,863,357,140,359,683,866,819,527,338,191,884,954,960,857,902,923,854,494,136,24,291,684,688,858,789,631,614,454,401,708,56,93,750,199,110,592,482,736,219,837,3,738,581,740,305,459,98,506,705,325,685,32,983,758,299,425,952,496,777,245,385,834,207,412,526,175,166,907,302,545,223,314,662,562,712
11 |
--------------------------------------------------------------------------------
/data/abide_schaefer100/val.index:
--------------------------------------------------------------------------------
1 | 276,429,82,568,705,200,391,105,95,282,668,987,215,831,764,416,397,66,29,893,689,176,672,541,72,858,249,719,170,63,640,784,833,139,38,974,551,159,442,178,742,896,864,698,92,183,349,789,298,3,365,520,135,402,710,295,562,174,612,886,851,852,774,86,740,320,588,67,511,493,347,728,880,259,440,532,648,921,542,901,942,823,624,212,461,405,406,228,707,65,84,911,515,948,979,404,952,804,412
2 | 572,433,32,786,236,577,770,654,805,60,558,578,566,746,958,640,680,249,653,360,297,482,675,95,82,102,166,716,377,980,484,85,317,763,160,327,783,764,204,438,451,822,938,128,80,791,254,227,51,475,888,697,361,308,840,775,741,779,413,981,474,820,407,277,956,405,237,427,345,960,28,864,234,968,810,353,641,334,782,588,116,496,534,77,510,599,894,446,161,225,579,873,45,10,22,552,147,44,923
3 | 458,933,297,973,343,173,877,130,853,601,160,491,918,58,900,730,884,676,893,873,549,615,726,489,378,584,762,545,519,723,851,826,147,679,499,281,363,503,112,988,337,201,100,332,829,46,977,60,571,27,311,724,124,774,507,758,577,370,1,434,316,249,445,538,635,244,274,169,722,916,258,271,459,230,531,296,823,416,592,511,421,476,6,623,150,443,791,651,463,467,542,231,804,576,639,198,581,766,931
4 | 661,579,109,184,298,740,113,734,872,116,61,950,518,724,392,407,663,159,95,767,21,244,227,467,57,60,397,768,387,624,501,492,698,208,593,274,471,370,485,547,556,299,175,223,602,354,187,537,412,607,100,619,271,750,668,521,931,83,890,919,262,927,714,557,591,172,654,755,820,662,473,580,304,636,795,836,952,286,403,531,775,204,653,968,825,344,712,305,328,774,449,285,334,855,578,782,58,43,316
5 | 954,311,902,519,610,242,168,418,60,626,190,161,316,932,910,296,665,943,641,766,113,475,375,529,448,826,326,709,385,314,408,325,364,704,635,395,940,148,8,254,192,73,967,312,52,182,664,290,7,96,769,735,699,890,29,27,245,237,449,724,838,437,415,827,155,975,361,676,803,909,413,109,462,931,260,471,422,920,721,801,732,434,45,571,90,776,780,655,276,566,798,837,901,737,747,917,337,121,619
6 | 503,122,955,573,644,919,653,317,389,718,631,297,931,224,91,537,896,392,253,650,714,200,926,941,762,386,824,438,114,145,782,792,55,41,2,151,827,656,458,126,666,892,186,874,427,350,247,626,818,282,439,600,501,885,150,359,238,337,472,278,409,352,57,730,255,250,430,871,193,723,742,496,767,958,788,526,258,724,815,855,292,293,265,837,440,799,807,299,411,115,556,953,306,521,228,249,866,665,734
7 | 525,44,341,530,640,524,170,520,64,491,655,275,622,242,474,549,508,823,346,235,85,94,103,916,785,565,976,941,726,120,440,939,522,234,623,437,799,856,694,629,792,536,239,489,899,615,49,435,496,6,449,723,462,286,218,317,61,343,340,279,639,901,637,876,366,259,152,659,164,909,532,499,533,144,507,344,206,649,703,749,707,677,698,873,498,736,756,46,842,594,780,918,318,751,702,528,772,70,986
8 | 580,719,833,722,848,783,659,284,411,283,831,324,632,776,631,218,757,726,775,386,782,520,563,45,767,314,364,617,576,424,570,272,799,693,246,154,502,574,701,46,626,141,83,824,200,99,252,434,259,139,19,378,698,405,107,122,622,402,116,744,95,928,486,147,409,692,982,369,735,815,271,67,347,26,850,542,466,745,969,373,54,930,882,786,392,335,492,61,755,664,481,243,267,439,161,77,163,714,772
9 | 117,640,941,272,154,85,368,875,891,911,179,830,457,325,793,705,588,275,112,574,629,700,979,763,32,873,302,398,12,132,51,245,274,805,711,866,556,114,536,436,930,323,40,57,37,868,889,383,267,637,34,712,884,831,433,508,666,474,864,533,311,201,210,627,603,403,690,346,390,217,570,465,466,88,6,370,134,295,587,542,856,192,518,47,338,819,473,488,462,552,829,852,944,15,476,651,648,650,876
10 | 928,156,10,940,889,151,860,193,925,956,29,639,763,585,957,735,462,838,689,169,71,823,554,681,229,579,753,828,478,386,472,943,215,206,694,463,406,321,780,107,228,832,77,656,648,443,358,883,224,530,102,405,471,798,560,327,164,607,72,351,366,99,906,937,741,809,611,262,728,125,547,311,731,939,853,693,144,111,235,81,588,540,461,236,887,444,400,605,124,539,926,115,33,145,128,95,238,176,533
11 |
--------------------------------------------------------------------------------
/data/adni_schaefer100/test.index:
--------------------------------------------------------------------------------
1 | 4,26,53,57,71,76,77,79,110,120,126,131,137,146,151,162,183,188,199,204,227,232,239,248,249,258,269,285,287,300,301,314,317,324,329,359,382,383,385,400,420,427,431,447,452,469,471,472,473,474,475,484,498,512,518,544,550,557,565,566,571,584,594,600,604,606,609,611,612,634,638,651,656,688,695,696,708,713,714,732,740,756,826,834,841,848,854,858,878,896,897,909,912,919,923,926,937,953,968,980,992,1015,1023,1029,1039,1043,1045,1052,1054,1059,1071,1096,1107,1125,1138,1147,1161,1175,1178,1208,1210,1213,1231,1234,1239,1241,1262,1268,1282,1291,1299,1307,1319
2 | 6,10,16,17,25,27,42,49,69,80,89,101,122,138,144,157,169,174,180,210,212,242,253,266,280,288,306,318,322,332,339,341,350,356,361,365,368,374,377,384,387,392,404,407,411,426,428,448,451,457,504,553,568,569,572,574,579,580,610,622,632,642,645,647,652,660,673,676,685,718,738,739,744,749,754,762,764,765,769,771,778,809,819,836,840,846,849,865,873,884,893,934,950,963,979,990,991,995,997,1003,1006,1022,1027,1037,1055,1064,1075,1082,1088,1094,1136,1143,1148,1165,1166,1176,1182,1184,1188,1193,1211,1219,1233,1240,1243,1245,1264,1274,1285,1297,1298,1310,1323
3 | 0,28,35,36,59,62,78,88,104,113,116,141,164,178,179,181,185,187,190,202,203,208,213,226,230,236,246,251,259,264,277,284,298,299,309,321,323,334,347,357,380,395,417,435,442,446,459,465,483,485,486,495,509,516,517,520,523,527,533,540,543,552,561,613,627,633,641,654,659,667,686,687,690,693,701,717,720,723,731,736,774,813,828,832,837,850,853,860,879,890,891,942,948,949,951,952,959,966,972,976,986,1010,1016,1017,1030,1034,1056,1062,1063,1109,1126,1137,1140,1150,1160,1167,1174,1179,1185,1209,1217,1224,1230,1237,1242,1247,1257,1266,1277,1279,1313,1314,1324
4 | 3,8,14,33,34,37,41,46,61,66,68,86,90,96,98,112,121,133,139,145,153,159,161,166,184,191,193,195,198,207,215,216,220,223,235,238,244,255,261,267,278,282,305,307,327,333,344,360,379,390,402,405,415,423,424,430,436,458,481,482,493,555,567,575,583,589,601,602,625,628,630,631,648,658,664,668,674,712,724,761,767,783,820,823,842,844,868,876,888,898,902,915,925,928,935,956,969,977,982,983,994,1019,1020,1042,1051,1068,1072,1098,1101,1108,1115,1123,1129,1132,1135,1139,1144,1159,1170,1198,1199,1203,1212,1216,1238,1253,1260,1270,1278,1290,1306,1315,1318
5 | 11,22,23,39,45,65,75,83,115,117,123,134,148,152,155,171,172,194,219,225,229,237,247,270,272,276,295,302,330,342,343,355,363,370,372,386,388,419,421,440,450,460,462,467,476,494,503,513,532,536,542,548,549,576,590,607,614,615,620,626,629,637,666,675,678,684,689,694,698,706,715,722,729,782,784,785,788,789,793,801,804,807,825,830,843,847,852,882,892,901,903,913,924,929,933,938,941,962,988,998,1008,1026,1036,1069,1070,1078,1081,1083,1089,1099,1104,1105,1110,1112,1118,1119,1122,1158,1180,1189,1200,1214,1228,1229,1251,1254,1281,1283,1288,1292,1293,1300,1303
6 | 5,7,9,15,30,43,63,74,84,93,97,103,130,140,143,147,168,186,197,200,201,206,209,231,260,262,263,268,271,286,291,296,304,313,315,345,346,348,366,394,397,418,422,425,429,453,456,463,470,477,478,492,496,497,524,526,528,531,539,546,556,570,593,603,617,639,644,646,655,682,716,725,735,743,753,755,770,787,792,794,795,817,833,838,851,857,866,870,877,899,907,910,920,932,943,964,970,981,985,989,999,1014,1021,1024,1028,1048,1050,1058,1061,1079,1086,1087,1093,1111,1128,1151,1162,1169,1187,1191,1194,1221,1225,1226,1250,1259,1269,1272,1276,1312,1317,1320,1325
7 | 18,44,55,64,72,82,85,91,109,125,129,132,135,142,211,218,224,228,279,303,319,325,353,358,362,364,367,369,371,391,396,408,412,414,416,432,437,438,441,444,449,479,488,489,499,502,508,525,530,537,554,563,564,591,592,595,598,599,618,640,661,662,671,679,705,719,727,728,733,734,746,747,758,760,763,766,772,777,791,802,810,811,829,856,862,863,867,886,889,894,905,917,918,930,954,973,978,987,1001,1002,1004,1011,1033,1040,1041,1049,1067,1092,1095,1114,1117,1120,1121,1131,1146,1168,1172,1181,1186,1190,1196,1197,1218,1223,1235,1263,1265,1271,1287,1295,1308,1322
8 | 1,20,21,24,50,54,56,95,102,106,114,124,127,165,167,170,175,177,182,192,252,274,297,308,310,316,331,338,351,375,381,389,393,403,409,413,433,434,468,487,490,491,500,501,506,510,514,521,538,547,560,562,573,578,581,582,596,597,621,650,653,657,677,691,692,699,700,702,707,710,721,730,737,741,751,773,775,798,803,805,814,816,821,824,827,839,861,871,887,895,900,916,922,945,946,955,960,961,965,1005,1009,1038,1046,1047,1053,1057,1066,1076,1080,1085,1103,1106,1130,1134,1149,1153,1164,1171,1173,1205,1220,1227,1232,1236,1244,1256,1258,1273,1289,1294,1304,1321
9 | 2,12,13,19,32,38,47,48,58,60,67,70,87,92,94,99,107,108,111,118,119,128,156,160,163,176,189,196,205,221,222,250,254,256,265,273,275,283,289,290,311,312,320,326,336,337,352,373,399,406,439,443,445,454,455,480,511,535,541,551,585,587,616,619,623,635,636,669,672,681,683,697,711,742,748,750,752,799,800,806,812,818,822,835,845,869,872,880,881,904,908,911,914,939,940,957,958,974,975,984,993,1012,1013,1018,1031,1032,1060,1077,1084,1090,1091,1097,1100,1102,1133,1142,1156,1163,1177,1195,1202,1206,1246,1248,1249,1261,1267,1280,1286,1301,1302,1309
10 | 29,31,40,51,52,73,81,100,105,136,149,150,154,158,173,214,217,233,234,240,241,243,245,257,281,292,293,294,328,335,340,349,354,376,378,398,401,410,461,464,466,505,507,515,519,522,529,534,545,558,559,577,586,588,605,608,624,643,649,663,665,670,680,703,704,709,726,745,757,759,768,776,779,780,781,786,790,796,797,808,815,831,855,859,864,874,875,883,885,906,921,927,931,936,944,947,967,971,996,1000,1007,1025,1035,1044,1065,1073,1074,1113,1116,1124,1127,1141,1145,1152,1154,1155,1157,1183,1192,1201,1204,1207,1215,1222,1252,1255,1275,1284,1296,1305,1311,1316
11 |
--------------------------------------------------------------------------------
/data/adni_schaefer100/val.index:
--------------------------------------------------------------------------------
1 | 185,363,1184,96,41,560,1136,1038,450,798,831,526,783,358,1004,1060,1,458,1152,462,1130,334,1308,857,459,1137,435,642,985,1271,265,50,231,1085,357,588,1242,321,936,620,37,576,545,1114,791,628,1191,647,415,351,1112,422,1237,832,187,952,68,1013,1325,531,977,661,337,561,907,994,795,8,1276,81,729,75,820,284,553,1215,1324,1000,1105,786,944,139,707,416,448,666,801,625,1264,203,984,42,215,1221,36,1093,456,543,1082,404,372,333,864,156,847,1088,563,1098,343,461,5,965,122,46,504,541,636,1127,633,1240,583,860,736,895,1193,617,1116,485,654,1272,686,806,174
2 | 540,198,931,164,1284,625,443,410,607,1068,1278,554,51,1035,419,55,700,629,942,1212,880,501,145,902,65,251,525,1072,552,1047,602,1092,1098,706,0,261,968,265,760,192,889,551,282,628,120,1280,658,1286,479,18,305,476,168,166,244,870,281,338,608,816,290,160,94,630,1271,1218,1135,581,947,1031,1058,1206,336,598,998,959,740,231,1106,248,621,90,279,1063,536,441,264,444,992,564,1107,867,1239,1229,903,842,510,634,982,814,317,31,620,662,821,1260,1170,667,1121,1043,1249,1119,314,1008,1127,292,333,465,651,1149,1207,1317,163,961,233,717,692,825,773,289,548,458,28
3 | 1043,348,814,1301,549,47,682,911,849,759,283,455,646,276,461,383,1275,428,441,665,643,1308,326,1054,1013,1082,797,794,647,129,630,715,163,146,935,940,924,1012,513,1221,770,303,109,73,15,765,680,306,975,482,1295,883,535,1239,234,55,1212,24,507,733,1072,1248,992,225,266,590,1261,1125,856,1229,338,93,669,882,183,855,399,824,1299,274,473,664,871,695,186,521,337,1088,670,977,367,1166,1049,261,964,537,1068,111,571,1041,1057,608,118,27,907,124,210,25,466,340,1270,970,1120,679,1178,278,583,877,1218,1046,222,985,384,494,544,1243,43,1134,13,229,1118,800,677
4 | 446,380,62,326,1083,75,697,167,957,590,409,1166,1100,1150,1057,371,824,22,26,515,684,689,670,1094,5,952,938,1210,736,92,502,510,200,108,862,1130,678,407,1114,394,1237,1119,538,1005,1208,199,604,1136,468,475,136,277,342,651,421,368,30,1227,1294,187,453,351,989,751,709,1084,770,358,899,420,831,1126,944,961,406,1316,1149,288,470,147,694,789,160,1053,999,1034,826,971,1273,395,1324,556,489,1026,1200,1321,1250,896,335,447,905,263,457,189,598,854,181,385,1096,138,960,126,833,611,70,1295,486,1177,25,532,1224,239,706,334,325,699,354,1312,225,228,376,640,299
5 | 770,133,750,797,586,1175,257,1272,926,671,879,1182,859,1117,702,1259,74,915,486,781,602,708,582,136,710,682,1048,808,899,1252,46,977,364,1203,1168,1302,728,1017,530,777,246,374,744,93,1082,1240,528,358,1265,454,260,1262,507,1004,352,13,1244,53,481,173,1035,332,1237,105,153,904,524,1051,1155,673,1195,220,971,851,180,212,127,1170,565,968,499,241,506,131,531,221,1223,63,888,97,1095,983,1183,465,119,831,1023,251,707,283,73,672,946,721,141,597,12,949,410,817,588,187,990,886,900,1113,457,27,128,1021,3,264,116,1116,1315,742,1319,62,664,1030,57,662,56
6 | 228,846,1081,70,991,11,941,923,122,577,252,715,1042,265,285,1182,821,255,278,790,543,574,129,997,12,509,373,1155,740,342,1136,1122,335,327,559,1138,750,10,534,1026,1036,1181,1017,1248,1288,611,1121,884,764,280,956,774,840,1291,141,827,202,1106,924,371,1313,823,731,1205,612,876,1154,495,1260,2,1173,40,1301,550,1323,1120,219,938,266,843,155,214,338,117,159,945,454,488,1279,942,590,1045,29,1219,215,46,785,239,614,804,59,832,3,665,685,739,1148,230,580,728,1204,979,216,390,586,573,479,419,1241,353,328,729,473,678,23,672,620,1193,466,754,791,382,1238
7 | 446,633,435,118,25,1035,491,106,381,1103,994,1006,801,0,819,605,697,260,239,451,79,576,1032,575,1221,402,1231,1149,861,604,1109,670,372,909,799,518,515,343,596,677,1156,504,220,276,538,102,1253,1185,1277,1258,956,207,189,1207,832,225,468,904,1136,480,1303,631,921,219,1089,1116,624,731,1007,817,1272,695,1026,194,482,1198,828,821,666,988,494,985,465,300,664,1255,929,88,543,1125,481,327,307,943,534,178,1314,609,242,675,1106,654,648,1183,725,213,837,556,375,510,286,637,339,74,316,131,471,409,348,1208,442,1020,680,1113,186,868,1279,873,1229,1060,521,158,1052
8 | 416,341,1008,877,520,494,583,940,1313,1225,605,367,19,212,1167,387,1012,720,1229,92,543,198,845,1126,678,145,12,322,399,213,1190,84,1050,48,1072,306,1271,228,931,286,1209,993,941,493,754,701,412,1197,153,964,140,355,1265,227,155,899,117,574,185,624,422,862,134,687,1147,1306,298,195,304,275,143,1006,1048,637,1092,995,697,156,1122,142,144,1044,216,832,1094,666,255,1262,287,302,1013,1117,439,673,558,266,540,1084,857,806,796,944,400,786,1299,764,97,1160,425,458,248,1204,336,11,865,268,474,766,1185,735,896,319,130,31,566,1098,34,1248,1259,180,957,835,1131
9 | 899,319,393,371,991,408,733,1178,211,931,478,949,258,972,1157,934,181,360,388,218,1075,441,567,282,41,1251,1303,1167,600,106,304,558,51,479,379,306,988,1104,1135,18,774,281,322,871,503,701,267,970,699,834,540,836,1095,745,35,675,449,246,284,1068,487,76,342,398,85,1052,594,694,1203,1009,559,93,537,639,1034,1270,368,796,666,1228,1264,566,1217,1273,739,179,315,102,1001,831,1193,700,1155,542,55,16,707,561,802,476,435,877,719,1002,117,650,1132,162,1057,885,795,169,1105,768,1256,358,1219,1214,1071,409,1041,1292,573,1291,848,240,64,852,1058,301,7,765,1146
10 | 35,675,694,654,77,1126,579,103,20,508,363,470,222,71,1224,274,1120,517,1133,205,925,923,1178,1199,299,890,765,610,1175,296,364,1208,1235,1076,411,344,89,1230,611,1078,999,1280,1245,365,1031,677,1214,79,969,676,752,562,53,734,1085,485,289,408,600,499,817,727,840,196,812,805,380,1259,445,824,1179,323,1069,564,762,1098,412,63,854,625,527,963,1301,748,879,770,117,232,1292,754,829,871,1297,531,207,1068,437,1312,760,295,182,556,1020,422,761,617,528,1016,9,1105,355,57,740,932,458,246,902,1011,778,989,688,995,1299,432,104,216,914,359,208,1114,839,928,638
11 |
--------------------------------------------------------------------------------
/data/data.py:
--------------------------------------------------------------------------------
1 | """
2 | File to load dataset based on user control from main file
3 | """
4 | from data.BrainNet import BrainDataset
5 |
6 |
7 | def LoadData(DATASET_NAME, threshold=0, edge_ratio=0, node_feat_transform='original'):
8 | """
9 | This function is called in the main.py file
10 | returns:
11 | ; dataset object
12 | """
13 |
14 | return BrainDataset(DATASET_NAME, threshold=threshold, edge_ratio=edge_ratio, node_feat_transform=node_feat_transform)
15 |
--------------------------------------------------------------------------------
/data/generate_data_from_mat.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import pandas as pd
3 | import numpy as np
4 | import networkx as nx
5 | import os # To create directories
6 | import shutil
7 | import scipy.io
8 | import dgl
9 | import torch
10 | import glob
11 | import csv
12 | import re
13 | import json
14 | from tqdm import tqdm
15 | from dgl.data.utils import save_graphs
16 | from sklearn.model_selection import StratifiedKFold, train_test_split
17 |
18 |
19 | def _load_matrix_subject_with_files(files, remove_negative=False):
20 | subjects = []
21 | for file in files:
22 | mat = scipy.io.loadmat(file)
23 | mat = mat["data"]
24 | np.fill_diagonal(mat, 0)
25 | if remove_negative:
26 | mat[mat < 0] = 0
27 | subjects.append(mat)
28 | return np.array(subjects)
29 |
30 | def construct_dataset(data_name):
31 | feat_dir = 'data/to/connectivity_matrices_schaefer/' + data_name + '/'
32 |
33 | G_dataset = []
34 | Labels = []
35 | group2idx = {}
36 | paths = glob.glob(feat_dir + '/*/' + '*_features_timeseries.mat', recursive=True)
37 | feats = _load_matrix_subject_with_files(paths)
38 |
39 | print('Processing ' + data_name + '...')
40 |
41 | for j in tqdm(range(len(feats))):
42 | name = paths[j].split('/')[-1]
43 | group = re.findall('sub-([^\d]+)', name)[0]
44 | if group not in group2idx.keys():
45 | group2idx[group] = len(group2idx.keys())
46 | i = group2idx[group]
47 |
48 | G = nx.DiGraph(np.ones([feats[j].shape[0], feats[j].shape[0]]))
49 | graph_dgl = dgl.from_networkx(G)
50 |
51 | graph_dgl.ndata['N_features'] = torch.from_numpy(feats[j])
52 | # Include edge features
53 | weights = []
54 | for u, v, w in G.edges.data('weight'):
55 | # if w is not None:
56 | weights.append(w)
57 | graph_dgl.edata['E_features'] = torch.Tensor(weights)
58 |
59 | G_dataset.append(graph_dgl)
60 | Labels.append(i)
61 |
62 | print('Finish process ' + data_name + '. ' + str(len(feats)) + ' subjects in total.')
63 |
64 | Labels = torch.LongTensor(Labels)
65 | graph_labels = {"glabel": Labels}
66 | if not os.path.exists('./bin_dataset/'):
67 | os.mkdir('./bin_dataset/')
68 | print(Labels.shape)
69 | print(len(G_dataset))
70 | save_graphs("./bin_dataset/" + data_name + ".bin", G_dataset, graph_labels)
71 |
72 |
73 | def move_files(data_name):
74 | feat_dir = '/data/jiaxing/brain/connectivity_matrices_schaefer/' + data_name + '/'
75 | paths = glob.glob(feat_dir + '/*/*', recursive=True)
76 | for path in paths:
77 | if path[-4:] == '.mat':
78 | if 'schashaefer' in path:
79 | new_path = re.sub('schashaefer', 'schaefer', path)
80 | os.rename(path, new_path)
81 | continue
82 | else:
83 | parcellation = data_name.split('_')[-1]
84 | os.rename(path, path + '_' + parcellation + '_correlation_matrix.mat')
85 |
86 |
87 | if __name__ == '__main__':
88 | error_name = []
89 | # file_name_list = os.listdir('./correlation_datasets/')
90 | file_name_list = ['adni_schaefer100']
91 |
92 | for data_name in file_name_list:
93 | move_files(data_name)
94 | # construct_dataset(data_name)
95 | # try:
96 | # construct_dataset(data_name)
97 | # except:
98 | # print('[ERROR]: ' + data_name)
99 | # error_name.append(data_name)
100 | print(error_name)
101 | print('Done!')
102 |
--------------------------------------------------------------------------------
/data/neurocon_schaefer100/test.index:
--------------------------------------------------------------------------------
1 | 10,12,15,22,29
2 | 6,9,20,30
3 | 3,14,36,37
4 | 0,1,27,39
5 | 7,11,16,25
6 | 4,31,33,34
7 | 8,32,35,38
8 | 13,17,19,21
9 | 5,23,26,28
10 | 2,18,24,40
11 |
--------------------------------------------------------------------------------
/data/neurocon_schaefer100/train.index:
--------------------------------------------------------------------------------
1 | 0,31,23,28,30,26,2,9,40,18,20,19,35,33,27,25,36,4,38,7,24,17,8,21,13,5,3,1,6,39,16,14
2 | 35,21,10,2,14,4,3,39,23,29,11,28,26,13,38,27,15,18,32,12,16,5,1,24,34,36,33,31,25,8,40,37
3 | 2,8,12,20,11,27,0,32,29,33,31,7,4,18,24,23,40,28,17,5,1,19,30,9,35,38,10,25,22,21,26,34
4 | 20,15,19,26,34,10,16,11,6,23,38,37,7,35,9,18,5,30,32,31,24,8,13,33,29,28,12,40,14,25,36,2
5 | 13,26,2,30,40,21,19,38,32,31,35,14,17,24,15,36,29,1,28,33,34,0,3,6,27,5,37,12,23,39,9,10
6 | 26,18,1,3,10,25,32,36,9,20,28,17,14,27,19,15,22,39,24,8,13,29,23,30,5,6,37,7,35,12,11,40
7 | 4,13,3,23,12,15,10,30,17,19,27,0,11,20,29,22,31,34,9,14,18,24,36,16,21,33,7,26,37,40,6,5
8 | 22,25,30,29,3,11,16,5,28,14,34,36,2,0,32,39,7,8,15,4,10,38,18,1,37,24,26,27,9,35,40,23
9 | 2,10,0,32,15,11,36,27,17,1,13,21,37,20,34,39,38,4,7,22,16,9,3,25,18,8,33,12,40,30,35,19
10 | 38,12,16,8,14,32,39,29,31,10,19,1,20,36,22,4,6,26,13,30,37,34,25,28,0,11,9,23,3,17,33,15
11 |
--------------------------------------------------------------------------------
/data/neurocon_schaefer100/val.index:
--------------------------------------------------------------------------------
1 | 34,11,32,37
2 | 19,0,17,22,7
3 | 39,15,13,16,6
4 | 4,22,17,21,3
5 | 18,20,8,4,22
6 | 0,2,38,21,16
7 | 2,25,39,28,1
8 | 12,31,6,20,33
9 | 14,29,31,6,24
10 | 7,5,21,35,27
11 |
--------------------------------------------------------------------------------
/data/ppmi_schaefer100/test.index:
--------------------------------------------------------------------------------
1 | 0,14,22,30,37,45,53,63,68,76,90,93,112,128,135,155,157,164,171,181,199
2 | 3,5,41,44,55,65,66,77,81,84,94,105,126,134,139,141,148,159,160,169,205
3 | 8,12,34,35,47,48,58,78,80,98,100,101,113,132,133,136,144,149,173,182,206
4 | 4,6,31,52,56,59,60,71,72,82,119,120,123,130,165,168,172,176,180,193,202
5 | 11,13,20,29,32,46,57,83,106,111,115,118,124,129,142,150,162,166,177,191,203
6 | 10,17,25,38,39,43,74,79,85,88,95,96,102,151,154,156,158,178,185,200,204
7 | 1,15,19,33,61,67,73,104,107,110,117,121,122,138,140,145,146,153,188,201,207
8 | 9,18,27,36,50,54,64,70,75,86,89,108,125,137,152,161,167,170,183,195,196
9 | 7,21,24,26,40,42,49,92,97,103,114,116,131,147,184,186,187,190,194,198,208
10 | 2,16,23,28,51,62,69,87,91,99,109,127,143,163,174,175,179,189,192,197
11 |
--------------------------------------------------------------------------------
/data/ppmi_schaefer100/train.index:
--------------------------------------------------------------------------------
1 | 31,15,169,204,167,24,99,129,180,100,29,203,133,40,60,56,61,80,118,173,33,36,191,150,18,144,102,186,91,2,183,127,175,122,146,103,27,32,192,156,57,177,123,69,116,74,47,41,158,83,137,88,109,111,16,11,77,107,8,184,48,59,198,163,172,208,179,73,108,131,194,95,121,105,120,21,201,72,1,25,161,9,79,139,114,3,170,26,205,75,71,98,126,35,67,17,162,132,94,34,206,197,174,124,38,202,141,87,110,52,188,154,125,130,51,113,7,168,166,148,42,86,196,65,28,142,49,66,182,193,44,143,10,84,187,96,153,12,101,152,62,13,159,70,19,134,178,81,50,106,20,97,89,85,6,145,5,78,136,119,165,185,190,207,92,140,43
2 | 110,71,32,80,83,31,186,46,158,92,4,74,51,136,119,155,88,146,156,26,153,48,37,53,78,20,98,135,208,13,97,58,64,23,33,12,124,150,95,29,123,43,72,149,61,89,196,198,57,75,137,127,68,11,21,175,200,164,163,201,115,128,204,197,117,144,207,90,111,25,109,143,104,151,70,176,54,120,73,206,180,165,82,138,177,60,39,157,194,93,122,188,174,162,178,199,27,47,30,99,125,145,19,147,112,50,91,59,189,181,17,15,22,191,85,106,168,2,101,183,100,185,202,40,108,96,35,69,67,187,42,131,161,170,38,167,121,133,114,113,103,7,86,130,24,14,49,152,0,193,182,8,173,132,34,195,28,179,1,16,6,87,116,192,45,184,129
3 | 72,38,203,200,156,88,19,104,14,111,207,86,90,126,117,51,155,75,183,143,181,63,124,167,55,151,9,1,119,79,23,197,74,165,36,153,56,81,193,103,166,192,20,141,159,184,114,4,91,106,205,37,150,160,45,121,169,176,178,15,65,198,130,7,107,186,99,17,64,112,32,95,42,129,102,128,185,16,161,195,177,135,96,66,158,163,73,175,142,199,84,40,24,194,31,148,22,108,77,168,92,76,139,85,93,69,171,140,27,116,3,105,50,191,115,49,162,26,57,70,13,30,28,145,10,204,201,154,152,39,89,172,120,188,157,123,196,94,110,67,179,170,131,118,44,180,60,33,43,0,41,52,208,2,127,125,146,97,54,109,5,137,61,62,82,46,174
4 | 30,115,33,146,127,95,145,178,184,143,77,188,87,45,117,208,22,50,139,32,141,183,27,68,51,66,122,153,43,65,207,3,101,192,13,155,53,67,134,113,23,135,17,62,16,159,104,151,41,92,181,170,38,189,136,204,110,84,103,142,112,14,175,88,173,21,156,42,11,140,83,182,86,108,169,203,24,44,40,80,37,15,0,20,19,150,160,109,177,58,121,5,64,54,190,99,137,26,28,179,70,205,93,129,152,79,200,76,74,57,144,131,36,100,157,1,29,158,194,48,149,167,75,46,186,118,97,78,69,7,201,171,197,133,89,116,34,105,114,81,196,10,63,195,107,96,199,187,191,132,174,154,39,8,163,102,126,12,111,55,164,85,2,47,106,128,94
5 | 192,91,70,190,131,183,25,148,157,68,147,65,41,188,151,143,200,145,44,34,154,125,105,120,71,185,155,159,161,149,108,127,85,56,128,45,69,33,78,14,130,103,50,26,9,1,175,60,12,97,202,102,180,123,163,152,146,176,22,58,164,167,156,140,7,182,196,74,207,35,92,137,168,59,141,8,169,165,39,113,15,153,24,40,179,36,17,79,53,119,187,5,66,114,95,201,138,48,139,63,81,178,16,109,101,88,6,96,99,195,160,37,67,204,55,172,184,30,170,126,86,174,199,47,10,206,62,42,43,76,110,73,72,121,205,194,52,122,135,87,117,198,61,104,208,90,80,21,134,100,3,144,173,23,51,89,112,82,158,27,38,64,31,18,93,0,193
6 | 208,143,83,15,37,6,2,77,11,73,23,160,145,150,122,125,65,194,162,5,36,78,91,60,50,51,199,18,169,42,180,170,45,168,109,172,8,98,206,176,133,148,53,163,94,157,186,142,159,52,183,131,130,135,32,177,34,87,29,110,182,203,81,202,196,61,153,69,205,100,190,33,89,114,71,13,12,191,92,116,75,207,44,108,3,9,80,121,174,171,195,118,97,57,103,47,193,86,139,136,146,165,105,137,16,124,106,66,62,84,181,30,198,0,126,54,166,22,27,28,128,192,173,46,127,70,175,129,201,26,187,101,59,132,117,184,55,107,104,141,120,134,58,188,64,113,140,76,115,68,138,14,21,49,144,48,167,63,152,35,67,119,40,31,7,111,19
7 | 52,100,11,206,148,150,116,81,165,59,111,96,27,76,46,125,105,187,161,114,56,178,30,44,54,134,78,99,135,185,37,194,143,179,43,186,123,13,120,80,98,64,149,139,127,163,48,199,17,92,112,58,86,101,51,189,164,173,183,45,62,118,180,50,75,130,38,32,172,159,24,41,4,176,22,167,16,141,124,175,191,126,18,162,3,94,6,90,197,113,72,40,20,157,21,177,144,68,108,36,119,85,83,77,93,66,8,181,2,88,0,49,53,174,128,195,200,23,203,151,87,7,205,57,106,160,65,142,89,147,182,166,132,79,95,5,190,71,82,137,198,14,133,34,204,152,70,103,91,154,28,9,171,97,63,169,131,156,31,102,129,25,136,115,208,196,29
8 | 7,82,128,83,120,140,22,100,153,21,123,191,177,42,4,201,39,188,49,85,34,122,117,158,112,118,124,131,15,180,163,166,88,102,26,35,135,55,40,68,107,142,139,155,94,38,187,72,162,151,206,20,58,165,44,56,133,134,156,115,46,190,78,169,8,45,98,179,148,185,25,198,28,76,24,197,104,208,109,136,121,73,143,31,65,3,14,144,113,111,96,77,51,48,63,192,129,41,10,99,101,93,74,47,87,171,81,1,29,30,6,71,189,66,17,116,11,168,202,90,157,52,80,207,159,164,205,141,150,145,69,95,92,199,91,103,5,62,194,33,174,130,110,186,182,59,184,114,154,12,57,23,37,203,147,127,43,13,181,146,126,204,175,105,132,138,67
9 | 170,70,129,96,110,192,160,155,84,154,74,175,135,146,158,145,54,162,35,169,17,3,44,200,2,185,27,68,152,71,85,28,173,125,91,93,207,201,13,121,204,206,73,183,195,181,43,122,25,196,22,138,11,171,159,182,6,9,107,5,112,20,67,50,79,23,143,41,53,88,99,77,52,124,134,32,69,83,55,58,19,36,203,0,100,136,56,1,178,76,18,189,14,118,80,46,15,65,109,63,197,117,127,172,105,111,151,115,60,38,47,113,30,153,48,102,174,16,156,75,101,89,168,157,142,64,150,202,163,132,164,167,8,140,144,137,191,12,205,176,106,166,45,33,78,94,165,72,128,82,139,126,87,123,148,86,31,29,108,37,149,104,39,177,57,61,180
10 | 4,73,57,146,3,121,119,5,133,15,60,141,76,114,145,142,188,67,182,123,169,90,205,30,86,77,100,137,94,116,71,108,45,151,103,178,208,200,107,6,81,155,117,204,44,32,122,26,97,49,193,24,177,46,165,134,36,63,170,66,39,125,157,172,12,186,41,61,33,161,154,106,40,70,93,42,59,111,160,153,89,185,183,92,194,104,25,78,101,129,74,158,79,14,8,128,191,159,150,18,95,113,147,140,190,27,166,164,50,124,55,203,110,98,105,20,207,58,88,35,202,43,75,9,167,29,180,181,22,21,34,47,118,10,130,1,195,131,84,138,85,7,96,187,199,52,13,19,184,112,144,54,162,83,198,173,115,120,56,201,135,82,136,31,196,53,139,149
11 |
--------------------------------------------------------------------------------
/data/ppmi_schaefer100/val.index:
--------------------------------------------------------------------------------
1 | 82,55,138,149,117,176,46,39,147,4,58,151,115,200,104,160,189,195,23,64,54
2 | 102,107,62,76,203,142,154,166,63,18,140,172,10,9,79,190,118,56,36,171,52
3 | 202,29,18,11,147,122,190,164,25,189,71,187,87,83,134,21,59,6,138,68,53
4 | 90,138,166,206,198,185,161,73,91,98,61,125,49,25,148,147,9,162,18,35,124
5 | 171,28,19,77,4,107,54,132,181,116,2,186,136,197,98,189,133,84,49,94,75
6 | 112,56,197,4,179,24,155,93,147,99,20,82,1,189,161,41,72,90,149,164,123
7 | 184,10,26,74,170,168,202,60,109,12,193,192,35,158,47,69,42,84,39,55,155
8 | 16,176,149,53,172,200,0,19,97,79,173,60,160,119,2,106,32,61,84,178,193
9 | 90,161,133,120,179,199,62,4,59,34,10,98,130,119,66,95,141,51,81,188,193
10 | 0,126,148,48,65,11,17,206,152,156,72,37,64,168,102,38,132,171,68,80,176
11 |
--------------------------------------------------------------------------------
/data/taowu_schaefer100/test.index:
--------------------------------------------------------------------------------
1 | 0,16,26,31
2 | 1,9,23,27
3 | 10,14,22,24
4 | 15,18,25,30
5 | 3,4,21,38
6 | 6,11,34,36
7 | 17,19,28,35
8 | 5,8,33,37
9 | 2,7,20,39
10 | 12,13,29,32
11 |
--------------------------------------------------------------------------------
/data/taowu_schaefer100/train.index:
--------------------------------------------------------------------------------
1 | 3,19,20,34,36,5,32,17,38,21,13,2,9,22,28,29,24,12,14,1,35,11,18,37,39,30,27,4,8,15,23,7
2 | 12,5,8,35,11,21,3,28,22,29,30,15,18,39,17,26,38,0,14,16,2,31,34,4,36,7,10,33,37,20,32,13
3 | 19,20,13,17,33,30,29,34,26,37,6,2,32,28,35,8,15,18,12,38,36,21,27,16,25,1,5,0,11,39,7,4
4 | 5,14,4,39,17,16,29,9,26,31,36,21,11,38,22,3,10,34,2,37,12,7,19,8,23,35,27,0,20,33,1,32
5 | 13,1,37,15,22,35,39,14,0,26,11,8,36,23,20,17,2,34,33,24,16,10,30,5,27,6,31,18,25,28,7,9
6 | 18,38,13,37,10,26,3,16,19,28,15,7,27,5,32,1,17,39,29,9,25,4,0,20,14,33,12,21,35,24,30,22
7 | 22,14,37,11,8,13,20,25,29,34,39,7,30,12,3,36,4,16,9,33,1,23,2,6,31,21,5,27,18,15,38,26
8 | 4,32,6,38,1,13,36,12,24,0,9,25,35,31,22,20,28,30,19,2,23,15,10,3,18,7,34,16,26,11,39,21
9 | 0,12,21,36,17,14,9,11,27,28,25,4,5,29,1,19,31,8,26,22,34,23,32,10,15,24,16,18,35,30,37,6
10 | 7,9,26,10,27,23,6,1,36,31,17,4,3,28,22,25,5,34,0,33,39,16,11,24,20,30,2,15,18,35,19,37
11 |
--------------------------------------------------------------------------------
/data/taowu_schaefer100/val.index:
--------------------------------------------------------------------------------
1 | 33,6,10,25
2 | 25,6,19,24
3 | 23,31,3,9
4 | 24,13,28,6
5 | 29,19,12,32
6 | 8,31,2,23
7 | 32,24,10,0
8 | 29,14,27,17
9 | 3,38,13,33
10 | 14,38,8,21
11 |
--------------------------------------------------------------------------------
/figs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AngusMonroe/ContrastPool/f3df45d5fc1573b3b6a05a5c214a9552c34a37d7/figs/framework.png
--------------------------------------------------------------------------------
/layers/attention_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import csv
5 | import numpy as np
6 |
7 |
8 | class EncoderLayer(nn.Module):
9 | def __init__(self, hid_dim, n_heads, pf_dim, dropout, device, feat_dim, learnable_q=False, pos_enc=None):
10 | super().__init__()
11 |
12 | self.learnable_q = learnable_q
13 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
14 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
15 | self.dropout = nn.Dropout(dropout)
16 | self.q = torch.nn.Parameter(torch.ones([pf_dim, feat_dim, hid_dim])) if self.learnable_q else None
17 |
18 | def forward(self, src, src_mask=None):
19 | if self.learnable_q:
20 | _src, _ = self.self_attention(self.q, src, src, src_mask)
21 | else:
22 | _src, _ = self.self_attention(src, src, src, src_mask)
23 | src = self.self_attn_layer_norm(src + self.dropout(_src))
24 | # src = [batch size, src len, hid dim]
25 | return src
26 |
27 |
28 | class MultiHeadAttentionLayer(nn.Module):
29 | def __init__(self, hid_dim, n_heads, dropout, device):
30 | super().__init__()
31 |
32 | self.hid_dim = hid_dim
33 | self.n_heads = n_heads
34 |
35 | assert hid_dim % n_heads == 0
36 |
37 | self.w_q = nn.Linear(hid_dim, hid_dim)
38 | self.w_k = nn.Linear(hid_dim, hid_dim)
39 | self.w_v = nn.Linear(hid_dim, hid_dim)
40 |
41 | self.fc = nn.Linear(hid_dim, hid_dim)
42 |
43 | self.dropout = nn.Dropout(dropout)
44 |
45 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)
46 |
47 | def forward(self, query, key, value, mask=None):
48 |
49 | bsz = query.shape[0]
50 |
51 | Q = self.w_q(query)
52 | K = self.w_k(key)
53 | V = self.w_v(value)
54 |
55 | Q = Q.view(bsz, -1, self.n_heads, self.hid_dim //
56 | self.n_heads).permute(0, 2, 1, 3)
57 | K = K.view(bsz, -1, self.n_heads, self.hid_dim //
58 | self.n_heads).permute(0, 2, 1, 3)
59 | V = V.view(bsz, -1, self.n_heads, self.hid_dim //
60 | self.n_heads).permute(0, 2, 1, 3)
61 |
62 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
63 |
64 | if mask is not None:
65 | energy = energy.masked_fill(mask == 0, -1e10)
66 |
67 | attention = self.dropout(torch.softmax(energy, dim=-1))
68 |
69 |
70 | x = torch.matmul(attention, V)
71 |
72 | x = x.permute(0, 2, 1, 3).contiguous()
73 |
74 | x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))
75 |
76 | x = self.fc(x)
77 |
78 | return x, attention.squeeze()
79 |
80 |
81 | class PositionwiseFeedforwardLayer(nn.Module):
82 | def __init__(self, hid_dim, pf_dim, dropout):
83 | super().__init__()
84 |
85 | self.fc_1 = nn.Linear(hid_dim, pf_dim)
86 | self.fc_2 = nn.Linear(pf_dim, hid_dim)
87 |
88 | self.dropout = nn.Dropout(dropout)
89 |
90 | def forward(self, x):
91 | # x = [batch size, seq len, hid dim]
92 | x = self.dropout(torch.relu(self.fc_1(x)))
93 | # x = [batch size, seq len, pf dim]
94 | x = self.fc_2(x)
95 | # x = [batch size, seq len, hid dim]
96 |
97 | return x
98 |
99 |
100 | class PositionalEncoding(nn.Module):
101 | "Implement the PE function."
102 | def __init__(self, d_model, dropout, max_len=5000):
103 | super(PositionalEncoding, self).__init__()
104 | self.dropout = nn.Dropout(p=dropout)
105 |
106 | # Compute the positional encodings once in log space.
107 | pe = torch.zeros(max_len, d_model)
108 | position = torch.arange(0., max_len).unsqueeze(1)
109 | div_term = torch.exp(torch.arange(0., d_model, 2) *
110 | -(math.log(10000.0) / d_model))
111 | pe[:, 0::2] = torch.sin(position * div_term)
112 | pe[:, 1::2] = torch.cos(position * div_term)
113 | pe = pe.unsqueeze(0)
114 | self.register_buffer('pe', pe)
115 |
116 | def forward(self, x):
117 | x = x + self.pe[:, :x.size(1)]
118 | return self.dropout(x)
119 |
--------------------------------------------------------------------------------
/layers/contrastpool_layer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | import numpy as np
6 | from scipy.linalg import block_diag
7 | from torch.autograd import Function
8 | from layers.graphsage_layer import GraphSageLayer, DenseGraphSage
9 |
10 |
11 | def masked_softmax(matrix, mask, dim=-1, memory_efficient=True,
12 | mask_fill_value=-1e32):
13 | '''
14 | masked_softmax for dgl batch graph
15 | code snippet contributed by AllenNLP (https://github.com/allenai/allennlp)
16 | '''
17 | if mask is None:
18 | result = torch.nn.functional.softmax(matrix, dim=dim)
19 | else:
20 | mask = mask.float()
21 | while mask.dim() < matrix.dim():
22 | mask = mask.unsqueeze(1)
23 | if not memory_efficient:
24 | result = torch.nn.functional.softmax(matrix * mask, dim=dim)
25 | result = result * mask
26 | result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
27 | else:
28 | masked_matrix = matrix.masked_fill((1 - mask).byte(),
29 | mask_fill_value)
30 | result = torch.nn.functional.softmax(masked_matrix, dim=dim)
31 | return result
32 |
33 |
34 | class EntropyLoss(nn.Module):
35 | # Return Scalar
36 | # loss used in diffpool
37 | def forward(self, adj, anext, s_l):
38 | entropy = (torch.distributions.Categorical(
39 | probs=s_l).entropy()).sum(-1).mean(-1)
40 | assert not torch.isnan(entropy)
41 | return entropy
42 |
43 |
44 | class ContrastPoolLayer(nn.Module):
45 |
46 | def __init__(self, input_dim, assign_dim, output_feat_dim,
47 | activation, dropout, aggregator_type, link_pred, batch_norm, pool_assign='GraphSage', max_node_num=0):
48 | super().__init__()
49 | self.embedding_dim = input_dim
50 | self.assign_dim = assign_dim
51 | self.hidden_dim = output_feat_dim
52 | self.link_pred = link_pred
53 | self.feat_gc = GraphSageLayer(
54 | input_dim,
55 | output_feat_dim,
56 | activation,
57 | dropout,
58 | aggregator_type,
59 | batch_norm)
60 | if pool_assign == 'GraphSage':
61 | self.pool_gc = GraphSageLayer(
62 | input_dim,
63 | assign_dim,
64 | activation,
65 | dropout,
66 | aggregator_type,
67 | batch_norm)
68 | else:
69 | pass
70 | self.reg_loss = nn.ModuleList([])
71 | self.loss_log = {}
72 | self.reg_loss.append(EntropyLoss())
73 |
74 | # cs
75 | self.weight = nn.Parameter(torch.Tensor(max_node_num, assign_dim))
76 | self.bias = nn.Parameter(torch.Tensor(1, assign_dim))
77 | stdv = 1. / math.sqrt(self.weight.size(1))
78 | self.weight.data.uniform_(-stdv, stdv)
79 | self.bias.data.uniform_(-stdv, stdv)
80 |
81 | def forward(self, g, h, diff_h=None, adj=None, e=None):
82 | # h: [1000, 86]
83 | batch_size = len(g.batch_num_nodes())
84 | feat, e = self.feat_gc(g, h, e)
85 | device = feat.device
86 | # GCN
87 | if diff_h is not None:
88 | # print(diff_h.shape)
89 | # print(self.weight.shape)
90 | support = torch.matmul(diff_h, self.weight)
91 | if adj is not None:
92 | output = torch.matmul(adj.to(device), support)
93 | else:
94 | output = torch.matmul(g.adj().to_dense().clone().to(device), support.repeat(batch_size, 1))
95 | assign_tensor = output + self.bias
96 | else:
97 | assign_tensor, e = self.pool_gc(g, h, e)
98 | # assign_tensor: [2000, 50]
99 | # print(assign_tensor.shape)
100 |
101 | assign_tensor_masks = []
102 | assign_size = int(assign_tensor.size()[1]) if adj is not None else int(assign_tensor.size()[1] / batch_size)
103 | for g_n_nodes in g.batch_num_nodes():
104 | mask = torch.ones((g_n_nodes, assign_size))
105 | assign_tensor_masks.append(mask)
106 |
107 | """
108 | The first pooling layer is computed on batched graph.
109 | We first take the adjacency matrix of the batched graph, which is block-wise diagonal.
110 | We then compute the assignment matrix for the whole batch graph, which will also be block diagonal
111 | """
112 | mask = torch.FloatTensor(
113 | block_diag(
114 | *
115 | assign_tensor_masks)).to(
116 | device=device)
117 | if adj is not None:
118 | assign_tensor = assign_tensor.repeat(batch_size, batch_size)
119 |
120 | assign_tensor = masked_softmax(assign_tensor, mask, memory_efficient=False)
121 | h = torch.matmul(torch.t(assign_tensor), feat) # equation (3) of DIFFPOOL paper
122 | adj = g.adjacency_matrix(ctx=device)
123 |
124 | adj_new = torch.sparse.mm(adj, assign_tensor)
125 | adj_new = torch.mm(torch.t(assign_tensor), adj_new) # equation (4) of DIFFPOOL paper
126 |
127 | if self.link_pred:
128 | current_lp_loss = torch.norm(adj.to_dense() -
129 | torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(), 2)
130 | self.loss_log['LinkPredLoss'] = current_lp_loss
131 |
132 | for loss_layer in self.reg_loss:
133 | loss_name = str(type(loss_layer).__name__)
134 |
135 | self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor)
136 | return adj_new, h
137 |
138 |
139 | class LinkPredLoss(nn.Module):
140 | # loss used in diffpool
141 | def forward(self, adj, anext, s_l):
142 | link_pred_loss = (
143 | adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2))
144 | link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))
145 | return link_pred_loss.mean()
146 |
147 |
148 | class DenseDiffPool(nn.Module):
149 | def __init__(self, nfeat, nnext, nhid, link_pred=False, entropy=True):
150 | super().__init__()
151 | self.link_pred = link_pred
152 | self.log = {}
153 | self.link_pred_layer = LinkPredLoss()
154 | self.embed = DenseGraphSage(nfeat, nhid, use_bn=True)
155 | self.assign = DiffPoolAssignment(nfeat, nnext)
156 | self.reg_loss = nn.ModuleList([])
157 | self.loss_log = {}
158 | if link_pred:
159 | self.reg_loss.append(LinkPredLoss())
160 | if entropy:
161 | self.reg_loss.append(EntropyLoss())
162 |
163 | def forward(self, x, adj, log=False):
164 | z_l = self.embed(x, adj)
165 | s_l = self.assign(x, adj)
166 | if log:
167 | self.log['s'] = s_l.cpu().numpy()
168 | xnext = torch.matmul(s_l.transpose(-1, -2), z_l)
169 | anext = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l)
170 |
171 | for loss_layer in self.reg_loss:
172 | loss_name = str(type(loss_layer).__name__)
173 | self.loss_log[loss_name] = loss_layer(adj, anext, s_l)
174 | if log:
175 | self.log['a'] = anext.cpu().numpy()
176 | return xnext, anext
177 |
178 |
179 | class DiffPoolAssignment(nn.Module):
180 | def __init__(self, nfeat, nnext):
181 | super().__init__()
182 | self.assign_mat = DenseGraphSage(nfeat, nnext, use_bn=True)
183 |
184 | def forward(self, x, adj, log=False):
185 | s_l_init = self.assign_mat(x, adj)
186 | s_l = F.softmax(s_l_init, dim=-1)
187 | return s_l
188 |
--------------------------------------------------------------------------------
/layers/diffpool_layer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | import numpy as np
6 | from scipy.linalg import block_diag
7 |
8 | from torch.autograd import Function
9 |
10 | """
11 | DIFFPOOL:
12 | Z. Ying, J. You, C. Morris, X. Ren, W. Hamilton, and J. Leskovec,
13 | Hierarchical graph representation learning with differentiable pooling (NeurIPS 2018)
14 | https://arxiv.org/pdf/1806.08804.pdf
15 |
16 | ! code started from dgl diffpool examples dir
17 | """
18 |
19 | from layers.graphsage_layer import GraphSageLayer, DenseGraphSage
20 |
21 |
22 | def masked_softmax(matrix, mask, dim=-1, memory_efficient=True,
23 | mask_fill_value=-1e32):
24 | '''
25 | masked_softmax for dgl batch graph
26 | code snippet contributed by AllenNLP (https://github.com/allenai/allennlp)
27 | '''
28 | if mask is None:
29 | result = torch.nn.functional.softmax(matrix, dim=dim)
30 | else:
31 | mask = mask.float()
32 | while mask.dim() < matrix.dim():
33 | mask = mask.unsqueeze(1)
34 | if not memory_efficient:
35 | result = torch.nn.functional.softmax(matrix * mask, dim=dim)
36 | result = result * mask
37 | result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
38 | else:
39 | masked_matrix = matrix.masked_fill((1 - mask).byte(),
40 | mask_fill_value)
41 | result = torch.nn.functional.softmax(masked_matrix, dim=dim)
42 | return result
43 |
44 |
45 | class EntropyLoss(nn.Module):
46 | # Return Scalar
47 | # loss used in diffpool
48 | def forward(self, adj, anext, s_l):
49 | entropy = (torch.distributions.Categorical(
50 | probs=s_l).entropy()).sum(-1).mean(-1)
51 | assert not torch.isnan(entropy)
52 | return entropy
53 |
54 |
55 | class DiffPoolLayer(nn.Module):
56 |
57 | def __init__(self, input_dim, assign_dim, output_feat_dim,
58 | activation, dropout, aggregator_type, link_pred, batch_norm, pool_assign='GraphSage'):
59 | super().__init__()
60 | self.embedding_dim = input_dim
61 | self.assign_dim = assign_dim
62 | self.hidden_dim = output_feat_dim
63 | self.link_pred = link_pred
64 | self.feat_gc = GraphSageLayer(
65 | input_dim,
66 | output_feat_dim,
67 | activation,
68 | dropout,
69 | aggregator_type,
70 | batch_norm)
71 | if pool_assign == 'GraphSage':
72 | self.pool_gc = GraphSageLayer(
73 | input_dim,
74 | assign_dim,
75 | activation,
76 | dropout,
77 | aggregator_type,
78 | batch_norm)
79 | else:
80 | pass
81 | self.reg_loss = nn.ModuleList([])
82 | self.loss_log = {}
83 | self.reg_loss.append(EntropyLoss())
84 |
85 | def forward(self, g, h, e=None):
86 | # h: [1000, 86]
87 | feat, e = self.feat_gc(g, h, e)
88 | device = feat.device
89 | assign_tensor, e = self.pool_gc(g, h, e)
90 |
91 | assign_tensor_masks = []
92 | batch_size = len(g.batch_num_nodes())
93 | for g_n_nodes in g.batch_num_nodes():
94 | mask = torch.ones((g_n_nodes,
95 | int(assign_tensor.size()[1] / batch_size)))
96 | assign_tensor_masks.append(mask)
97 | """
98 | The first pooling layer is computed on batched graph.
99 | We first take the adjacency matrix of the batched graph, which is block-wise diagonal.
100 | We then compute the assignment matrix for the whole batch graph, which will also be block diagonal
101 | """
102 | mask = torch.FloatTensor(
103 | block_diag(
104 | *
105 | assign_tensor_masks)).to(
106 | device=device)
107 |
108 | assign_tensor = masked_softmax(assign_tensor, mask,
109 | memory_efficient=False)
110 | # print(assign_tensor.shape)
111 | h = torch.matmul(torch.t(assign_tensor), feat) # equation (3) of DIFFPOOL paper
112 | adj = g.adjacency_matrix(ctx=device)
113 |
114 | adj_new = torch.sparse.mm(adj, assign_tensor)
115 | adj_new = torch.mm(torch.t(assign_tensor), adj_new) # equation (4) of DIFFPOOL paper
116 |
117 | if self.link_pred:
118 | current_lp_loss = torch.norm(adj.to_dense() -
119 | torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(), 2)
120 | self.loss_log['LinkPredLoss'] = current_lp_loss
121 |
122 | for loss_layer in self.reg_loss:
123 | loss_name = str(type(loss_layer).__name__)
124 |
125 | self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor)
126 | return adj_new, h
127 |
128 |
129 | class LinkPredLoss(nn.Module):
130 | # loss used in diffpool
131 | def forward(self, adj, anext, s_l):
132 | link_pred_loss = (
133 | adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2))
134 | link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))
135 | return link_pred_loss.mean()
136 |
137 |
138 | class DenseDiffPool(nn.Module):
139 | def __init__(self, nfeat, nnext, nhid, link_pred=False, entropy=True):
140 | super().__init__()
141 | self.link_pred = link_pred
142 | self.log = {}
143 | self.link_pred_layer = self.LinkPredLoss()
144 | self.embed = DenseGraphSage(nfeat, nhid, use_bn=True)
145 | self.assign = DiffPoolAssignment(nfeat, nnext)
146 | self.reg_loss = nn.ModuleList([])
147 | self.loss_log = {}
148 | if link_pred:
149 | self.reg_loss.append(LinkPredLoss())
150 | if entropy:
151 | self.reg_loss.append(EntropyLoss())
152 |
153 | def forward(self, x, adj, log=False):
154 | z_l = self.embed(x, adj)
155 | s_l = self.assign(x, adj)
156 | if log:
157 | self.log['s'] = s_l.cpu().numpy()
158 | xnext = torch.matmul(s_l.transpose(-1, -2), z_l)
159 | anext = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l)
160 |
161 | for loss_layer in self.reg_loss:
162 | loss_name = str(type(loss_layer).__name__)
163 | self.loss_log[loss_name] = loss_layer(adj, anext, s_l)
164 | if log:
165 | self.log['a'] = anext.cpu().numpy()
166 | return xnext, anext
167 |
168 |
169 | class DiffPoolAssignment(nn.Module):
170 | def __init__(self, nfeat, nnext):
171 | super().__init__()
172 | self.assign_mat = DenseGraphSage(nfeat, nnext, use_bn=True)
173 |
174 | def forward(self, x, adj, log=False):
175 | s_l_init = self.assign_mat(x, adj)
176 | s_l = F.softmax(s_l_init, dim=-1)
177 | return s_l
178 |
--------------------------------------------------------------------------------
/layers/graphsage_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl.function as fn
6 | from dgl.nn.pytorch import SAGEConv
7 |
8 | """
9 | GraphSAGE:
10 | William L. Hamilton, Rex Ying, Jure Leskovec, Inductive Representation Learning on Large Graphs (NeurIPS 2017)
11 | https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
12 | """
13 |
14 | class GraphSageLayer(nn.Module):
15 |
16 | def __init__(self, in_feats, out_feats, activation, dropout,
17 | aggregator_type, batch_norm, residual=False,
18 | bias=True, dgl_builtin=False, e_feat=False):
19 | super().__init__()
20 | self.in_channels = in_feats
21 | self.out_channels = out_feats
22 | self.aggregator_type = aggregator_type
23 | self.batch_norm = batch_norm
24 | self.residual = residual
25 | self.dgl_builtin = dgl_builtin
26 |
27 | if in_feats != out_feats:
28 | self.residual = False
29 |
30 | self.dropout = nn.Dropout(p=dropout)
31 |
32 | self.message_func = fn.copy_src(src='h', out='m') if not e_feat else fn.u_mul_e('h', 'e', 'm')
33 |
34 | if dgl_builtin == False:
35 | self.nodeapply = NodeApply(in_feats, out_feats, activation, dropout,
36 | bias=bias)
37 | if aggregator_type == "maxpool":
38 | self.aggregator = MaxPoolAggregator(in_feats, in_feats,
39 | activation, bias)
40 | elif aggregator_type == "lstm":
41 | self.aggregator = LSTMAggregator(in_feats, in_feats)
42 | else:
43 | self.aggregator = MeanAggregator()
44 | else:
45 | self.sageconv = SAGEConv(in_feats, out_feats, aggregator_type,
46 | dropout, activation=activation)
47 |
48 | if self.batch_norm:
49 | self.batchnorm_h = nn.BatchNorm1d(out_feats)
50 | self.batchnorm_e = nn.BatchNorm1d(out_feats)
51 |
52 | def forward(self, g, h, e=None):
53 | h_in = h # for residual connection
54 | # e_in = e
55 |
56 | if self.dgl_builtin == False:
57 | h = self.dropout(h)
58 | # e = self.dropout(e)
59 | g.ndata['h'] = h
60 | # g.edata['e'] = e
61 | g.update_all(fn.copy_src(src='h', out='m'),
62 | self.aggregator,
63 | self.nodeapply)
64 |
65 | h = g.ndata['h']
66 | else:
67 | h = self.sageconv(g, h)
68 |
69 | if self.batch_norm:
70 | h = self.batchnorm_h(g, h)
71 |
72 | if self.residual:
73 | h = h_in + h # residual connection
74 |
75 | return h, e
76 |
77 | def __repr__(self):
78 | return '{}(in_channels={}, out_channels={}, aggregator={}, residual={})'.format(self.__class__.__name__,
79 | self.in_channels,
80 | self.out_channels, self.aggregator_type, self.residual)
81 |
82 |
83 |
84 | """
85 | Aggregators for GraphSage
86 | """
87 | class Aggregator(nn.Module):
88 | """
89 | Base Aggregator class.
90 | """
91 |
92 | def __init__(self):
93 | super().__init__()
94 |
95 | def forward(self, node):
96 | neighbour = node.mailbox['m']
97 | c = self.aggre(neighbour)
98 | return {"c": c}
99 |
100 | def aggre(self, neighbour):
101 | # N x F
102 | raise NotImplementedError
103 |
104 |
105 | class MeanAggregator(Aggregator):
106 | """
107 | Mean Aggregator for graphsage
108 | """
109 |
110 | def __init__(self):
111 | super().__init__()
112 |
113 | def aggre(self, neighbour):
114 | mean_neighbour = torch.mean(neighbour, dim=1)
115 | return mean_neighbour
116 |
117 |
118 | class MaxPoolAggregator(Aggregator):
119 | """
120 | Maxpooling aggregator for graphsage
121 | """
122 |
123 | def __init__(self, in_feats, out_feats, activation, bias):
124 | super().__init__()
125 | self.linear = nn.Linear(in_feats, out_feats, bias=bias)
126 | self.activation = activation
127 |
128 | def aggre(self, neighbour):
129 | neighbour = self.linear(neighbour)
130 | if self.activation:
131 | neighbour = self.activation(neighbour)
132 | maxpool_neighbour = torch.max(neighbour, dim=1)[0]
133 | return maxpool_neighbour
134 |
135 |
136 | class LSTMAggregator(Aggregator):
137 | """
138 | LSTM aggregator for graphsage
139 | """
140 |
141 | def __init__(self, in_feats, hidden_feats):
142 | super().__init__()
143 | self.lstm = nn.LSTM(in_feats, hidden_feats, batch_first=True)
144 | self.hidden_dim = hidden_feats
145 | self.hidden = self.init_hidden()
146 |
147 | nn.init.xavier_uniform_(self.lstm.weight,
148 | gain=nn.init.calculate_gain('relu'))
149 |
150 | def init_hidden(self):
151 | """
152 | Defaulted to initialite all zero
153 | """
154 | return (torch.zeros(1, 1, self.hidden_dim),
155 | torch.zeros(1, 1, self.hidden_dim))
156 |
157 | def aggre(self, neighbours):
158 | """
159 | aggregation function
160 | """
161 | # N X F
162 | rand_order = torch.randperm(neighbours.size()[1])
163 | neighbours = neighbours[:, rand_order, :]
164 |
165 | (lstm_out, self.hidden) = self.lstm(neighbours.view(neighbours.size()[0], neighbours.size()[1], -1))
166 | return lstm_out[:, -1, :]
167 |
168 | def forward(self, node):
169 | neighbour = node.mailbox['m']
170 | c = self.aggre(neighbour)
171 | return {"c": c}
172 |
173 |
174 | class NodeApply(nn.Module):
175 | """
176 | Works -> the node_apply function in DGL paradigm
177 | """
178 |
179 | def __init__(self, in_feats, out_feats, activation, dropout, bias=True):
180 | super().__init__()
181 | self.dropout = nn.Dropout(p=dropout)
182 | self.linear = nn.Linear(in_feats * 2, out_feats, bias)
183 | self.activation = activation
184 |
185 | def concat(self, h, aggre_result):
186 | bundle = torch.cat((h, aggre_result), 1)
187 | bundle = self.linear(bundle)
188 | return bundle
189 |
190 | def forward(self, node):
191 | h = node.data['h']
192 | c = node.data['c']
193 | bundle = self.concat(h, c)
194 | bundle = F.normalize(bundle, p=2, dim=1)
195 | if self.activation:
196 | bundle = self.activation(bundle)
197 | return {"h": bundle}
198 |
199 |
200 | class GraphSageLayerEdgeFeat(nn.Module):
201 |
202 | def __init__(self, in_feats, out_feats, activation, dropout,
203 | aggregator_type, batch_norm, residual=False,
204 | bias=True, dgl_builtin=False):
205 | super().__init__()
206 | self.in_channels = in_feats
207 | self.out_channels = out_feats
208 | self.batch_norm = batch_norm
209 | self.residual = residual
210 |
211 | if in_feats != out_feats:
212 | self.residual = False
213 |
214 | self.dropout = nn.Dropout(p=dropout)
215 |
216 | self.activation = activation
217 |
218 | self.A = nn.Linear(in_feats, out_feats, bias=bias)
219 | self.B = nn.Linear(in_feats, out_feats, bias=bias)
220 |
221 | self.nodeapply = NodeApply(in_feats, out_feats, activation, dropout, bias=bias)
222 |
223 | if self.batch_norm:
224 | self.batchnorm_h = nn.BatchNorm1d(out_feats)
225 |
226 | def message_func(self, edges):
227 | Ah_j = edges.src['Ah']
228 | e_ij = edges.src['Bh'] + edges.dst['Bh'] # e_ij = Bhi + Bhj
229 | edges.data['e'] = e_ij
230 | return {'Ah_j' : Ah_j, 'e_ij' : e_ij}
231 |
232 | def reduce_func(self, nodes):
233 | # Anisotropic MaxPool aggregation
234 |
235 | Ah_j = nodes.mailbox['Ah_j']
236 | e = nodes.mailbox['e_ij']
237 | sigma_ij = torch.sigmoid(e) # sigma_ij = sigmoid(e_ij)
238 |
239 | Ah_j = sigma_ij * Ah_j
240 | if self.activation:
241 | Ah_j = self.activation(Ah_j)
242 |
243 | c = torch.max(Ah_j, dim=1)[0]
244 | return {'c' : c}
245 |
246 | def forward(self, g, h):
247 | h_in = h # for residual connection
248 | h = self.dropout(h)
249 |
250 | g.ndata['h'] = h
251 | g.ndata['Ah'] = self.A(h)
252 | g.ndata['Bh'] = self.B(h)
253 | g.update_all(self.message_func,
254 | self.reduce_func,
255 | self.nodeapply)
256 | h = g.ndata['h']
257 |
258 | if self.batch_norm:
259 | h = self.batchnorm_h(h)
260 |
261 | if self.residual:
262 | h = h_in + h # residual connection
263 |
264 | return h
265 |
266 | def __repr__(self):
267 | return '{}(in_channels={}, out_channels={}, residual={})'.format(
268 | self.__class__.__name__,
269 | self.in_channels,
270 | self.out_channels,
271 | self.residual)
272 |
273 |
274 | ##############################################################
275 |
276 |
277 | class GraphSageLayerEdgeReprFeat(nn.Module):
278 |
279 | def __init__(self, in_feats, out_feats, activation, dropout,
280 | aggregator_type, batch_norm, residual=False,
281 | bias=True, dgl_builtin=False):
282 | super().__init__()
283 | self.in_channels = in_feats
284 | self.out_channels = out_feats
285 | self.batch_norm = batch_norm
286 | self.residual = residual
287 |
288 | if in_feats != out_feats:
289 | self.residual = False
290 |
291 | self.dropout = nn.Dropout(p=dropout)
292 |
293 | self.activation = activation
294 |
295 | self.A = nn.Linear(in_feats, out_feats, bias=bias)
296 | self.B = nn.Linear(in_feats, out_feats, bias=bias)
297 | self.C = nn.Linear(in_feats, out_feats, bias=bias)
298 |
299 | self.nodeapply = NodeApply(in_feats, out_feats, activation, dropout, bias=bias)
300 |
301 | if self.batch_norm:
302 | self.batchnorm_h = nn.BatchNorm1d(out_feats)
303 | self.batchnorm_e = nn.BatchNorm1d(out_feats)
304 |
305 | def message_func(self, edges):
306 | Ah_j = edges.src['Ah']
307 | e_ij = edges.data['Ce'] + edges.src['Bh'] + edges.dst['Bh'] # e_ij = Ce_ij + Bhi + Bhj
308 | edges.data['e'] = e_ij
309 | return {'Ah_j' : Ah_j, 'e_ij' : e_ij}
310 |
311 | def reduce_func(self, nodes):
312 | # Anisotropic MaxPool aggregation
313 |
314 | Ah_j = nodes.mailbox['Ah_j']
315 | e = nodes.mailbox['e_ij']
316 | sigma_ij = torch.sigmoid(e) # sigma_ij = sigmoid(e_ij)
317 |
318 | Ah_j = sigma_ij * Ah_j
319 | if self.activation:
320 | Ah_j = self.activation(Ah_j)
321 |
322 | c = torch.max(Ah_j, dim=1)[0]
323 | return {'c' : c}
324 |
325 | def forward(self, g, h, e):
326 | h_in = h # for residual connection
327 | e_in = e
328 | h = self.dropout(h)
329 |
330 | g.ndata['h'] = h
331 | g.ndata['Ah'] = self.A(h)
332 | g.ndata['Bh'] = self.B(h)
333 | g.edata['e'] = e
334 | g.edata['Ce'] = self.C(e)
335 | g.update_all(self.message_func,
336 | self.reduce_func,
337 | self.nodeapply)
338 | h = g.ndata['h']
339 | e = g.edata['e']
340 |
341 | if self.activation:
342 | e = self.activation(e) # non-linear activation
343 |
344 | if self.batch_norm:
345 | h = self.batchnorm_h(h)
346 | e = self.batchnorm_e(e)
347 |
348 | if self.residual:
349 | h = h_in + h # residual connection
350 | e = e_in + e # residual connection
351 |
352 | return h, e
353 |
354 | def __repr__(self):
355 | return '{}(in_channels={}, out_channels={}, residual={})'.format(
356 | self.__class__.__name__,
357 | self.in_channels,
358 | self.out_channels,
359 | self.residual)
360 |
361 |
362 | class DenseGraphSage(nn.Module):
363 | def __init__(self, infeat, outfeat, residual=False, use_bn=True,
364 | mean=False, add_self=False):
365 | super().__init__()
366 | self.add_self = add_self
367 | self.use_bn = use_bn
368 | self.mean = mean
369 | self.residual = residual
370 |
371 | if infeat != outfeat:
372 | self.residual = False
373 |
374 | self.W = nn.Linear(infeat, outfeat, bias=True)
375 |
376 | nn.init.xavier_uniform_(
377 | self.W.weight,
378 | gain=nn.init.calculate_gain('relu'))
379 |
380 | def forward(self, x, adj):
381 | h_in = x # for residual connection
382 |
383 | if self.use_bn and not hasattr(self, 'bn'):
384 | self.bn = nn.BatchNorm1d(adj.size(1)).to(adj.device)
385 |
386 | if self.add_self:
387 | adj = adj + torch.eye(adj.size(0)).to(adj.device)
388 |
389 | if self.mean:
390 | adj = adj / adj.sum(1, keepdim=True)
391 |
392 | h_k_N = torch.matmul(adj, x)
393 | h_k = self.W(h_k_N)
394 | h_k = F.normalize(h_k, dim=2, p=2)
395 | h_k = F.relu(h_k)
396 |
397 | if self.residual:
398 | h_k = h_in + h_k # residual connection
399 |
400 | if self.use_bn:
401 | h_k = self.bn(h_k)
402 | return h_k
403 |
404 | def __repr__(self):
405 | if self.use_bn:
406 | return 'BN' + super(DenseGraphSage, self).__repr__()
407 | else:
408 | return super(DenseGraphSage, self).__repr__()
409 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import socket
4 | import time
5 | import random
6 | import glob
7 | import argparse, json
8 | import dgl
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 | from torch.utils.data import DataLoader
14 | from tensorboardX import SummaryWriter
15 | from tqdm import tqdm
16 | from nets.load_net import gnn_model # import GNNs
17 | from data.data import LoadData # import dataset
18 |
19 |
20 | def gpu_setup(use_gpu, gpu_id):
21 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
22 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
23 |
24 | if torch.cuda.is_available() and use_gpu:
25 | print('cuda available with GPU:', torch.cuda.get_device_name(0))
26 | device = torch.device("cuda")
27 | else:
28 | print('cuda not available')
29 | device = torch.device("cpu")
30 | return device
31 |
32 |
33 | def view_model_param(MODEL_NAME, net_params):
34 | model = gnn_model(MODEL_NAME, net_params)
35 | total_param = 0
36 | print("MODEL DETAILS:\n")
37 | for param in model.parameters():
38 | total_param += np.prod(list(param.data.size()))
39 | print('MODEL/Total parameters:', MODEL_NAME, total_param)
40 | return total_param
41 |
42 |
43 | def train_val_pipeline(MODEL_NAME, DATASET_NAME, params, net_params, dirs):
44 | avg_test_acc = []
45 | avg_train_acc = []
46 | avg_convergence_epochs = []
47 |
48 | t0 = time.time()
49 | per_epoch_time = []
50 |
51 | dataset = LoadData(DATASET_NAME, threshold=params['threshold'], node_feat_transform=params['node_feat_transform'])
52 |
53 | trainset, valset, testset = dataset.train, dataset.val, dataset.test
54 |
55 | root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs
56 | device = net_params['device']
57 |
58 | # Write the network and optimization hyper-parameters in folder config/
59 | with open(write_config_file + '.txt', 'w') as f:
60 | f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""".format(DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param']))
61 |
62 | # At any point you can hit Ctrl + C to break out of training early.
63 | try:
64 | for split_number in range(10):
65 | t0_split = time.time()
66 | log_dir = os.path.join(root_log_dir, "RUN_" + str(split_number))
67 | writer = SummaryWriter(log_dir=log_dir)
68 |
69 | # setting seeds
70 | random.seed(params['seed'])
71 | np.random.seed(params['seed'])
72 | torch.manual_seed(params['seed'])
73 | if device.type == 'cuda':
74 | torch.cuda.manual_seed(params['seed'])
75 |
76 | print("RUN NUMBER: ", split_number)
77 | trainset, valset, testset = dataset.train[split_number], dataset.val[split_number], dataset.test[split_number]
78 | print("Training Graphs: ", len(trainset))
79 | print("Validation Graphs: ", len(valset))
80 | print("Test Graphs: ", len(testset))
81 | print("Number of Classes: ", net_params['n_classes'])
82 |
83 | model = gnn_model(MODEL_NAME, net_params)
84 | model = model.to(device)
85 | if net_params['contrast'] and MODEL_NAME in ['ContrastPool']:
86 | model.cal_contrast(trainset, device)
87 | optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
88 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
89 | factor=params['lr_reduce_factor'],
90 | patience=params['lr_schedule_patience'],
91 | verbose=True)
92 |
93 | epoch_train_losses, epoch_val_losses = [], []
94 | epoch_train_accs, epoch_val_accs = [], []
95 |
96 | # batching exception for Diffpool
97 | drop_last = True if MODEL_NAME in ['DiffPool', 'ContrastPool'] else False
98 |
99 | from train_TUs_graph_classification import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network
100 |
101 | train_loader = DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, drop_last=drop_last, collate_fn=dataset.collate)
102 | val_loader = DataLoader(valset, batch_size=params['batch_size'], shuffle=False, drop_last=drop_last, collate_fn=dataset.collate)
103 | test_loader = DataLoader(testset, batch_size=params['batch_size'], shuffle=False, drop_last=drop_last, collate_fn=dataset.collate)
104 |
105 | with tqdm(range(params['epochs'])) as t:
106 | for epoch in t:
107 |
108 | t.set_description('Epoch %d' % epoch)
109 |
110 | start = time.time()
111 |
112 | epoch_train_loss, epoch_train_acc, optimizer = train_epoch(model, optimizer, device, train_loader, epoch)
113 |
114 | epoch_val_loss, epoch_val_acc = evaluate_network(model, device, val_loader, epoch)
115 | _, epoch_test_acc = evaluate_network(model, device, test_loader, epoch)
116 |
117 | epoch_train_losses.append(epoch_train_loss)
118 | epoch_val_losses.append(epoch_val_loss)
119 | epoch_train_accs.append(epoch_train_acc)
120 | epoch_val_accs.append(epoch_val_acc)
121 |
122 | writer.add_scalar('train/_loss', epoch_train_loss, epoch)
123 | writer.add_scalar('val/_loss', epoch_val_loss, epoch)
124 | writer.add_scalar('train/_acc', epoch_train_acc, epoch)
125 | writer.add_scalar('val/_acc', epoch_val_acc, epoch)
126 | writer.add_scalar('test/_acc', epoch_test_acc, epoch)
127 | writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
128 |
129 | _, epoch_test_acc = evaluate_network(model, device, test_loader, epoch)
130 | t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'],
131 | train_loss=epoch_train_loss, val_loss=epoch_val_loss,
132 | train_acc=epoch_train_acc, val_acc=epoch_val_acc,
133 | test_acc=epoch_test_acc)
134 |
135 | per_epoch_time.append(time.time()-start)
136 |
137 | # Saving checkpoint
138 | ckpt_dir = os.path.join(root_ckpt_dir, "RUN_" + str(split_number))
139 | if not os.path.exists(ckpt_dir):
140 | os.makedirs(ckpt_dir)
141 |
142 | torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))
143 | if MODEL_NAME in ['ContrastPool']:
144 | torch.save(model.ad_adj, '{}.pkl'.format(log_dir + "/epoch_" + str(epoch)))
145 |
146 | adj_files = glob.glob(log_dir + '/*.pkl')
147 | for adj_file in adj_files:
148 | epoch_nb = adj_file.split('_')[-1]
149 | epoch_nb = int(epoch_nb.split('.')[0])
150 | if epoch_nb < epoch - 1:
151 | os.remove(adj_file)
152 |
153 | files = glob.glob(ckpt_dir + '/*.pkl')
154 | for file in files:
155 | epoch_nb = file.split('_')[-1]
156 | epoch_nb = int(epoch_nb.split('.')[0])
157 | if epoch_nb < epoch-1:
158 | os.remove(file)
159 |
160 | scheduler.step(epoch_val_loss)
161 |
162 | if optimizer.param_groups[0]['lr'] < params['min_lr']:
163 | print("\n!! LR EQUAL TO MIN LR SET.")
164 | break
165 |
166 | # Stop training after params['max_time'] hours
167 | if time.time()-t0_split > params['max_time']*3600/10: # Dividing max_time by 10, since there are 10 runs in TUs
168 | print('-' * 89)
169 | print("Max_time for one train-val-test split experiment elapsed {:.3f} hours, so stopping".format(params['max_time']/10))
170 | break
171 |
172 | _, test_acc = evaluate_network(model, device, test_loader, epoch)
173 | _, train_acc = evaluate_network(model, device, train_loader, epoch)
174 | avg_test_acc.append(test_acc)
175 | avg_train_acc.append(train_acc)
176 | avg_convergence_epochs.append(epoch)
177 |
178 | print("Test Accuracy [LAST EPOCH]: {:.4f}".format(test_acc))
179 | print("Train Accuracy [LAST EPOCH]: {:.4f}".format(train_acc))
180 | print("Convergence Time (Epochs): {:.4f}".format(epoch))
181 |
182 | except KeyboardInterrupt:
183 | print('-' * 89)
184 | print('Exiting from training early because of KeyboardInterrupt')
185 |
186 |
187 | print("TOTAL TIME TAKEN: {:.4f}hrs".format((time.time()-t0)/3600))
188 | print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))
189 | print("AVG CONVERGENCE Time (Epochs): {:.4f}".format(np.mean(np.array(avg_convergence_epochs))))
190 | # Final test accuracy value averaged over 10-fold
191 | print("""\n\n\nFINAL RESULTS\n\nTEST ACCURACY averaged: {:.4f} with s.d. {:.4f}""".format(np.mean(np.array(avg_test_acc))*100, np.std(avg_test_acc)*100))
192 | print("\nAll splits Test Accuracies:\n", avg_test_acc)
193 | print("""\n\n\nFINAL RESULTS\n\nTRAIN ACCURACY averaged: {:.4f} with s.d. {:.4f}""".format(np.mean(np.array(avg_train_acc))*100, np.std(avg_train_acc)*100))
194 | print("\nAll splits Train Accuracies:\n", avg_train_acc)
195 |
196 | writer.close()
197 |
198 | """
199 | Write the results in out/results folder
200 | """
201 | with open(write_file_name + '.txt', 'w') as f:
202 | f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n edge_num: {}\n\n
203 | FINAL RESULTS\nTEST ACCURACY averaged: {:.4f} with s.d. {:.4f}\nTRAIN ACCURACY averaged: {:.4f} with s.d. {:.4f}\n\n
204 | Average Convergence Time (Epochs): {:.4f} with s.d. {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\nAll Splits Test Accuracies: {}""" \
205 | .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'], len(trainset[0][0].edata['feat']),
206 | np.mean(np.array(avg_test_acc))*100, np.std(avg_test_acc)*100,
207 | np.mean(np.array(avg_train_acc))*100, np.std(avg_train_acc)*100,
208 | np.mean(avg_convergence_epochs), np.std(avg_convergence_epochs),
209 | (time.time()-t0)/3600, np.mean(per_epoch_time), avg_test_acc))
210 |
211 |
212 | def main():
213 | """
214 | USER CONTROLS
215 | """
216 | parser = argparse.ArgumentParser()
217 | parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details")
218 | parser.add_argument('--gpu_id', help="Please give a value for gpu id")
219 | parser.add_argument('--model', help="Please give a value for model name")
220 | parser.add_argument('--dataset', help="Please give a value for dataset name")
221 | parser.add_argument('--out_dir', help="Please give a value for out_dir")
222 | parser.add_argument('--seed', help="Please give a value for seed")
223 | parser.add_argument('--epochs', help="Please give a value for epochs")
224 | parser.add_argument('--batch_size', help="Please give a value for batch_size")
225 | parser.add_argument('--init_lr', help="Please give a value for init_lr")
226 | parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor")
227 | parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience")
228 | parser.add_argument('--min_lr', help="Please give a value for min_lr")
229 | parser.add_argument('--weight_decay', help="Please give a value for weight_decay")
230 | parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval")
231 | parser.add_argument('--L', help="Please give a value for L")
232 | parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim")
233 | parser.add_argument('--out_dim', help="Please give a value for out_dim")
234 | parser.add_argument('--residual', help="Please give a value for residual")
235 | parser.add_argument('--edge_feat', help="Please give a value for edge_feat")
236 | parser.add_argument('--readout', help="Please give a value for readout")
237 | parser.add_argument('--kernel', help="Please give a value for kernel")
238 | parser.add_argument('--n_heads', help="Please give a value for n_heads")
239 | parser.add_argument('--gated', help="Please give a value for gated")
240 | parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout")
241 | parser.add_argument('--dropout', help="Please give a value for dropout")
242 | parser.add_argument('--layer_norm', help="Please give a value for layer_norm")
243 | parser.add_argument('--batch_norm', help="Please give a value for batch_norm")
244 | parser.add_argument('--sage_aggregator', help="Please give a value for sage_aggregator")
245 | parser.add_argument('--data_mode', help="Please give a value for data_mode")
246 | parser.add_argument('--num_pool', help="Please give a value for num_pool")
247 | parser.add_argument('--gnn_per_block', help="Please give a value for gnn_per_block")
248 | parser.add_argument('--embedding_dim', help="Please give a value for embedding_dim")
249 | parser.add_argument('--pool_ratio', help="Please give a value for pool_ratio")
250 | parser.add_argument('--linkpred', help="Please give a value for linkpred")
251 | parser.add_argument('--cat', help="Please give a value for cat")
252 | parser.add_argument('--self_loop', help="Please give a value for self_loop")
253 | parser.add_argument('--max_time', help="Please give a value for max_time")
254 | parser.add_argument('--threshold', type=float, help="Please give a threshold to drop edge", default=0.3)
255 | parser.add_argument('--edge_ratio', type=float, help="Please give a ratio to drop edge", default=0)
256 | parser.add_argument('--node_feat_transform', help="Please give a value for node feature transform", default='original')
257 | parser.add_argument('--contrast', default=False, action='store_true')
258 | parser.add_argument('--pooling', type=float, default=0.5)
259 | parser.add_argument('--lambda1', type=float, default=0.001)
260 | parser.add_argument('--learnable_q', default=False, action='store_true')
261 | args = parser.parse_args()
262 | with open(args.config) as f:
263 | config = json.load(f)
264 |
265 | # device
266 | if args.gpu_id is not None and config['gpu']['use']:
267 | config['gpu']['id'] = int(args.gpu_id)
268 | config['gpu']['use'] = True
269 | device = gpu_setup(config['gpu']['use'], config['gpu']['id'])
270 | else:
271 | config['gpu']['id'] = 0
272 | device = torch.device('cpu')
273 | # model, dataset, out_dir
274 | if args.model is not None:
275 | MODEL_NAME = args.model
276 | else:
277 | MODEL_NAME = config['model']
278 | if args.dataset is not None:
279 | DATASET_NAME = args.dataset
280 | else:
281 | DATASET_NAME = config['dataset']
282 | dataset = LoadData(DATASET_NAME, args.threshold, args.edge_ratio, args.node_feat_transform)
283 | if args.out_dir is not None:
284 | out_dir = args.out_dir
285 | else:
286 | out_dir = config['out_dir']
287 | # parameters
288 | params = config['params']
289 | if args.seed is not None:
290 | params['seed'] = int(args.seed)
291 | if args.epochs is not None:
292 | params['epochs'] = int(args.epochs)
293 | if args.batch_size is not None:
294 | params['batch_size'] = int(args.batch_size)
295 | if args.init_lr is not None:
296 | params['init_lr'] = float(args.init_lr)
297 | if args.lr_reduce_factor is not None:
298 | params['lr_reduce_factor'] = float(args.lr_reduce_factor)
299 | if args.lr_schedule_patience is not None:
300 | params['lr_schedule_patience'] = int(args.lr_schedule_patience)
301 | if args.min_lr is not None:
302 | params['min_lr'] = float(args.min_lr)
303 | if args.weight_decay is not None:
304 | params['weight_decay'] = float(args.weight_decay)
305 | if args.print_epoch_interval is not None:
306 | params['print_epoch_interval'] = int(args.print_epoch_interval)
307 | if args.max_time is not None:
308 | params['max_time'] = float(args.max_time)
309 | if args.threshold is not None:
310 | params['threshold'] = float(args.threshold)
311 | if args.edge_ratio is not None:
312 | params['edge_ratio'] = float(args.edge_ratio)
313 | if args.node_feat_transform is not None:
314 | params['node_feat_transform'] = args.node_feat_transform
315 | # network parameters
316 | net_params = config['net_params']
317 | if 'node_num' in dir(dataset):
318 | net_params['node_num'] = int(dataset.node_num)
319 | net_params['device'] = device
320 | net_params['gpu_id'] = config['gpu']['id']
321 | net_params['batch_size'] = params['batch_size']
322 | if args.L is not None:
323 | net_params['L'] = int(args.L)
324 | if args.hidden_dim is not None:
325 | net_params['hidden_dim'] = int(args.hidden_dim)
326 | if args.out_dim is not None:
327 | net_params['out_dim'] = int(args.out_dim)
328 | if args.residual is not None:
329 | net_params['residual'] = True if args.residual=='True' else False
330 | if args.edge_feat is not None:
331 | net_params['edge_feat'] = True if args.edge_feat=='True' else False
332 | if args.readout is not None:
333 | net_params['readout'] = args.readout
334 | if args.kernel is not None:
335 | net_params['kernel'] = int(args.kernel)
336 | if args.n_heads is not None:
337 | net_params['n_heads'] = int(args.n_heads)
338 | if args.gated is not None:
339 | net_params['gated'] = True if args.gated=='True' else False
340 | if args.in_feat_dropout is not None:
341 | net_params['in_feat_dropout'] = float(args.in_feat_dropout)
342 | if args.dropout is not None:
343 | net_params['dropout'] = float(args.dropout)
344 | if args.layer_norm is not None:
345 | net_params['layer_norm'] = True if args.layer_norm=='True' else False
346 | if args.batch_norm is not None:
347 | net_params['batch_norm'] = True if args.batch_norm=='True' else False
348 | if args.sage_aggregator is not None:
349 | net_params['sage_aggregator'] = args.sage_aggregator
350 | if args.data_mode is not None:
351 | net_params['data_mode'] = args.data_mode
352 | if args.num_pool is not None:
353 | net_params['num_pool'] = int(args.num_pool)
354 | if args.gnn_per_block is not None:
355 | net_params['gnn_per_block'] = int(args.gnn_per_block)
356 | if args.embedding_dim is not None:
357 | net_params['embedding_dim'] = int(args.embedding_dim)
358 | if args.pool_ratio is not None:
359 | net_params['pool_ratio'] = float(args.pool_ratio)
360 | if args.linkpred is not None:
361 | net_params['linkpred'] = True if args.linkpred=='True' else False
362 | if args.cat is not None:
363 | net_params['cat'] = True if args.cat=='True' else False
364 | if args.self_loop is not None:
365 | net_params['self_loop'] = True if args.self_loop=='True' else False
366 | if args.contrast is not None:
367 | net_params['contrast'] = args.contrast
368 | if args.pooling is not None:
369 | net_params['pooling'] = float(args.pooling)
370 | if args.lambda1 is not None:
371 | net_params['lambda1'] = float(args.lambda1)
372 | if args.learnable_q is not None:
373 | net_params['learnable_q'] = args.learnable_q
374 |
375 | # TUs
376 | net_params['in_dim'] = dataset.all.graph_lists[0].ndata['feat'].shape[1]
377 | net_params['edge_dim'] = dataset.all.graph_lists[0].edata['feat'][0].shape[0] \
378 | if 'feat' in dataset.all.graph_lists[0].edata else None
379 | num_classes = len(np.unique(dataset.all.graph_labels))
380 | net_params['n_classes'] = num_classes
381 |
382 | if MODEL_NAME in ['DiffPool', 'ContrastPool']:
383 | net_params['max_num_node'] = dataset.node_num
384 | # calculate assignment dimension: pool_ratio * largest graph's maximum
385 | # number of nodes in the dataset
386 | num_nodes = [dataset.all[i][0].number_of_nodes() for i in range(len(dataset.all))]
387 | max_num_node = max(num_nodes)
388 | net_params['assign_dim'] = int(max_num_node * net_params['pool_ratio']) * net_params['batch_size']
389 |
390 | root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
391 | root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
392 | write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
393 | write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
394 | dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file
395 |
396 | if not os.path.exists(out_dir + 'results'):
397 | os.makedirs(out_dir + 'results')
398 |
399 | if not os.path.exists(out_dir + 'configs'):
400 | os.makedirs(out_dir + 'configs')
401 |
402 | net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
403 | train_val_pipeline(MODEL_NAME, DATASET_NAME, params, net_params, dirs)
404 |
405 |
406 | main()
407 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from sklearn.metrics import confusion_matrix
6 | from sklearn.metrics import f1_score
7 | import numpy as np
8 |
9 |
10 | def accuracy_TU(scores, targets):
11 | scores = scores.detach().argmax(dim=1)
12 | acc = (scores==targets).float().sum().item()
13 | return acc
14 |
--------------------------------------------------------------------------------
/nets/contrastpool_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 | from layers.attention_layer import EncoderLayer
6 | import time
7 | import numpy as np
8 | from scipy.linalg import block_diag
9 | import dgl
10 |
11 | from layers.graphsage_layer import GraphSageLayer, DenseGraphSage
12 | from layers.contrastpool_layer import ContrastPoolLayer, DenseDiffPool
13 |
14 |
15 | class ContrastPoolNet(nn.Module):
16 | """
17 | DiffPool Fuse with GNN layers and pooling layers in sequence
18 | """
19 |
20 | def __init__(self, net_params, pool_ratio=0.5):
21 |
22 | super().__init__()
23 | input_dim = net_params['in_dim']
24 | self.hidden_dim = net_params['hidden_dim']
25 | embedding_dim = net_params['hidden_dim']
26 | out_dim = net_params['hidden_dim']
27 | self.n_classes = net_params['n_classes']
28 | activation = F.relu
29 | n_layers = net_params['L']
30 | dropout = net_params['dropout']
31 | self.batch_norm = net_params['batch_norm']
32 | self.residual = net_params['residual']
33 | aggregator_type = net_params['sage_aggregator']
34 | self.lambda1 = net_params['lambda1']
35 | self.learnable_q = net_params['learnable_q']
36 |
37 | self.device = net_params['device']
38 | self.link_pred = True
39 | self.concat = False
40 | self.n_pooling = 1
41 | self.batch_size = net_params['batch_size']
42 | if 'pool_ratio' in net_params.keys():
43 | pool_ratio = net_params['pool_ratio']
44 | self.e_feat = net_params['edge_feat']
45 | self.link_pred_loss = []
46 | self.entropy_loss = []
47 |
48 | self.embedding_h = nn.Linear(input_dim, self.hidden_dim)
49 |
50 | # list of GNN modules before the first diffpool operation
51 | self.gc_before_pool = nn.ModuleList()
52 |
53 | self.assign_dim = int(net_params['max_num_node'] * pool_ratio)
54 | self.bn = True
55 | self.num_aggs = 1
56 |
57 | # constructing layers
58 | # layers before diffpool
59 | assert n_layers >= 2, "n_layers too few"
60 | self.gc_before_pool.append(GraphSageLayer(self.hidden_dim, self.hidden_dim, activation,
61 | dropout, aggregator_type, self.residual, self.bn, e_feat=self.e_feat))
62 |
63 | for _ in range(n_layers - 2):
64 | self.gc_before_pool.append(GraphSageLayer(self.hidden_dim, self.hidden_dim, activation,
65 | dropout, aggregator_type, self.residual, self.bn, e_feat=self.e_feat))
66 |
67 | self.gc_before_pool.append(GraphSageLayer(self.hidden_dim, embedding_dim, None, dropout, aggregator_type, self.residual, e_feat=self.e_feat))
68 |
69 |
70 | assign_dims = []
71 | assign_dims.append(self.assign_dim)
72 | if self.concat:
73 | # diffpool layer receive pool_emedding_dim node feature tensor
74 | # and return pool_embedding_dim node embedding
75 | pool_embedding_dim = self.hidden_dim * (n_layers - 1) + embedding_dim
76 | else:
77 |
78 | pool_embedding_dim = embedding_dim
79 |
80 | self.first_diffpool_layer = ContrastPoolLayer(pool_embedding_dim, self.assign_dim, self.hidden_dim, activation,
81 | dropout, aggregator_type, self.link_pred, self.batch_norm,
82 | max_node_num=net_params['max_num_node'])
83 | gc_after_per_pool = nn.ModuleList()
84 |
85 | # list of list of GNN modules, each list after one diffpool operation
86 | self.gc_after_pool = nn.ModuleList()
87 |
88 | for _ in range(n_layers - 1):
89 | gc_after_per_pool.append(DenseGraphSage(self.hidden_dim, self.hidden_dim, self.residual))
90 | gc_after_per_pool.append(DenseGraphSage(self.hidden_dim, embedding_dim, self.residual))
91 | self.gc_after_pool.append(gc_after_per_pool)
92 |
93 | self.assign_dim = int(self.assign_dim * pool_ratio)
94 |
95 | self.diffpool_layers = nn.ModuleList()
96 | # each pooling module
97 | for _ in range(self.n_pooling - 1):
98 | self.diffpool_layers.append(DenseDiffPool(pool_embedding_dim, self.assign_dim, self.hidden_dim, self.link_pred))
99 |
100 | gc_after_per_pool = nn.ModuleList()
101 |
102 | for _ in range(n_layers - 1):
103 | gc_after_per_pool.append(DenseGraphSage(self.hidden_dim, self.hidden_dim, self.residual))
104 | gc_after_per_pool.append(DenseGraphSage(self.hidden_dim, embedding_dim, self.residual))
105 | self.gc_after_pool.append(gc_after_per_pool)
106 |
107 | assign_dims.append(self.assign_dim)
108 | self.assign_dim = int(self.assign_dim * pool_ratio)
109 |
110 | # predicting layer
111 | if self.concat:
112 | self.pred_input_dim = pool_embedding_dim * \
113 | self.num_aggs * (self.n_pooling + 1)
114 | else:
115 | self.pred_input_dim = embedding_dim * self.num_aggs
116 | self.pred_layer = nn.Linear(self.pred_input_dim, self.n_classes)
117 |
118 | # weight initialization
119 | for m in self.modules():
120 | if isinstance(m, nn.Linear):
121 | m.weight.data = init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu'))
122 | if m.bias is not None:
123 | m.bias.data = init.constant_(m.bias.data, 0.0)
124 |
125 | self.contrast_adj = None
126 | self.adj_dict = None
127 | self.nodes_dict = None
128 | self.nodes1 = None
129 | self.nodes2 = None
130 | self.encoder1 = None
131 | self.encoder2 = None
132 | self.encoder1_node = None
133 | self.encoder2_node = None
134 | self.num_A = None
135 | self.num_B = None
136 | self.node_num = None
137 | self.diff_h = None
138 | self.attn_loss = None
139 | self.ad_adj = None
140 | self.softmax = nn.Softmax(dim=-1)
141 | # self.sim = nn.CosineSimilarity(dim=-1, eps=1e-08)
142 |
143 | def cal_attn_loss(self, attn):
144 | entropy = (torch.distributions.Categorical(logits=attn).entropy()).mean()
145 | assert not torch.isnan(entropy)
146 | return entropy
147 |
148 | def cal_contrast(self, trainset, device, merge_classes=True):
149 | from contrast_subgraph import get_summary_tensor
150 | G_dataset = trainset[:][0]
151 | Labels = torch.tensor(trainset[:][1])
152 |
153 | self.adj_dict, self.nodes_dict = get_summary_tensor(G_dataset, Labels, device, merge_classes=merge_classes)
154 | self.node_num = G_dataset[0].ndata['feat'].size(0)
155 | feat_dim = G_dataset[0].ndata['feat'].size(1)
156 |
157 | learnable_q = self.learnable_q
158 | n_head = 1
159 | self.encoder1 = EncoderLayer(self.node_num, n_head, self.node_num, 0.0, device, self.node_num, learnable_q, pos_enc='index').to(device)
160 | self.encoder2 = EncoderLayer(self.node_num, n_head, self.node_num, 0.0, device, self.node_num, learnable_q).to(device)
161 | self.encoder1_node = EncoderLayer(self.node_num, n_head, self.node_num, 0.0, device, feat_dim, learnable_q, pos_enc='index').to(device)
162 | self.encoder2_node = EncoderLayer(self.node_num, n_head, self.node_num, 0.0, device, feat_dim, learnable_q).to(device)
163 |
164 | def cal_contrast_adj(self, device):
165 | adj_list = []
166 | nodes_list = []
167 | for i in self.adj_dict.keys():
168 | adj = self.encoder1(self.adj_dict[i])
169 | adj = self.encoder2(adj.permute(1, 0, 2))
170 | adj_list.append(adj.mean(1))
171 |
172 | nodes_feat = self.encoder1_node(self.nodes_dict[i])
173 | nodes_feat = self.encoder2_node(nodes_feat.permute(1, 0, 2))
174 | nodes_list.append(nodes_feat.mean(1))
175 | self.ad_adj = torch.stack(adj_list)
176 | adj_var = torch.std(torch.stack(adj_list).to(device), 0)
177 | nodes_var = torch.std(torch.stack(nodes_list).to(device), 0)
178 |
179 | self.contrast_adj = adj_var
180 | self.diff_h = nodes_var
181 | self.attn_loss = self.cal_attn_loss(self.contrast_adj)
182 |
183 | self.contrast_adj_trans = self.contrast_adj
184 |
185 | def gcn_forward(self, g, h, e, gc_layers, cat=False):
186 | """
187 | Return gc_layer embedding cat.
188 | """
189 | block_readout = []
190 | for gc_layer in gc_layers[:-1]:
191 | h, e = gc_layer(g, h, e)
192 | block_readout.append(h)
193 | h, e = gc_layers[-1](g, h, e)
194 | block_readout.append(h)
195 | if cat:
196 | block = torch.cat(block_readout, dim=1) # N x F, F = F1 + F2 + ...
197 | else:
198 | block = h
199 | return block
200 |
201 | def gcn_forward_tensorized(self, h, adj, gc_layers, cat=False):
202 | block_readout = []
203 | for gc_layer in gc_layers:
204 | h = gc_layer(h, adj)
205 | block_readout.append(h)
206 | if cat:
207 | block = torch.cat(block_readout, dim=2) # N x F, F = F1 + F2 + ...
208 | else:
209 | block = h
210 | return block
211 |
212 | def forward(self, g, h, e):
213 | self.link_pred_loss = []
214 | self.entropy_loss = []
215 |
216 | # node feature for assignment matrix computation is the same as the
217 | # original node feature
218 | h = self.embedding_h(h)
219 |
220 | out_all = []
221 |
222 | # we use GCN blocks to get an embedding first
223 | g_embedding = self.gcn_forward(g, h, e, self.gc_before_pool, self.concat)
224 |
225 | g.ndata['h'] = g_embedding
226 |
227 | readout = dgl.sum_nodes(g, 'h')
228 | out_all.append(readout)
229 | if self.num_aggs == 2:
230 | readout = dgl.max_nodes(g, 'h')
231 | out_all.append(readout)
232 |
233 | self.cal_contrast_adj(device=h.device)
234 | adj, h = self.first_diffpool_layer(g, g_embedding, self.diff_h, self.contrast_adj_trans)
235 | node_per_pool_graph = int(adj.size()[0] / self.batch_size)
236 |
237 | h, adj = self.batch2tensor(adj, h, node_per_pool_graph)
238 | h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[0], self.concat)
239 |
240 | readout = torch.sum(h, dim=1)
241 | out_all.append(readout)
242 | if self.num_aggs == 2:
243 | readout, _ = torch.max(h, dim=1)
244 | out_all.append(readout)
245 |
246 | for i, diffpool_layer in enumerate(self.diffpool_layers):
247 | h, adj = diffpool_layer(h, adj)
248 | h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[i + 1], self.concat)
249 |
250 | readout = torch.sum(h, dim=1)
251 | out_all.append(readout)
252 |
253 | if self.num_aggs == 2:
254 | readout, _ = torch.max(h, dim=1)
255 | out_all.append(readout)
256 |
257 | if self.concat or self.num_aggs > 1:
258 | hg = torch.cat(out_all, dim=1)
259 | else:
260 | hg = readout
261 |
262 | ypred = self.pred_layer(hg)
263 | return ypred
264 |
265 | def batch2tensor(self, batch_adj, batch_feat, node_per_pool_graph):
266 | """
267 | transform a batched graph to batched adjacency tensor and node feature tensor
268 | """
269 | batch_size = int(batch_adj.size()[0] / node_per_pool_graph)
270 | adj_list = []
271 | feat_list = []
272 |
273 | for i in range(batch_size):
274 | start = i * node_per_pool_graph
275 | end = (i + 1) * node_per_pool_graph
276 |
277 | # 1/sqrt(V) normalization
278 | snorm_n = torch.FloatTensor(node_per_pool_graph, 1).fill_(1./float(node_per_pool_graph)).sqrt().to(self.device)
279 |
280 | adj_list.append(batch_adj[start:end, start:end])
281 | feat_list.append((batch_feat[start:end, :])*snorm_n)
282 | adj_list = list(map(lambda x: torch.unsqueeze(x, 0), adj_list))
283 | feat_list = list(map(lambda x: torch.unsqueeze(x, 0), feat_list))
284 | adj = torch.cat(adj_list, dim=0)
285 | feat = torch.cat(feat_list, dim=0)
286 |
287 | return feat, adj
288 |
289 | def loss(self, pred, label):
290 | '''
291 | loss function
292 | '''
293 | #softmax + CE
294 | criterion = nn.CrossEntropyLoss()
295 | loss = criterion(pred, label)
296 | e1_loss = 0.0
297 | for diffpool_layer in self.diffpool_layers:
298 | for key, value in diffpool_layer.loss_log.items():
299 | e1_loss += value
300 | loss += e1_loss + self.lambda1 * self.attn_loss
301 | return loss
302 |
--------------------------------------------------------------------------------
/nets/load_net.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility file to select GraphNN model as
3 | selected by the user
4 | """
5 |
6 | from nets.contrastpool_net import ContrastPoolNet
7 |
8 |
9 | def ContrastPool(net_params):
10 | return ContrastPoolNet(net_params)
11 |
12 |
13 | def gnn_model(MODEL_NAME, net_params):
14 | models = {
15 | "ContrastPool": ContrastPool
16 | }
17 | model = models[MODEL_NAME](net_params)
18 | model.name = MODEL_NAME
19 |
20 | return model
21 |
--------------------------------------------------------------------------------
/train_TUs_graph_classification.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for training one epoch
3 | and evaluating one epoch
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import math
8 |
9 | from metrics import accuracy_TU as accuracy
10 |
11 | """
12 | For GCNs
13 | """
14 | def train_epoch_sparse(model, optimizer, device, data_loader, epoch):
15 | model.train()
16 | epoch_loss = 0
17 | epoch_train_acc = 0
18 | nb_data = 0
19 | for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
20 | batch_graphs = batch_graphs.to(device)
21 | batch_x = batch_graphs.ndata['feat'].to(device) # num x feat
22 | batch_e = batch_graphs.edata['feat'].to(device)
23 | batch_labels = batch_labels.to(device)
24 | optimizer.zero_grad()
25 | if model.name in ["PRGNN", "LINet"]:
26 | batch_scores, score1, score2 = model.forward(batch_graphs, batch_x, batch_e)
27 | loss = model.loss(batch_scores, batch_labels, score1, score2)
28 | else:
29 | batch_scores = model.forward(batch_graphs, batch_x, batch_e)
30 | loss = model.loss(batch_scores, batch_labels)
31 | loss.backward()
32 | optimizer.step()
33 | epoch_loss += loss.detach().item()
34 | epoch_train_acc += accuracy(batch_scores, batch_labels)
35 | nb_data += batch_labels.size(0)
36 | epoch_loss /= (iter + 1)
37 | epoch_train_acc /= nb_data
38 |
39 | return epoch_loss, epoch_train_acc, optimizer
40 |
41 |
42 | def evaluate_network_sparse(model, device, data_loader, epoch):
43 | model.eval()
44 | epoch_test_loss = 0
45 | epoch_test_acc = 0
46 | nb_data = 0
47 | with torch.no_grad():
48 | for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
49 | batch_graphs = batch_graphs.to(device)
50 | batch_x = batch_graphs.ndata['feat'].to(device)
51 | batch_e = batch_graphs.edata['feat'].to(device)
52 | batch_labels = batch_labels.to(device)
53 | if model.name in ["PRGNN", "LINet"]:
54 | batch_scores, score1, score2 = model.forward(batch_graphs, batch_x, batch_e)
55 | loss = model.loss(batch_scores, batch_labels, score1, score2)
56 | else:
57 | batch_scores = model.forward(batch_graphs, batch_x, batch_e)
58 | loss = model.loss(batch_scores, batch_labels)
59 | epoch_test_loss += loss.detach().item()
60 | epoch_test_acc += accuracy(batch_scores, batch_labels)
61 | nb_data += batch_labels.size(0)
62 | epoch_test_loss /= (iter + 1)
63 | epoch_test_acc /= nb_data
64 |
65 | return epoch_test_loss, epoch_test_acc
66 |
67 | def check_patience(all_losses, best_loss, best_epoch, curr_loss, curr_epoch, counter):
68 | if curr_loss < best_loss:
69 | counter = 0
70 | best_loss = curr_loss
71 | best_epoch = curr_epoch
72 | else:
73 | counter += 1
74 | return best_loss, best_epoch, counter
75 |
--------------------------------------------------------------------------------