├── .gitignore ├── LICENSE ├── README.md ├── aggGNN.png ├── alegnn ├── __init__.py ├── modules │ ├── __init__.py │ ├── architectures.py │ ├── architecturesTime.py │ ├── evaluation.py │ ├── loss.py │ ├── model.py │ └── training.py └── utils │ ├── __init__.py │ ├── dataTools.py │ ├── graphML.py │ ├── graphTools.py │ ├── miscTools.py │ └── visualTools.py ├── datasets ├── authorshipData │ ├── authorshipData.part1.rar │ ├── authorshipData.part2.rar │ └── authorshipData.part3.rar ├── epidemics │ └── edge_list.txt └── facebookEgo │ └── facebookEgo234.pkl ├── examples ├── authorshipGNN.py ├── epidemicGRNN.py ├── flockingGNN.py ├── movieGNN.py └── sourceLocGNN.py ├── pyproject.toml ├── selGNN.png └── tutorial.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | /dist/ 2 | /build/ 3 | /alegnn.egg-info/ 4 | poetry.lock -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Neural Networks 2 | This is a PyTorch library to implement graph neural networks and graph recurrent neural networks. Any questions, comments or suggestions, please e-mail Fernando Gama at fgama@seas.upenn.edu and/or Luana Ruiz at rubruiz@seas.upenn.edu. An in-depth tutorial on a source localization example can be found [here](tutorial.ipynb). 3 | 4 | * [Introduction](#introduction) 5 | * [Code](#code) 6 | * [Dependencies](#dependencies) 7 | * [Datasets](#datasets) 8 | * [Libraries](#libraries) 9 | * [Architectures](#architectures) 10 | * [Examples](#examples) ([tutorial](tutorial.ipynb)) 11 | * [Version](#version) 12 | 13 | Whenever using any part of this code, please cite the following paper 14 | 15 | F. Gama, A. G. Marques, G. Leus, and A. Ribeiro, "[Convolutional Neural Network Architectures for Signals Supported on Graphs](http://ieeexplore.ieee.org/document/8579589)," _IEEE Trans. Signal Process._, vol. 67, no. 4, pp. 1034–1049, Feb. 2019. 16 | 17 | We note that some specific [architectures](#architectures) have specific paper citation to adequately acknowledge the respective contributors. 18 | 19 | Other papers on GNNs by the authors are 20 | 21 | E. Isufi, F. Gama, and A. Ribeiro, "[EdgeNets: Edge Varying Graph Neural Networks](http://arxiv.org/abs/2001.07620)," submitted to _IEEE Trans. Pattern Analysis and Mach. Intell._ 22 | 23 | F. Gama, E. Isufi, G. Leus, and A. Ribeiro, "[Graphs, Convolutions, and Neural Networks](http://arxiv.org/abs/2003.03777)," submitted to _IEEE Signal Process. Mag._ 24 | 25 | L. Ruiz, F. Gama, and A. Ribeiro, "[Gated Graph Recurrent Neural Networks](http://arxiv.org/abs/2002.01038)," submitted to _IEEE Trans. Signal Process._ 26 | 27 | F. Gama, J. Bruna, and A. Ribeiro, "[Stability Properties of Graph Neural Networks](http://arxiv.org/abs/1905.04497)," submitted to _IEEE Trans. Signal Process._ 28 | 29 | F. Gama, E. Tolstaya, and A. Ribeiro, "[Graph Neural Networks for Decentralized Controllers](http://arxiv.org/abs/2003.10280)," _arXiv:2003.10280v1 [cs.LG],_ 23 March 2020. 30 | 31 | L. Ruiz, F. Gama, A. G. Marques, and A. Ribeiro, "[Invariance-Preserving Localized Activation Functions for Graph Neural Networks](https://ieeexplore.ieee.org/document/8911416)," _IEEE Trans. Signal Process._, vol. 68, no. 1, pp. 127-141, Jan. 2020. 32 | 33 | F. Gama, J. Bruna, and A. Ribeiro, "[Stability of Graph Scattering Transforms](http://arxiv.org/abs/1906.04784)," in _33rd Conf. Neural Inform. Process. Syst._ Vancouver, BC: Neural Inform. Process. Syst. Foundation, 8-14 Dec. 2019. 34 | 35 | F. Gama, A. G. Marques, A. Ribeiro, and G. Leus, "[MIMO Graph Filters for Convolutional Networks](http://ieeexplore.ieee.org/document/8445934)," in _19th IEEE Int. Workshop Signal Process. Advances in Wireless Commun._ Kalamata, Greece: IEEE, 25-28 June 2018, pp. 1–5. 36 | 37 | F. Gama, G. Leus, A. G. Marques, and A. Ribeiro, "[Convolutional Neural Networks via Node-Varying Graph Filters](https://ieeexplore.ieee.org/document/8439899)," in _2018 IEEE Data Sci. Workshop._ Lausanne, Switzerland: IEEE, 4-6 June 2018, pp. 220–224. 38 | 39 | 40 | ## Introduction 41 | 42 | We consider data supported by an underlying graph with _N_ nodes. We describe the graph in terms of an _N x N_ matrix _S_ that respects the sparsity of the graph. That is, the element _(i,j)_ of matrix _S_ can be nonzero, if and only if, _i=j_ or _(j,i)_ is an edge of the graph. Examples of such matrices are the adjacency matrix, the graph Laplacian, the Markov matrix, and many normalized counterparts. In general, we refer to this matrix _S_ as the __graph shift operator__ (GSO). This code supports extension to a tensor GSO whenever we want to assign a vector weight to each edge, instead of a scalar weight. 43 | 44 | To describe the _N_-dimensional data _x_ as supported by the graph, we assume that each element of _x_ represents the data value at each node, i.e. the _i_-th element _[x]i = xi_ represents the data value at node _i_. We thus refer to _x_ as a __graph signal__. To effectively relate the graph signal _x_ (which is an _N_-dimensional vector) to the underlying graph support, we use the GSO matrix _S_. In fact, the linear operation _Sx_ represents an exchange of information with neighbors. When computing _Sx_, each node interacts with its one-hop neighbors and computes a weighted average of the information in these neighbors. More precisely, if we denote by _Ni_ the set of neighbors of node _i_, we see that the output of the operation _Sx_ at node _i_ is given by 45 | 46 | 47 | 48 | due to the sparsity pattern of the matrix _S_ where the only nonzero elements are those where there is an edge _(j,i)_ connecting the nodes. We note that the use of the GSO allows for a very simple and straightforward way of explicitly relating the information between different nodes, following the support specified by the given graph. We can extend the descriptive power of graph signals by assining an _F_-dimensional vector to each node, instead of a simple scalar. Each element _f_ of this vector is refered to as __feature__. Then, the data can be thought of as a collection of _F_ graph signals _xf_, for each _f=1,...,F_, where each graph signal _xf_ represents the value of the specific feature _f_ across all nodes. Describing the data as a collection of _F_ graph signals, as opposed to a collection of _N_ vectors of dimension _F_, allows us to exploit the GSO to easily relate the data with the underlying graph support (as discussed for the case of a scalar graph signal). 49 | 50 | A graph neural network is an information processing architecture that regularizes the linear transform of neural networks to take into account the support of the graph. In its most general description, we assume that we have a cascade of _L_ layers, where each layer _l_ takes as input the previous signal, which is a graph signal described by _Fl-1_ features, and process it through a bank of _Fl Fl-1_ linear operations that exploit the graph structure _S_ to obtain _Fl_ output features, which are processed by an activation function _σl_. Namely, for layer _l_, the output is computed as 51 | 52 | 53 | 54 | where the linear operators _Hlfg(S)_ represent __graph filters__ which are linear transforms that exploit the underlying graph structure (typically, by means of local exchanges only, and access to partial information). There are several choices of graph filters that give rise to different architectures (the most popular choice being the linear shift-invariant graph filters, which give rise to __graph convolutions__), many of which can be found in the ensuing library. The operation of pooling, and the extension of the activation functions to include local neighborhoods, can also be found in this library. 55 | 56 | ## Code 57 | 58 | The library is written in [Python3](http://www.python.org/), drawing heavily from [numpy](http://www.numpy.org/), and with neural network models that are defined and trained within the [PyTorch](http://pytorch.org/) framework. 59 | 60 | ### Dependencies 61 | 62 | The required packages are os, numpy, matplotlib, pickle, datetime, scipy.io, copy, torch, scipy, math, and sklearn. Additionally, to handle specific datasets listed below, the following are also required hdf5storage, urllib, zipfile, gzip and shutil; and to handle tensorboard visualization, also include glob, torchvision, operator and tensorboardX. 63 | 64 | ### Datasets 65 | 66 | The different datasets involved graph data that are available in this library are the following ones. 67 | 68 |

1. Authorship attribution dataset, available under datasets/authorshipData (note that the available .rar files have to be uncompressed into the authorshipData.mat to be able to use that dataset with the provided code). When using this dataset, please cite

69 | 70 | S. Segarra, M. Eisen, and A. Ribeiro, "[Authorship attribution through function word adjacency networks](http://ieeexplore.ieee.org/document/6638728)," _IEEE Trans. Signal Process._, vol. 63, no. 20, pp. 5464–5478, Oct. 2015. 71 | 72 |

2. The MovieLens-100k dataset. When using this dataset, please cite

73 | 74 | F. M. Harper and J. A. Konstan, "[The MovieLens datasets: History and Context](http://dl.acm.org/citation.cfm?id=2827872)", _ACM Trans. Interactive Intell. Syst._, vol. 5, no. 4, pp. 19:(1-19), Jan. 2016. 75 | 76 |

3. A source localization dataset. This source localization problem generates synthetic data at execution time. This data can be generated on synthetic graphs such as the Small World graph or the Stochastic Block Model. It can also generate synthetic data, on a real Facebook graph. When using the Facebook graph, please cite

77 | 78 | J. McAuley and J. Leskovec, "[Learning to discover social circles in Ego networks](http://papers.nips.cc/paper/4532-learning-to-discover-social-circles-in-ego-networks)," in _26th Neural Inform. Process. Syst._ Stateline, TX: NeurIPS Foundation, 3-8 Dec. 2012. 79 | 80 |

4. A flocking dataset. The problem of flocking consists on controlling a robot swarm, initially flying at random, arbitrary velocities, to fly together at the same velocity while avoiding collisions with each other. The task is to do so in a distributed and decentralized manner, where each agent (each robot) can compute its control action at every time instat relying only on information obtained from communications with immediate neighbors. The dataset is synthetic in that it generates different sample trajectories with random initializations. When using this dataset, please cite 81 | 82 | F. Gama, E. Tolstaya, and A. Ribeiro, "[Graph Neural Networks for Decentralized Controllers](http://arxiv.org/abs/2003.10280)," _arXiv:2003.10280v1 [cs.LG],_ 23 March 2020. 83 | 84 |

4. An epidemic dataset. In this problem, we track the spread of an epidemic on a high school friendship network. The epidemic data is generated by using the SIR model to simulate the spread of an infectious disease on the friendship network built from this SocioPatterns dataset. When using this dataset, please cite 85 | 86 | L. Ruiz, F. Gama, and A. Ribeiro, "[Gated Graph Recurrent Neural Networks](http://arxiv.org/abs/2002.01038)," submitted to _IEEE Trans. Signal Process._ 87 | 88 | 89 | ### Libraries 90 | 91 | The `alelab` package is split up into two sub-package: `alelab.modules` and `alelab.utils`. 92 | 93 | * modules.architectures contains the implementation of several standard architectures (as nn.Module subclasses) so that they can be readily initialized and trained. Details are provided in the [next section](#architectures). 94 | 95 | * modules.architecturesTime contains the implementation of several standard architectures (as nn.Module subclasses) that handle time-dependent topologies, so that they can be readily initialized and trained. Details are provided in the [next section](#architectures). 96 | 97 | * modules.evaluation contains functions that act as intermediaries between the model and the data in order to evaluate a trained architecture. 98 | 99 | * modules.loss contains a wrapper for the loss function so that it can adapt to multiple scenarios, and the loss function for the F1 score. 100 | 101 | * modules.model defines a Model that binds together the three basic elements to construct a machine learning model: the (neural network) architecture, the loss function and the optimizer. Additionally, it assigns a training handler and an evaluator. It assigns a name to the model and a directory where to save the trained parameters of the architecture, as well. It is the basic class that can train and evaluate a model and also offers methods to save and load parameters. 102 | 103 | * modules.training contains classes that handle the training of each model, acting as an intermediary between the data and the specific architecture within the model being trained. 104 | 105 | * utils.dataTools loads each of the datasets described [above](#datasets) as classes with several functionalities particular to each dataset. All the data classes do have two methods: .getSamples to gather the corresponding samples to training, validation or testing sets, and .evaluate that compute the corresponding evaluation measure. 106 | 107 | * utils.graphML is the main library containing the implementation of all the possible graph neural network layers (as nn.Module subclasses). This library is the analogous of the torch.nn layer, but for graph-based operations. It contains the definition of the basic layers that need to be put together to build a graph neural network. Details are provided in the [next section](#architectures). 108 | 109 | * utils.graphTools defines the Graph class that handles graph-structure information, and offers several other tools to handle graphs. 110 | 111 | * utils.miscTools defines some miscellaneous functions. 112 | 113 | * utils.visualTools contains all the relevant classes and functions to handle visualization in tensorboard. 114 | 115 | ### Architectures 116 | 117 | In what follows, we describe several ways of parameterizing the filters _Hlfg(S)_ that are implemented in this library. 118 | 119 | * ___Convolutional Graph Neural Networks (via Selection)___. The most popular graph neural network (GNN) is that one that parameterizes _Hlfg(S)_ by a linear shift-invariant graph filter, giving rise to a __graph convolution__. The nn.Module subclass that implements the graph filter (convolutional) layer can be found in utils.graphML.GraphFilter. This layer is the basic linear layer in the Selection GNN architecture (which also adds the pointwise activation function and the zero-padding pooling operation), which is already implemented in modules.architectures.SelectionGNN and shown in several examples. For more details on this graph convolutional layer or its architecture, and whenever using it, please cite the following paper 120 | 121 | F. Gama, A. G. Marques, G. Leus, and A. Ribeiro, "[Convolutional Neural Network Architectures for Signals Supported on Graphs](http://ieeexplore.ieee.org/document/8579589)," _IEEE Trans. Signal Process._, vol. 67, no. 4, pp. 1034–1049, Feb. 2019. 122 | 123 | The modules.architectures.SelectionGNN also has a flag called coarsening that allows for the pooling to be done in terms of graph coarsening, following the Graclus algorithm. This part of the code was mainly adapted to PyTorch from this repository. For more details on graph coarsening, and whenever using the SelectionGNN with graph coarsening pooling, please cite the following [paper](http://papers.nips.cc/paper/6081-convolutional-neural-networks-on-graphs-with-fast-localized-spectral-filtering.pdf). Also note that by setting the number of filter taps (nFilterTaps) to 2 on every layer leads to this [architecture](http://openreview.net/forum?id=SJU4ayYgl). Finally, this other [architecture](https://openreview.net/forum?id=ryGs6iA5Km) is obtained by setting the number of filter taps to 1 for each number of designed fully-connected layers, and then setting it to 2 to complete the corresponding _GIN layer_. There is one further implementation that is entirely local (i.e. it only involves operations exchanging information with one-hop neighbors). This implementation essentially replaces the last fully-connected layer by a readout layer that only operates on the features obtained at the node. The implementation is dubbed LocalGNN and is used in the MovieLens example. 124 | 125 | * ___Convolutional Graph Neural Networks (via Spectrum)___. The spectral GNN is an early implementation of the convolutional GNN in the graph frequency domain. It does not scale to large graphs due to the cost of the eigendecomposition of the GSO. The spectral filtering layer is implemented as a nn.Module subclass in utils.graphML.SpectralGF and the corresponding architecture with these linear layers, together with pointwise nonlinearities is implemented in modules.architectures.SpectralGNN. For more details on the spectral graph filtering layer or its architecture, and whenever using it, please cite 126 | 127 | J. Bruna, W. Zaremba, A. Szlam, and Y. LeCun, "[Spectral networks and deep locally connected networks on graphs](http://openreview.net/forum?id=DQNsQf-UsoDBa)," in _Int. Conf. Learning Representations 2014_. Banff, AB: Assoc. Comput. Linguistics, 14-16 Apr. 2014, pp. 1–14. 128 | 129 | * ___Convolutional Graph Neural Networks (via Aggregation)___. An alternative way to implementing a graph convolution is by means of building an aggregation sequence on each node. Instead of thinking of the graph signal as being diffused through the graph and each diffusion being weighed separately (as is the case of a GCNN via Selection), we think of the signal as being aggregated at each node, by means of successive communications with the one-hop neighbors, and each communication is being weighed by a separate filter tap. The key point is that these aggregation sequences exhibit a regular structure that simultaneously take into account the underlying graph support, since each contiguous element in the sequence represents a contiguous neighborhood. Once we have a regular sequence we can go ahead and apply a regular CNN to process its information. This idea is called an Aggregation GNN and is implemented in modules.architectures.AggregationGNN, since it relies on regular convolution and pooling already defined on torch.nn. A more sophisticated and powerful variant of the Aggregation GNN, called the __Multi-Node Aggregation GNN__ is also available on modules.architectures.MultiNodeAggregationGNN. For more details on the Aggregation GNN, and whenever using it, please cite the following paper 130 | 131 | F. Gama, A. G. Marques, G. Leus, and A. Ribeiro, "[Convolutional Neural Network Architectures for Signals Supported on Graphs](http://ieeexplore.ieee.org/document/8579589)," _IEEE Trans. Signal Process._, vol. 67, no. 4, pp. 1034–1049, Feb. 2019. 132 | 133 | * ___Node-Variant Graph Neural Networks___. Parameterizing _Hlfg(S)_ with a node-variant graph filter (as opposed to a shift-invariant graph filter), a non-convolutional graph neural network architecture can be built. A node-variant graph filter, essentially lets each node learn its own weight for each neighborhood information. In order to allow this architecture to scale (so that the number of learnable parameters does not depend on the size of the graph), we offer a hybrid node-variant GNN approach as well. The graph filtering layer using node-variant graph filters is defined in utils.graphML.NodeVariantGF and an example of an architecture using these filters for the linear operation, combined with pointwise activation functions and zero-padding pooling, is available in modules.architectures.NodeVariantGNN. For more details on node-variant GNNs, and whenever using these filters or architecture, please cite the following paper 134 | 135 | E. Isufi, F. Gama, and A. Ribeiro, "[EdgeNets: Edge Varying Graph Neural Networks](http://arxiv.org/abs/2001.07620)," submitted to _IEEE Trans. Pattern Analysis and Mach. Intell._ 136 | 137 | * ___ARMA Graph Neural Networks___. A convolutional architecture that is very flexible and with enlarged descriptive power. It replaces the graph convolution with a FIR filter (i.e. the use of a polynomial of the shift operator) by an ratio of polynomials. This architecture offers a good trade-off between number of paramters and selectivity of learnable filters. The edge-variant graph filter layer can be found in utils.graphML.EdgeVariantGF. An example of an architecture with ARMA graph filters as the linear layer, and pointwise activation functions and zero-padding pooling is available in modules.architectures.ARMAfilterGNN. A Local version of this architecture is also available. For more details on ARMA GNNs, and whenever using these filters or architecture, please cite the following paper 138 | 139 | E. Isufi, F. Gama, and A. Ribeiro, "[EdgeNets: Edge Varying Graph Neural Networks](http://arxiv.org/abs/2001.07620)," submitted to _IEEE Trans. Pattern Analysis and Mach. Intell._ 140 | 141 | * ___Edge-Variant Graph Neural Networks___. The most general parameterization that we can make of a linear operation that also takes into account the underlying graph support, is to let each node weigh each of their neighbors' information differently. This is achieved by means of an edge-variant graph filter. Certainly, the edge-variant graph filter has a number of parameters that scales with the number of edges, so a hybrid approach is available. The edge-variant graph filter layer can be found in utils.graphML.GraphFilterARMA. An example of an architecture with edge-variant graph filters as the linear layer, and pointwise activation functions and zero-padding pooling is available in modules.architectures.EdgeVariantGNN. A Local version of this architecture is also available. For more details on edge-variant GNNs, and whenever using these filters or architecture, please cite the following paper 142 | 143 | E. Isufi, F. Gama, and A. Ribeiro, "[EdgeNets: Edge Varying Graph Neural Networks](http://arxiv.org/abs/2001.07620)," submitted to _IEEE Trans. Pattern Analysis and Mach. Intell._ 144 | 145 | * ___Graph Attention Networks___. A particular case of edge-variant graph filters (that predates the use of more general edge-variant filters) and that has been shown to be successful is the graph attention network (commonly known as GAT). The original implementation of GATs can be found in this repository. In this library, we offer a PyTorch adaptation of this code (which was originally written for TensorFlow). The GAT parameterizes the edge-variant graph filter by taking into account both the graph support and the data, yielding an architecture with a number of parameters that is independent of the size of the graph. The graph attentional layer can be found in utils.graphML.GraphAttentional, and an example of this architecture in modules.architectures.GraphAttentionNetwork. For more details on GATs, and whenever using this code, please cite the following paper 146 | 147 | P. Veličković, G. Cucurull, A. Casanova, A. Romero, P. Liò, and Y. Bengio, "[Graph Attention Networks](http://openreview.net/forum?id=rJXMpikCZ)," in _6th Int. Conf. Learning Representations_. Vancouver, BC: Assoc. Comput. Linguistics, 30 Apr.-3 May 2018, pp. 1–12. 148 | 149 | * ___Local Activation Functions___. Local activation functions exploit the irregular neighborhoods that are inherent to arbitrary graphs. Instead of just applying a pointwise (node-wise) activation function, using a local activation function that carries out a nonlinear operation within a neighborhood has been shown to be effective as well. The corresponding architecture is named LocalActivationGNN and is available under modules/architectures.py. In particular, in this code, the __median activation function__ is implemented in utils.graphML.MedianLocalActivation and the __max activation function__ is implemented in utils.graphML.MaxLocalActivation. For more details on local activation function, and whenever using these operational layers, please cite the following papers 150 | 151 | L. Ruiz, F. Gama, A. G. Marques, and A. Ribeiro, "[Invariance-Preserving Localized Activation Functions for Graph Neural Networks](https://ieeexplore.ieee.org/document/8911416)," _IEEE Trans. Signal Process._, vol. 68, no. 1, pp. 127-141, Jan. 2020. 152 | 153 | * ___Time-Varying Architectures___. The Selection and Aggregation GNNs have a version adapted to handling time-varying graph signals as well as time-varying shift operators, acting with a unit-delay between communication with neighbors. These architectures can be found in architecturesTime.LocalGNN_DB and architecturesTime.AggregationGNN_DB. For more details on these architectures, please see (and if use, please cite) 154 | 155 | F. Gama, E. Tolstaya, and A. Ribeiro, "[Graph Neural Networks for Decentralized Controllers](http://arxiv.org/abs/2003.10280)," _arXiv:2003.10280v1 [cs.LG],_ 23 March 2020. 156 | 157 | E. Tolstaya, F. Gama, J. Paulos, G. Pappas, V. Kumar, and A. Ribeiro, "[Learning Decentralized COntrollers for Robot Swarms with Graph Neural Networks](http://arxiv.org/abs/1903.10527)," in _Conf. Robot Learning 2019._ Osaka, Japan: Int. Found. Robotics Res., 30 Oct.-1 Nov. 2019. 158 | 159 | * ___Graph Recurrent Neural Networks___. A graph RNN approximates a time-varying graph process with a hidden Markov model, where the hidden state is learned from data. In a graph RNN all linear transforms involved are graph filters that respect the graph. This is a highly flexible architecture that exploits the graph structure as well as the time-dependencies present in data. For static graphs, the architecture can be found in architectures.GraphRecurrentNN, and in architectures.GatedGraphRecurrentNN for time, node and edge gated variations. For time-varying graphs, the architecture is architecturesTime.GraphRecurrentNN_DB. For more details please see, and when using this architecture please cite, 160 | 161 | L. Ruiz, F. Gama, and A. Ribeiro, "[Gated Graph Recurrent Neural Networks](http://arxiv.org/abs/2002.01038)," submitted to _IEEE Trans. Signal Process._ 162 | 163 | ### Examples 164 | 165 | We have included an in-depth [tutorial](tutorial.ipynb) tutorial.ipynb on a [Jupyter Notebook](http://jupyter.org/). We have also included other examples involving all the four datasets presented [above](#datasets), with examples of all the architectures [just](#architectures) discussed. 166 | 167 | * [Tutorial](tutorial.ipynb): tutorial.ipynb. The tutorial covers the basic mathematical formulation for the graph neural networks, and considers a small synthetic problem of source localization. It implements the Aggregation and Selection GNN (both zero-padding and graph coarsening). This tutorial explain, in-depth, all the elements intervening in the setup, training and evaluation of the models, that serves as skeleton for all the other examples. 168 | 169 | * [Source Localization](examples/sourceLocGNN.py): sourceLocGNN.py. This example deals with the source localization problem on a 100-node, 5-community random-generated SBM graph. It can consider multiple graph and data realizations to account for randomness in data generation. Implementations of Selection and Aggregation GNNs with different node sampling criteria are presented. 170 | 171 | * [MovieLens](examples/movieGNN.py): movieGNN.py. This example has the objective of predicting the rating some user would give to a movie, based on the movies it has ranked before (following the MovieLens-100k dataset). In this case we present a one- and two-layer Selection GNN with no-padding and the one- and two-layer local implementation available at LocalGNN. 172 | 173 | * [Authorship Attribution](examples/authorshipGNN.py): authorshipGNN.py. This example addresses the problem of authorship attribution, by which a text has to be assigned to some author according to their styolmetric signature (based on the underlying word adjacency network; details here). In this case, we test different local activation functions (median, max, and pointwise). 174 | 175 | * [Flocking](examples/flockingGNN.py): flockingGNN.py. This is an example of controlling a robot swarm to fly together at the same velocity while avoiding collisions. It is a synthetic dataset where time-dependent architectures can be tested. In particular, we test the use of a linear filter, a Local GNN, an Aggregation GNN and a GRNN, considering, not only samples of the form (S_t, x_t), for each t, but also delayed communications where the information observed from further away neighbors is actually delayed. 176 | 177 | * [Epidemic Tracking](examples/epidemicGRNN.py): epidemicGRNN.py. In this example, we compare GRNNs and gated GRNNs in a binary node classification problem modeling the spread of an epidemic on a high school friendship network. The disease is first recorded on day t=0, when each individual node is infected with probability p_seed=0.05. On the days that follow, an infected student can then spread the disease to their susceptible friends with probability p_inf=0.3 each day. Infected students become immune after 4 days, at which point they can no longer spread or contract the disease. Given the state of each node at some point in time (susceptible, infected or recovered), the binary node classification problem is to predict whether each node in the network will have the disease (i.e., be infected) 8 days ahead. 178 | 179 | ## Version 180 | 181 | * ___0.4 (March 5, 2021):___ Added the main file for the epidemic tracking experiment, epidemicGRNN.py. Added the edge list from which the graph used in this experiment is built. dataTools.py now has an Epidemics class which handles the abovementioned graph and the epidemic data. loss.py now has a new loss function, which computes the loss corresponding to the F1 score (1-F1 score). graphML.py now has the functional GatedGRNN and the layers HiddenState, TimeGatedHiddenState, NodeGatedHiddenState, EdgeGatedHiddenState, which are used to calculate the hidden state of (gated) GRNNs. architectures.py now has the architectures GraphRecurrentNN and GatedGraphRecurrentNN. 182 | 183 | * ___0.3 (May 2, 2020):___ Added the time-dependent architectures that handle (graph, graph signal) batch data as well as delayed communications. These architectures can be found in architecturesTime.py. A new synthetic dataset has also been added, namely, that used in the Flocking problem. Made the Model class to be the central handler of all the machine learning model. Training multiple models has been dropped in favor of training through the method offered in the Model class. Trainers and evaluators had to been added to be effective intermediaries between the architectures and the data, especially in problems that are not classification ones (i.e. regression -interpolation- in the movie recommendation setting, and imitation learning in the flocking problem). This should give flexibility to carry over these architectures to new problems, as well as make prototyping easier since training and evaluating has been greatly simplified. Minor modifications and eventual bug fixes have been made here and there. 184 | 185 | * ___0.2 (Dec 16, 2019):___ Added new architecture: LocalActivationGNN and LocalGNN. Added new loss module to handle the logic that gives flexibility to the loss function. Moved the ordering from external to the architecture, to internal to it. Added two new methods: .splitForward() and .changeGSO() to separate the output from the graph layers and the MLP, and to change the GSO from training to test time, respectively. Class Model does not keep track of the order anymore. Got rid of MATLAB(R) support. Better memory management (do not move the entire dataset to memory, only the batch). Created methods to normalize dat aand change data type. Deleted the 20News dataset which is not supported anymore. Added the method .expandDims() to the data for increased flexibility. Changed the evaluate method so that it is always a decreasing function. Totally revamped the MovieLens class. Corrected a bug on the computeNeighborhood() function (thanks to Bianca Iancu, A (dot) Iancu-1 (at) student (dot) tudelft (dot) nl and Gabriele Mazzola, G (dot) Mazzola (at) student (dot) tudelft (dot) nl for spotting it). Corrected bugs on device handling of local activation functions. Updated tutorial. 186 | 187 | * ___0.1 (Jul 12, 2019):___ First released (beta) version of this graph neural network library. Includes the basic convolutional graph neural networks (selection -zero-padding and graph coarsening-, spectral, aggregation), and some non-convolutional graph neural networks as well (node-variant, edge-variant and graph attention networks). It also inlcudes local activation functions (max and median). In terms of examples, it considers the source localization problem (both in the tutorial and in a separate example), the movie recommendation problem, the authorship attribution problem and the text categorization problems. In terms of structure, it sets the basis for data handling and training of multiple models. 188 | -------------------------------------------------------------------------------- /aggGNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alelab-upenn/graph-neural-networks/a84a39fabad5378bdcbaad20b5dbcff14b4eebcd/aggGNN.png -------------------------------------------------------------------------------- /alegnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alelab-upenn/graph-neural-networks/a84a39fabad5378bdcbaad20b5dbcff14b4eebcd/alegnn/__init__.py -------------------------------------------------------------------------------- /alegnn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alelab-upenn/graph-neural-networks/a84a39fabad5378bdcbaad20b5dbcff14b4eebcd/alegnn/modules/__init__.py -------------------------------------------------------------------------------- /alegnn/modules/architecturesTime.py: -------------------------------------------------------------------------------- 1 | # 2019/12/31~ 2 | # Fernando Gama, fgama@seas.upenn.edu 3 | # Luana Ruiz, rubruiz@seas.upenn.edu 4 | # Kate Tolstaya, eig@seas.upenn.edu 5 | """ 6 | architecturesTime.py Architectures module 7 | 8 | Definition of GNN architectures. The basic idea of these architectures is that 9 | the data comes in the form {(S_t, x_t)} where the shift operator as well as the 10 | signal change with time, and where each training point consists of a trajectory. 11 | Unlike architectures.py where the shift operator S is fixed (although it can 12 | be changed after the architectures has been initialized) and the training set 13 | consist of a set of {x_b} with b=1,...,B for a total of B samples, here the 14 | training set is assumed to be a trajectory, and to include a different shift 15 | operator for each sample {(S_t, x_t)_{t=1}^{T}}_{b=1,...,B}. Also, all 16 | implementations consider a unit delay exchange (i.e. the S_t and x_t values 17 | get delayed by one unit of time for each neighboring exchange). 18 | 19 | LocalGNN_DB: implements the selection GNN architecture by means of local 20 | operations only 21 | GraphRecurrentNN_DB: implements the GRNN architecture 22 | AggregationGNN_DB: implements the aggregation GNN architecture 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | import torch.nn as nn 28 | 29 | import alegnn.utils.graphML as gml 30 | 31 | zeroTolerance = 1e-9 # Absolute values below this number are considered zero. 32 | 33 | class LocalGNN_DB(nn.Module): 34 | """ 35 | LocalGNN_DB: implement the local GNN architecture where all operations are 36 | implemented locally, i.e. by means of neighboring exchanges only. More 37 | specifically, it has graph convolutional layers, but the readout layer, 38 | instead of being an MLP for the entire graph signal, it is a linear 39 | combination of the features at each node. It considers signals 40 | that change in time with batch GSOs. 41 | 42 | Initialization: 43 | 44 | LocalGNN_DB(dimNodeSignals, nFilterTaps, bias, # Graph Filtering 45 | nonlinearity, # Nonlinearity 46 | dimReadout, # Local readout layer 47 | dimEdgeFeatures) # Structure 48 | 49 | Input: 50 | /** Graph convolutional layers **/ 51 | dimNodeSignals (list of int): dimension of the signals at each layer 52 | (i.e. number of features at each node, or size of the vector 53 | supported at each node) 54 | nFilterTaps (list of int): number of filter taps on each layer 55 | (i.e. nFilterTaps-1 is the extent of neighborhoods that are 56 | reached, for example K=2 is info from the 1-hop neighbors) 57 | bias (bool): include bias after graph filter on every layer 58 | >> Obs.: dimNodeSignals[0] is the number of features (the dimension 59 | of the node signals) of the data, where dimNodeSignals[l] is the 60 | dimension obtained at the output of layer l, l=1,...,L. 61 | Therefore, for L layers, len(dimNodeSignals) = L+1. Slightly 62 | different, nFilterTaps[l] is the number of filter taps for the 63 | filters implemented at layer l+1, thus len(nFilterTaps) = L. 64 | 65 | /** Activation function **/ 66 | nonlinearity (torch.nn): module from torch.nn non-linear activations 67 | 68 | /** Readout layers **/ 69 | dimReadout (list of int): number of output hidden units of a 70 | sequence of fully connected layers applied locally at each node 71 | (i.e. no exchange of information involved). 72 | 73 | /** Graph structure **/ 74 | dimEdgeFeatures (int): number of edge features 75 | 76 | Output: 77 | nn.Module with a Local GNN architecture with the above specified 78 | characteristics that considers time-varying batch GSO and delayed 79 | signals 80 | 81 | Forward call: 82 | 83 | LocalGNN_DB(x, S) 84 | 85 | Input: 86 | x (torch.tensor): input data of shape 87 | batchSize x timeSamples x dimFeatures x numberNodes 88 | GSO (torch.tensor): graph shift operator; shape 89 | batchSize x timeSamples (x dimEdgeFeatures) 90 | x numberNodes x numberNodes 91 | 92 | Output: 93 | y (torch.tensor): output data after being processed by the GNN; 94 | batchSize x timeSamples x dimReadout[-1] x numberNodes 95 | 96 | Other methods: 97 | 98 | y, yGNN = .splitForward(x, S): gives the output of the entire GNN y, 99 | which has shape batchSize x timeSamples x dimReadout[-1] x numberNodes, 100 | as well as the output of all the GNN layers (i.e. before the readout 101 | layers), yGNN of shape batchSize x timeSamples x dimFeatures[-1] 102 | x numberNodes. This can be used to isolate the effect of the graph 103 | convolutions from the effect of the readout layer. 104 | 105 | y = .singleNodeForward(x, S, nodes): outputs the value of the last 106 | layer at a single node. x is the usual input of shape batchSize 107 | x timeSamples x dimFeatures x numberNodes. nodes is either a single 108 | node (int) or a collection of nodes (list or numpy.array) of length 109 | batchSize, where for each element in the batch, we get the output at 110 | the single specified node. The output y is of shape batchSize 111 | x timeSamples x dimReadout[-1]. 112 | """ 113 | 114 | def __init__(self, 115 | # Graph filtering 116 | dimNodeSignals, nFilterTaps, bias, 117 | # Nonlinearity 118 | nonlinearity, 119 | # MLP in the end 120 | dimReadout, 121 | # Structure 122 | dimEdgeFeatures): 123 | # Initialize parent: 124 | super().__init__() 125 | # dimNodeSignals should be a list and of size 1 more than nFilter taps. 126 | assert len(dimNodeSignals) == len(nFilterTaps) + 1 127 | 128 | # Store the values (using the notation in the paper): 129 | self.L = len(nFilterTaps) # Number of graph filtering layers 130 | self.F = dimNodeSignals # Features 131 | self.K = nFilterTaps # Filter taps 132 | self.E = dimEdgeFeatures # Number of edge features 133 | self.bias = bias # Boolean 134 | # Store the rest of the variables 135 | self.sigma = nonlinearity 136 | self.dimReadout = dimReadout 137 | # And now, we're finally ready to create the architecture: 138 | #\\\ Graph filtering layers \\\ 139 | # OBS.: We could join this for with the one before, but we keep separate 140 | # for clarity of code. 141 | gfl = [] # Graph Filtering Layers 142 | for l in range(self.L): 143 | #\\ Graph filtering stage: 144 | gfl.append(gml.GraphFilter_DB(self.F[l], self.F[l+1], self.K[l], 145 | self.E, self.bias)) 146 | #\\ Nonlinearity 147 | gfl.append(self.sigma()) 148 | # And now feed them into the sequential 149 | self.GFL = nn.Sequential(*gfl) # Graph Filtering Layers 150 | #\\\ MLP (Fully Connected Layers) \\\ 151 | fc = [] 152 | if len(self.dimReadout) > 0: # Maybe we don't want to readout anything 153 | # The first layer has to connect whatever was left of the graph 154 | # filtering stage to create the number of features required by 155 | # the readout layer 156 | fc.append(nn.Linear(self.F[-1], dimReadout[0], bias = self.bias)) 157 | # The last linear layer cannot be followed by nonlinearity, because 158 | # usually, this nonlinearity depends on the loss function (for 159 | # instance, if we have a classification problem, this nonlinearity 160 | # is already handled by the cross entropy loss or we add a softmax.) 161 | for l in range(len(dimReadout)-1): 162 | # Add the nonlinearity because there's another linear layer 163 | # coming 164 | fc.append(self.sigma()) 165 | # And add the linear layer 166 | fc.append(nn.Linear(dimReadout[l], dimReadout[l+1], 167 | bias = self.bias)) 168 | # And we're done 169 | self.Readout = nn.Sequential(*fc) 170 | # so we finally have the architecture. 171 | 172 | def splitForward(self, x, S): 173 | 174 | # Check the dimensions of the input 175 | # S: B x T (x E) x N x N 176 | # x: B x T x F[0] x N 177 | assert len(S.shape) == 4 or len(S.shape) == 5 178 | if len(S.shape) == 4: 179 | S = S.unsqueeze(2) 180 | B = S.shape[0] 181 | T = S.shape[1] 182 | assert S.shape[2] == self.E 183 | N = S.shape[3] 184 | assert S.shape[4] == N 185 | 186 | assert len(x.shape) == 4 187 | assert x.shape[0] == B 188 | assert x.shape[1] == T 189 | assert x.shape[2] == self.F[0] 190 | assert x.shape[3] == N 191 | 192 | # Add the GSO at each layer 193 | for l in range(self.L): 194 | self.GFL[2*l].addGSO(S) 195 | # Let's call the graph filtering layer 196 | yGFL = self.GFL(x) 197 | # Change the order, for the readout 198 | y = yGFL.permute(0, 1, 3, 2) # B x T x N x F[-1] 199 | # And, feed it into the Readout layer 200 | y = self.Readout(y) # B x T x N x dimReadout[-1] 201 | # Reshape and return 202 | return y.permute(0, 1, 3, 2), yGFL 203 | # B x T x dimReadout[-1] x N, B x T x dimFeatures[-1] x N 204 | 205 | def forward(self, x, S): 206 | 207 | # Most of the times, we just need the actual, last output. But, since in 208 | # this case, we also want to compare with the output of the GNN itself, 209 | # we need to create this other forward funciton that takes both outputs 210 | # (the GNN and the MLP) and returns only the MLP output in the proper 211 | # forward function. 212 | output, _ = self.splitForward(x, S) 213 | 214 | return output 215 | 216 | def singleNodeForward(self, x, S, nodes): 217 | 218 | # x is of shape B x T x F[0] x N 219 | batchSize = x.shape[0] 220 | N = x.shape[3] 221 | 222 | # nodes is either an int, or a list/np.array of ints of size B 223 | assert type(nodes) is int \ 224 | or type(nodes) is list \ 225 | or type(nodes) is np.ndarray 226 | 227 | # Let us start by building the selection matrix 228 | # This selection matrix has to be a matrix of shape 229 | # B x 1 x N[-1] x 1 230 | # so that when multiplying with the output of the forward, we get a 231 | # B x T x dimRedout[-1] x 1 232 | # and we just squeeze the last dimension 233 | 234 | # TODO: The big question here is if multiplying by a matrix is faster 235 | # than doing torch.index_select 236 | 237 | # Let's always work with numpy arrays to make it easier. 238 | if type(nodes) is int: 239 | # Change the node number to accommodate the new order 240 | nodes = self.order.index(nodes) 241 | # If it's int, make it a list and an array 242 | nodes = np.array([nodes], dtype=np.int) 243 | # And repeat for the number of batches 244 | nodes = np.tile(nodes, batchSize) 245 | if type(nodes) is list: 246 | newNodes = [self.order.index(n) for n in nodes] 247 | nodes = np.array(newNodes, dtype = np.int) 248 | elif type(nodes) is np.ndarray: 249 | newNodes = np.array([np.where(np.array(self.order) == n)[0][0] \ 250 | for n in nodes]) 251 | nodes = newNodes.astype(np.int) 252 | # Now, nodes is an np.int np.ndarray with shape batchSize 253 | 254 | # Build the selection matrix 255 | selectionMatrix = np.zeros([batchSize, 1, N, 1]) 256 | selectionMatrix[np.arange(batchSize), nodes, 0] = 1. 257 | # And convert it to a tensor 258 | selectionMatrix = torch.tensor(selectionMatrix, 259 | dtype = x.dtype, 260 | device = x.device) 261 | 262 | # Now compute the output 263 | y = self.forward(x, S) 264 | # This output is of size B x T x dimReadout[-1] x N 265 | 266 | # Multiply the output 267 | y = torch.matmul(y, selectionMatrix) 268 | # B x T x dimReadout[-1] x 1 269 | 270 | # Squeeze the last dimension and return 271 | return y.squeeze(3) 272 | 273 | class GraphRecurrentNN_DB(nn.Module): 274 | """ 275 | GraphRecurrentNN_DB: implements the GRNN architecture on a time-varying GSO 276 | batch and delayed signals. It is a single-layer GRNN and the hidden 277 | state is initialized at random drawing from a standard gaussian. 278 | 279 | Initialization: 280 | 281 | GraphRecurrentNN_DB(dimInputSignals, dimOutputSignals, 282 | dimHiddenSignals, nFilterTaps, bias, # Filtering 283 | nonlinearityHidden, nonlinearityOutput, 284 | nonlinearityReadout, # Nonlinearities 285 | dimReadout, # Local readout layer 286 | dimEdgeFeatures) # Structure 287 | 288 | Input: 289 | /** Graph convolutions **/ 290 | dimInputSignals (int): dimension of the input signals 291 | dimOutputSignals (int): dimension of the output signals 292 | dimHiddenSignals (int): dimension of the hidden state 293 | nFilterTaps (list of int): a list with two elements, the first one 294 | is the number of filter taps for the filters in the hidden 295 | state equation, the second one is the number of filter taps 296 | for the filters in the output 297 | bias (bool): include bias after graph filter on every layer 298 | 299 | /** Activation functions **/ 300 | nonlinearityHidden (torch.function): the nonlinearity to apply 301 | when computing the hidden state; it has to be a torch function, 302 | not a nn.Module 303 | nonlinearityOutput (torch.function): the nonlinearity to apply when 304 | computing the output signal; it has to be a torch function, not 305 | a nn.Module. 306 | nonlinearityReadout (nn.Module): the nonlinearity to apply at the 307 | end of the readout layer (if the readout layer has more than 308 | one layer); this one has to be a nn.Module, instead of just a 309 | torch function. 310 | 311 | /** Readout layer **/ 312 | dimReadout (list of int): number of output hidden units of a 313 | sequence of fully connected layers applied locally at each node 314 | (i.e. no exchange of information involved). 315 | 316 | /** Graph structure **/ 317 | dimEdgeFeatures (int): number of edge features 318 | 319 | Output: 320 | nn.Module with a GRNN architecture with the above specified 321 | characteristics that considers time-varying batch GSO and delayed 322 | signals 323 | 324 | Forward call: 325 | 326 | GraphRecurrentNN_DB(x, S) 327 | 328 | Input: 329 | x (torch.tensor): input data of shape 330 | batchSize x timeSamples x dimInputSignals x numberNodes 331 | GSO (torch.tensor): graph shift operator; shape 332 | batchSize x timeSamples (x dimEdgeFeatures) 333 | x numberNodes x numberNodes 334 | 335 | Output: 336 | y (torch.tensor): output data after being processed by the GRNN; 337 | batchSize x timeSamples x dimReadout[-1] x numberNodes 338 | 339 | Other methods: 340 | 341 | y, yGNN = .splitForward(x, S): gives the output of the entire GRNN y, 342 | which has shape batchSize x timeSamples x dimReadout[-1] x numberNodes, 343 | as well as the output of the GRNN (i.e. before the readout layers), 344 | yGNN of shape batchSize x timeSamples x dimInputSignals x numberNodes. 345 | This can be used to isolate the effect of the graph convolutions from 346 | the effect of the readout layer. 347 | 348 | y = .singleNodeForward(x, S, nodes): outputs the value of the last 349 | layer at a single node. x is the usual input of shape batchSize 350 | x timeSamples x dimInputSignals x numberNodes. nodes is either a single 351 | node (int) or a collection of nodes (list or numpy.array) of length 352 | batchSize, where for each element in the batch, we get the output at 353 | the single specified node. The output y is of shape batchSize 354 | x timeSamples x dimReadout[-1]. 355 | """ 356 | def __init__(self, 357 | # Graph filtering 358 | dimInputSignals, 359 | dimOutputSignals, 360 | dimHiddenSignals, 361 | nFilterTaps, bias, 362 | # Nonlinearities 363 | nonlinearityHidden, 364 | nonlinearityOutput, 365 | nonlinearityReadout, # nn.Module 366 | # Local MLP in the end 367 | dimReadout, 368 | # Structure 369 | dimEdgeFeatures): 370 | # Initialize parent: 371 | super().__init__() 372 | 373 | # A list of two int, one for the number of filter taps (the computation 374 | # of the hidden state has the same number of filter taps) 375 | assert len(nFilterTaps) == 2 376 | 377 | # Store the values (using the notation in the paper): 378 | self.F = dimInputSignals # Number of input features 379 | self.G = dimOutputSignals # Number of output features 380 | self.H = dimHiddenSignals # NUmber of hidden features 381 | self.K = nFilterTaps # Filter taps 382 | self.E = dimEdgeFeatures # Number of edge features 383 | self.bias = bias # Boolean 384 | # Store the rest of the variables 385 | self.sigma = nonlinearityHidden 386 | self.rho = nonlinearityOutput 387 | self.nonlinearityReadout = nonlinearityReadout 388 | self.dimReadout = dimReadout 389 | #\\\ Hidden State RNN \\\ 390 | # Create the layer that generates the hidden state, and generate z0 391 | self.hiddenState = gml.HiddenState_DB(self.F, self.H, self.K[0], 392 | nonlinearity = self.sigma, E = self.E, 393 | bias = self.bias) 394 | #\\\ Output Graph Filters \\\ 395 | self.outputState = gml.GraphFilter_DB(self.H, self.G, self.K[1], 396 | E = self.E, bias = self.bias) 397 | #\\\ MLP (Fully Connected Layers) \\\ 398 | fc = [] 399 | if len(self.dimReadout) > 0: # Maybe we don't want to readout anything 400 | # The first layer has to connect whatever was left of the graph 401 | # filtering stage to create the number of features required by 402 | # the readout layer 403 | fc.append(nn.Linear(self.G, dimReadout[0], bias = self.bias)) 404 | # The last linear layer cannot be followed by nonlinearity, because 405 | # usually, this nonlinearity depends on the loss function (for 406 | # instance, if we have a classification problem, this nonlinearity 407 | # is already handled by the cross entropy loss or we add a softmax.) 408 | for l in range(len(dimReadout)-1): 409 | # Add the nonlinearity because there's another linear layer 410 | # coming 411 | fc.append(self.nonlinearityReadout()) 412 | # And add the linear layer 413 | fc.append(nn.Linear(dimReadout[l], dimReadout[l+1], 414 | bias = self.bias)) 415 | # And we're done 416 | self.Readout = nn.Sequential(*fc) 417 | # so we finally have the architecture. 418 | 419 | def splitForward(self, x, S): 420 | 421 | # Check the dimensions of the input 422 | # S: B x T (x E) x N x N 423 | # x: B x T x F[0] x N 424 | assert len(S.shape) == 4 or len(S.shape) == 5 425 | if len(S.shape) == 4: 426 | S = S.unsqueeze(2) 427 | B = S.shape[0] 428 | T = S.shape[1] 429 | assert S.shape[2] == self.E 430 | N = S.shape[3] 431 | assert S.shape[4] == N 432 | 433 | assert len(x.shape) == 4 434 | assert x.shape[0] == B 435 | assert x.shape[1] == T 436 | assert x.shape[2] == self.F 437 | assert x.shape[3] == N 438 | 439 | # This can be generated here or generated outside of here, not clear yet 440 | # what's the most coherent option 441 | z0 = torch.randn((B, self.H, N), device = x.device) 442 | 443 | # Add the GSO for each graph filter 444 | self.hiddenState.addGSO(S) 445 | self.outputState.addGSO(S) 446 | 447 | # Compute the trajectory of hidden states 448 | z, _ = self.hiddenState(x, z0) 449 | # Compute the output trajectory from the hidden states 450 | yOut = self.outputState(z) 451 | yOut = self.rho(yOut) # Don't forget the nonlinearity! 452 | # B x T x G x N 453 | # Change the order, for the readout 454 | y = yOut.permute(0, 1, 3, 2) # B x T x N x G 455 | # And, feed it into the Readout layer 456 | y = self.Readout(y) # B x T x N x dimReadout[-1] 457 | # Reshape and return 458 | return y.permute(0, 1, 3, 2), yOut 459 | # B x T x dimReadout[-1] x N, B x T x dimFeatures[-1] x N 460 | 461 | def forward(self, x, S): 462 | 463 | # Most of the times, we just need the actual, last output. But, since in 464 | # this case, we also want to compare with the output of the GNN itself, 465 | # we need to create this other forward funciton that takes both outputs 466 | # (the GNN and the MLP) and returns only the MLP output in the proper 467 | # forward function. 468 | output, _ = self.splitForward(x, S) 469 | 470 | return output 471 | 472 | def singleNodeForward(self, x, S, nodes): 473 | 474 | # x is of shape B x T x F[0] x N 475 | batchSize = x.shape[0] 476 | N = x.shape[3] 477 | 478 | # nodes is either an int, or a list/np.array of ints of size B 479 | assert type(nodes) is int \ 480 | or type(nodes) is list \ 481 | or type(nodes) is np.ndarray 482 | 483 | # Let us start by building the selection matrix 484 | # This selection matrix has to be a matrix of shape 485 | # B x 1 x N[-1] x 1 486 | # so that when multiplying with the output of the forward, we get a 487 | # B x T x dimRedout[-1] x 1 488 | # and we just squeeze the last dimension 489 | 490 | # TODO: The big question here is if multiplying by a matrix is faster 491 | # than doing torch.index_select 492 | 493 | # Let's always work with numpy arrays to make it easier. 494 | if type(nodes) is int: 495 | # Change the node number to accommodate the new order 496 | nodes = self.order.index(nodes) 497 | # If it's int, make it a list and an array 498 | nodes = np.array([nodes], dtype=np.int) 499 | # And repeat for the number of batches 500 | nodes = np.tile(nodes, batchSize) 501 | if type(nodes) is list: 502 | newNodes = [self.order.index(n) for n in nodes] 503 | nodes = np.array(newNodes, dtype = np.int) 504 | elif type(nodes) is np.ndarray: 505 | newNodes = np.array([np.where(np.array(self.order) == n)[0][0] \ 506 | for n in nodes]) 507 | nodes = newNodes.astype(np.int) 508 | # Now, nodes is an np.int np.ndarray with shape batchSize 509 | 510 | # Build the selection matrix 511 | selectionMatrix = np.zeros([batchSize, 1, N, 1]) 512 | selectionMatrix[np.arange(batchSize), nodes, 0] = 1. 513 | # And convert it to a tensor 514 | selectionMatrix = torch.tensor(selectionMatrix, 515 | dtype = x.dtype, 516 | device = x.device) 517 | 518 | # Now compute the output 519 | y = self.forward(x, S) 520 | # This output is of size B x T x dimReadout[-1] x N 521 | 522 | # Multiply the output 523 | y = torch.matmul(y, selectionMatrix) 524 | # B x T x dimReadout[-1] x 1 525 | 526 | # Squeeze the last dimension and return 527 | return y.squeeze(3) 528 | 529 | class AggregationGNN_DB(nn.Module): 530 | """ 531 | AggregationGNN_DB: implement the aggregation GNN architecture with delayed 532 | time structure and batch GSOs 533 | 534 | Initialization: 535 | 536 | Input: 537 | /** Regular convolutional layers **/ 538 | dimFeatures (list of int): number of features on each layer 539 | nFilterTaps (list of int): number of filter taps on each layer 540 | bias (bool): include bias after graph filter on every layer 541 | >> Obs.: dimFeatures[0] is the number of features (the dimension 542 | of the node signals) of the data, where dimFeatures[l] is the 543 | dimension obtained at the output of layer l, l=1,...,L. 544 | Therefore, for L layers, len(dimFeatures) = L+1. Slightly 545 | different, nFilterTaps[l] is the number of filter taps for the 546 | filters implemented at layer l+1, thus len(nFilterTaps) = L. 547 | 548 | /** Activation function **/ 549 | nonlinearity (torch.nn): module from torch.nn non-linear activations 550 | 551 | /** Pooling **/ 552 | poolingFunction (torch.nn): module from torch.nn pooling layers 553 | poolingSize (list of int): size of the neighborhood to compute the 554 | summary from at each layer 555 | 556 | /** Readout layer **/ 557 | dimReadout (list of int): number of output hidden units of a 558 | sequence of fully connected layers after the filters have 559 | been applied 560 | 561 | /** Graph structure **/ 562 | dimEdgeFeatures (int): number of edge features 563 | nExchanges (int): maximum number of neighborhood exchanges 564 | 565 | Output: 566 | nn.Module with an Aggregation GNN architecture with the above 567 | specified characteristics. 568 | 569 | Forward call: 570 | 571 | Input: 572 | x (torch.tensor): input data of shape 573 | batchSize x timeSamples x dimFeatures x numberNodes 574 | GSO (torch.tensor): graph shift operator of shape 575 | batchSize x timeSamples (x dimEdgeFeatures) 576 | x numberNodes x numberNodes 577 | 578 | Output: 579 | y (torch.tensor): output data after being processed by the selection 580 | GNN; shape: batchSize x x timeSamples x dimReadout[-1] x nNodes 581 | """ 582 | def __init__(self, 583 | # Graph filtering 584 | dimFeatures, nFilterTaps, bias, 585 | # Nonlinearity 586 | nonlinearity, 587 | # Pooling 588 | poolingFunction, poolingSize, 589 | # MLP in the end 590 | dimReadout, 591 | # Structure 592 | dimEdgeFeatures, nExchanges): 593 | super().__init__() 594 | # dimNodeSignals should be a list and of size 1 more than nFilter taps. 595 | assert len(dimFeatures) == len(nFilterTaps) + 1 596 | # poolingSize also has to be a list of the same size 597 | assert len(poolingSize) == len(nFilterTaps) 598 | # Check whether the GSO has features or not. After that, always handle 599 | # it as a matrix of dimension E x N x N. 600 | 601 | # Store the values (using the notation in the paper): 602 | self.L = len(nFilterTaps) # Number of convolutional layers 603 | self.F = dimFeatures # Features 604 | self.K = nFilterTaps # Filter taps 605 | self.E = dimEdgeFeatures # Dimension of edge features 606 | self.bias = bias # Boolean 607 | self.sigma = nonlinearity 608 | self.rho = poolingFunction 609 | self.alpha = poolingSize # This acts as both the kernel_size and the 610 | # stride, so there is no overlap on the elements over which we take 611 | # the maximum (this is how it works as default) 612 | self.dimReadout = dimReadout 613 | self.nExchanges = nExchanges # Number of exchanges 614 | # Let's also record the number of nodes on each layer (L+1, actually) 615 | self.N = [self.nExchanges+1] # If we have one exchange, then we have 616 | # two entries in the collected vector (the zeroth-exchange the 617 | # first exchange) 618 | for l in range(self.L): 619 | # In pyTorch, the convolution is a valid correlation, instead of a 620 | # full one, which means that the output is smaller than the input. 621 | # Precisely, this smaller (check documentation for nn.conv1d) 622 | outConvN = self.N[l] - (self.K[l] - 1) # Size of the conv output 623 | # The next equation to compute the number of nodes is obtained from 624 | # the maxPool1d help in the pytorch documentation 625 | self.N += [int( 626 | (outConvN - (self.alpha[l]-1) - 1)/self.alpha[l] + 1 627 | )] 628 | # int() on a float always applies floor() 629 | 630 | # And now, we're finally ready to create the architecture: 631 | #\\\ Graph filtering layers \\\ 632 | # OBS.: We could join this for with the one before, but we keep separate 633 | # for clarity of code. 634 | convl = [] # Convolutional Layers 635 | for l in range(self.L): 636 | #\\ Graph filtering stage: 637 | convl.append(nn.Conv1d(self.F[l]*self.E, 638 | self.F[l+1]*self.E, 639 | self.K[l], 640 | bias = self.bias)) 641 | #\\ Nonlinearity 642 | convl.append(self.sigma()) 643 | #\\ Pooling 644 | convl.append(self.rho(self.alpha[l])) 645 | # And now feed them into the sequential 646 | self.ConvLayers = nn.Sequential(*convl) # Convolutional layers 647 | #\\\ MLP (Fully Connected Layers) \\\ 648 | fc = [] 649 | if len(self.dimReadout) > 0: # Maybe we don't want to MLP anything 650 | # The first layer has to connect whatever was left of the graph 651 | # signal, flattened. 652 | dimInputReadout = self.N[-1] * self.F[-1] * self.E 653 | # (i.e., we have N[-1] nodes left, each one described by F[-1] 654 | # features which means this will be flattened into a vector of size 655 | # N[-1]*F[-1]) 656 | fc.append(nn.Linear(dimInputReadout,dimReadout[0],bias=self.bias)) 657 | # The last linear layer cannot be followed by nonlinearity, because 658 | # usually, this nonlinearity depends on the loss function (for 659 | # instance, if we have a classification problem, this nonlinearity 660 | # is already handled by the cross entropy loss or we add a softmax.) 661 | for l in range(len(dimReadout)-1): 662 | # Add the nonlinearity because there's another linear layer 663 | # coming 664 | fc.append(self.sigma()) 665 | # And add the linear layer 666 | fc.append(nn.Linear(dimReadout[l], dimReadout[l+1], 667 | bias = self.bias)) 668 | # And we're done within each node 669 | self.Readout = nn.Sequential(*fc) 670 | 671 | def forward(self, x, S): 672 | 673 | # Check the dimensions of the input first 674 | # S: B x T (x E) x N x N 675 | # x: B x T x F[0] x N 676 | assert len(S.shape) == 4 or len(S.shape) == 5 677 | if len(S.shape) == 4: 678 | # Then S is B x T x N x N 679 | S = S.unsqueeze(2) # And we want it B x T x 1 x N x N 680 | B = S.shape[0] 681 | T = S.shape[1] 682 | assert S.shape[2] == self.E 683 | N = S.shape[3] 684 | assert S.shape[4] == N 685 | # Check the dimensions of x 686 | assert len(x.shape) == 4 687 | assert x.shape[0] == B 688 | assert x.shape[1] == T 689 | assert x.shape[2] == self.F[0] 690 | assert x.shape[3] == N 691 | 692 | # Now we need to do the exchange to build the aggregation vector at 693 | # every node 694 | # z has to be of shape: B x T x F[0] x (nExchanges+1) x N 695 | # to be fed into conv1d it has to be (B*T*N) x F[0] x (nExchanges+1) 696 | 697 | # This vector is built by multiplying x with S, so we need to adapt x 698 | # to have a dimension that can be multiplied by S (we need to add the 699 | # E dimension) 700 | x = x.reshape([B, T, 1, self.F[0], N]).repeat(1, 1, self.E, 1, 1) 701 | 702 | # The first element of z is, precisely, this element (no exchanges) 703 | z = x.reshape([B, T, 1, self.E, self.F[0], N]) # The new dimension is 704 | # the one that accumulates the nExchanges 705 | 706 | # Now we start with the exchanges (multiplying by S) 707 | for k in range(1, self.nExchanges+1): 708 | # Across dim = 1 (time) we need to "displace the dimension down", 709 | # i.e. where it used to be t = 1 we now need it to be t=0 and so 710 | # on. For t=0 we add a "row" of zeros. 711 | x, _ = torch.split(x, [T-1, 1], dim = 1) 712 | # The second part is the most recent time instant which we do 713 | # not need anymore (it's used only once for the first value of K) 714 | # Now, we need to add a "row" of zeros at the beginning (for t = 0) 715 | zeroRow = torch.zeros(B, 1, self.E, self.F[0], N, 716 | dtype=x.dtype,device=x.device) 717 | x = torch.cat((zeroRow, x), dim = 1) 718 | # And now we multiply with S 719 | x = torch.matmul(x, S) 720 | # Add the dimension along K 721 | xS = x.reshape(B, T, 1, self.E, self.F[0], N) 722 | # And concatenate it with z 723 | z = torch.cat((z, xS), dim = 2) 724 | 725 | # Now, we have finally built the vector of delayed aggregations. This 726 | # vector has shape B x T x (nExchanges+1) x E x F[0] x N 727 | # To get rid of the edge features (dim E) we just sum through that 728 | # dimension 729 | z = torch.sum(z, dim = 3) # B x T x (nExchanges+1) x F[0] x N 730 | # It is, essentially, a matrix of N x (nExchanges+1) for each feature, 731 | # for each time instant, for each batch. 732 | # NOTE1: This is inconsequential if self.E = 1 (most of the cases) 733 | # NOTE2: Alternatively, not to lose information, we could contatenate 734 | # dim E after dim F[0] to get E*F[0] features; this increases the 735 | # dimensionsonality of the data (which could be fine) but need to be 736 | # adapted so that the first input in the conv1d takes self.E*self.F[0] 737 | # features instead of just self.F[0] 738 | 739 | # The operation conv1d takes tensors of shape 740 | # batchSize x nFeatures x nEntries 741 | # This means that the convolution takes place along nEntries with 742 | # a summation along nFeatures, for each of the elements along 743 | # batchSize. So we need to put (nExchanges+1) last since it is along 744 | # those elements that we want the convolution to be performed, and 745 | # we need to put F[0] as nFeatures since there is where we want the 746 | # features to be combined. The other three dimensions are different 747 | # elements (agents, time, batch) to which the convolution needs to be 748 | # applied. 749 | # Therefore, we want a vector z of shape 750 | # (B*T*N) x F[0] x (nExchanges+1) 751 | 752 | # Let's get started with this reorganization 753 | # First, we join B*T*N. Because we always join the last dimensions, 754 | # we need to permute first to put B, T, N as the last dimensions. 755 | # z: B x T x (nExchanges+1) x F[0] x N 756 | z = z.permute(3, 2, 0, 1, 4) # F[0] x (nExchanges+1) x B x T x N 757 | z = z.reshape([self.F[0], self.nExchanges+1, B*T*N]) 758 | # F[0] x (nExchanges+1) x B*T*N 759 | # Second, we put it back at the beginning 760 | z = z.permute(2, 0, 1) # B*T*N x F[0] x (nExchanges+1) 761 | 762 | # Let's call the convolutional layers 763 | y = self.ConvLayers(z) 764 | # B*T*N x F[-1] x N[-1] 765 | # Flatten the output 766 | y = y.reshape([B*T*N, self.F[-1] * self.N[-1]]) 767 | # And, feed it into the per node readout layers 768 | y = self.Readout(y) # (B*T*N) x dimReadout[-1] 769 | # And now we have to unpack it back for every node, i.e. to get it 770 | # back to shape B x T x N x dimReadout[-1] 771 | y = y.permute(1, 0) # dimReadout[-1] x (B*T*N) 772 | y = y.reshape(self.dimReadout[-1], B, T, N) 773 | # And finally put it back to the usual B x T x F x N 774 | y = y.permute(1, 2, 0, 3) 775 | return y 776 | 777 | def to(self, device): 778 | # Because only the filter taps and the weights are registered as 779 | # parameters, when we do a .to(device) operation it does not move the 780 | # GSOs. So we need to move them ourselves. 781 | # Call the parent .to() method (to move the registered parameters) 782 | super().to(device) 783 | -------------------------------------------------------------------------------- /alegnn/modules/evaluation.py: -------------------------------------------------------------------------------- 1 | # 2020/02/25~ 2 | # Fernando Gama, fgama@seas.upenn.edu 3 | # Luana Ruiz, rubruiz@seas.upenn.edu 4 | """ 5 | evaluation.py Evaluation Module 6 | 7 | Methods for evaluating the models. 8 | 9 | evaluate: evaluate a model 10 | evaluateSingleNode: evaluate a model that has a single node forward 11 | evaluateFlocking: evaluate a model using the flocking cost 12 | """ 13 | 14 | import os 15 | import torch 16 | import pickle 17 | 18 | def evaluate(model, data, **kwargs): 19 | """ 20 | evaluate: evaluate a model using classification error 21 | 22 | Input: 23 | model (model class): class from Modules.model 24 | data (data class): a data class from the Utils.dataTools; it needs to 25 | have a getSamples method and an evaluate method. 26 | doPrint (optional, bool): if True prints results 27 | 28 | Output: 29 | evalVars (dict): 'errorBest' contains the error rate for the best 30 | model, and 'errorLast' contains the error rate for the last model 31 | """ 32 | 33 | # Get the device we're working on 34 | device = model.device 35 | 36 | if 'doSaveVars' in kwargs.keys(): 37 | doSaveVars = kwargs['doSaveVars'] 38 | else: 39 | doSaveVars = True 40 | 41 | ######## 42 | # DATA # 43 | ######## 44 | 45 | xTest, yTest = data.getSamples('test') 46 | xTest = xTest.to(device) 47 | yTest = yTest.to(device) 48 | 49 | ############## 50 | # BEST MODEL # 51 | ############## 52 | 53 | model.load(label = 'Best') 54 | 55 | with torch.no_grad(): 56 | # Process the samples 57 | yHatTest = model.archit(xTest) 58 | # yHatTest is of shape 59 | # testSize x numberOfClasses 60 | # We compute the error 61 | costBest = data.evaluate(yHatTest, yTest) 62 | 63 | ############## 64 | # LAST MODEL # 65 | ############## 66 | 67 | model.load(label = 'Last') 68 | 69 | with torch.no_grad(): 70 | # Process the samples 71 | yHatTest = model.archit(xTest) 72 | # yHatTest is of shape 73 | # testSize x numberOfClasses 74 | # We compute the error 75 | costLast = data.evaluate(yHatTest, yTest) 76 | 77 | evalVars = {} 78 | evalVars['costBest'] = costBest.item() 79 | evalVars['costLast'] = costLast.item() 80 | 81 | if doSaveVars: 82 | saveDirVars = os.path.join(model.saveDir, 'evalVars') 83 | if not os.path.exists(saveDirVars): 84 | os.makedirs(saveDirVars) 85 | pathToFile = os.path.join(saveDirVars, model.name + 'evalVars.pkl') 86 | with open(pathToFile, 'wb') as evalVarsFile: 87 | pickle.dump(evalVars, evalVarsFile) 88 | 89 | return evalVars 90 | 91 | def evaluateSingleNode(model, data, **kwargs): 92 | """ 93 | evaluateSingleNode: evaluate a model that has a single node forward 94 | 95 | Input: 96 | model (model class): class from Modules.model, needs to have a 97 | 'singleNodeForward' method 98 | data (data class): a data class from the Utils.dataTools; it needs to 99 | have a getSamples method and an evaluate method and it also needs to 100 | have a 'getLabelID' method 101 | doPrint (optional, bool): if True prints results 102 | 103 | Output: 104 | evalVars (dict): 'errorBest' contains the error rate for the best 105 | model, and 'errorLast' contains the error rate for the last model 106 | """ 107 | 108 | assert 'singleNodeForward' in dir(model.archit) 109 | assert 'getLabelID' in dir(data) 110 | 111 | # Get the device we're working on 112 | device = model.device 113 | 114 | if 'doSaveVars' in kwargs.keys(): 115 | doSaveVars = kwargs['doSaveVars'] 116 | else: 117 | doSaveVars = True 118 | 119 | ######## 120 | # DATA # 121 | ######## 122 | 123 | xTest, yTest = data.getSamples('test') 124 | xTest = xTest.to(device) 125 | yTest = yTest.to(device) 126 | targetIDs = data.getLabelID('test') 127 | 128 | ############## 129 | # BEST MODEL # 130 | ############## 131 | 132 | model.load(label = 'Best') 133 | 134 | with torch.no_grad(): 135 | # Process the samples 136 | yHatTest = model.archit.singleNodeForward(xTest, targetIDs) 137 | # yHatTest is of shape 138 | # testSize x numberOfClasses 139 | # We compute the error 140 | costBest = data.evaluate(yHatTest, yTest) 141 | 142 | ############## 143 | # LAST MODEL # 144 | ############## 145 | 146 | model.load(label = 'Last') 147 | 148 | with torch.no_grad(): 149 | # Process the samples 150 | yHatTest = model.archit.singleNodeForward(xTest, targetIDs) 151 | # yHatTest is of shape 152 | # testSize x numberOfClasses 153 | # We compute the error 154 | costLast = data.evaluate(yHatTest, yTest) 155 | 156 | evalVars = {} 157 | evalVars['costBest'] = costBest.item() 158 | evalVars['costLast'] = costLast.item() 159 | 160 | if doSaveVars: 161 | saveDirVars = os.path.join(model.saveDir, 'evalVars') 162 | if not os.path.exists(saveDirVars): 163 | os.makedirs(saveDirVars) 164 | pathToFile = os.path.join(saveDirVars, model.name + 'evalVars.pkl') 165 | with open(pathToFile, 'wb') as evalVarsFile: 166 | pickle.dump(evalVars, evalVarsFile) 167 | 168 | return evalVars 169 | 170 | def evaluateFlocking(model, data, **kwargs): 171 | """ 172 | evaluateClassif: evaluate a model using the flocking cost of velocity 173 | variacne of the team 174 | 175 | Input: 176 | model (model class): class from Modules.model 177 | data (data class): the data class that generates the flocking data 178 | doPrint (optional; bool, default: True): if True prints results 179 | nVideos (optional; int, default: 3): number of videos to save 180 | graphNo (optional): identify the run with a number 181 | realizationNo (optional): identify the run with another number 182 | 183 | Output: 184 | evalVars (dict): 185 | 'costBestFull': cost of the best model over the full trajectory 186 | 'costBestEnd': cost of the best model at the end of the trajectory 187 | 'costLastFull': cost of the last model over the full trajectory 188 | 'costLastEnd': cost of the last model at the end of the trajectory 189 | """ 190 | 191 | if 'doPrint' in kwargs.keys(): 192 | doPrint = kwargs['doPrint'] 193 | else: 194 | doPrint = True 195 | 196 | if 'nVideos' in kwargs.keys(): 197 | nVideos = kwargs['nVideos'] 198 | else: 199 | nVideos = 3 200 | 201 | if 'graphNo' in kwargs.keys(): 202 | graphNo = kwargs['graphNo'] 203 | else: 204 | graphNo = -1 205 | 206 | if 'realizationNo' in kwargs.keys(): 207 | if 'graphNo' in kwargs.keys(): 208 | realizationNo = kwargs['realizationNo'] 209 | else: 210 | graphNo = kwargs['realizationNo'] 211 | realizationNo = -1 212 | else: 213 | realizationNo = -1 214 | 215 | #\\\\\\\\\\\\\\\\\\\\ 216 | #\\\ TRAJECTORIES \\\ 217 | #\\\\\\\\\\\\\\\\\\\\ 218 | 219 | ######## 220 | # DATA # 221 | ######## 222 | 223 | # Initial data 224 | initPosTest = data.getData('initPos', 'test') 225 | initVelTest = data.getData('initVel', 'test') 226 | 227 | ############## 228 | # BEST MODEL # 229 | ############## 230 | 231 | model.load(label = 'Best') 232 | 233 | if doPrint: 234 | print("\tComputing learned trajectory for best model...", 235 | end = ' ', flush = True) 236 | 237 | posTestBest, \ 238 | velTestBest, \ 239 | accelTestBest, \ 240 | stateTestBest, \ 241 | commGraphTestBest = \ 242 | data.computeTrajectory(initPosTest, initVelTest, data.duration, 243 | archit = model.archit) 244 | 245 | if doPrint: 246 | print("OK") 247 | 248 | ############## 249 | # LAST MODEL # 250 | ############## 251 | 252 | model.load(label = 'Last') 253 | 254 | if doPrint: 255 | print("\tComputing learned trajectory for last model...", 256 | end = ' ', flush = True) 257 | 258 | posTestLast, \ 259 | velTestLast, \ 260 | accelTestLast, \ 261 | stateTestLast, \ 262 | commGraphTestLast = \ 263 | data.computeTrajectory(initPosTest, initVelTest, data.duration, 264 | archit = model.archit) 265 | 266 | if doPrint: 267 | print("OK") 268 | 269 | ########### 270 | # PREVIEW # 271 | ########### 272 | 273 | learnedTrajectoriesDir = os.path.join(model.saveDir, 274 | 'learnedTrajectories') 275 | 276 | if not os.path.exists(learnedTrajectoriesDir): 277 | os.mkdir(learnedTrajectoriesDir) 278 | 279 | if graphNo > -1: 280 | learnedTrajectoriesDir = os.path.join(learnedTrajectoriesDir, 281 | '%03d' % graphNo) 282 | if not os.path.exists(learnedTrajectoriesDir): 283 | os.mkdir(learnedTrajectoriesDir) 284 | if realizationNo > -1: 285 | learnedTrajectoriesDir = os.path.join(learnedTrajectoriesDir, 286 | '%03d' % realizationNo) 287 | if not os.path.exists(learnedTrajectoriesDir): 288 | os.mkdir(learnedTrajectoriesDir) 289 | 290 | learnedTrajectoriesDir = os.path.join(learnedTrajectoriesDir, model.name) 291 | 292 | if not os.path.exists(learnedTrajectoriesDir): 293 | os.mkdir(learnedTrajectoriesDir) 294 | 295 | if doPrint: 296 | print("\tPreview data...", 297 | end = ' ', flush = True) 298 | 299 | data.saveVideo(os.path.join(learnedTrajectoriesDir,'Best'), 300 | posTestBest, 301 | nVideos, 302 | commGraph = commGraphTestBest, 303 | vel = velTestBest, 304 | videoSpeed = 0.5, 305 | doPrint = False) 306 | 307 | data.saveVideo(os.path.join(learnedTrajectoriesDir,'Last'), 308 | posTestLast, 309 | nVideos, 310 | commGraph = commGraphTestLast, 311 | vel = velTestLast, 312 | videoSpeed = 0.5, 313 | doPrint = False) 314 | 315 | if doPrint: 316 | print("OK", flush = True) 317 | 318 | #\\\\\\\\\\\\\\\\\\ 319 | #\\\ EVALUATION \\\ 320 | #\\\\\\\\\\\\\\\\\\ 321 | 322 | evalVars = {} 323 | evalVars['costBestFull'] = data.evaluate(vel = velTestBest) 324 | evalVars['costBestEnd'] = data.evaluate(vel = velTestBest[:,-1:,:,:]) 325 | evalVars['costLastFull'] = data.evaluate(vel = velTestLast) 326 | evalVars['costLastEnd'] = data.evaluate(vel = velTestLast[:,-1:,:,:]) 327 | 328 | return evalVars -------------------------------------------------------------------------------- /alegnn/modules/loss.py: -------------------------------------------------------------------------------- 1 | # 2021/03/04~ 2 | # Fernando Gama, fgama@seas.upenn.edu 3 | # Luana Ruiz, rubruiz@seas.upenn.edu 4 | """ 5 | loss.py Loss functions 6 | 7 | adaptExtraDimensionLoss: wrapper that handles extra dimensions 8 | F1Score: loss function corresponding to 1 - F1 score 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | # An arbitrary loss function handling penalties needs to have the following 15 | # conditions 16 | # .penaltyList attribute listing the names of the penalties 17 | # .nPenalties attibute is an int with the number of penalties 18 | # Forward function has to output the actual loss, the main loss (with no 19 | # penalties), and a dictionary with the value of each of the penalties. 20 | # This will be standard procedure for all loss functions that have penalties. 21 | # Note: The existence of a penalty will be signaled by an attribute in the model 22 | 23 | class adaptExtraDimensionLoss(nn.modules.loss._Loss): 24 | """ 25 | adaptExtraDimensionLoss: wrapper that handles extra dimensions 26 | 27 | Some loss functions take vectors as inputs while others take scalars; if we 28 | input a one-dimensional vector instead of a scalar, although virtually the 29 | same, the loss function could complain. 30 | 31 | The output of the GNNs is, by default, a vector. And sometimes we want it 32 | to still be a vector (i.e. crossEntropyLoss where we output a one-hot 33 | vector) and sometimes we want it to be treated as a scalar (i.e. MSELoss). 34 | Since we still have a single training function to train multiple models, we 35 | do not know whether we will have a scalar or a vector. So this wrapper 36 | adapts the input to the loss function seamlessly. 37 | 38 | Eventually, more loss functions could be added to the code below to better 39 | handle their dimensions. 40 | 41 | Initialization: 42 | 43 | Input: 44 | lossFunction (torch.nn loss function): desired loss function 45 | arguments: arguments required to initialize the loss function 46 | >> Obs.: The loss function gets initialized as well 47 | 48 | Forward: 49 | Input: 50 | estimate (torch.tensor): output of the GNN 51 | target (torch.tensor): target representation 52 | """ 53 | 54 | # When we want to compare scalars, we will have a B x 1 output of the GNN, 55 | # since the number of features is always there. However, most of the scalar 56 | # comparative functions take just a B vector, so we have an extra 1 dim 57 | # that raises a warning. This container will simply get rid of it. 58 | 59 | # This allows to change loss from crossEntropy (class based, expecting 60 | # B x C input) to MSE or SmoothL1Loss (expecting B input) 61 | 62 | def __init__(self, lossFunction, *args): 63 | # The second argument is optional and it is if there are any extra 64 | # arguments with which we want to initialize the loss 65 | 66 | super().__init__() 67 | 68 | if len(args) > 0: 69 | self.loss = lossFunction(*args) # Initialize loss function 70 | else: 71 | self.loss = lossFunction() 72 | 73 | def forward(self, estimate, target): 74 | 75 | # What we're doing here is checking what kind of loss it is and 76 | # what kind of reshape we have to do on the estimate 77 | 78 | if 'CrossEntropyLoss' in repr(self.loss): 79 | # This is supposed to be a one-hot vector batchSize x nClasses 80 | assert len(estimate.shape) == 2 81 | elif 'SmoothL1Loss' in repr(self.loss) \ 82 | or 'MSELoss' in repr(self.loss) \ 83 | or 'L1Loss' in repr(self.loss): 84 | # In this case, the estimate has to be a batchSize tensor, so if 85 | # it has two dimensions, the second dimension has to be 1 86 | if len(estimate.shape) == 2: 87 | assert estimate.shape[1] == 1 88 | estimate = estimate.squeeze(1) 89 | assert len(estimate.shape) == 1 90 | 91 | return self.loss(estimate, target) 92 | 93 | def F1Score(yHat, y): 94 | # Luana R. Ruiz, rubruiz@seas.upenn.edu, 2021/03/04 95 | dimensions = len(yHat.shape) 96 | C = yHat.shape[dimensions-2] 97 | N = yHat.shape[dimensions-1] 98 | yHat = yHat.reshape((-1,C,N)) 99 | yHat = torch.nn.functional.log_softmax(yHat, dim=1) 100 | yHat = torch.exp(yHat) 101 | yHat = yHat[:,1,:] 102 | y = y.reshape((-1,N)) 103 | 104 | tp = torch.sum(y*yHat,1) 105 | #tn = torch.sum((1-y)*(1-yHat),1) 106 | fp = torch.sum((1-y)*yHat,1) 107 | fn = torch.sum(y*(1-yHat),1) 108 | 109 | p = tp / (tp + fp) 110 | r = tp / (tp + fn) 111 | 112 | idx_p = p!=p 113 | idx_tp = tp==0 114 | idx_p1 = idx_p*idx_tp 115 | p[idx_p] = 0 116 | p[idx_p1] = 1 117 | idx_r = r!=r 118 | idx_r1 = idx_r*idx_tp 119 | r[idx_r] = 0 120 | r[idx_r1] = 1 121 | 122 | f1 = 2*p*r / (p+r) 123 | f1[f1!=f1] = 0 124 | 125 | return 1 - torch.mean(f1) -------------------------------------------------------------------------------- /alegnn/modules/model.py: -------------------------------------------------------------------------------- 1 | # 2018/10/02~ 2 | # Fernando Gama, fgama@seas.upenn.edu 3 | # Luana Ruiz, rubruiz@seas.upenn.edu 4 | """ 5 | model.py Model Module 6 | 7 | Utilities useful for working on the model 8 | 9 | Model: binds together the architecture, the loss function, the optimizer, 10 | the trainer, and the evaluator. 11 | """ 12 | 13 | import os 14 | import torch 15 | 16 | class Model: 17 | """ 18 | Model: binds together the architecture, the loss function, the optimizer, 19 | the trainer, and the evaluator. 20 | 21 | Initialization: 22 | 23 | architecture (nn.Module) 24 | loss (nn.modules.loss._Loss) 25 | optimizer (nn.optim) 26 | trainer (Modules.training) 27 | evaluator (Modules.evaluation) 28 | device (string or device) 29 | name (string) 30 | saveDir (string or path) 31 | 32 | .train(data, nEpochs, batchSize, **kwargs): train the model for nEpochs 33 | epochs, using batches of size batchSize and running over data data 34 | class; see the specific selected trainer for extra options 35 | 36 | .evaluate(data): evaluate the model over data data class; see the specific 37 | selected evaluator for extra options 38 | 39 | .save(label = '', [saveDir=dirPath]): save the model parameters under the 40 | name given by label, if the saveDir is different from the one specified 41 | in the initialization, it needs to be specified now 42 | 43 | .load(label = '', [loadFiles=(architLoadFile, optimLoadFile)]): loads the 44 | model parameters under the specified name inside the specific saveDir, 45 | unless they are provided externally through the keyword 'loadFiles'. 46 | 47 | .getTrainingOptions(): get a dict with the options used during training; it 48 | returns None if it hasn't been trained yet.' 49 | """ 50 | 51 | def __init__(self, 52 | # Architecture (nn.Module) 53 | architecture, 54 | # Loss Function (nn.modules.loss._Loss) 55 | loss, 56 | # Optimization Algorithm (nn.optim) 57 | optimizer, 58 | # Training Algorithm (Modules.training) 59 | trainer, 60 | # Evaluating Algorithm (Modules.evaluation) 61 | evaluator, 62 | # Other 63 | device, name, saveDir): 64 | 65 | #\\\ ARCHITECTURE 66 | # Store 67 | self.archit = architecture 68 | # Move it to device 69 | self.archit.to(device) 70 | # Count parameters (doesn't work for EdgeVarying) 71 | self.nParameters = 0 72 | for param in list(self.archit.parameters()): 73 | if len(param.shape)>0: 74 | thisNParam = 1 75 | for p in range(len(param.shape)): 76 | thisNParam *= param.shape[p] 77 | self.nParameters += thisNParam 78 | else: 79 | pass 80 | #\\\ LOSS FUNCTION 81 | self.loss = loss 82 | #\\\ OPTIMIZATION ALGORITHM 83 | self.optim = optimizer 84 | #\\\ TRAINING ALGORITHM 85 | self.trainer = trainer 86 | #\\\ EVALUATING ALGORITHM 87 | self.evaluator = evaluator 88 | #\\\ OTHER 89 | # Device 90 | self.device = device 91 | # Model name 92 | self.name = name 93 | # Saving directory 94 | self.saveDir = saveDir 95 | 96 | def train(self, data, nEpochs, batchSize, **kwargs): 97 | 98 | self.trainer = self.trainer(self, data, nEpochs, batchSize, **kwargs) 99 | 100 | return self.trainer.train() 101 | 102 | def evaluate(self, data, **kwargs): 103 | 104 | return self.evaluator(self, data, **kwargs) 105 | 106 | def save(self, label = '', **kwargs): 107 | if 'saveDir' in kwargs.keys(): 108 | saveDir = kwargs['saveDir'] 109 | else: 110 | saveDir = self.saveDir 111 | saveModelDir = os.path.join(saveDir,'savedModels') 112 | # Create directory savedModels if it doesn't exist yet: 113 | if not os.path.exists(saveModelDir): 114 | os.makedirs(saveModelDir) 115 | saveFile = os.path.join(saveModelDir, self.name) 116 | torch.save(self.archit.state_dict(), saveFile+'Archit'+ label+'.ckpt') 117 | torch.save(self.optim.state_dict(), saveFile+'Optim'+label+'.ckpt') 118 | 119 | def load(self, label = '', **kwargs): 120 | if 'loadFiles' in kwargs.keys(): 121 | (architLoadFile, optimLoadFile) = kwargs['loadFiles'] 122 | else: 123 | saveModelDir = os.path.join(self.saveDir,'savedModels') 124 | architLoadFile = os.path.join(saveModelDir, 125 | self.name + 'Archit' + label +'.ckpt') 126 | optimLoadFile = os.path.join(saveModelDir, 127 | self.name + 'Optim' + label + '.ckpt') 128 | self.archit.load_state_dict(torch.load(architLoadFile)) 129 | self.optim.load_state_dict(torch.load(optimLoadFile)) 130 | 131 | def getTrainingOptions(self): 132 | 133 | return self.trainer.trainingOptions \ 134 | if 'trainingOptions' in dir(self.trainer) \ 135 | else None 136 | 137 | def __repr__(self): 138 | reprString = "Name: %s\n" % (self.name) 139 | reprString += "Number of learnable parameters: %d\n"%(self.nParameters) 140 | reprString += "\n" 141 | reprString += "Model architecture:\n" 142 | reprString += "----- -------------\n" 143 | reprString += "\n" 144 | reprString += repr(self.archit) + "\n" 145 | reprString += "\n" 146 | reprString += "Loss function:\n" 147 | reprString += "---- ---------\n" 148 | reprString += "\n" 149 | reprString += repr(self.loss) + "\n" 150 | reprString += "\n" 151 | reprString += "Optimizer:\n" 152 | reprString += "----------\n" 153 | reprString += "\n" 154 | reprString += repr(self.optim) + "\n" 155 | reprString += "Training algorithm:\n" 156 | reprString += "-------- ----------\n" 157 | reprString += "\n" 158 | reprString += repr(self.trainer) + "\n" 159 | reprString += "Evaluation algorithm:\n" 160 | reprString += "---------- ----------\n" 161 | reprString += "\n" 162 | reprString += repr(self.evaluator) + "\n" 163 | return reprString 164 | -------------------------------------------------------------------------------- /alegnn/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alelab-upenn/graph-neural-networks/a84a39fabad5378bdcbaad20b5dbcff14b4eebcd/alegnn/utils/__init__.py -------------------------------------------------------------------------------- /alegnn/utils/miscTools.py: -------------------------------------------------------------------------------- 1 | # 2018/10/15~ 2 | # Fernando Gama, fgama@seas.upenn.edu. 3 | # Luana Ruiz, rubruiz@seas.upenn.edu. 4 | """ 5 | miscTools Miscellaneous Tools module 6 | 7 | num2filename: change a numerical value into a string usable as a filename 8 | saveSeed: save the random state of generators 9 | loadSeed: load the number of random state of generators 10 | writeVarValues: write the specified values in the specified txt file 11 | """ 12 | 13 | import os 14 | import pickle 15 | import numpy as np 16 | import torch 17 | 18 | def num2filename(x,d): 19 | """ 20 | Takes a number and returns a string with the value of the number, but in a 21 | format that is writable into a filename. 22 | 23 | s = num2filename(x,d) Gets rid of decimal points which are usually 24 | inconvenient to have in a filename. 25 | If the number x is an integer, then s = str(int(x)). 26 | If the number x is a decimal number, then it replaces the '.' by the 27 | character specified by d. Setting d = '' erases the decimal point, 28 | setting d = '.' simply returns a string with the exact same number. 29 | 30 | Example: 31 | >> num2filename(2,'d') 32 | >> '2' 33 | 34 | >> num2filename(3.1415,'d') 35 | >> '3d1415' 36 | 37 | >> num2filename(3.1415,'') 38 | >> '31415' 39 | 40 | >> num2filename(3.1415,'.') 41 | >> '3.1415' 42 | """ 43 | if x == int(x): 44 | return str(int(x)) 45 | else: 46 | return str(x).replace('.',d) 47 | 48 | def saveSeed(randomStates, saveDir): 49 | """ 50 | Takes a list of dictionaries of random generator states of different modules 51 | and saves them in a .pkl format. 52 | 53 | Inputs: 54 | randomStates (list): The length of this list is equal to the number of 55 | modules whose states want to be saved (torch, numpy, etc.). Each 56 | element in this list is a dictionary. The dictionary has three keys: 57 | 'module' with the name of the module in string format ('numpy' or 58 | 'torch', for example), 'state' with the saved generator state and, 59 | if corresponds, 'seed' with the specific seed for the generator 60 | (note that torch has both state and seed, but numpy only has state) 61 | saveDir (path): where to save the seed, it will be saved under the 62 | filename 'randomSeedUsed.pkl' 63 | """ 64 | pathToSeed = os.path.join(saveDir, 'randomSeedUsed.pkl') 65 | with open(pathToSeed, 'wb') as seedFile: 66 | pickle.dump({'randomStates': randomStates}, seedFile) 67 | 68 | def loadSeed(loadDir): 69 | """ 70 | Loads the states and seed saved in a specified path 71 | 72 | Inputs: 73 | loadDir (path): where to look for thee seed to load; it is expected that 74 | the appropriate file within loadDir is named 'randomSeedUsed.pkl' 75 | 76 | Obs.: The file 'randomSeedUsed.pkl' should contain a list structured as 77 | follows. The length of this list is equal to the number of modules whose 78 | states were saved (torch, numpy, etc.). Each element in this list is a 79 | dictionary. The dictionary has three keys: 'module' with the name of 80 | the module in string format ('numpy' or 'torch', for example), 'state' 81 | with the saved generator state and, if corresponds, 'seed' with the 82 | specific seed for the generator (note that torch has both state and 83 | seed, but numpy only has state) 84 | """ 85 | pathToSeed = os.path.join(loadDir, 'randomSeedUsed.pkl') 86 | with open(pathToSeed, 'rb') as seedFile: 87 | randomStates = pickle.load(seedFile) 88 | randomStates = randomStates['randomStates'] 89 | for module in randomStates: 90 | thisModule = module['module'] 91 | if thisModule == 'numpy': 92 | np.random.RandomState().set_state(module['state']) 93 | elif thisModule == 'torch': 94 | torch.set_rng_state(module['state']) 95 | torch.manual_seed(module['seed']) 96 | 97 | 98 | def writeVarValues(fileToWrite, varValues): 99 | """ 100 | Write the value of several string variables specified by a dictionary into 101 | the designated .txt file. 102 | 103 | Input: 104 | fileToWrite (os.path): text file to save the specified variables 105 | varValues (dictionary): values to save in the text file. They are 106 | saved in the format "key = value". 107 | """ 108 | with open(fileToWrite, 'a+') as file: 109 | for key in varValues.keys(): 110 | file.write('%s = %s\n' % (key, varValues[key])) 111 | file.write('\n') 112 | -------------------------------------------------------------------------------- /alegnn/utils/visualTools.py: -------------------------------------------------------------------------------- 1 | # 2019/01/21~2018/07/12 2 | # This function is taken almost verbatim from https://github.com/amaiasalvador 3 | # and all credit should go to Amaia Salvador. 4 | 5 | import os 6 | import glob 7 | import torchvision.utils as vutils 8 | from operator import itemgetter 9 | from tensorboardX import SummaryWriter 10 | 11 | class Visualizer(): 12 | def __init__(self, checkpoints_dir, name): 13 | self.win_size = 256 14 | self.name = name 15 | self.saved = False 16 | self.checkpoints_dir = checkpoints_dir 17 | self.ncols = 4 18 | 19 | # remove existing 20 | for filename in glob.glob(self.checkpoints_dir+"/events*"): 21 | os.remove(filename) 22 | self.writer = SummaryWriter(checkpoints_dir) 23 | 24 | def reset(self): 25 | self.saved = False 26 | 27 | # images: (b, c, 0, 1) array of images 28 | def image_summary(self, mode, epoch, images): 29 | images = vutils.make_grid(images, normalize=True, scale_each=True) 30 | self.writer.add_image('{}/Image'.format(mode), images, epoch) 31 | 32 | # figure (for matplotlib figures) 33 | def figure_summary(self, mode, epoch, fig): 34 | self.writer.add_figure('{}/Figure'.format(mode), fig, epoch) 35 | 36 | # text: type: ingredients/recipe 37 | def text_summary(self, mode, epoch, type, text, vocabulary, gt=True, max_length=20): 38 | for i, el in enumerate(text): # text_list 39 | if not gt: # we are printing a sample 40 | idx = el.nonzero().squeeze() + 1 41 | else: 42 | idx = el # we are printing the ground truth 43 | 44 | words_list = itemgetter(*idx)(vocabulary) 45 | 46 | if len(words_list) <= max_length: 47 | self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'), ', '.join(filter(lambda x: x != '', words_list)), epoch) 48 | else: 49 | self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'), 'Number of sampled ingredients is too big: {}'.format(len(words_list)), epoch) 50 | 51 | # losses: dictionary of error labels and values 52 | def scalar_summary(self, mode, epoch, **args): 53 | for k, v in args.items(): 54 | self.writer.add_scalar('{}/{}'.format(mode, k), v, epoch) 55 | 56 | self.writer.export_scalars_to_json("{}/tensorboard_all_scalars.json".format(self.checkpoints_dir)) 57 | 58 | def histo_summary(self, model, step): 59 | """Log a histogram of the tensor of values.""" 60 | 61 | for name, param in model.named_parameters(): 62 | self.writer.add_histogram(name, param, step) 63 | 64 | def close(self): 65 | self.writer.close() 66 | -------------------------------------------------------------------------------- /datasets/authorshipData/authorshipData.part1.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alelab-upenn/graph-neural-networks/a84a39fabad5378bdcbaad20b5dbcff14b4eebcd/datasets/authorshipData/authorshipData.part1.rar -------------------------------------------------------------------------------- /datasets/authorshipData/authorshipData.part2.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alelab-upenn/graph-neural-networks/a84a39fabad5378bdcbaad20b5dbcff14b4eebcd/datasets/authorshipData/authorshipData.part2.rar -------------------------------------------------------------------------------- /datasets/authorshipData/authorshipData.part3.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alelab-upenn/graph-neural-networks/a84a39fabad5378bdcbaad20b5dbcff14b4eebcd/datasets/authorshipData/authorshipData.part3.rar -------------------------------------------------------------------------------- /datasets/epidemics/edge_list.txt: -------------------------------------------------------------------------------- 1 | 1 55 2 | 1 205 3 | 1 272 4 | 1 494 5 | 1 779 6 | 1 894 7 | 3 1 8 | 3 28 9 | 3 147 10 | 3 272 11 | 3 407 12 | 3 674 13 | 3 884 14 | 27 63 15 | 27 173 16 | 28 202 17 | 28 327 18 | 28 353 19 | 28 407 20 | 28 429 21 | 28 441 22 | 28 492 23 | 28 545 24 | 32 440 25 | 32 624 26 | 32 797 27 | 32 920 28 | 34 151 29 | 34 277 30 | 34 502 31 | 34 866 32 | 45 48 33 | 45 79 34 | 45 335 35 | 45 496 36 | 45 601 37 | 45 674 38 | 45 765 39 | 46 117 40 | 46 196 41 | 46 257 42 | 46 268 43 | 48 45 44 | 48 79 45 | 48 496 46 | 55 1 47 | 55 170 48 | 55 205 49 | 55 252 50 | 55 272 51 | 55 779 52 | 55 883 53 | 55 894 54 | 61 797 55 | 63 27 56 | 63 125 57 | 63 173 58 | 70 101 59 | 70 132 60 | 70 240 61 | 70 425 62 | 70 447 63 | 72 407 64 | 72 674 65 | 72 857 66 | 79 45 67 | 79 48 68 | 79 335 69 | 79 496 70 | 79 601 71 | 79 674 72 | 79 765 73 | 80 120 74 | 80 285 75 | 80 468 76 | 80 601 77 | 85 190 78 | 85 213 79 | 85 214 80 | 85 335 81 | 85 603 82 | 85 605 83 | 85 765 84 | 92 468 85 | 92 845 86 | 101 70 87 | 101 119 88 | 101 122 89 | 101 132 90 | 101 240 91 | 101 343 92 | 101 364 93 | 101 425 94 | 101 447 95 | 117 1 96 | 117 46 97 | 117 196 98 | 117 205 99 | 117 252 100 | 117 257 101 | 117 265 102 | 117 268 103 | 117 272 104 | 117 364 105 | 117 407 106 | 117 465 107 | 117 494 108 | 117 587 109 | 117 883 110 | 117 894 111 | 119 101 112 | 119 122 113 | 119 132 114 | 119 240 115 | 119 425 116 | 119 447 117 | 120 80 118 | 120 285 119 | 120 488 120 | 122 101 121 | 122 119 122 | 122 255 123 | 122 425 124 | 122 447 125 | 124 471 126 | 124 970 127 | 125 92 128 | 125 248 129 | 125 325 130 | 125 468 131 | 125 491 132 | 125 622 133 | 125 624 134 | 125 797 135 | 125 960 136 | 132 70 137 | 132 101 138 | 132 119 139 | 132 122 140 | 132 240 141 | 132 425 142 | 132 447 143 | 134 388 144 | 134 492 145 | 134 496 146 | 147 1 147 | 147 3 148 | 147 28 149 | 147 72 150 | 147 184 151 | 147 272 152 | 147 353 153 | 147 407 154 | 147 674 155 | 147 884 156 | 151 34 157 | 151 38 158 | 151 201 159 | 151 277 160 | 151 452 161 | 151 502 162 | 151 634 163 | 151 642 164 | 151 691 165 | 151 694 166 | 151 753 167 | 151 866 168 | 151 869 169 | 156 694 170 | 159 38 171 | 159 642 172 | 165 498 173 | 170 1 174 | 170 55 175 | 170 205 176 | 170 779 177 | 170 883 178 | 170 894 179 | 173 27 180 | 173 63 181 | 173 1332 182 | 184 327 183 | 184 429 184 | 184 441 185 | 190 85 186 | 190 213 187 | 190 214 188 | 190 272 189 | 196 46 190 | 196 117 191 | 196 252 192 | 196 268 193 | 196 364 194 | 200 480 195 | 200 845 196 | 201 34 197 | 201 38 198 | 201 245 199 | 201 502 200 | 201 642 201 | 201 691 202 | 201 753 203 | 201 869 204 | 202 28 205 | 202 170 206 | 202 407 207 | 202 545 208 | 202 883 209 | 205 1 210 | 205 55 211 | 205 117 212 | 205 170 213 | 205 265 214 | 205 272 215 | 205 494 216 | 205 587 217 | 205 691 218 | 205 779 219 | 205 883 220 | 205 894 221 | 211 242 222 | 211 468 223 | 211 845 224 | 213 190 225 | 213 214 226 | 214 190 227 | 214 213 228 | 214 603 229 | 219 605 230 | 222 248 231 | 222 343 232 | 222 867 233 | 232 492 234 | 232 798 235 | 240 70 236 | 240 101 237 | 240 119 238 | 240 132 239 | 240 327 240 | 240 425 241 | 240 447 242 | 242 211 243 | 242 468 244 | 242 845 245 | 245 325 246 | 245 440 247 | 245 634 248 | 245 691 249 | 245 753 250 | 245 869 251 | 245 959 252 | 248 222 253 | 248 564 254 | 248 694 255 | 252 55 256 | 252 117 257 | 252 196 258 | 252 265 259 | 252 272 260 | 252 364 261 | 252 779 262 | 255 275 263 | 257 46 264 | 257 268 265 | 257 364 266 | 265 117 267 | 265 170 268 | 265 196 269 | 265 205 270 | 265 252 271 | 265 494 272 | 265 587 273 | 265 883 274 | 265 894 275 | 268 46 276 | 268 117 277 | 268 196 278 | 268 257 279 | 268 364 280 | 272 1 281 | 272 55 282 | 272 170 283 | 272 190 284 | 272 205 285 | 272 214 286 | 272 441 287 | 272 494 288 | 272 587 289 | 272 779 290 | 272 883 291 | 275 255 292 | 275 312 293 | 275 612 294 | 277 34 295 | 277 151 296 | 277 502 297 | 277 634 298 | 277 691 299 | 277 866 300 | 285 80 301 | 285 120 302 | 285 232 303 | 285 488 304 | 285 492 305 | 312 275 306 | 312 612 307 | 325 125 308 | 325 245 309 | 325 622 310 | 325 624 311 | 325 769 312 | 325 797 313 | 325 959 314 | 327 27 315 | 327 28 316 | 327 119 317 | 327 184 318 | 327 353 319 | 327 407 320 | 327 429 321 | 327 441 322 | 335 45 323 | 335 79 324 | 335 765 325 | 343 101 326 | 343 222 327 | 343 867 328 | 353 28 329 | 353 46 330 | 353 122 331 | 353 407 332 | 353 425 333 | 364 117 334 | 364 196 335 | 364 252 336 | 364 257 337 | 364 268 338 | 366 974 339 | 388 45 340 | 388 79 341 | 388 134 342 | 388 335 343 | 388 492 344 | 388 496 345 | 388 603 346 | 388 765 347 | 407 3 348 | 407 28 349 | 407 72 350 | 407 147 351 | 407 184 352 | 407 202 353 | 407 327 354 | 407 353 355 | 407 429 356 | 407 441 357 | 407 545 358 | 407 674 359 | 407 884 360 | 425 70 361 | 425 101 362 | 425 119 363 | 425 122 364 | 425 132 365 | 425 240 366 | 425 343 367 | 425 353 368 | 425 441 369 | 425 447 370 | 429 28 371 | 429 119 372 | 429 184 373 | 429 327 374 | 429 353 375 | 429 407 376 | 429 441 377 | 440 32 378 | 440 245 379 | 440 605 380 | 440 797 381 | 440 920 382 | 441 28 383 | 441 184 384 | 441 272 385 | 441 327 386 | 441 407 387 | 441 429 388 | 441 447 389 | 447 70 390 | 447 101 391 | 447 119 392 | 447 122 393 | 447 132 394 | 447 240 395 | 447 425 396 | 447 441 397 | 452 85 398 | 452 634 399 | 452 691 400 | 452 869 401 | 452 1332 402 | 465 486 403 | 465 531 404 | 465 857 405 | 468 80 406 | 468 92 407 | 468 125 408 | 468 211 409 | 468 242 410 | 468 601 411 | 468 845 412 | 471 124 413 | 471 970 414 | 480 200 415 | 480 771 416 | 486 465 417 | 486 531 418 | 488 48 419 | 488 120 420 | 488 285 421 | 491 219 422 | 491 520 423 | 491 576 424 | 491 605 425 | 492 28 426 | 492 120 427 | 492 134 428 | 492 232 429 | 492 285 430 | 492 388 431 | 492 447 432 | 492 488 433 | 492 496 434 | 494 1 435 | 494 117 436 | 494 205 437 | 494 265 438 | 494 272 439 | 494 587 440 | 494 883 441 | 494 894 442 | 496 45 443 | 496 48 444 | 496 79 445 | 496 85 446 | 496 134 447 | 496 388 448 | 496 492 449 | 496 603 450 | 498 165 451 | 498 857 452 | 502 151 453 | 502 277 454 | 502 691 455 | 502 866 456 | 502 869 457 | 520 219 458 | 520 491 459 | 520 576 460 | 520 605 461 | 531 465 462 | 531 486 463 | 531 691 464 | 545 28 465 | 545 202 466 | 545 407 467 | 564 577 468 | 564 694 469 | 576 219 470 | 576 491 471 | 576 520 472 | 576 605 473 | 577 564 474 | 577 694 475 | 587 117 476 | 587 265 477 | 587 272 478 | 587 494 479 | 587 883 480 | 601 45 481 | 601 79 482 | 601 80 483 | 601 134 484 | 601 200 485 | 601 388 486 | 601 603 487 | 603 85 488 | 603 214 489 | 603 335 490 | 603 388 491 | 603 496 492 | 603 765 493 | 605 85 494 | 605 201 495 | 605 219 496 | 605 520 497 | 605 576 498 | 612 275 499 | 612 312 500 | 622 125 501 | 622 452 502 | 622 624 503 | 622 769 504 | 622 797 505 | 622 798 506 | 622 959 507 | 622 960 508 | 624 125 509 | 624 245 510 | 624 325 511 | 624 452 512 | 624 491 513 | 624 622 514 | 624 769 515 | 624 797 516 | 624 798 517 | 624 960 518 | 634 691 519 | 634 869 520 | 634 1332 521 | 642 151 522 | 642 201 523 | 642 452 524 | 642 634 525 | 642 691 526 | 642 866 527 | 642 869 528 | 674 3 529 | 674 45 530 | 674 72 531 | 674 79 532 | 674 147 533 | 674 407 534 | 691 205 535 | 691 245 536 | 691 452 537 | 691 502 538 | 691 531 539 | 691 634 540 | 691 642 541 | 691 869 542 | 691 883 543 | 691 1332 544 | 694 564 545 | 694 577 546 | 753 201 547 | 753 245 548 | 765 45 549 | 765 79 550 | 765 335 551 | 769 125 552 | 769 245 553 | 769 325 554 | 769 452 555 | 769 622 556 | 769 624 557 | 769 797 558 | 769 959 559 | 769 960 560 | 771 200 561 | 771 480 562 | 779 1 563 | 779 48 564 | 779 55 565 | 779 170 566 | 779 205 567 | 779 252 568 | 779 272 569 | 779 883 570 | 779 894 571 | 797 125 572 | 797 325 573 | 797 622 574 | 797 624 575 | 797 959 576 | 798 605 577 | 798 960 578 | 845 80 579 | 845 92 580 | 845 200 581 | 845 211 582 | 845 242 583 | 845 468 584 | 857 55 585 | 857 72 586 | 857 465 587 | 857 498 588 | 857 779 589 | 866 34 590 | 866 151 591 | 866 452 592 | 866 502 593 | 866 634 594 | 866 642 595 | 866 691 596 | 866 753 597 | 866 869 598 | 867 222 599 | 867 343 600 | 869 245 601 | 869 634 602 | 869 691 603 | 869 1332 604 | 883 1 605 | 883 55 606 | 883 117 607 | 883 170 608 | 883 205 609 | 883 252 610 | 883 265 611 | 883 272 612 | 883 494 613 | 883 587 614 | 883 779 615 | 883 894 616 | 883 1401 617 | 884 3 618 | 884 272 619 | 894 1 620 | 894 55 621 | 894 117 622 | 894 170 623 | 894 205 624 | 894 265 625 | 894 272 626 | 894 494 627 | 894 587 628 | 894 779 629 | 894 883 630 | 920 32 631 | 920 440 632 | 920 797 633 | 959 245 634 | 959 325 635 | 959 622 636 | 959 624 637 | 959 769 638 | 959 797 639 | 960 622 640 | 960 798 641 | 970 124 642 | 970 471 643 | 974 366 644 | 974 1485 645 | 1228 642 646 | 1228 1401 647 | 1228 1519 648 | 1332 452 649 | 1332 634 650 | 1332 691 651 | 1332 869 652 | 1332 1401 653 | 1332 1519 654 | 1401 642 655 | 1401 1228 656 | 1401 1332 657 | 1401 1519 658 | 1485 974 659 | 1519 642 660 | 1519 1228 661 | 1519 1332 662 | 1519 1401 663 | 1519 1594 664 | 1519 1828 665 | 1594 1519 666 | 1594 1828 667 | 1828 1519 668 | 1828 1594 669 | -------------------------------------------------------------------------------- /datasets/facebookEgo/facebookEgo234.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alelab-upenn/graph-neural-networks/a84a39fabad5378bdcbaad20b5dbcff14b4eebcd/datasets/facebookEgo/facebookEgo234.pkl -------------------------------------------------------------------------------- /examples/authorshipGNN.py: -------------------------------------------------------------------------------- 1 | # 2019/04/08~ 2 | # Fernando Gama, fgama@seas.upenn.edu 3 | # Luana Ruiz, rubruiz@seas.upenn.edu 4 | 5 | # Test the authorship attribution dataset. The dataset consists on word 6 | # adjacency networks (graph support) and word frequency count of short texts 7 | # (graph signal) for a pool of authors of the 19th century. The word adjacency 8 | # networks are graphs whose nodes are function words and whose edges are 9 | # measures of co-occurrence between these words. These graphs are different 10 | # for each author, but it takes long texts to produce them. In this problem, 11 | # we will use WANs already created, and try to attribute authorship of short 12 | # texts; we count the number of function words present in each short text, 13 | # assign them to the corresponding nodes of the WAN (i.e. graph signals), and 14 | # use those to classify texts. The classification is binary: each texts either 15 | # belongs to the author whose WAN we are using or does not. 16 | 17 | # Outputs: 18 | # - Text file with all the hyperparameters selected for the run and the 19 | # corresponding results (hyperparameters.txt) 20 | # - Pickle file with the random seeds of both torch and numpy for accurate 21 | # reproduction of results (randomSeedUsed.pkl) 22 | # - The parameters of the trained models, for both the Best and the Last 23 | # instance of each model (savedModels/) 24 | # - The figures of loss and evaluation through the training iterations for 25 | # each model (figs/ and trainVars/) 26 | # - If selected, logs in tensorboardX certain useful training variables 27 | 28 | #%%################################################################## 29 | # # 30 | # IMPORTING # 31 | # # 32 | ##################################################################### 33 | 34 | #\\\ Standard libraries: 35 | import os 36 | import numpy as np 37 | import matplotlib 38 | matplotlib.rcParams['text.usetex'] = True 39 | matplotlib.rcParams['font.family'] = 'serif' 40 | matplotlib.rcParams['text.latex.preamble']=[r'\usepackage{amsmath}'] 41 | import matplotlib.pyplot as plt 42 | import pickle 43 | import datetime 44 | from copy import deepcopy 45 | 46 | import torch; torch.set_default_dtype(torch.float64) 47 | import torch.nn as nn 48 | import torch.optim as optim 49 | 50 | #\\\ Own libraries: 51 | import alegnn.utils.graphTools as graphTools 52 | import alegnn.utils.dataTools 53 | import alegnn.utils.graphML as gml 54 | import alegnn.modules.architectures as archit 55 | import alegnn.modules.model as model 56 | import alegnn.modules.training as training 57 | import alegnn.modules.evaluation as evaluation 58 | import alegnn.modules.loss as loss 59 | 60 | #\\\ Separate functions: 61 | from alegnn.utils.miscTools import writeVarValues 62 | from alegnn.utils.miscTools import saveSeed 63 | 64 | # Start measuring time 65 | startRunTime = datetime.datetime.now() 66 | 67 | #%%################################################################## 68 | # # 69 | # SETTING PARAMETERS # 70 | # # 71 | ##################################################################### 72 | 73 | authorName = 'austen' 74 | # jacob 'abbott', robert louis 'stevenson', louisa may 'alcott', 75 | # horatio 'alger', james 'allen', jane 'austen', 76 | # emily 'bronte', james 'cooper', charles 'dickens', 77 | # hamlin 'garland', nathaniel 'hawthorne', henry 'james', 78 | # herman 'melville', 'page', henry 'thoreau', 79 | # mark 'twain', arthur conan 'doyle', washington 'irving', 80 | # edgar allan 'poe', sarah orne 'jewett', edith 'wharton' 81 | 82 | thisFilename = 'authorshipGNN' # This is the general name of all related files 83 | 84 | saveDirRoot = 'experiments' # In this case, relative location 85 | saveDir = os.path.join(saveDirRoot, thisFilename + '-' + authorName) 86 | # Dir where to save all the results from each run 87 | dataPath = os.path.join('datasets','authorshipData','authorshipData.mat') 88 | 89 | #\\\ Create .txt to store the values of the setting parameters for easier 90 | # reference when running multiple experiments 91 | today = datetime.datetime.now().strftime("%Y%m%d%H%M%S") 92 | # Append date and time of the run to the directory, to avoid several runs of 93 | # overwritting each other. 94 | saveDir = saveDir + '-' + today 95 | # Create directory 96 | if not os.path.exists(saveDir): 97 | os.makedirs(saveDir) 98 | # Create the file where all the (hyper)parameters and results will be saved. 99 | varsFile = os.path.join(saveDir,'hyperparameters.txt') 100 | with open(varsFile, 'w+') as file: 101 | file.write('%s\n\n' % datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S")) 102 | 103 | #\\\ Save seeds for reproducibility 104 | # PyTorch seeds 105 | torchState = torch.get_rng_state() 106 | torchSeed = torch.initial_seed() 107 | # Numpy seeds 108 | numpyState = np.random.RandomState().get_state() 109 | # Collect all random states 110 | randomStates = [] 111 | randomStates.append({}) 112 | randomStates[0]['module'] = 'numpy' 113 | randomStates[0]['state'] = numpyState 114 | randomStates.append({}) 115 | randomStates[1]['module'] = 'torch' 116 | randomStates[1]['state'] = torchState 117 | randomStates[1]['seed'] = torchSeed 118 | # This list and dictionary follows the format to then be loaded, if needed, 119 | # by calling the loadSeed function in Utils.miscTools 120 | saveSeed(randomStates, saveDir) 121 | 122 | ######## 123 | # DATA # 124 | ######## 125 | 126 | useGPU = True # If true, and GPU is available, use it. 127 | 128 | nClasses = 2 # Either authorName or not 129 | ratioTrain = 0.95 # Ratio of training samples 130 | ratioValid = 0.08 # Ratio of validation samples (out of the total training 131 | # samples) 132 | # Final split is: 133 | # nValidation = round(ratioValid * ratioTrain * nTotal) 134 | # nTrain = round((1 - ratioValid) * ratioTrain * nTotal) 135 | # nTest = nTotal - nTrain - nValidation 136 | 137 | nDataSplits = 10 # Number of data realizations 138 | # Obs.: The built graph depends on the split between training, validation and 139 | # testing. Therefore, we will run several of these splits and average across 140 | # them, to obtain some result that is more robust to this split. 141 | 142 | # Every training excerpt has a WAN associated to it. We combine all these WANs 143 | # into a single graph to use as the supporting graph for all samples. This 144 | # combination happens under some extra options: 145 | graphNormalizationType = 'rows' # or 'cols' - Makes all rows add up to 1. 146 | keepIsolatedNodes = False # If True keeps isolated nodes 147 | forceUndirected = True # If True forces the graph to be undirected (symmetrizes) 148 | forceConnected = True # If True removes nodes (from lowest to highest degree) 149 | # until the resulting graph is connected. 150 | 151 | #\\\ Save values: 152 | writeVarValues(varsFile, 153 | {'authorName': authorName, 154 | 'nClasses': nClasses, 155 | 'ratioTrain': ratioTrain, 156 | 'ratioValid': ratioValid, 157 | 'nDataSplits': nDataSplits, 158 | 'graphNormalizationType': graphNormalizationType, 159 | 'keepIsolatedNodes': keepIsolatedNodes, 160 | 'forceUndirected': forceUndirected, 161 | 'forceConnected': forceConnected, 162 | 'useGPU': useGPU}) 163 | 164 | ############ 165 | # TRAINING # 166 | ############ 167 | 168 | #\\\ Individual model training options 169 | optimAlg = 'ADAM' # Options: 'SGD', 'ADAM', 'RMSprop' 170 | learningRate = 0.005 # In all options 171 | beta1 = 0.9 # beta1 if 'ADAM', alpha if 'RMSprop' 172 | beta2 = 0.999 # ADAM option only 173 | 174 | #\\\ Loss function choice 175 | lossFunction = nn.CrossEntropyLoss # This applies a softmax before feeding 176 | # it into the NLL, so we don't have to apply the softmax ourselves. 177 | 178 | #\\\ Overall training options 179 | nEpochs = 25 # Number of epochs 180 | batchSize = 20 # Batch size 181 | doLearningRateDecay = False # Learning rate decay 182 | learningRateDecayRate = 0.9 # Rate 183 | learningRateDecayPeriod = 1 # How many epochs after which update the lr 184 | validationInterval = 5 # How many training steps to do the validation 185 | 186 | #\\\ Save values 187 | writeVarValues(varsFile, 188 | {'optimAlg': optimAlg, 189 | 'learningRate': learningRate, 190 | 'beta1': beta1, 191 | 'lossFunction': lossFunction, 192 | 'nEpochs': nEpochs, 193 | 'batchSize': batchSize, 194 | 'doLearningRateDecay': doLearningRateDecay, 195 | 'learningRateDecayRate': learningRateDecayRate, 196 | 'learningRateDecayPeriod': learningRateDecayPeriod, 197 | 'validationInterval': validationInterval}) 198 | 199 | ################# 200 | # ARCHITECTURES # 201 | ################# 202 | 203 | # Here, there will be three one-layer architectures 204 | 205 | doLocalMax = True 206 | doLocalMed = True 207 | doPointwse = True 208 | 209 | # In this section, we determine the (hyper)parameters of models that we are 210 | # going to train. This only sets the parameters. The architectures need to be 211 | # created later below. Do not forget to add the name of the architecture 212 | # to modelList. 213 | 214 | # If the model dictionary is called 'model' + name, then it can be 215 | # picked up immediately later on, and there's no need to recode anything after 216 | # the section 'Setup' (except for setting the number of nodes in the 'N' 217 | # variable after it has been coded). 218 | 219 | # The name of the keys in the model dictionary have to be the same 220 | # as the names of the variables in the architecture call, because they will 221 | # be called by unpacking the dictionary. 222 | 223 | modelList = [] 224 | 225 | #\\\\\\\\\\\\\\\\\\\\\ 226 | #\\\ SELECTION GNN \\\ 227 | #\\\\\\\\\\\\\\\\\\\\\ 228 | 229 | # Hyperparameters to be shared by all architectures 230 | 231 | modelActvFn = {} 232 | 233 | modelActvFn['name'] = 'ActvFn' # To be modified later on depending on the 234 | # specific ordering selected 235 | modelActvFn['device'] = 'cuda:0' if (useGPU and torch.cuda.is_available()) \ 236 | else 'cpu' 237 | 238 | #\\\ ARCHITECTURE 239 | 240 | # Select architectural nn.Module to use 241 | modelActvFn['archit'] = archit.LocalActivationGNN 242 | # Graph convolutional layers 243 | modelActvFn['dimNodeSignals'] = [1, 32] # Number of features per layer 244 | modelActvFn['nFilterTaps'] = [5] # Number of filter taps 245 | modelActvFn['bias'] = True # Include bias 246 | # Nonlinearity 247 | modelActvFn['nonlinearity'] = gml.NoActivation 248 | modelActvFn['kHopActivation'] = [2] 249 | # Pooling 250 | modelActvFn['nSelectedNodes'] = None # To be determined later 251 | modelActvFn['poolingFunction'] = gml.NoPool # Summarizing function 252 | modelActvFn['poolingSize'] = [1] # Summarizing neighborhoods 253 | # Readout layer 254 | modelActvFn['dimLayersMLP'] = [nClasses] 255 | # Graph Structure 256 | modelActvFn['GSO'] = None # To be determined later on, based on data 257 | modelActvFn['order'] = None # Not used because there is no pooling 258 | 259 | #\\\ TRAINER 260 | 261 | modelActvFn['trainer'] = training.Trainer 262 | 263 | #\\\ EVALUATOR 264 | 265 | modelActvFn['evaluator'] = evaluation.evaluate 266 | 267 | #\\\\\\\\\\\\ 268 | #\\\ MODEL 1: Max Local Activation 269 | #\\\\\\\\\\\\ 270 | 271 | if doLocalMax: 272 | 273 | #\\\ Basic parameters for all the Aggregation GNN architectures 274 | 275 | modelActvFnMax = deepcopy(modelActvFn) 276 | 277 | modelActvFnMax['name'] += 'Max' 278 | # Nonlinearity 279 | modelActvFnMax['nonlinearity'] = gml.MaxLocalActivation 280 | 281 | #\\\ Save Values: 282 | writeVarValues(varsFile, modelActvFnMax) 283 | modelList += [modelActvFnMax['name']] 284 | 285 | #\\\\\\\\\\\\ 286 | #\\\ MODEL 2: Median Local Activation 287 | #\\\\\\\\\\\\ 288 | 289 | if doLocalMed: 290 | 291 | #\\\ Basic parameters for all the Aggregation GNN architectures 292 | 293 | modelActvFnMed = deepcopy(modelActvFn) 294 | 295 | modelActvFnMed['name'] += 'Med' 296 | # Nonlinearity 297 | modelActvFnMed['nonlinearity'] = gml.MedianLocalActivation 298 | 299 | #\\\ Save Values: 300 | writeVarValues(varsFile, modelActvFnMed) 301 | modelList += [modelActvFnMed['name']] 302 | 303 | #\\\\\\\\\\\\ 304 | #\\\ MODEL 3: ReLU nonlinearity 305 | #\\\\\\\\\\\\ 306 | 307 | if doPointwse: 308 | 309 | #\\\ Basic parameters for all the Aggregation GNN architectures 310 | 311 | modelActvFnPnt = deepcopy(modelActvFn) 312 | 313 | modelActvFnPnt['name'] += 'Pnt' 314 | # Change the architecture 315 | modelActvFnPnt['archit'] = archit.SelectionGNN 316 | # Nonlinearity 317 | modelActvFnPnt['nonlinearity'] = nn.ReLU 318 | # Get rid of the parameter kHopActivation that we do not need anymore 319 | modelActvFnPnt.pop('kHopActivation') 320 | 321 | #\\\ Save Values: 322 | writeVarValues(varsFile, modelActvFnPnt) 323 | modelList += [modelActvFnPnt['name']] 324 | 325 | ########### 326 | # LOGGING # 327 | ########### 328 | 329 | # Options: 330 | doPrint = True # Decide whether to print stuff while running 331 | doLogging = False # Log into tensorboard 332 | doSaveVars = True # Save (pickle) useful variables 333 | doFigs = True # Plot some figures (this only works if doSaveVars is True) 334 | # Parameters: 335 | printInterval = 5 # After how many training steps, print the partial results 336 | # 0 means to never print partial results while training 337 | xAxisMultiplierTrain = 10 # How many training steps in between those shown in 338 | # the plot, i.e., one training step every xAxisMultiplierTrain is shown. 339 | xAxisMultiplierValid = 2 # How many validation steps in between those shown, 340 | # same as above. 341 | figSize = 5 # Overall size of the figure that contains the plot 342 | lineWidth = 2 # Width of the plot lines 343 | markerShape = 'o' # Shape of the markers 344 | markerSize = 3 # Size of the markers 345 | 346 | #\\\ Save values: 347 | writeVarValues(varsFile, 348 | {'doPrint': doPrint, 349 | 'doLogging': doLogging, 350 | 'doSaveVars': doSaveVars, 351 | 'doFigs': doFigs, 352 | 'saveDir': saveDir, 353 | 'printInterval': printInterval, 354 | 'figSize': figSize, 355 | 'lineWidth': lineWidth, 356 | 'markerShape': markerShape, 357 | 'markerSize': markerSize}) 358 | 359 | #%%################################################################## 360 | # # 361 | # SETUP # 362 | # # 363 | ##################################################################### 364 | 365 | #\\\ Determine processing unit: 366 | if useGPU and torch.cuda.is_available(): 367 | torch.cuda.empty_cache() 368 | 369 | #\\\ Notify of processing units 370 | if doPrint: 371 | print("Selected devices:") 372 | for thisModel in modelList: 373 | modelDict = eval('model' + thisModel) 374 | print("\t%s: %s" % (thisModel, modelDict['device'])) 375 | 376 | #\\\ Logging options 377 | if doLogging: 378 | # If logging is on, load the tensorboard visualizer and initialize it 379 | from alegnn.utils.visualTools import Visualizer 380 | logsTB = os.path.join(saveDir, 'logsTB') 381 | logger = Visualizer(logsTB, name='visualResults') 382 | 383 | #\\\ Save variables during evaluation. 384 | # We will save all the evaluations obtained for each of the trained models. 385 | # It basically is a dictionary, containing a list. The key of the 386 | # dictionary determines the model, then the first list index determines 387 | # which split realization. Then, this will be converted to numpy to compute 388 | # mean and standard deviation (across the split dimension). 389 | costBest = {} # Cost for the best model (Evaluation cost: Error rate) 390 | costLast = {} # Cost for the last model 391 | for thisModel in modelList: # Create an element for each split realization, 392 | costBest[thisModel] = [None] * nDataSplits 393 | costLast[thisModel] = [None] * nDataSplits 394 | 395 | if doFigs: 396 | #\\\ SAVE SPACE: 397 | # Create the variables to save all the realizations. This is, again, a 398 | # dictionary, where each key represents a model, and each model is a list 399 | # for each data split. 400 | # Each data split, in this case, is not a scalar, but a vector of 401 | # length the number of training steps (or of validation steps) 402 | lossTrain = {} 403 | costTrain = {} 404 | lossValid = {} 405 | costValid = {} 406 | # Initialize the splits dimension 407 | for thisModel in modelList: 408 | lossTrain[thisModel] = [None] * nDataSplits 409 | costTrain[thisModel] = [None] * nDataSplits 410 | lossValid[thisModel] = [None] * nDataSplits 411 | costValid[thisModel] = [None] * nDataSplits 412 | 413 | 414 | #################### 415 | # TRAINING OPTIONS # 416 | #################### 417 | 418 | # Training phase. It has a lot of options that are input through a 419 | # dictionary of arguments. 420 | # The value of these options was decided above with the rest of the parameters. 421 | # This just creates a dictionary necessary to pass to the train function. 422 | 423 | trainingOptions = {} 424 | 425 | if doLogging: 426 | trainingOptions['logger'] = logger 427 | if doSaveVars: 428 | trainingOptions['saveDir'] = saveDir 429 | if doPrint: 430 | trainingOptions['printInterval'] = printInterval 431 | if doLearningRateDecay: 432 | trainingOptions['learningRateDecayRate'] = learningRateDecayRate 433 | trainingOptions['learningRateDecayPeriod'] = learningRateDecayPeriod 434 | trainingOptions['validationInterval'] = validationInterval 435 | 436 | # And in case each model has specific training options, then we create a 437 | # separate dictionary per model. 438 | 439 | trainingOptsPerModel= {} 440 | 441 | #%%################################################################## 442 | # # 443 | # DATA SPLIT REALIZATION # 444 | # # 445 | ##################################################################### 446 | 447 | # Start generating a new data split for each of the number of data splits that 448 | # we previously specified 449 | 450 | for split in range(nDataSplits): 451 | 452 | #%%################################################################## 453 | # # 454 | # DATA HANDLING # 455 | # # 456 | ##################################################################### 457 | 458 | ############ 459 | # DATASETS # 460 | ############ 461 | 462 | if doPrint: 463 | print("\nLoading data", end = '') 464 | if nDataSplits > 1: 465 | print(" for split %d" % (split+1), end = '') 466 | print("...", end = ' ', flush = True) 467 | 468 | # Load the data, which will give a specific split 469 | data = alegnn.utils.dataTools.Authorship(authorName, 470 | ratioTrain, 471 | ratioValid, 472 | dataPath, 473 | graphNormalizationType, 474 | keepIsolatedNodes, 475 | forceUndirected, 476 | forceConnected) 477 | 478 | if doPrint: 479 | print("OK") 480 | 481 | ######### 482 | # GRAPH # 483 | ######### 484 | 485 | if doPrint: 486 | print("Setting up the graph...", end = ' ', flush = True) 487 | 488 | # Create graph 489 | adjacencyMatrix = data.getGraph() 490 | G = graphTools.Graph('adjacency', adjacencyMatrix.shape[0], 491 | {'adjacencyMatrix': adjacencyMatrix}) 492 | G.computeGFT() # Compute the GFT of the stored GSO 493 | 494 | # And re-update the number of nodes for changes in the graph (due to 495 | # enforced connectedness, for instance) 496 | nNodes = G.N 497 | 498 | # Once data is completely formatted and in appropriate fashion, change its 499 | # type to torch and move it to the appropriate device 500 | data.astype(torch.float64) 501 | # And the corresponding feature dimension that we will need to use 502 | data.expandDims() # Data are just graph signals, but the architectures 503 | # require that the input signals are of the form B x F x N, so we need 504 | # to expand the middle dimensions to convert them from B x N to 505 | # B x 1 x N 506 | 507 | if doPrint: 508 | print("OK") 509 | 510 | #%%################################################################## 511 | # # 512 | # MODELS INITIALIZATION # 513 | # # 514 | ##################################################################### 515 | 516 | # This is the dictionary where we store the models (in a model.Model 517 | # class, that is then passed to training). 518 | modelsGNN = {} 519 | 520 | # If a new model is to be created, it should be called for here. 521 | 522 | if doPrint: 523 | print("Model initialization...", flush = True) 524 | 525 | for thisModel in modelList: 526 | 527 | # Get the corresponding parameter dictionary 528 | modelDict = deepcopy(eval('model' + thisModel)) 529 | # and training options 530 | trainingOptsPerModel[thisModel] = deepcopy(trainingOptions) 531 | 532 | # Now, this dictionary has all the hyperparameters that we need to pass 533 | # to the architecture function, but it also has other keys that belong 534 | # to the more general model (like 'name' or 'device'), so we need to 535 | # extract them and save them in seperate variables for future use. 536 | thisName = modelDict.pop('name') 537 | callArchit = modelDict.pop('archit') 538 | thisDevice = modelDict.pop('device') 539 | thisTrainer = modelDict.pop('trainer') 540 | thisEvaluator = modelDict.pop('evaluator') 541 | 542 | # If more than one graph or data realization is going to be carried out, 543 | # we are going to store all of those models separately, so that any of 544 | # them can be brought back and studied in detail. 545 | if nDataSplits > 1: 546 | thisName += 'G%02d' % split 547 | 548 | if doPrint: 549 | print("\tInitializing %s..." % thisName, 550 | end = ' ',flush = True) 551 | 552 | ############## 553 | # PARAMETERS # 554 | ############## 555 | 556 | #\\\ Optimizer options 557 | # (If different from the default ones, change here.) 558 | thisOptimAlg = optimAlg 559 | thisLearningRate = learningRate 560 | thisBeta1 = beta1 561 | thisBeta2 = beta2 562 | 563 | #\\\ Ordering 564 | S = G.S.copy()/np.max(np.real(G.E)) 565 | # Do not forget to add the GSO to the input parameters of the archit 566 | modelDict['GSO'] = S 567 | # Add the number of nodes for the no-pooling part 568 | modelDict['nSelectedNodes'] = [nNodes] 569 | 570 | ################ 571 | # ARCHITECTURE # 572 | ################ 573 | 574 | thisArchit = callArchit(**modelDict) 575 | 576 | ############# 577 | # OPTIMIZER # 578 | ############# 579 | 580 | if thisOptimAlg == 'ADAM': 581 | thisOptim = optim.Adam(thisArchit.parameters(), 582 | lr = learningRate, 583 | betas = (beta1, beta2)) 584 | elif thisOptimAlg == 'SGD': 585 | thisOptim = optim.SGD(thisArchit.parameters(), 586 | lr = learningRate) 587 | elif thisOptimAlg == 'RMSprop': 588 | thisOptim = optim.RMSprop(thisArchit.parameters(), 589 | lr = learningRate, alpha = beta1) 590 | 591 | ######## 592 | # LOSS # 593 | ######## 594 | 595 | # Initialize the loss function 596 | thisLossFunction = loss.adaptExtraDimensionLoss(lossFunction) 597 | 598 | ######### 599 | # MODEL # 600 | ######### 601 | 602 | # Create the model 603 | modelCreated = model.Model(thisArchit, 604 | thisLossFunction, 605 | thisOptim, 606 | thisTrainer, 607 | thisEvaluator, 608 | thisDevice, 609 | thisName, 610 | saveDir) 611 | 612 | # Store it 613 | modelsGNN[thisName] = modelCreated 614 | 615 | # Write the main hyperparameters 616 | writeVarValues(varsFile, 617 | {'name': thisName, 618 | 'thisOptimizationAlgorithm': thisOptimAlg, 619 | 'thisTrainer': thisTrainer, 620 | 'thisEvaluator': thisEvaluator, 621 | 'thisLearningRate': thisLearningRate, 622 | 'thisBeta1': thisBeta1, 623 | 'thisBeta2': thisBeta2}) 624 | 625 | if doPrint: 626 | print("OK") 627 | 628 | if doPrint: 629 | print("Model initialization... COMPLETE") 630 | 631 | #%%################################################################## 632 | # # 633 | # TRAINING # 634 | # # 635 | ##################################################################### 636 | 637 | print("") 638 | 639 | # We train each model separately 640 | 641 | for thisModel in modelsGNN.keys(): 642 | 643 | if doPrint: 644 | print("Training model %s..." % thisModel) 645 | 646 | # Remember that modelsGNN.keys() has the split numbering as well as the 647 | # name, while modelList has only the name. So we need to map the 648 | # specific model for this specific split with the actual model name, 649 | # since there are several variables that are indexed by the model name 650 | # (for instance, the training options, or the dictionaries saving the 651 | # loss values) 652 | for m in modelList: 653 | if m in thisModel: 654 | modelName = m 655 | 656 | # Identify the specific split number at training time 657 | if nDataSplits > 1: 658 | trainingOptsPerModel[modelName]['graphNo'] = split 659 | 660 | # Train the model 661 | thisTrainVars = modelsGNN[thisModel].train(data, 662 | nEpochs, 663 | batchSize, 664 | **trainingOptsPerModel[modelName]) 665 | 666 | if doFigs: 667 | # Find which model to save the results (when having multiple 668 | # realizations) 669 | lossTrain[modelName][split] = thisTrainVars['lossTrain'] 670 | costTrain[modelName][split] = thisTrainVars['costTrain'] 671 | lossValid[modelName][split] = thisTrainVars['lossValid'] 672 | costValid[modelName][split] = thisTrainVars['costValid'] 673 | 674 | # And we also need to save 'nBatches' but is the same for all models, so 675 | if doFigs: 676 | nBatches = thisTrainVars['nBatches'] 677 | 678 | #%%################################################################## 679 | # # 680 | # EVALUATION # 681 | # # 682 | ##################################################################### 683 | 684 | # Now that the models have been trained, we evaluate them on the test 685 | # samples. 686 | 687 | # We have two versions of each model to evaluate: the one obtained 688 | # at the best result of the validation step, and the last trained model. 689 | 690 | if doPrint: 691 | print("\nTotal testing error rate", end = '', flush = True) 692 | if nDataSplits > 1: 693 | print(" (Split %02d)" % split, end = '', flush = True) 694 | print(":", flush = True) 695 | 696 | 697 | for thisModel in modelsGNN.keys(): 698 | 699 | # Same as before, separate the model name from the data split 700 | # realization number 701 | for m in modelList: 702 | if m in thisModel: 703 | modelName = m 704 | 705 | # Evaluate the model 706 | thisEvalVars = modelsGNN[thisModel].evaluate(data) 707 | 708 | # Save the outputs 709 | thisCostBest = thisEvalVars['costBest'] 710 | thisCostLast = thisEvalVars['costLast'] 711 | 712 | # Write values 713 | writeVarValues(varsFile, 714 | {'costBest%s' % thisModel: thisCostBest, 715 | 'costLast%s' % thisModel: thisCostLast}) 716 | 717 | # Now check which is the model being trained 718 | costBest[modelName][split] = thisCostBest 719 | costLast[modelName][split] = thisCostLast 720 | # This is so that we can later compute a total accuracy with 721 | # the corresponding error. 722 | 723 | if doPrint: 724 | print("\t%s: %6.2f%% [Best] %6.2f%% [Last]" % (thisModel, 725 | thisCostBest*100, 726 | thisCostLast*100)) 727 | 728 | ############################ 729 | # FINAL EVALUATION RESULTS # 730 | ############################ 731 | 732 | # Now that we have computed the accuracy of all runs, we can obtain a final 733 | # result (mean and standard deviation) 734 | 735 | meanCostBest = {} # Mean across data splits 736 | meanCostLast = {} # Mean across data splits 737 | stdDevCostBest = {} # Standard deviation across data splits 738 | stdDevCostLast = {} # Standard deviation across data splits 739 | 740 | if doPrint: 741 | print("\nFinal evaluations (%02d data splits)" % (nDataSplits)) 742 | 743 | for thisModel in modelList: 744 | # Convert the lists into a nDataSplits vector 745 | costBest[thisModel] = np.array(costBest[thisModel]) 746 | costLast[thisModel] = np.array(costLast[thisModel]) 747 | 748 | # And now compute the statistics (across graphs) 749 | meanCostBest[thisModel] = np.mean(costBest[thisModel]) 750 | meanCostLast[thisModel] = np.mean(costLast[thisModel]) 751 | stdDevCostBest[thisModel] = np.std(costBest[thisModel]) 752 | stdDevCostLast[thisModel] = np.std(costLast[thisModel]) 753 | 754 | # And print it: 755 | if doPrint: 756 | print("\t%s: %6.2f%% (+-%6.2f%%) [Best] %6.2f%% (+-%6.2f%%) [Last]" % ( 757 | thisModel, 758 | meanCostBest[thisModel] * 100, 759 | stdDevCostBest[thisModel] * 100, 760 | meanCostLast[thisModel] * 100, 761 | stdDevCostLast[thisModel] * 100)) 762 | 763 | # Save values 764 | writeVarValues(varsFile, 765 | {'meanCostBest%s' % thisModel: meanCostBest[thisModel], 766 | 'stdDevCostBest%s' % thisModel: stdDevCostBest[thisModel], 767 | 'meanCostLast%s' % thisModel: meanCostLast[thisModel], 768 | 'stdDevCostLast%s' % thisModel : stdDevCostLast[thisModel]}) 769 | 770 | # Save the printed info into the .txt file as well 771 | with open(varsFile, 'a+') as file: 772 | file.write("Final evaluations (%02d data splits)\n" % (nDataSplits)) 773 | for thisModel in modelList: 774 | file.write("\t%s: %6.2f%% (+-%6.2f%%) [Best] %6.2f%% (+-%6.2f%%) [Last]\n" % ( 775 | thisModel, 776 | meanCostBest[thisModel] * 100, 777 | stdDevCostBest[thisModel] * 100, 778 | meanCostLast[thisModel] * 100, 779 | stdDevCostLast[thisModel] * 100)) 780 | file.write('\n') 781 | 782 | #%%################################################################## 783 | # # 784 | # PLOT # 785 | # # 786 | ##################################################################### 787 | 788 | # Finally, we might want to plot several quantities of interest 789 | 790 | if doFigs and doSaveVars: 791 | 792 | ################### 793 | # DATA PROCESSING # 794 | ################### 795 | 796 | #\\\ FIGURES DIRECTORY: 797 | saveDirFigs = os.path.join(saveDir,'figs') 798 | # If it doesn't exist, create it. 799 | if not os.path.exists(saveDirFigs): 800 | os.makedirs(saveDirFigs) 801 | 802 | #\\\ COMPUTE STATISTICS: 803 | # The first thing to do is to transform those into a matrix with all the 804 | # realizations, so create the variables to save that. 805 | meanLossTrain = {} 806 | meanCostTrain = {} 807 | meanLossValid = {} 808 | meanCostValid = {} 809 | stdDevLossTrain = {} 810 | stdDevCostTrain = {} 811 | stdDevLossValid = {} 812 | stdDevCostValid = {} 813 | # Initialize the variables 814 | for thisModel in modelList: 815 | # Transform into np.array 816 | lossTrain[thisModel] = np.array(lossTrain[thisModel]) 817 | costTrain[thisModel] = np.array(costTrain[thisModel]) 818 | lossValid[thisModel] = np.array(lossValid[thisModel]) 819 | costValid[thisModel] = np.array(costValid[thisModel]) 820 | # Each of one of these variables should be of shape 821 | # nDataSplits x numberOfTrainingSteps 822 | # And compute the statistics 823 | meanLossTrain[thisModel] = np.mean(lossTrain[thisModel], axis = 0) 824 | meanCostTrain[thisModel] = np.mean(costTrain[thisModel], axis = 0) 825 | meanLossValid[thisModel] = np.mean(lossValid[thisModel], axis = 0) 826 | meanCostValid[thisModel] = np.mean(costValid[thisModel], axis = 0) 827 | stdDevLossTrain[thisModel] = np.std(lossTrain[thisModel], axis = 0) 828 | stdDevCostTrain[thisModel] = np.std(costTrain[thisModel], axis = 0) 829 | stdDevLossValid[thisModel] = np.std(lossValid[thisModel], axis = 0) 830 | stdDevCostValid[thisModel] = np.std(costValid[thisModel], axis = 0) 831 | 832 | #################### 833 | # SAVE FIGURE DATA # 834 | #################### 835 | 836 | # And finally, we can plot. But before, let's save the variables mean and 837 | # stdDev so, if we don't like the plot, we can re-open them, and re-plot 838 | # them, a piacere. 839 | # Pickle, first: 840 | varsPickle = {} 841 | varsPickle['nEpochs'] = nEpochs 842 | varsPickle['nBatches'] = nBatches 843 | varsPickle['meanLossTrain'] = meanLossTrain 844 | varsPickle['stdDevLossTrain'] = stdDevLossTrain 845 | varsPickle['meanCostTrain'] = meanCostTrain 846 | varsPickle['stdDevCostTrain'] = stdDevCostTrain 847 | varsPickle['meanLossValid'] = meanLossValid 848 | varsPickle['stdDevLossValid'] = stdDevLossValid 849 | varsPickle['meanCostValid'] = meanCostValid 850 | varsPickle['stdDevCostValid'] = stdDevCostValid 851 | with open(os.path.join(saveDirFigs,'figVars.pkl'), 'wb') as figVarsFile: 852 | pickle.dump(varsPickle, figVarsFile) 853 | 854 | ######## 855 | # PLOT # 856 | ######## 857 | 858 | # Compute the x-axis 859 | xTrain = np.arange(0, nEpochs * nBatches, xAxisMultiplierTrain) 860 | xValid = np.arange(0, nEpochs * nBatches, \ 861 | validationInterval*xAxisMultiplierValid) 862 | 863 | # If we do not want to plot all the elements (to avoid overcrowded plots) 864 | # we need to recompute the x axis and take those elements corresponding 865 | # to the training steps we want to plot 866 | if xAxisMultiplierTrain > 1: 867 | # Actual selected samples 868 | selectSamplesTrain = xTrain 869 | # Go and fetch tem 870 | for thisModel in modelList: 871 | meanLossTrain[thisModel] = meanLossTrain[thisModel]\ 872 | [selectSamplesTrain] 873 | stdDevLossTrain[thisModel] = stdDevLossTrain[thisModel]\ 874 | [selectSamplesTrain] 875 | meanCostTrain[thisModel] = meanCostTrain[thisModel]\ 876 | [selectSamplesTrain] 877 | stdDevCostTrain[thisModel] = stdDevCostTrain[thisModel]\ 878 | [selectSamplesTrain] 879 | # And same for the validation, if necessary. 880 | if xAxisMultiplierValid > 1: 881 | selectSamplesValid = np.arange(0, len(meanLossValid[thisModel]), \ 882 | xAxisMultiplierValid) 883 | for thisModel in modelList: 884 | meanLossValid[thisModel] = meanLossValid[thisModel]\ 885 | [selectSamplesValid] 886 | stdDevLossValid[thisModel] = stdDevLossValid[thisModel]\ 887 | [selectSamplesValid] 888 | meanCostValid[thisModel] = meanCostValid[thisModel]\ 889 | [selectSamplesValid] 890 | stdDevCostValid[thisModel] = stdDevCostValid[thisModel]\ 891 | [selectSamplesValid] 892 | 893 | #\\\ LOSS (Training and validation) for EACH MODEL 894 | for key in meanLossTrain.keys(): 895 | lossFig = plt.figure(figsize=(1.61*figSize, 1*figSize)) 896 | plt.errorbar(xTrain, meanLossTrain[key], yerr = stdDevLossTrain[key], 897 | color = '#01256E', linewidth = lineWidth, 898 | marker = markerShape, markersize = markerSize) 899 | plt.errorbar(xValid, meanLossValid[key], yerr = stdDevLossValid[key], 900 | color = '#95001A', linewidth = lineWidth, 901 | marker = markerShape, markersize = markerSize) 902 | plt.ylabel(r'Loss') 903 | plt.xlabel(r'Training steps') 904 | plt.legend([r'Training', r'Validation']) 905 | plt.title(r'%s' % key) 906 | lossFig.savefig(os.path.join(saveDirFigs,'loss%s.pdf' % key), 907 | bbox_inches = 'tight') 908 | 909 | #\\\ RMSE (Training and validation) for EACH MODEL 910 | for key in meanCostTrain.keys(): 911 | costFig = plt.figure(figsize=(1.61*figSize, 1*figSize)) 912 | plt.errorbar(xTrain, meanCostTrain[key], yerr = stdDevCostTrain[key], 913 | color = '#01256E', linewidth = lineWidth, 914 | marker = markerShape, markersize = markerSize) 915 | plt.errorbar(xValid, meanCostValid[key], yerr = stdDevCostValid[key], 916 | color = '#95001A', linewidth = lineWidth, 917 | marker = markerShape, markersize = markerSize) 918 | plt.ylabel(r'Error rate') 919 | plt.xlabel(r'Training steps') 920 | plt.legend([r'Training', r'Validation']) 921 | plt.title(r'%s' % key) 922 | costFig.savefig(os.path.join(saveDirFigs,'cost%s.pdf' % key), 923 | bbox_inches = 'tight') 924 | 925 | # LOSS (training) for ALL MODELS 926 | allLossTrain = plt.figure(figsize=(1.61*figSize, 1*figSize)) 927 | for key in meanLossTrain.keys(): 928 | plt.errorbar(xTrain, meanLossTrain[key], yerr = stdDevLossTrain[key], 929 | linewidth = lineWidth, 930 | marker = markerShape, markersize = markerSize) 931 | plt.ylabel(r'Loss') 932 | plt.xlabel(r'Training steps') 933 | plt.legend(list(meanLossTrain.keys())) 934 | allLossTrain.savefig(os.path.join(saveDirFigs,'allLossTrain.pdf'), 935 | bbox_inches = 'tight') 936 | 937 | # RMSE (validation) for ALL MODELS 938 | allCostValidFig = plt.figure(figsize=(1.61*figSize, 1*figSize)) 939 | for key in meanCostValid.keys(): 940 | plt.errorbar(xValid, meanCostValid[key], yerr = stdDevCostValid[key], 941 | linewidth = lineWidth, 942 | marker = markerShape, markersize = markerSize) 943 | plt.ylabel(r'Error rate') 944 | plt.xlabel(r'Training steps') 945 | plt.legend(list(meanCostValid.keys())) 946 | allCostValidFig.savefig(os.path.join(saveDirFigs,'allCostValid.pdf'), 947 | bbox_inches = 'tight') 948 | 949 | # Finish measuring time 950 | endRunTime = datetime.datetime.now() 951 | 952 | totalRunTime = abs(endRunTime - startRunTime) 953 | totalRunTimeH = int(divmod(totalRunTime.total_seconds(), 3600)[0]) 954 | totalRunTimeM, totalRunTimeS = \ 955 | divmod(totalRunTime.total_seconds() - totalRunTimeH * 3600., 60) 956 | totalRunTimeM = int(totalRunTimeM) 957 | 958 | if doPrint: 959 | print(" ") 960 | print("Simulation started: %s" %startRunTime.strftime("%Y/%m/%d %H:%M:%S")) 961 | print("Simulation ended: %s" % endRunTime.strftime("%Y/%m/%d %H:%M:%S")) 962 | print("Total time: %dh %dm %.2fs" % (totalRunTimeH, 963 | totalRunTimeM, 964 | totalRunTimeS)) 965 | 966 | # And save this info into the .txt file as well 967 | with open(varsFile, 'a+') as file: 968 | file.write("Simulation started: %s\n" % 969 | startRunTime.strftime("%Y/%m/%d %H:%M:%S")) 970 | file.write("Simulation ended: %s\n" % 971 | endRunTime.strftime("%Y/%m/%d %H:%M:%S")) 972 | file.write("Total time: %dh %dm %.2fs" % (totalRunTimeH, 973 | totalRunTimeM, 974 | totalRunTimeS)) 975 | -------------------------------------------------------------------------------- /examples/epidemicGRNN.py: -------------------------------------------------------------------------------- 1 | # 2021/03/04~ 2 | # Luana Ruiz, rubruiz@seas.upenn.edu. 3 | # Fernando Gama, fgama@seas.upenn.edu. 4 | 5 | # Simulate the epidemic tracking problem. In this experiment, we compare GRNNs 6 | # and gated GRNNs in a binary node classification problem modeling the spread of 7 | # an epidemic on a high school friendship network. The epidemic data is generated 8 | # by using the SIR model to simulate the spread of an infectious disease on the 9 | # friendship network. The disease is first recorded on day t=0, when each individual 10 | # node is infected with probability p_{seed}=0.05. On the days that follow, an 11 | # infected student can then spread the disease to their susceptible friends with 12 | # probability p_inf=0.3 each day. Infected students become immune after 4 days, 13 | # at which point they can no longer spread or contract the disease. 14 | # Given the state of each node at some point in time (susceptible, infected or 15 | # recovered), the binary node classification problem is to predict whether each 16 | # node in the network will have the disease (i.e., be infected) seqLen=8 days ahead. 17 | 18 | # Outputs: 19 | # - Text file with all the hyperparameters selected for the run and the 20 | # corresponding results (hyperparameters.txt) 21 | # - Pickle file with the random seeds of both torch and numpy for accurate 22 | # reproduction of results (randomSeedUsed.pkl) 23 | # - The parameters of the trained models, for both the Best and the Last 24 | # instance of each model (savedModels/) 25 | # - The figures of loss and evaluation through the training iterations for 26 | # each model (figs/ and trainVars/) 27 | # - If selected, logs in tensorboardX certain useful training variables 28 | 29 | #%%################################################################## 30 | # # 31 | # IMPORTING # 32 | # # 33 | ##################################################################### 34 | 35 | #\\\ Standard libraries: 36 | import os 37 | import numpy as np 38 | import matplotlib 39 | matplotlib.rcParams['text.usetex'] = True 40 | matplotlib.rcParams['font.family'] = 'serif' 41 | matplotlib.rcParams['text.latex.preamble']=[r'\usepackage{amsmath}'] 42 | import matplotlib.pyplot as plt 43 | import pickle 44 | import datetime 45 | from copy import deepcopy 46 | 47 | import torch; torch.set_default_dtype(torch.float64) 48 | import torch.nn as nn 49 | import torch.optim as optim 50 | 51 | #\\\ Own libraries: 52 | import alegnn.utils.graphTools as graphTools 53 | import alegnn.utils.dataTools 54 | import alegnn.modules.architectures as archit 55 | import alegnn.modules.model as model 56 | import alegnn.modules.training as training 57 | import alegnn.modules.evaluation as evaluation 58 | import alegnn.modules.loss as loss 59 | 60 | #\\\ Separate functions: 61 | from alegnn.utils.miscTools import writeVarValues 62 | from alegnn.utils.miscTools import saveSeed 63 | 64 | # Start measuring time 65 | startRunTime = datetime.datetime.now() 66 | 67 | #%%################################################################## 68 | # # 69 | # SETTING PARAMETERS # 70 | # # 71 | ##################################################################### 72 | 73 | thisFilename = 'epidemicGRNN' # This is the general name of all related files 74 | 75 | saveDirRoot = 'experiments' # In this case, relative location 76 | saveDir = os.path.join(saveDirRoot, thisFilename) # Dir where to save all 77 | # the results from each run 78 | 79 | #\\\ Create .txt to store the values of the setting parameters for easier 80 | # reference when running multiple experiments 81 | today = datetime.datetime.now().strftime("%Y%m%d%H%M%S") 82 | # Append date and time of the run to the directory, to avoid several runs of 83 | # overwritting each other. 84 | saveDir = saveDir + '-' + today 85 | # Create directory 86 | if not os.path.exists(saveDir): 87 | os.makedirs(saveDir) 88 | # Create the file where all the (hyper)parameters are results will be saved. 89 | varsFile = os.path.join(saveDir,'hyperparameters.txt') 90 | with open(varsFile, 'w+') as file: 91 | file.write('%s\n\n' % datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S")) 92 | 93 | #\\\ Save seeds for reproducibility 94 | # PyTorch seeds 95 | torchState = torch.get_rng_state() 96 | torchSeed = torch.initial_seed() 97 | # Numpy seeds 98 | numpyState = np.random.RandomState().get_state() 99 | # Collect all random states 100 | randomStates = [] 101 | randomStates.append({}) 102 | randomStates[0]['module'] = 'numpy' 103 | randomStates[0]['state'] = numpyState 104 | randomStates.append({}) 105 | randomStates[1]['module'] = 'torch' 106 | randomStates[1]['state'] = torchState 107 | randomStates[1]['seed'] = torchSeed 108 | # This list and dictionary follows the format to then be loaded, if needed, 109 | # by calling the loadSeed function in Utils.miscTools 110 | saveSeed(randomStates, saveDir) 111 | 112 | ######## 113 | # DATA # 114 | ######## 115 | 116 | useGPU = True # If true, and GPU is available, use it. 117 | 118 | nTrain = 1000 # Number of training samples 119 | nValid = 120 # Number of validation samples 120 | nTest = 200 # Number of testing samples 121 | seqLen = 8 # Sequence length 122 | seedProb = 0.05 123 | infectionProb = 0.3 124 | recoveryTime = 4 125 | 126 | nDataRealizations = 10 # Number of data realizations 127 | 128 | #\\\ Save values: 129 | writeVarValues(varsFile, {'nTrain': nTrain, 130 | 'nValid': nValid, 131 | 'nTest': nTest, 132 | 'seqLen': seqLen, 133 | 'seedProb': seedProb, 134 | 'infectionProb': infectionProb, 135 | 'recoveryTime': recoveryTime, 136 | 'nDataRealizations':nDataRealizations, 137 | 'useGPU': useGPU}) 138 | 139 | ############ 140 | # TRAINING # 141 | ############ 142 | 143 | #\\\ Individual model training options 144 | optimAlg = 'ADAM' # Options: 'SGD', 'ADAM', 'RMSprop' 145 | learningRate = 0.0005 # In all options 146 | beta1 = 0.9 # beta1 if 'ADAM', alpha if 'RMSprop' 147 | beta2 = 0.999 # ADAM option only 148 | 149 | #\\\ Loss function choice 150 | lossFunction = loss.F1Score 151 | 152 | #\\\ Overall training options 153 | nEpochs = 10 # Number of epochs 154 | batchSize = 100 # Batch size 155 | doLearningRateDecay = False # Learning rate decay 156 | learningRateDecayRate = 0.9 # Rate 157 | learningRateDecayPeriod = 1 # How many epochs after which update the lr 158 | validationInterval = 5 # How many training steps to do the validation 159 | 160 | #\\\ Save values 161 | writeVarValues(varsFile, 162 | {'optimAlg': optimAlg, 163 | 'learningRate': learningRate, 164 | 'beta1': beta1, 165 | 'lossFunction': lossFunction, 166 | 'nEpochs': nEpochs, 167 | 'batchSize': batchSize, 168 | 'doLearningRateDecay': doLearningRateDecay, 169 | 'learningRateDecayRate': learningRateDecayRate, 170 | 'learningRateDecayPeriod': learningRateDecayPeriod, 171 | 'validationInterval': validationInterval}) 172 | 173 | ################# 174 | # ARCHITECTURES # 175 | ################# 176 | 177 | # Select desired architectures 178 | doGRNN = True 179 | doTimeGatedGRNN = True 180 | doNodeGatedGRNN = True 181 | doEdgeGatedGRNN = True 182 | 183 | # In this section, we determine the (hyper)parameters of models that we are 184 | # going to train. This only sets the parameters. The architectures need to be 185 | # created later below. Do not forget to add the name of the architecture 186 | # to modelList. 187 | 188 | # If the model dictionary is called 'model' + name, then it can be 189 | # picked up immediately later on, and there's no need to recode anything after 190 | # the section 'Setup' (except for setting the number of nodes in the 'N' 191 | # variable after it has been coded). 192 | 193 | # The name of the keys in the model dictionary have to be the same 194 | # as the names of the variables in the architecture call, because they will 195 | # be called by unpacking the dictionary. 196 | 197 | modelList = [] 198 | 199 | #\\\\\\\\\\\\ 200 | #\\\ MODEL 1: GRNN 201 | #\\\\\\\\\\\\ 202 | 203 | #\\\ Basic parameters for all the Selection GNN architectures 204 | 205 | modelGRNN = {} 206 | modelGRNN['name'] = 'GRNN' # To be modified later on depending on the 207 | # specific ordering selected 208 | modelGRNN['device'] = 'cuda:0' if (useGPU and torch.cuda.is_available()) \ 209 | else 'cpu' 210 | 211 | #\\\ ARCHITECTURE 212 | 213 | # Select architectural nn.Module to use 214 | modelGRNN['archit'] = archit.GraphRecurrentNN 215 | # Graph convolutional layers 216 | modelGRNN['dimInputSignals'] = 1 # Number of features of x 217 | modelGRNN['dimOutputSignals'] = 2 # Number of features of y 218 | modelGRNN['dimHiddenSignals'] = 12 # Number of features of z 219 | modelGRNN['nFilterTaps'] = [5,5] # Number of filter taps 220 | modelGRNN['bias'] = True # Include bias 221 | # Nonlinearity 222 | modelGRNN['nonlinearityHidden'] = nn.Tanh() 223 | modelGRNN['nonlinearityOutput'] = nn.ReLU() 224 | modelGRNN['nonlinearityReadout'] = nn.ReLU() 225 | # Readout layer 226 | modelGRNN['dimReadout'] = [] 227 | # Graph Structure 228 | modelGRNN['GSO'] = None # To be determined later on, based on data 229 | 230 | #\\\ TRAINER 231 | 232 | modelGRNN['trainer'] = training.Trainer 233 | 234 | #\\\ EVALUATOR 235 | 236 | modelGRNN['evaluator'] = evaluation.evaluate 237 | 238 | if doGRNN: 239 | 240 | #\\\ Save Values: 241 | writeVarValues(varsFile, modelGRNN) 242 | modelList += [modelGRNN['name']] 243 | 244 | #\\\\\\\\\\\\ 245 | #\\\ MODEL 2: Time-gated GRNN 246 | #\\\\\\\\\\\\ 247 | 248 | if doTimeGatedGRNN: 249 | 250 | modelTimeGatedGRNN = deepcopy(modelGRNN) 251 | 252 | modelTimeGatedGRNN['name'] = 'TimeGatedGRNN' 253 | modelTimeGatedGRNN['archit'] = archit.GatedGraphRecurrentNN 254 | modelTimeGatedGRNN['gateType'] = 'time' 255 | 256 | #\\\ Save Values: 257 | writeVarValues(varsFile, modelTimeGatedGRNN) 258 | modelList += [modelTimeGatedGRNN['name']] 259 | 260 | #\\\\\\\\\\\\ 261 | #\\\ MODEL 3: Node-gated GRNN 262 | #\\\\\\\\\\\\ 263 | 264 | if doNodeGatedGRNN: 265 | 266 | modelNodeGatedGRNN = deepcopy(modelGRNN) 267 | 268 | modelNodeGatedGRNN['name'] = 'NodeGatedGRNN' 269 | modelNodeGatedGRNN['archit'] = archit.GatedGraphRecurrentNN 270 | modelNodeGatedGRNN['gateType'] = 'node' 271 | 272 | #\\\ Save Values: 273 | writeVarValues(varsFile, modelNodeGatedGRNN) 274 | modelList += [modelNodeGatedGRNN['name']] 275 | 276 | #\\\\\\\\\\\\ 277 | #\\\ MODEL 4: Edge-gated GRNN 278 | #\\\\\\\\\\\\ 279 | 280 | if doEdgeGatedGRNN: 281 | 282 | modelEdgeGatedGRNN = deepcopy(modelGRNN) 283 | 284 | modelEdgeGatedGRNN['name'] = 'EdgeGatedGRNN' 285 | modelEdgeGatedGRNN['archit'] = archit.GatedGraphRecurrentNN 286 | modelEdgeGatedGRNN['gateType'] = 'edge' 287 | 288 | #\\\ Save Values: 289 | writeVarValues(varsFile, modelEdgeGatedGRNN) 290 | modelList += [modelEdgeGatedGRNN['name']] 291 | 292 | ########### 293 | # LOGGING # 294 | ########### 295 | 296 | # Options: 297 | doPrint = True # Decide whether to print stuff while running 298 | doLogging = False # Log into tensorboard 299 | doSaveVars = True # Save (pickle) useful variables 300 | doFigs = True # Plot some figures (this only works if doSaveVars is True) 301 | # Parameters: 302 | printInterval = 10 # After how many training steps, print the partial results 303 | xAxisMultiplierTrain = 100 # How many training steps in between those shown in 304 | # the plot, i.e., one training step every xAxisMultiplierTrain is shown. 305 | xAxisMultiplierValid = 10 # How many validation steps in between those shown, 306 | # same as above. 307 | figSize = 5 # Overall size of the figure that contains the plot 308 | lineWidth = 2 # Width of the plot lines 309 | markerShape = 'o' # Shape of the markers 310 | markerSize = 3 # Size of the markers 311 | 312 | #\\\ Save values: 313 | writeVarValues(varsFile, 314 | {'doPrint': doPrint, 315 | 'doLogging': doLogging, 316 | 'doSaveVars': doSaveVars, 317 | 'doFigs': doFigs, 318 | 'saveDir': saveDir, 319 | 'printInterval': printInterval, 320 | 'figSize': figSize, 321 | 'lineWidth': lineWidth, 322 | 'markerShape': markerShape, 323 | 'markerSize': markerSize}) 324 | 325 | #%%################################################################## 326 | # # 327 | # SETUP # 328 | # # 329 | ##################################################################### 330 | 331 | #\\\ Determine processing unit: 332 | if useGPU and torch.cuda.is_available(): 333 | torch.cuda.empty_cache() 334 | 335 | #\\\ Notify of processing units 336 | if doPrint: 337 | print("Selected devices:") 338 | for thisModel in modelList: 339 | modelDict = eval('model' + thisModel) 340 | print("\t%s: %s" % (thisModel, modelDict['device'])) 341 | 342 | #\\\ Logging options 343 | if doLogging: 344 | from alegnn.utils.visualTools import Visualizer 345 | logsTB = os.path.join(saveDir, 'logsTB') 346 | logger = Visualizer(logsTB, name='visualResults') 347 | 348 | #\\\ Save variables during evaluation. 349 | # We will save all the evaluations obtained for each of the trained models. 350 | # It basically is a dictionary, containing a list. The key of the 351 | # dictionary determines the model, then the first list index determines 352 | # which split realization. Then, this will be converted to numpy to compute 353 | # mean and standard deviation (across the split dimension). 354 | costBest = {} # Cost for the best model (Evaluation cost: Error rate) 355 | costLast = {} # Cost for the last model 356 | for thisModel in modelList: # Create an element for each split realization, 357 | costBest[thisModel] = [] 358 | costLast[thisModel] = [] 359 | 360 | if doFigs: 361 | #\\\ SAVE SPACE: 362 | # Create the variables to save all the realizations. This is, again, a 363 | # dictionary, where each key represents a model, and each model is a list 364 | # for each data split. 365 | # Each data split, in this case, is not a scalar, but a vector of 366 | # length the number of training steps (or of validation steps) 367 | lossTrain = {} 368 | costTrain = {} 369 | lossValid = {} 370 | costValid = {} 371 | # Initialize the splits dimension 372 | for thisModel in modelList: 373 | lossTrain[thisModel] = [] 374 | costTrain[thisModel] = [] 375 | lossValid[thisModel] = [] 376 | costValid[thisModel] = [] 377 | 378 | 379 | #################### 380 | # TRAINING OPTIONS # 381 | #################### 382 | 383 | # Training phase. It has a lot of options that are input through a 384 | # dictionary of arguments. 385 | # The value of this options was decided above with the rest of the parameters. 386 | # This just creates a dictionary necessary to pass to the train function. 387 | 388 | trainingOptions = {} 389 | 390 | if doLogging: 391 | trainingOptions['logger'] = logger 392 | if doSaveVars: 393 | trainingOptions['saveDir'] = saveDir 394 | if doPrint: 395 | trainingOptions['printInterval'] = printInterval 396 | if doLearningRateDecay: 397 | trainingOptions['learningRateDecayRate'] = learningRateDecayRate 398 | trainingOptions['learningRateDecayPeriod'] = learningRateDecayPeriod 399 | trainingOptions['validationInterval'] = validationInterval 400 | 401 | # And in case each model has specific training options, then we create a 402 | # separate dictionary per model. 403 | 404 | trainingOptsPerModel= {} 405 | 406 | #%%################################################################## 407 | # # 408 | # DATA HANDLING # 409 | # # 410 | ##################################################################### 411 | 412 | ######### 413 | # GRAPH # 414 | ######### 415 | 416 | # Create graph 417 | Adj = alegnn.utils.dataTools.Epidemics.createGraph() 418 | nNodes = Adj.shape[0] 419 | graphOptions = {} 420 | graphOptions['adjacencyMatrix'] = Adj 421 | G = graphTools.Graph('adjacency', nNodes, graphOptions) 422 | G.computeGFT() # Compute the eigendecomposition of the stored GSO 423 | 424 | for realization in range(nDataRealizations): 425 | 426 | ############ 427 | # DATASETS # 428 | ############ 429 | 430 | data = alegnn.utils.dataTools.Epidemics(seqLen, seedProb, infectionProb, 431 | recoveryTime, nTrain, nValid, 432 | nTest) 433 | data.astype(torch.float64) 434 | #data.to(device) 435 | data.expandDims() # Data are just graph processes, but the architectures 436 | # require that the input signals are of the form B x T x F x N, so we 437 | # need to expand the middle dimensions to convert them from B x T x N 438 | # to B x T x 1 x N 439 | 440 | #%%################################################################## 441 | # # 442 | # MODELS INITIALIZATION # 443 | # # 444 | ##################################################################### 445 | 446 | # This is the dictionary where we store the models (in a model.Model 447 | # class, that is then passed to training). 448 | modelsGRNN = {} 449 | 450 | # If a new model is to be created, it should be called for here. 451 | 452 | if doPrint: 453 | print("Model initialization...", flush = True) 454 | 455 | for thisModel in modelList: 456 | 457 | # Get the corresponding parameter dictionary 458 | modelDict = deepcopy(eval('model' + thisModel)) 459 | # and training options 460 | trainingOptsPerModel[thisModel] = deepcopy(trainingOptions) 461 | 462 | # Now, this dictionary has all the hyperparameters that we need to 463 | # pass to the architecture function, but it also has other keys 464 | # that belong to the more general model (like 'name' or 'device'), 465 | # so we need to extract them and save them in seperate variables 466 | # for future use. 467 | thisName = modelDict.pop('name') 468 | callArchit = modelDict.pop('archit') 469 | thisDevice = modelDict.pop('device') 470 | thisTrainer = modelDict.pop('trainer') 471 | thisEvaluator = modelDict.pop('evaluator') 472 | 473 | # If more than one graph or data realization is going to be 474 | # carried out, we are going to store all of thos models 475 | # separately, so that any of them can be brought back and 476 | # studied in detail. 477 | if nDataRealizations > 1: 478 | thisName += 'R%02d' % realization 479 | 480 | if doPrint: 481 | print("\tInitializing %s..." % thisName, 482 | end = ' ',flush = True) 483 | 484 | ############## 485 | # PARAMETERS # 486 | ############## 487 | 488 | #\\\ Optimizer options 489 | # (If different from the default ones, change here.) 490 | thisOptimAlg = optimAlg 491 | thisLearningRate = learningRate 492 | thisBeta1 = beta1 493 | thisBeta2 = beta2 494 | 495 | #\\\ GSO 496 | # Normalize adjacency 497 | S = G.S.copy()/np.max(np.real(G.E)) 498 | 499 | modelDict['GSO'] = S 500 | 501 | ################ 502 | # ARCHITECTURE # 503 | ################ 504 | 505 | thisArchit = callArchit(**modelDict) 506 | 507 | ############# 508 | # OPTIMIZER # 509 | ############# 510 | 511 | if thisOptimAlg == 'ADAM': 512 | thisOptim = optim.Adam(thisArchit.parameters(), 513 | lr = learningRate, 514 | betas = (beta1, beta2)) 515 | elif thisOptimAlg == 'SGD': 516 | thisOptim = optim.SGD(thisArchit.parameters(), 517 | lr = learningRate) 518 | elif thisOptimAlg == 'RMSprop': 519 | thisOptim = optim.RMSprop(thisArchit.parameters(), 520 | lr = learningRate, alpha = beta1) 521 | 522 | ######## 523 | # LOSS # 524 | ######## 525 | 526 | # Initialize the loss function 527 | thisLossFunction = lossFunction 528 | 529 | ######### 530 | # MODEL # 531 | ######### 532 | 533 | # Create the model 534 | modelCreated = model.Model(thisArchit, 535 | thisLossFunction, 536 | thisOptim, 537 | thisTrainer, 538 | thisEvaluator, 539 | thisDevice, 540 | thisName, 541 | saveDir) 542 | 543 | # Store it 544 | modelsGRNN[thisName] = modelCreated 545 | 546 | # Write the main hyperparameters 547 | writeVarValues(varsFile, 548 | {'name': thisName, 549 | 'thisOptimizationAlgorithm': thisOptimAlg, 550 | 'thisTrainer': thisTrainer, 551 | 'thisEvaluator': thisEvaluator, 552 | 'thisLearningRate': thisLearningRate, 553 | 'thisBeta1': thisBeta1, 554 | 'thisBeta2': thisBeta2}) 555 | 556 | if doPrint: 557 | print("OK") 558 | 559 | if doPrint: 560 | print("Model initialization... COMPLETE") 561 | 562 | #%%################################################################## 563 | # # 564 | # TRAINING # 565 | # # 566 | ##################################################################### 567 | 568 | print("") 569 | 570 | # We train each model separately 571 | 572 | for thisModel in modelsGRNN.keys(): 573 | 574 | if doPrint: 575 | print("Training model %s..." % thisModel) 576 | 577 | # Remember that modelsGNN.keys() has the split numbering as well as 578 | # the name, while modelList has only the name. So we need to map 579 | # the specific model for this specific split with the actual model 580 | # name, since there are several variables that are indexed by the 581 | # model name (for instance, the training options, or the 582 | # dictionaries saving the loss values) 583 | for m in modelList: 584 | if m in thisModel: 585 | modelName = m 586 | 587 | # Identify the specific graph and data realizations at training time 588 | if nDataRealizations > 1: 589 | trainingOptions['realizationNo'] = realization 590 | 591 | # Train the model 592 | thisTrainVars = modelsGRNN[thisModel].train(data, 593 | nEpochs, 594 | batchSize, 595 | **trainingOptsPerModel[modelName]) 596 | 597 | if doFigs: 598 | # Find which model to save the results (when having multiple 599 | # realizations) 600 | lossTrain[modelName].append(thisTrainVars['lossTrain']) 601 | costTrain[modelName].append(thisTrainVars['costTrain']) 602 | lossValid[modelName].append(thisTrainVars['lossValid']) 603 | costValid[modelName].append(thisTrainVars['costValid']) 604 | 605 | # And we also need to save 'nBatch' but is the same for all models, so 606 | if doFigs: 607 | nBatches = thisTrainVars['nBatches'] 608 | 609 | #%%################################################################## 610 | # # 611 | # EVALUATION # 612 | # # 613 | ##################################################################### 614 | 615 | # Now that the model has been trained, we evaluate them on the test 616 | # samples. 617 | 618 | # We have two versions of each model to evaluate: the one obtained 619 | # at the best result of the validation step, and the last trained model. 620 | 621 | if doPrint: 622 | print("\nTotal testing error rate", end = '', flush = True) 623 | if nDataRealizations > 1: 624 | print(" (", end = '', flush = True) 625 | print("Realization %02d" % realization, end = '', flush = True) 626 | print(")", end = '', flush = True) 627 | print(":", flush = True) 628 | 629 | 630 | for thisModel in modelsGRNN.keys(): 631 | 632 | # Same as before, separate the model name from the data or graph 633 | # realization number 634 | for m in modelList: 635 | if m in thisModel: 636 | modelName = m 637 | 638 | # Evaluate the model 639 | thisEvalVars = modelsGRNN[thisModel].evaluate(data) 640 | 641 | # Save the outputs 642 | thisCostBest = thisEvalVars['costBest'] 643 | thisCostLast = thisEvalVars['costLast'] 644 | 645 | # Write values 646 | writeVarValues(varsFile, 647 | {'costBest%s' % thisModel: thisCostBest, 648 | 'costLast%s' % thisModel: thisCostLast}) 649 | 650 | # Now check which is the model being trained 651 | costBest[modelName].append(thisCostBest) 652 | costLast[modelName].append(thisCostLast) 653 | # This is so that we can later compute a total accuracy with 654 | # the corresponding error. 655 | 656 | if doPrint: 657 | print("\t%s: %1.4f [Best] %1.4f [Last]" % (thisModel, 658 | thisCostBest, 659 | thisCostLast)) 660 | 661 | ############################ 662 | # FINAL EVALUATION RESULTS # 663 | ############################ 664 | 665 | # Now that we have computed the accuracy of all runs, we can obtain a final 666 | # result (mean and standard deviation) 667 | 668 | meanCostBestPerGraph = {} # Compute the mean accuracy (best) across all 669 | # realizations data realizations of a graph 670 | meanCostLastPerGraph = {} # Compute the mean accuracy (last) across all 671 | # realizations data realizations of a graph 672 | meanCostBest = {} # Mean across graphs (after having averaged across data 673 | # realizations) 674 | meanCostLast = {} # Mean across graphs 675 | stdDevCostBest = {} # Standard deviation across graphs 676 | stdDevCostLast = {} # Standard deviation across graphs 677 | 678 | if doPrint: 679 | print("\nFinal evaluations (%02d realizations)" % (nDataRealizations)) 680 | 681 | for thisModel in modelList: 682 | # Convert the lists nDataRealizations array 683 | costBest[thisModel] = np.array(costBest[thisModel]) 684 | costLast[thisModel] = np.array(costLast[thisModel]) 685 | 686 | if nDataRealizations == 1: 687 | meanCostBest[thisModel] = np.squeeze(costBest[thisModel]) 688 | meanCostLast[thisModel] = np.squeeze(costLast[thisModel]) 689 | else: 690 | meanCostBest[thisModel] = np.mean(costBest[thisModel]) 691 | meanCostLast[thisModel] = np.mean(costLast[thisModel]) 692 | stdDevCostBest[thisModel] = np.std(costBest[thisModel]) 693 | stdDevCostLast[thisModel] = np.std(costLast[thisModel]) 694 | 695 | # And print it: 696 | if doPrint: 697 | print("\t%s: %1.4f (+-%1.4f) [Best] %1.4f (+-%1.4f) [Last]" % ( 698 | thisModel, 699 | meanCostBest[thisModel], 700 | stdDevCostBest[thisModel], 701 | meanCostLast[thisModel], 702 | stdDevCostLast[thisModel])) 703 | 704 | # Save values 705 | writeVarValues(varsFile, 706 | {'meanCostBest%s' % thisModel: meanCostBest[thisModel], 707 | 'stdDevCostBest%s' % thisModel: stdDevCostBest[thisModel], 708 | 'meanCostLast%s' % thisModel: meanCostLast[thisModel], 709 | 'stdDevCostLast%s' % thisModel : stdDevCostLast[thisModel]}) 710 | 711 | with open(varsFile, 'a+') as file: 712 | file.write("Final evaluations (%02d realizations)\n" % (nDataRealizations)) 713 | for thisModel in modelList: 714 | file.write("\t%s: %1.4f (+-%1.4f) [Best] %1.4f (+-%1.4f) [Last]\n" % ( 715 | thisModel, 716 | meanCostBest[thisModel], 717 | stdDevCostBest[thisModel], 718 | meanCostLast[thisModel], 719 | stdDevCostLast[thisModel])) 720 | file.write('\n') 721 | 722 | # FIX 723 | 724 | #%%################################################################## 725 | # # 726 | # PLOT # 727 | # # 728 | ##################################################################### 729 | 730 | # Finally, we might want to plot several quantities of interest 731 | 732 | if doFigs and doSaveVars: 733 | 734 | ################### 735 | # DATA PROCESSING # 736 | ################### 737 | 738 | #\\\ FIGURES DIRECTORY: 739 | saveDirFigs = os.path.join(saveDir,'figs') 740 | # If it doesn't exist, create it. 741 | if not os.path.exists(saveDirFigs): 742 | os.makedirs(saveDirFigs) 743 | 744 | #\\\ COMPUTE STATISTICS: 745 | # The first thing to do is to transform those into a matrix with all the 746 | # realizations, so create the variables to save that. 747 | meanLossTrain = {} 748 | meanCostTrain = {} 749 | meanLossValid = {} 750 | meanCostValid = {} 751 | stdDevLossTrain = {} 752 | stdDevCostTrain = {} 753 | stdDevLossValid = {} 754 | stdDevCostValid = {} 755 | # Initialize the variables 756 | for thisModel in modelList: 757 | # And compute the statistics 758 | meanLossTrain[thisModel] = \ 759 | np.mean(np.array(lossTrain[thisModel]), axis = 0) 760 | meanCostTrain[thisModel] = \ 761 | np.mean(np.array(costTrain[thisModel]), axis = 0) 762 | meanLossValid[thisModel] = \ 763 | np.mean(np.array(lossValid[thisModel]), axis = 0) 764 | meanCostValid[thisModel] = \ 765 | np.mean(np.array(costValid[thisModel]), axis = 0) 766 | stdDevLossTrain[thisModel] = \ 767 | np.std(np.array(lossTrain[thisModel]), axis = 0) 768 | stdDevCostTrain[thisModel] = \ 769 | np.std(np.array(costTrain[thisModel]), axis = 0) 770 | stdDevLossValid[thisModel] = \ 771 | np.std(np.array(lossValid[thisModel]), axis = 0) 772 | stdDevCostValid[thisModel] = \ 773 | np.std(np.array(costValid[thisModel]), axis = 0) 774 | 775 | #################### 776 | # SAVE FIGURE DATA # 777 | #################### 778 | 779 | # And finally, we can plot. But before, let's save the variables mean and 780 | # stdDev so, if we don't like the plot, we can re-open them, and re-plot 781 | # them, a piacere. 782 | varsPickle = {} 783 | varsPickle['nEpochs'] = nEpochs 784 | varsPickle['nBatches'] = nBatches 785 | varsPickle['meanLossTrain'] = meanLossTrain 786 | varsPickle['stdDevLossTrain'] = stdDevLossTrain 787 | varsPickle['meanCostTrain'] = meanCostTrain 788 | varsPickle['stdDevCostTrain'] = stdDevCostTrain 789 | varsPickle['meanLossValid'] = meanLossValid 790 | varsPickle['stdDevLossValid'] = stdDevLossValid 791 | varsPickle['meanCostValid'] = meanCostValid 792 | varsPickle['stdDevCostValid'] = stdDevCostValid 793 | with open(os.path.join(saveDirFigs,'figVars.pkl'), 'wb') as figVarsFile: 794 | pickle.dump(varsPickle, figVarsFile) 795 | 796 | ######## 797 | # PLOT # 798 | ######## 799 | 800 | # Compute the x-axis 801 | xTrain = np.arange(0, nEpochs * nBatches, xAxisMultiplierTrain) 802 | xValid = np.arange(0, nEpochs * nBatches, \ 803 | validationInterval*xAxisMultiplierValid) 804 | 805 | # If we do not want to plot all the elements (to avoid overcrowded plots) 806 | # we need to recompute the x axis and take those elements corresponding 807 | # to the training steps we want to plot 808 | if xAxisMultiplierTrain > 1: 809 | # Actual selected samples 810 | selectSamplesTrain = xTrain 811 | # Go and fetch tem 812 | for thisModel in modelList: 813 | meanLossTrain[thisModel] = meanLossTrain[thisModel]\ 814 | [selectSamplesTrain] 815 | stdDevLossTrain[thisModel] = stdDevLossTrain[thisModel]\ 816 | [selectSamplesTrain] 817 | meanCostTrain[thisModel] = meanCostTrain[thisModel]\ 818 | [selectSamplesTrain] 819 | stdDevCostTrain[thisModel] = stdDevCostTrain[thisModel]\ 820 | [selectSamplesTrain] 821 | # And same for the validation, if necessary. 822 | if xAxisMultiplierValid > 1: 823 | selectSamplesValid = np.arange(0, len(meanLossValid[thisModel]), \ 824 | xAxisMultiplierValid) 825 | for thisModel in modelList: 826 | meanLossValid[thisModel] = meanLossValid[thisModel]\ 827 | [selectSamplesValid] 828 | stdDevLossValid[thisModel] = stdDevLossValid[thisModel]\ 829 | [selectSamplesValid] 830 | meanCostValid[thisModel] = meanCostValid[thisModel]\ 831 | [selectSamplesValid] 832 | stdDevCostValid[thisModel] = stdDevCostValid[thisModel]\ 833 | [selectSamplesValid] 834 | 835 | #\\\ LOSS (Training and validation) for EACH MODEL 836 | for key in meanLossTrain.keys(): 837 | lossFig = plt.figure(figsize=(1.61*figSize, 1*figSize)) 838 | plt.errorbar(xTrain, meanLossTrain[key], yerr = stdDevLossTrain[key], 839 | color = '#01256E', linewidth = lineWidth, 840 | marker = markerShape, markersize = markerSize) 841 | plt.errorbar(xValid, meanLossValid[key], yerr = stdDevLossValid[key], 842 | color = '#95001A', linewidth = lineWidth, 843 | marker = markerShape, markersize = markerSize) 844 | plt.ylabel(r'Loss') 845 | plt.xlabel(r'Training steps') 846 | plt.legend([r'Training', r'Validation']) 847 | plt.title(r'%s' % key) 848 | lossFig.savefig(os.path.join(saveDirFigs,'loss%s.pdf' % key), 849 | bbox_inches = 'tight') 850 | 851 | #\\\ RMSE (Training and validation) for EACH MODEL 852 | for key in meanCostTrain.keys(): 853 | costFig = plt.figure(figsize=(1.61*figSize, 1*figSize)) 854 | plt.errorbar(xTrain, meanCostTrain[key], yerr = stdDevCostTrain[key], 855 | color = '#01256E', linewidth = lineWidth, 856 | marker = markerShape, markersize = markerSize) 857 | plt.errorbar(xValid, meanCostValid[key], yerr = stdDevCostValid[key], 858 | color = '#95001A', linewidth = lineWidth, 859 | marker = markerShape, markersize = markerSize) 860 | plt.ylabel(r'Error rate') 861 | plt.xlabel(r'Training steps') 862 | plt.legend([r'Training', r'Validation']) 863 | plt.title(r'%s' % key) 864 | costFig.savefig(os.path.join(saveDirFigs,'cost%s.pdf' % key), 865 | bbox_inches = 'tight') 866 | 867 | # LOSS (training) for ALL MODELS 868 | allLossTrain = plt.figure(figsize=(1.61*figSize, 1*figSize)) 869 | for key in meanLossTrain.keys(): 870 | plt.errorbar(xTrain, meanLossTrain[key], yerr = stdDevLossTrain[key], 871 | linewidth = lineWidth, 872 | marker = markerShape, markersize = markerSize) 873 | plt.ylabel(r'Loss') 874 | plt.xlabel(r'Training steps') 875 | plt.legend(list(meanLossTrain.keys())) 876 | allLossTrain.savefig(os.path.join(saveDirFigs,'allLossTrain.pdf'), 877 | bbox_inches = 'tight') 878 | 879 | # RMSE (validation) for ALL MODELS 880 | allCostValidFig = plt.figure(figsize=(1.61*figSize, 1*figSize)) 881 | for key in meanCostValid.keys(): 882 | plt.errorbar(xValid, meanCostValid[key], yerr = stdDevCostValid[key], 883 | linewidth = lineWidth, 884 | marker = markerShape, markersize = markerSize) 885 | plt.ylabel(r'Error rate') 886 | plt.xlabel(r'Training steps') 887 | plt.legend(list(meanCostValid.keys())) 888 | allCostValidFig.savefig(os.path.join(saveDirFigs,'allCostValid.pdf'), 889 | bbox_inches = 'tight') 890 | 891 | # Finish measuring time 892 | endRunTime = datetime.datetime.now() 893 | 894 | totalRunTime = abs(endRunTime - startRunTime) 895 | totalRunTimeH = int(divmod(totalRunTime.total_seconds(), 3600)[0]) 896 | totalRunTimeM, totalRunTimeS = \ 897 | divmod(totalRunTime.total_seconds() - totalRunTimeH * 3600., 60) 898 | totalRunTimeM = int(totalRunTimeM) 899 | 900 | if doPrint: 901 | print(" ") 902 | print("Simulation started: %s" %startRunTime.strftime("%Y/%m/%d %H:%M:%S")) 903 | print("Simulation ended: %s" % endRunTime.strftime("%Y/%m/%d %H:%M:%S")) 904 | print("Total time: %dh %dm %.2fs" % (totalRunTimeH, 905 | totalRunTimeM, 906 | totalRunTimeS)) 907 | 908 | # And save this info into the .txt file as well 909 | with open(varsFile, 'a+') as file: 910 | file.write("Simulation started: %s\n" % 911 | startRunTime.strftime("%Y/%m/%d %H:%M:%S")) 912 | file.write("Simulation ended: %s\n" % 913 | endRunTime.strftime("%Y/%m/%d %H:%M:%S")) 914 | file.write("Total time: %dh %dm %.2fs" % (totalRunTimeH, 915 | totalRunTimeM, 916 | totalRunTimeS)) 917 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "alegnn" 3 | version = "0.4.0" 4 | description = "" 5 | authors = ["Damian Owekro "] 6 | license = "GPL-3.0-or-later" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.7,<3.10" 10 | scikit-learn = "^0.24.2" 11 | torch = "^1.8.1" 12 | matplotlib = "^3.4.2" 13 | scipy = ">=1.5.4" 14 | hdf5storage = "^0.1.18" 15 | gensim = "^4.0.1" 16 | python-Levenshtein = "^0.12.2" 17 | 18 | [tool.poetry.dev-dependencies] 19 | 20 | [build-system] 21 | requires = ["poetry-core>=1.0.0"] 22 | build-backend = "poetry.core.masonry.api" 23 | -------------------------------------------------------------------------------- /selGNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alelab-upenn/graph-neural-networks/a84a39fabad5378bdcbaad20b5dbcff14b4eebcd/selGNN.png --------------------------------------------------------------------------------