├── README.md ├── netgan ├── __pycache__ │ ├── models.cpython-310.pyc │ ├── training.cpython-310.pyc │ └── utils.cpython-310.pyc ├── cora.pt ├── cora_ml.npz ├── demo_full_example.ipynb ├── demo_pytorch.ipynb ├── inf-USAir97.mtx ├── models.py ├── training.py └── utils.py └── netgan_modified ├── .ipynb_checkpoints └── demo_pytorch-checkpoint.ipynb ├── 17000_model.pt ├── 20000_model.pt ├── branch.csv ├── bus.csv ├── demo_pytorch.ipynb ├── models.py ├── training.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # NetGAN: Generating Graphs via Random Walks 2 | **Pytorch implementation of the method proposed in the paper:** 3 | [NetGAN: Generating Graphs via Random Walks](https://arxiv.org/abs/1803.00816) 4 | **based on the tensorflow implementation:** 5 | https://github.com/danielzuegner/netgan 6 | There are two folders "netgan" and "netgan_modified". The first folder is the normal netgan implementation. It contains four different python files *training.py*, *models.py*, *utils.py* and *demo_pytorch.ipynb*. *training.py* is the main file. With this file you can train a graph and generate synthetic graphs afterwards. *models.py* contains the generator and the discriminator. *utils.py* has usefull functions and the *demo_pytorch.ipynb* is a demo version where training in done on the cora dataset. 7 | For better understanding the architectures of the models are shown as images below. The hyperparameters are defined respectively. 8 | 9 | The folder "netgan_modified" is a modifed version of netgan. The generator has changed as the bottom picture shows. With the additional LSTM it is possible to generate graphs with an additional feature. *demo_pytorch.ipynb* is an example where synthetic graphs are created from an electrical grid. The structure and the conduction length are generated. *branch.csv* and *bus.csv* contain information from different electircal grids. They are created from https://electricgrids.engr.tamu.edu/electric-grid-test-cases/ 10 | # How GANs work: 11 | ![GAN](https://user-images.githubusercontent.com/17961647/81090125-8eb02500-8efd-11ea-8df5-34ec4ad643f7.png) 12 | # Generator model: 13 | ![Generator](https://user-images.githubusercontent.com/17961647/81085459-88b74580-8ef7-11ea-9614-368f8543a1f2.png) 14 | # Discriminator model: 15 | ![Discriminator](https://user-images.githubusercontent.com/17961647/81088760-bb633d00-8efb-11ea-8301-bd4887d91b91.png) 16 | # Generator model expanded with conductor length 17 | ![Generator_expanded](https://user-images.githubusercontent.com/17961647/83631876-869be180-a59e-11ea-83d8-2ba005c09983.png) 18 | -------------------------------------------------------------------------------- /netgan/__pycache__/models.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmiller96/netgan_pytorch/4511f7de6fb87000435c1fd498d720391f7ccdc5/netgan/__pycache__/models.cpython-310.pyc -------------------------------------------------------------------------------- /netgan/__pycache__/training.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmiller96/netgan_pytorch/4511f7de6fb87000435c1fd498d720391f7ccdc5/netgan/__pycache__/training.cpython-310.pyc -------------------------------------------------------------------------------- /netgan/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmiller96/netgan_pytorch/4511f7de6fb87000435c1fd498d720391f7ccdc5/netgan/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /netgan/cora.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmiller96/netgan_pytorch/4511f7de6fb87000435c1fd498d720391f7ccdc5/netgan/cora.pt -------------------------------------------------------------------------------- /netgan/cora_ml.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmiller96/netgan_pytorch/4511f7de6fb87000435c1fd498d720391f7ccdc5/netgan/cora_ml.npz -------------------------------------------------------------------------------- /netgan/inf-USAir97.mtx: -------------------------------------------------------------------------------- 1 | %%MatrixMarket matrix coordinate real symmetric 2 | %------------------------------------------------------------------------------- 3 | % UF Sparse Matrix Collection, Tim Davis 4 | % http://www.cise.ufl.edu/research/sparse/matrices/Pajek/USAir97 5 | % name: Pajek/USAir97 6 | % [Pajek network: US Air flights, 1997] 7 | % id: 1529 8 | % date: 1997 9 | % author: US Air 10 | % ed: V. Batagelj 11 | % fields: name title A id kind notes aux date author ed 12 | % aux: nodename coord 13 | % kind: undirected weighted graph 14 | %------------------------------------------------------------------------------- 15 | % notes: 16 | % ------------------------------------------------------------------------------ 17 | % Pajek network converted to sparse adjacency matrix for inclusion in UF sparse 18 | % matrix collection, Tim Davis. For Pajek datasets, See V. Batagelj & A. Mrvar, 19 | % http://vlado.fmf.uni-lj.si/pub/networks/data/. 20 | % ------------------------------------------------------------------------------ 21 | % The original problem had 3D xyz coordinates, but all values of z were equal 22 | % to 0.5, and have been removed. This graph has 2D coordinates. 23 | %------------------------------------------------------------------------------- 24 | 332 332 2126 25 | 2 1 .0436 26 | 4 1 .0767 27 | 8 1 .1026 28 | 4 2 .0515 29 | 8 2 .0866 30 | 5 3 .0269 31 | 8 3 .0843 32 | 8 4 .0365 33 | 26 4 .0915 34 | 47 4 .2109 35 | 8 5 .0849 36 | 7 6 .0197 37 | 8 6 .0683 38 | 13 6 .0143 39 | 8 7 .0488 40 | 13 7 .014 41 | 13 8 .0604 42 | 16 8 .0239 43 | 23 8 .0478 44 | 24 8 .0414 45 | 26 8 .083 46 | 27 8 .0351 47 | 28 8 .1099 48 | 30 8 .0596 49 | 34 8 .0774 50 | 35 8 .0863 51 | 36 8 .1093 52 | 37 8 .1999 53 | 38 8 .1622 54 | 47 8 .1926 55 | 65 8 .204 56 | 67 8 .3284 57 | 112 8 .3847 58 | 118 8 .3655 59 | 144 8 .2746 60 | 201 8 .2654 61 | 248 8 .3059 62 | 313 8 .3829 63 | 10 9 .003 64 | 11 9 .0022 65 | 12 9 .0037 66 | 13 9 .0053 67 | 11 10 .0011 68 | 12 10 .001 69 | 13 10 .0023 70 | 13 11 .0033 71 | 13 12 .002 72 | 14 13 .0009 73 | 15 13 .0011 74 | 17 13 .006 75 | 18 13 .0055 76 | 19 13 .0095 77 | 20 13 .0112 78 | 21 13 .0098 79 | 22 16 .0311 80 | 19 17 .0038 81 | 20 17 .0053 82 | 20 19 .0018 83 | 26 22 .0281 84 | 24 23 .01 85 | 26 25 .0057 86 | 29 26 .0131 87 | 31 26 .017 88 | 33 26 .0322 89 | 47 26 .1216 90 | 35 28 .0425 91 | 31 29 .0125 92 | 33 29 .0247 93 | 47 29 .1134 94 | 34 30 .0184 95 | 32 31 .0043 96 | 33 32 .0113 97 | 47 33 .0895 98 | 36 35 .0232 99 | 38 36 .0549 100 | 38 37 .0478 101 | 47 39 .013 102 | 65 39 .0306 103 | 142 39 .0835 104 | 45 40 .018 105 | 46 40 .0167 106 | 50 40 .0133 107 | 55 40 .02 108 | 61 40 .0242 109 | 62 40 .0377 110 | 144 40 .0728 111 | 53 41 .0145 112 | 67 41 .0522 113 | 166 41 .0831 114 | 47 42 .0089 115 | 67 43 .0364 116 | 67 44 .0356 117 | 46 45 .0314 118 | 47 45 .0244 119 | 50 45 .0187 120 | 54 45 .0086 121 | 55 45 .0299 122 | 58 45 .0122 123 | 59 45 .0152 124 | 61 45 .0302 125 | 62 45 .049 126 | 63 45 .037 127 | 65 45 .0323 128 | 67 45 .1266 129 | 83 45 .0393 130 | 118 45 .1602 131 | 142 45 .0655 132 | 144 45 .0712 133 | 166 45 .0988 134 | 201 45 .0986 135 | 50 46 .0149 136 | 55 46 .009 137 | 62 46 .0215 138 | 63 46 .0163 139 | 67 46 .0958 140 | 144 46 .0641 141 | 48 47 .0107 142 | 49 47 .0154 143 | 54 47 .0273 144 | 56 47 .0122 145 | 58 47 .0288 146 | 59 47 .0197 147 | 60 47 .0242 148 | 65 47 .0178 149 | 67 47 .1502 150 | 74 47 .031 151 | 75 47 .032 152 | 83 47 .0483 153 | 86 47 .0555 154 | 94 47 .1804 155 | 109 47 .2658 156 | 112 47 .2046 157 | 118 47 .1828 158 | 123 47 .1841 159 | 131 47 .2139 160 | 142 47 .0714 161 | 144 47 .0825 162 | 147 47 .2535 163 | 150 47 .2555 164 | 152 47 .2244 165 | 166 47 .115 166 | 169 47 .077 167 | 172 47 .1606 168 | 176 47 .2078 169 | 177 47 .2424 170 | 182 47 .1826 171 | 183 47 .0836 172 | 197 47 .0929 173 | 201 47 .0938 174 | 203 47 .0963 175 | 219 47 .1144 176 | 230 47 .2408 177 | 232 47 .2028 178 | 245 47 .1281 179 | 246 47 .1301 180 | 248 47 .1305 181 | 253 47 .1334 182 | 255 47 .2336 183 | 258 47 .1437 184 | 261 47 .1894 185 | 263 47 .1429 186 | 293 47 .216 187 | 311 47 .2976 188 | 313 47 .3086 189 | 316 47 .3074 190 | 49 48 .0048 191 | 65 48 .0212 192 | 56 49 .0087 193 | 60 49 .0119 194 | 55 50 .0112 195 | 61 50 .0123 196 | 62 50 .0302 197 | 63 50 .0185 198 | 67 50 .1082 199 | 144 50 .0596 200 | 166 50 .0828 201 | 67 51 .0267 202 | 82 51 .0319 203 | 118 51 .0655 204 | 166 51 .0796 205 | 67 52 .0194 206 | 67 53 .0424 207 | 82 53 .0367 208 | 144 53 .0809 209 | 166 53 .0701 210 | 58 54 .0036 211 | 59 54 .0113 212 | 60 54 .0087 213 | 65 54 .0301 214 | 83 54 .0307 215 | 61 55 .0067 216 | 62 55 .0191 217 | 63 55 .0089 218 | 144 55 .0556 219 | 58 56 .018 220 | 59 56 .0078 221 | 60 56 .0124 222 | 65 56 .0141 223 | 69 57 .0198 224 | 70 57 .0199 225 | 71 57 .0197 226 | 118 57 .0435 227 | 59 58 .0107 228 | 60 58 .0069 229 | 65 58 .0294 230 | 83 58 .0271 231 | 60 59 .0046 232 | 64 59 .0057 233 | 65 59 .0189 234 | 83 59 .0297 235 | 144 59 .0638 236 | 64 60 .0048 237 | 65 60 .0225 238 | 62 61 .0201 239 | 63 61 .007 240 | 63 62 .0133 241 | 67 62 .0786 242 | 144 62 .0511 243 | 166 62 .0605 244 | 67 63 .0918 245 | 144 63 .0479 246 | 166 63 .0656 247 | 65 64 .0191 248 | 66 65 .0068 249 | 67 65 .1498 250 | 74 65 .0147 251 | 75 65 .0143 252 | 83 65 .0378 253 | 87 65 .0223 254 | 108 65 .0307 255 | 112 65 .2025 256 | 116 65 .0331 257 | 118 65 .18 258 | 142 65 .0572 259 | 144 65 .0709 260 | 151 65 .0485 261 | 166 65 .1059 262 | 169 65 .06 263 | 176 65 .2031 264 | 183 65 .066 265 | 197 65 .0752 266 | 201 65 .0761 267 | 203 65 .0787 268 | 219 65 .0984 269 | 245 65 .1109 270 | 246 65 .1131 271 | 248 65 .1133 272 | 253 65 .1163 273 | 255 65 .2254 274 | 258 65 .128 275 | 261 65 .1779 276 | 263 65 .1259 277 | 293 65 .2037 278 | 87 66 .0156 279 | 116 66 .0271 280 | 151 66 .0422 281 | 71 67 .0262 282 | 76 67 .0508 283 | 78 67 .01 284 | 79 67 .0138 285 | 82 67 .0219 286 | 90 67 .0258 287 | 94 67 .0327 288 | 95 67 .076 289 | 99 67 .0435 290 | 109 67 .1156 291 | 111 67 .0465 292 | 112 67 .0563 293 | 118 67 .0388 294 | 119 67 .1082 295 | 120 67 .0297 296 | 123 67 .0406 297 | 128 67 .0321 298 | 131 67 .0667 299 | 133 67 .0368 300 | 136 67 .1058 301 | 144 67 .1033 302 | 146 67 .1061 303 | 147 67 .105 304 | 150 67 .1069 305 | 152 67 .0783 306 | 153 67 .0829 307 | 159 67 .0703 308 | 161 67 .0661 309 | 162 67 .1033 310 | 166 67 .077 311 | 167 67 .0607 312 | 172 67 .054 313 | 174 67 .1004 314 | 176 67 .0708 315 | 177 67 .0983 316 | 179 67 .1005 317 | 182 67 .0604 318 | 183 67 .1562 319 | 189 67 .0745 320 | 197 67 .1629 321 | 201 67 .1639 322 | 203 67 .163 323 | 217 67 .09 324 | 219 67 .1399 325 | 230 67 .1115 326 | 232 67 .0954 327 | 233 67 .1163 328 | 246 67 .1617 329 | 248 67 .1655 330 | 253 67 .1651 331 | 255 67 .1163 332 | 258 67 .1454 333 | 261 67 .1162 334 | 263 67 .1685 335 | 274 67 .1518 336 | 292 67 .1431 337 | 293 67 .1428 338 | 296 67 .146 339 | 299 67 .1685 340 | 301 67 .1705 341 | 307 67 .1847 342 | 310 67 .1917 343 | 311 67 .1939 344 | 313 67 .3992 345 | 80 68 .0134 346 | 96 68 .0223 347 | 109 68 .0259 348 | 118 68 .1008 349 | 125 68 .0323 350 | 147 68 .0478 351 | 71 69 .0084 352 | 73 69 .0077 353 | 118 69 .0283 354 | 112 70 .0267 355 | 118 70 .0289 356 | 73 71 .003 357 | 77 71 .0053 358 | 90 71 .0143 359 | 111 71 .0252 360 | 112 71 .0326 361 | 118 71 .024 362 | 96 72 .0171 363 | 118 72 .0788 364 | 146 72 .0355 365 | 147 72 .0365 366 | 152 72 .0523 367 | 162 72 .0452 368 | 177 72 .0572 369 | 77 73 .0026 370 | 90 73 .0115 371 | 94 73 .0129 372 | 118 73 .022 373 | 75 74 .0106 374 | 116 74 .0203 375 | 201 74 .0637 376 | 87 75 .0085 377 | 108 75 .0168 378 | 116 75 .0202 379 | 151 75 .0348 380 | 166 75 .1021 381 | 201 75 .0622 382 | 82 76 .0325 383 | 98 76 .0204 384 | 118 76 .0797 385 | 142 76 .0717 386 | 144 76 .0551 387 | 94 77 .0104 388 | 118 77 .0195 389 | 79 78 .0063 390 | 118 78 .0298 391 | 118 79 .025 392 | 101 80 .0198 393 | 109 80 .0128 394 | 118 80 .091 395 | 146 80 .0329 396 | 147 80 .0344 397 | 152 80 .0588 398 | 162 80 .044 399 | 176 80 .0853 400 | 177 80 .0579 401 | 85 81 .0069 402 | 118 81 .1173 403 | 144 81 .0277 404 | 166 81 .0472 405 | 261 81 .1239 406 | 106 82 .0113 407 | 118 82 .0475 408 | 130 82 .0376 409 | 140 82 .0261 410 | 144 82 .0821 411 | 166 82 .0551 412 | 182 82 .0564 413 | 85 83 .0211 414 | 86 83 .0098 415 | 97 83 .0195 416 | 104 83 .0136 417 | 118 83 .145 418 | 144 83 .0343 419 | 166 83 .0681 420 | 201 83 .0649 421 | 248 83 .0927 422 | 112 84 .0131 423 | 118 84 .0245 424 | 97 85 .0064 425 | 144 85 .0261 426 | 261 85 .1272 427 | 144 86 .0286 428 | 118 88 .0142 429 | 174 89 .0386 430 | 94 90 .0075 431 | 112 90 .0317 432 | 118 90 .0133 433 | 152 90 .0528 434 | 166 90 .0855 435 | 182 90 .0422 436 | 92 91 .008 437 | 95 91 .0057 438 | 109 91 .0348 439 | 112 91 .0302 440 | 118 91 .0532 441 | 119 91 .0278 442 | 146 91 .0296 443 | 147 91 .0293 444 | 150 91 .0309 445 | 152 91 .0283 446 | 162 91 .0334 447 | 174 91 .038 448 | 176 91 .0527 449 | 177 91 .0399 450 | 179 91 .0408 451 | 221 91 .0694 452 | 230 91 .0773 453 | 255 91 .0968 454 | 95 92 .0135 455 | 109 92 .027 456 | 112 92 .0379 457 | 118 92 .0611 458 | 119 92 .0207 459 | 143 92 .0269 460 | 146 92 .025 461 | 147 92 .0251 462 | 150 92 .0264 463 | 152 92 .0326 464 | 162 92 .0312 465 | 174 92 .0377 466 | 176 92 .0584 467 | 179 92 .0409 468 | 182 92 .0837 469 | 221 92 .0704 470 | 230 92 .0793 471 | 255 92 .0998 472 | 299 92 .1427 473 | 100 93 .0047 474 | 118 93 .0232 475 | 152 93 .0297 476 | 161 93 .0294 477 | 99 94 .0121 478 | 109 94 .0863 479 | 112 94 .0242 480 | 118 94 .0093 481 | 131 94 .0341 482 | 146 94 .0745 483 | 147 94 .0733 484 | 152 94 .0456 485 | 159 94 .038 486 | 162 94 .0709 487 | 166 94 .0917 488 | 172 94 .0492 489 | 176 94 .0408 490 | 179 94 .0677 491 | 182 94 .042 492 | 201 94 .1827 493 | 217 94 .0655 494 | 219 94 .1535 495 | 230 94 .0819 496 | 232 94 .0762 497 | 248 94 .1776 498 | 255 94 .0906 499 | 258 94 .1528 500 | 261 94 .1068 501 | 263 94 .1782 502 | 299 94 .1427 503 | 301 94 .1457 504 | 310 94 .1661 505 | 311 94 .1684 506 | 101 95 .0252 507 | 109 95 .0398 508 | 112 95 .0245 509 | 118 95 .0476 510 | 119 95 .0323 511 | 131 95 .0216 512 | 146 95 .0323 513 | 147 95 .0317 514 | 150 95 .0335 515 | 152 95 .0246 516 | 162 95 .0343 517 | 174 95 .0375 518 | 176 95 .0479 519 | 177 95 .0387 520 | 179 95 .0399 521 | 221 95 .0675 522 | 230 95 .0747 523 | 255 95 .0934 524 | 299 95 .1392 525 | 118 96 .0844 526 | 147 96 .0255 527 | 152 96 .0505 528 | 162 96 .0351 529 | 177 96 .0489 530 | 104 97 .0105 531 | 144 98 .0347 532 | 112 99 .0128 533 | 118 99 .0149 534 | 147 99 .0616 535 | 152 99 .0354 536 | 161 99 .0293 537 | 167 99 .0305 538 | 176 99 .037 539 | 111 100 .0071 540 | 112 100 .0083 541 | 118 100 .0186 542 | 137 100 .0175 543 | 152 100 .0312 544 | 161 100 .0276 545 | 112 101 .0489 546 | 118 101 .0722 547 | 129 101 .012 548 | 146 101 .0189 549 | 152 101 .0392 550 | 162 101 .0285 551 | 174 101 .0372 552 | 176 101 .0657 553 | 179 101 .0407 554 | 221 101 .0704 555 | 230 101 .0807 556 | 255 101 .1024 557 | 107 102 .0087 558 | 118 102 .0236 559 | 120 102 .0073 560 | 182 102 .0378 561 | 115 103 .0038 562 | 152 103 .0271 563 | 144 104 .0206 564 | 123 105 .0248 565 | 167 105 .0306 566 | 182 105 .0512 567 | 118 106 .0434 568 | 120 106 .0243 569 | 140 106 .015 570 | 166 106 .05 571 | 182 106 .0465 572 | 118 107 .0149 573 | 116 108 .0062 574 | 201 108 .0455 575 | 112 109 .0629 576 | 118 109 .0862 577 | 119 109 .0095 578 | 131 109 .056 579 | 144 109 .2092 580 | 146 109 .0211 581 | 147 109 .0227 582 | 150 109 .0218 583 | 152 109 .0503 584 | 159 109 .0647 585 | 161 109 .0713 586 | 162 109 .0321 587 | 166 109 .1744 588 | 167 109 .0819 589 | 170 109 .0332 590 | 174 109 .042 591 | 176 109 .0765 592 | 177 109 .0464 593 | 179 109 .0455 594 | 182 109 .1045 595 | 201 109 .2655 596 | 202 109 .0565 597 | 203 109 .2638 598 | 212 109 .0586 599 | 217 109 .0997 600 | 219 109 .2328 601 | 221 109 .0736 602 | 230 109 .0851 603 | 232 109 .1193 604 | 248 109 .2545 605 | 255 109 .1078 606 | 258 109 .2256 607 | 261 109 .1605 608 | 273 109 .1108 609 | 293 109 .1714 610 | 299 109 .1431 611 | 301 109 .1495 612 | 306 109 .1568 613 | 307 109 .1608 614 | 310 109 .1625 615 | 311 109 .1652 616 | 321 109 .23 617 | 152 110 .0458 618 | 162 110 .0286 619 | 112 111 .0112 620 | 118 111 .0123 621 | 152 111 .0318 622 | 161 111 .0233 623 | 118 112 .0233 624 | 119 112 .0544 625 | 123 112 .0228 626 | 125 112 .0609 627 | 126 112 .0159 628 | 131 112 .0108 629 | 136 112 .0503 630 | 137 112 .0151 631 | 144 112 .1464 632 | 146 112 .0502 633 | 147 112 .049 634 | 149 112 .043 635 | 150 112 .051 636 | 152 112 .0229 637 | 157 112 .0387 638 | 159 112 .0213 639 | 161 112 .0225 640 | 162 112 .047 641 | 166 112 .1121 642 | 167 112 .0282 643 | 170 112 .0519 644 | 172 112 .0643 645 | 174 112 .0448 646 | 176 112 .031 647 | 177 112 .0434 648 | 179 112 .0454 649 | 182 112 .0487 650 | 189 112 .0405 651 | 201 112 .2035 652 | 202 112 .0545 653 | 203 112 .2019 654 | 212 112 .0626 655 | 217 112 .0606 656 | 219 112 .1723 657 | 221 112 .0649 658 | 230 112 .068 659 | 232 112 .0764 660 | 248 112 .1952 661 | 253 112 .1938 662 | 255 112 .0821 663 | 258 112 .1684 664 | 261 112 .1131 665 | 263 112 .1947 666 | 292 112 .122 667 | 293 112 .1319 668 | 296 112 .1346 669 | 299 112 .1321 670 | 301 112 .1362 671 | 305 112 .1417 672 | 306 112 .1493 673 | 307 112 .15 674 | 310 112 .1552 675 | 311 112 .1577 676 | 152 113 .0271 677 | 174 113 .0292 678 | 118 114 .0065 679 | 152 115 .0233 680 | 151 116 .0159 681 | 152 117 .0152 682 | 119 118 .0775 683 | 120 118 .0194 684 | 125 118 .0839 685 | 126 118 .0085 686 | 127 118 .0212 687 | 128 118 .0296 688 | 129 118 .0704 689 | 130 118 .0142 690 | 131 118 .0313 691 | 133 118 .0412 692 | 134 118 .0768 693 | 136 118 .0728 694 | 137 118 .0168 695 | 139 118 .0344 696 | 140 118 .0464 697 | 143 118 .0762 698 | 144 118 .1231 699 | 145 118 .02 700 | 146 118 .0724 701 | 147 118 .0711 702 | 148 118 .0155 703 | 149 118 .0648 704 | 150 118 .073 705 | 152 118 .0416 706 | 153 118 .0994 707 | 154 118 .0152 708 | 155 118 .0157 709 | 157 118 .0592 710 | 158 118 .0186 711 | 159 118 .0317 712 | 161 118 .0273 713 | 162 118 .0676 714 | 163 118 .0223 715 | 164 118 .0211 716 | 166 118 .089 717 | 167 118 .0231 718 | 168 118 .0993 719 | 169 118 .164 720 | 172 118 .0431 721 | 173 118 .1001 722 | 174 118 .0632 723 | 176 118 .0325 724 | 177 118 .0606 725 | 179 118 .0628 726 | 181 118 .0907 727 | 182 118 .0332 728 | 183 118 .1744 729 | 186 118 .0471 730 | 189 118 .0379 731 | 191 118 .0376 732 | 192 118 .0412 733 | 197 118 .1795 734 | 198 118 .0637 735 | 201 118 .1804 736 | 202 118 .0688 737 | 203 118 .1788 738 | 204 118 .06 739 | 212 118 .0769 740 | 216 118 .0686 741 | 217 118 .0562 742 | 218 118 .0693 743 | 219 118 .1498 744 | 221 118 .0745 745 | 222 118 .0622 746 | 225 118 .08 747 | 229 118 .0955 748 | 230 118 .0737 749 | 232 118 .067 750 | 233 118 .1161 751 | 240 118 .0703 752 | 246 118 .1691 753 | 248 118 .1733 754 | 249 118 .0842 755 | 250 118 .1652 756 | 253 118 .172 757 | 255 118 .0815 758 | 256 118 .0806 759 | 258 118 .1474 760 | 260 118 .0955 761 | 261 118 .0985 762 | 263 118 .1733 763 | 273 118 .1001 764 | 274 118 .1504 765 | 276 118 .1353 766 | 284 118 .1142 767 | 288 118 .1222 768 | 292 118 .1151 769 | 293 118 .1207 770 | 296 118 .1237 771 | 297 118 .1305 772 | 299 118 .1337 773 | 301 118 .1366 774 | 305 118 .1419 775 | 306 118 .1514 776 | 307 118 .1508 777 | 310 118 .157 778 | 311 118 .1594 779 | 313 118 .4076 780 | 321 118 .251 781 | 131 119 .047 782 | 146 119 .0126 783 | 147 119 .0141 784 | 150 119 .0136 785 | 152 119 .0408 786 | 162 119 .0236 787 | 174 119 .0333 788 | 176 119 .067 789 | 177 119 .0375 790 | 179 119 .0368 791 | 182 119 .095 792 | 217 119 .0904 793 | 221 119 .0657 794 | 230 119 .0768 795 | 255 119 .0992 796 | 261 119 .1511 797 | 299 119 .1363 798 | 301 119 .1425 799 | 306 119 .1505 800 | 310 119 .1563 801 | 311 119 .159 802 | 321 119 .2271 803 | 130 120 .0073 804 | 166 120 .07 805 | 182 120 .0307 806 | 258 120 .1312 807 | 142 121 .0629 808 | 163 122 .022 809 | 131 123 .0302 810 | 146 123 .0714 811 | 147 123 .07 812 | 152 123 .0403 813 | 159 123 .03 814 | 163 123 .021 815 | 166 123 .0894 816 | 167 123 .0211 817 | 172 123 .0427 818 | 179 123 .0613 819 | 182 123 .0319 820 | 217 123 .0544 821 | 219 123 .1499 822 | 255 123 .0796 823 | 261 123 .0973 824 | 273 123 .0981 825 | 296 123 .1222 826 | 299 123 .1317 827 | 301 123 .1347 828 | 302 123 .1351 829 | 307 123 .1489 830 | 310 123 .1551 831 | 138 124 .0076 832 | 201 124 .0408 833 | 131 125 .0532 834 | 147 125 .0171 835 | 152 125 .0463 836 | 162 125 .0263 837 | 170 125 .027 838 | 174 125 .0362 839 | 177 125 .0406 840 | 179 125 .0396 841 | 221 125 .0673 842 | 230 125 .0789 843 | 255 125 .1017 844 | 299 125 .1366 845 | 137 126 .0089 846 | 152 126 .0331 847 | 161 126 .0203 848 | 176 126 .0268 849 | 182 126 .0349 850 | 217 126 .0533 851 | 137 127 .0092 852 | 152 127 .021 853 | 255 127 .076 854 | 133 128 .0116 855 | 142 128 .1129 856 | 152 128 .0691 857 | 166 128 .0596 858 | 182 128 .0315 859 | 258 128 .1214 860 | 261 128 .0844 861 | 152 129 .0327 862 | 221 129 .0589 863 | 255 129 .0917 864 | 145 130 .0071 865 | 148 130 .0086 866 | 160 130 .0148 867 | 166 130 .0749 868 | 182 130 .0258 869 | 258 130 .1336 870 | 146 131 .0412 871 | 147 131 .0398 872 | 150 131 .0418 873 | 152 131 .0121 874 | 153 131 .1295 875 | 162 131 .0368 876 | 166 131 .1183 877 | 167 131 .0278 878 | 170 131 .0415 879 | 172 131 .0685 880 | 174 131 .034 881 | 176 131 .0267 882 | 177 131 .0325 883 | 179 131 .0346 884 | 182 131 .0502 885 | 189 131 .0367 886 | 201 131 .2095 887 | 212 131 .0519 888 | 217 131 .0562 889 | 219 131 .1771 890 | 221 131 .0552 891 | 230 131 .0594 892 | 248 131 .1993 893 | 255 131 .0753 894 | 258 131 .1715 895 | 261 131 .1123 896 | 292 131 .1172 897 | 293 131 .129 898 | 299 131 .124 899 | 301 131 .1284 900 | 305 131 .1339 901 | 306 131 .141 902 | 307 131 .1421 903 | 310 131 .1469 904 | 311 131 .1494 905 | 152 132 .0243 906 | 174 132 .0212 907 | 230 132 .0643 908 | 255 132 .0858 909 | 140 133 .0062 910 | 144 133 .0821 911 | 152 133 .0802 912 | 166 133 .0481 913 | 172 133 .0201 914 | 182 133 .0373 915 | 219 133 .1101 916 | 232 133 .0669 917 | 255 133 .0936 918 | 258 133 .1113 919 | 261 133 .0806 920 | 293 133 .1082 921 | 136 134 .0046 922 | 139 135 .005 923 | 152 136 .0337 924 | 139 137 .0191 925 | 152 137 .0257 926 | 161 137 .0114 927 | 176 137 .0187 928 | 255 137 .0702 929 | 183 138 .0253 930 | 201 138 .0333 931 | 161 139 .0171 932 | 162 139 .0331 933 | 167 139 .0272 934 | 177 139 .0277 935 | 230 139 .0546 936 | 166 140 .0426 937 | 182 140 .0382 938 | 157 141 .0083 939 | 177 141 .0183 940 | 144 142 .0194 941 | 166 142 .0564 942 | 172 142 .1083 943 | 198 142 .0983 944 | 216 142 .1105 945 | 225 142 .1062 946 | 229 142 .0895 947 | 233 142 .0724 948 | 261 142 .1219 949 | 293 142 .1469 950 | 147 143 .0055 951 | 152 143 .0364 952 | 162 143 .014 953 | 170 143 .0148 954 | 174 143 .0239 955 | 221 143 .0552 956 | 230 143 .0666 957 | 299 143 .1253 958 | 306 143 .1394 959 | 310 143 .1452 960 | 311 143 .1479 961 | 150 144 .1945 962 | 162 144 .1873 963 | 166 144 .0374 964 | 168 144 .028 965 | 169 144 .0416 966 | 172 144 .089 967 | 176 144 .1401 968 | 177 144 .1767 969 | 181 144 .0416 970 | 182 144 .1118 971 | 183 144 .0529 972 | 197 144 .0599 973 | 201 144 .0609 974 | 203 144 .0603 975 | 213 144 .0549 976 | 216 144 .0929 977 | 219 144 .0477 978 | 225 144 .0895 979 | 233 144 .0613 980 | 245 144 .0708 981 | 246 144 .0704 982 | 248 144 .0731 983 | 250 144 .0704 984 | 253 144 .0742 985 | 255 144 .156 986 | 258 144 .0702 987 | 261 144 .1071 988 | 263 144 .0813 989 | 274 144 .083 990 | 292 144 .1512 991 | 293 144 .1335 992 | 297 144 .1276 993 | 299 144 .1958 994 | 148 145 .0074 995 | 182 145 .0198 996 | 147 146 .0017 997 | 150 146 .0014 998 | 152 146 .0325 999 | 159 146 .0466 1000 | 161 146 .0533 1001 | 166 146 .1583 1002 | 167 146 .0641 1003 | 168 146 .1686 1004 | 172 146 .1071 1005 | 174 146 .0209 1006 | 176 146 .0574 1007 | 177 146 .0253 1008 | 179 146 .0244 1009 | 182 146 .0862 1010 | 189 146 .0654 1011 | 202 146 .0358 1012 | 212 146 .039 1013 | 217 146 .079 1014 | 218 146 .0543 1015 | 221 146 .0531 1016 | 230 146 .0642 1017 | 232 146 .0987 1018 | 237 146 .0704 1019 | 255 146 .0868 1020 | 260 146 .0815 1021 | 261 146 .14 1022 | 284 146 .106 1023 | 292 146 .1326 1024 | 293 146 .1503 1025 | 299 146 .1238 1026 | 301 146 .13 1027 | 305 146 .1353 1028 | 306 146 .1383 1029 | 307 146 .1418 1030 | 310 146 .1441 1031 | 311 146 .1468 1032 | 150 147 .002 1033 | 152 147 .031 1034 | 153 147 .1684 1035 | 159 147 .045 1036 | 161 147 .0517 1037 | 162 147 .0095 1038 | 166 147 .1567 1039 | 167 147 .0625 1040 | 172 147 .1055 1041 | 174 147 .0193 1042 | 176 147 .0558 1043 | 177 147 .0237 1044 | 179 147 .0229 1045 | 182 147 .0845 1046 | 201 147 .2472 1047 | 202 147 .0344 1048 | 212 147 .0378 1049 | 217 147 .0773 1050 | 218 147 .0528 1051 | 219 147 .2134 1052 | 221 147 .0517 1053 | 230 147 .0627 1054 | 232 147 .097 1055 | 248 147 .2344 1056 | 249 147 .0736 1057 | 255 147 .0852 1058 | 258 147 .2049 1059 | 260 147 .0802 1060 | 261 147 .1383 1061 | 273 147 .0893 1062 | 284 147 .1047 1063 | 292 147 .131 1064 | 293 147 .1487 1065 | 298 147 .1155 1066 | 299 147 .1226 1067 | 300 147 .1247 1068 | 301 147 .1288 1069 | 305 147 .134 1070 | 306 147 .1372 1071 | 307 147 .1406 1072 | 310 147 .143 1073 | 311 147 .1457 1074 | 320 147 .2151 1075 | 321 147 .2166 1076 | 322 147 .2187 1077 | 324 147 .2201 1078 | 325 147 .2247 1079 | 158 148 .0094 1080 | 166 148 .0779 1081 | 182 148 .0186 1082 | 152 149 .0244 1083 | 157 149 .008 1084 | 221 149 .0487 1085 | 230 149 .059 1086 | 255 149 .0811 1087 | 152 150 .0328 1088 | 159 150 .0468 1089 | 162 150 .0104 1090 | 166 150 .1586 1091 | 167 150 .0643 1092 | 174 150 .0203 1093 | 176 150 .0574 1094 | 177 150 .0247 1095 | 179 150 .0238 1096 | 182 150 .0863 1097 | 201 150 .2491 1098 | 202 150 .0349 1099 | 212 150 .0379 1100 | 219 150 .2151 1101 | 221 150 .0522 1102 | 230 150 .0634 1103 | 248 150 .236 1104 | 255 150 .086 1105 | 258 150 .2065 1106 | 261 150 .1396 1107 | 263 150 .2335 1108 | 292 150 .1318 1109 | 296 150 .1517 1110 | 299 150 .1227 1111 | 301 150 .1289 1112 | 306 150 .1371 1113 | 310 150 .1429 1114 | 311 150 .1456 1115 | 320 150 .2143 1116 | 321 150 .2158 1117 | 322 150 .2178 1118 | 324 150 .2193 1119 | 325 150 .2238 1120 | 201 151 .0276 1121 | 156 152 .0277 1122 | 157 152 .0178 1123 | 158 152 .0412 1124 | 159 152 .0144 1125 | 161 152 .0211 1126 | 162 152 .0261 1127 | 166 152 .1258 1128 | 167 152 .0318 1129 | 170 152 .0305 1130 | 172 152 .0747 1131 | 174 152 .0221 1132 | 176 152 .0265 1133 | 177 152 .0204 1134 | 179 152 .0225 1135 | 182 152 .0542 1136 | 183 152 .2114 1137 | 186 152 .0214 1138 | 187 152 .0235 1139 | 189 152 .0357 1140 | 191 152 .0439 1141 | 192 152 .0323 1142 | 198 152 .0917 1143 | 201 152 .2163 1144 | 202 152 .0321 1145 | 210 152 .0373 1146 | 212 152 .04 1147 | 215 152 .04 1148 | 217 152 .0531 1149 | 218 152 .042 1150 | 219 152 .1828 1151 | 221 152 .0447 1152 | 222 152 .0486 1153 | 230 152 .0505 1154 | 232 152 .0719 1155 | 233 152 .1441 1156 | 248 152 .2042 1157 | 252 152 .0654 1158 | 255 152 .0688 1159 | 256 152 .0741 1160 | 258 152 .1754 1161 | 261 152 .1122 1162 | 263 152 .2023 1163 | 284 152 .0958 1164 | 288 152 .1319 1165 | 292 152 .1126 1166 | 293 152 .1265 1167 | 297 152 .14 1168 | 299 152 .1153 1169 | 301 152 .1202 1170 | 305 152 .1257 1171 | 306 152 .1319 1172 | 307 152 .1335 1173 | 310 152 .1378 1174 | 311 152 .1404 1175 | 166 153 .0137 1176 | 203 153 .0806 1177 | 248 153 .0845 1178 | 261 153 .0892 1179 | 293 153 .1171 1180 | 164 154 .0062 1181 | 171 154 .0128 1182 | 182 154 .018 1183 | 171 155 .0094 1184 | 177 157 .0125 1185 | 221 157 .0426 1186 | 230 157 .0521 1187 | 255 157 .0738 1188 | 161 158 .0207 1189 | 171 158 .0074 1190 | 182 158 .0162 1191 | 161 159 .0068 1192 | 162 159 .039 1193 | 166 159 .112 1194 | 167 159 .0175 1195 | 174 159 .0327 1196 | 176 159 .0128 1197 | 179 159 .0317 1198 | 182 159 .0398 1199 | 217 159 .0417 1200 | 219 159 .1685 1201 | 230 159 .0468 1202 | 248 159 .1898 1203 | 255 159 .0612 1204 | 258 159 .161 1205 | 261 159 .099 1206 | 293 159 .1148 1207 | 299 159 .1108 1208 | 301 159 .1149 1209 | 182 160 .0122 1210 | 162 161 .0457 1211 | 167 161 .0108 1212 | 174 161 .039 1213 | 179 161 .0378 1214 | 182 161 .0332 1215 | 189 161 .0183 1216 | 191 161 .0245 1217 | 230 161 .0478 1218 | 248 161 .1832 1219 | 255 161 .0598 1220 | 261 161 .0936 1221 | 299 161 .1106 1222 | 301 161 .1143 1223 | 166 162 .151 1224 | 167 162 .0563 1225 | 172 162 .0994 1226 | 174 162 .0099 1227 | 176 162 .0486 1228 | 177 162 .0144 1229 | 179 162 .0134 1230 | 182 162 .0777 1231 | 201 162 .241 1232 | 202 162 .025 1233 | 212 162 .0289 1234 | 217 162 .0684 1235 | 218 162 .0433 1236 | 219 162 .2065 1237 | 221 162 .0423 1238 | 230 162 .0532 1239 | 232 162 .0881 1240 | 237 162 .0593 1241 | 243 162 .0552 1242 | 248 162 .227 1243 | 249 162 .0641 1244 | 252 162 .0621 1245 | 255 162 .0757 1246 | 258 162 .1971 1247 | 261 162 .1295 1248 | 273 162 .08 1249 | 284 162 .0954 1250 | 288 162 .1465 1251 | 292 162 .1215 1252 | 293 162 .1393 1253 | 299 162 .1136 1254 | 301 162 .1196 1255 | 305 162 .1249 1256 | 306 162 .1284 1257 | 307 162 .1317 1258 | 310 162 .1342 1259 | 311 162 .1369 1260 | 321 162 .2101 1261 | 182 163 .011 1262 | 182 164 .0128 1263 | 201 165 .021 1264 | 167 166 .0947 1265 | 168 166 .0104 1266 | 169 166 .0758 1267 | 172 166 .052 1268 | 173 166 .0114 1269 | 174 166 .1438 1270 | 175 166 .0195 1271 | 176 166 .1032 1272 | 177 166 .1399 1273 | 181 166 .0094 1274 | 182 166 .0746 1275 | 183 166 .0857 1276 | 184 166 .0158 1277 | 197 166 .0905 1278 | 198 166 .0431 1279 | 201 166 .0914 1280 | 203 166 .0898 1281 | 206 166 .0634 1282 | 213 166 .0808 1283 | 216 166 .0572 1284 | 217 166 .0991 1285 | 219 166 .0631 1286 | 225 166 .056 1287 | 230 166 .1295 1288 | 232 166 .0883 1289 | 233 166 .0461 1290 | 242 166 .0917 1291 | 245 166 .0869 1292 | 246 166 .0847 1293 | 248 166 .0886 1294 | 250 166 .0821 1295 | 251 166 .0883 1296 | 253 166 .0881 1297 | 255 166 .1195 1298 | 258 166 .0706 1299 | 261 166 .077 1300 | 263 166 .0919 1301 | 274 166 .0794 1302 | 276 166 .0766 1303 | 288 166 .0977 1304 | 292 166 .1196 1305 | 293 166 .1055 1306 | 296 166 .1085 1307 | 297 166 .1032 1308 | 299 166 .1618 1309 | 301 166 .1603 1310 | 311 166 .1832 1311 | 313 166 .3226 1312 | 172 167 .0431 1313 | 174 167 .0493 1314 | 176 167 .0105 1315 | 179 167 .0478 1316 | 182 167 .0227 1317 | 189 167 .015 1318 | 191 167 .0172 1319 | 201 167 .1848 1320 | 217 167 .0344 1321 | 219 167 .151 1322 | 230 167 .0509 1323 | 232 167 .0484 1324 | 248 167 .1726 1325 | 255 167 .0588 1326 | 258 167 .1441 1327 | 261 167 .0851 1328 | 292 167 .0951 1329 | 293 167 .1038 1330 | 299 167 .1108 1331 | 301 167 .1138 1332 | 302 167 .1143 1333 | 307 167 .128 1334 | 310 167 .1341 1335 | 311 167 .1365 1336 | 248 168 .08 1337 | 258 168 .0647 1338 | 261 168 .0818 1339 | 183 169 .012 1340 | 197 169 .0211 1341 | 201 169 .0222 1342 | 203 169 .0232 1343 | 213 169 .026 1344 | 219 169 .0402 1345 | 246 169 .0532 1346 | 248 169 .0536 1347 | 258 169 .0701 1348 | 261 169 .1319 1349 | 263 169 .0659 1350 | 174 172 .0919 1351 | 176 172 .0513 1352 | 179 172 .0901 1353 | 182 172 .0228 1354 | 198 172 .0209 1355 | 201 172 .1417 1356 | 216 172 .0302 1357 | 217 172 .0509 1358 | 219 172 .1085 1359 | 225 172 .0401 1360 | 230 172 .0802 1361 | 232 172 .0472 1362 | 248 172 .131 1363 | 255 172 .0752 1364 | 258 172 .1044 1365 | 261 172 .0623 1366 | 263 172 .1305 1367 | 293 172 .0891 1368 | 299 172 .1243 1369 | 248 173 .0774 1370 | 261 173 .0786 1371 | 176 174 .0407 1372 | 177 174 .0046 1373 | 182 174 .0698 1374 | 201 174 .2332 1375 | 202 174 .0162 1376 | 204 174 .0243 1377 | 212 174 .0219 1378 | 217 174 .0587 1379 | 218 174 .0338 1380 | 219 174 .1982 1381 | 221 174 .0333 1382 | 230 174 .0436 1383 | 248 174 .2183 1384 | 255 174 .0659 1385 | 258 174 .1882 1386 | 261 174 .1198 1387 | 273 174 .0712 1388 | 284 174 .0867 1389 | 292 174 .1117 1390 | 293 174 .1294 1391 | 299 174 .1053 1392 | 301 174 .1111 1393 | 305 174 .1164 1394 | 306 174 .1206 1395 | 307 174 .1234 1396 | 310 174 .1264 1397 | 311 174 .1291 1398 | 321 174 .2054 1399 | 179 176 .0388 1400 | 182 176 .0291 1401 | 189 176 .01 1402 | 192 176 .0096 1403 | 201 176 .1925 1404 | 202 176 .0402 1405 | 212 176 .0478 1406 | 217 176 .0297 1407 | 218 176 .0371 1408 | 219 176 .1579 1409 | 221 176 .0426 1410 | 222 176 .0311 1411 | 230 176 .0412 1412 | 232 176 .0468 1413 | 233 176 .1182 1414 | 239 176 .0564 1415 | 248 176 .1786 1416 | 255 176 .0516 1417 | 256 176 .0535 1418 | 258 176 .1493 1419 | 261 176 .0862 1420 | 263 176 .1762 1421 | 271 176 .0701 1422 | 284 176 .0831 1423 | 292 176 .0911 1424 | 296 176 .1048 1425 | 299 176 .1028 1426 | 301 176 .1064 1427 | 305 176 .1118 1428 | 306 176 .1204 1429 | 307 176 .1204 1430 | 310 176 .1261 1431 | 311 176 .1285 1432 | 179 177 .0022 1433 | 182 177 .0657 1434 | 190 177 .0092 1435 | 201 177 .229 1436 | 202 177 .0137 1437 | 204 177 .02 1438 | 212 177 .0206 1439 | 218 177 .03 1440 | 230 177 .0398 1441 | 240 177 .0627 1442 | 248 177 .2139 1443 | 255 177 .0618 1444 | 258 177 .1837 1445 | 260 177 .0592 1446 | 261 177 .1152 1447 | 263 177 .2108 1448 | 273 177 .0678 1449 | 292 177 .1075 1450 | 293 177 .125 1451 | 299 177 .1023 1452 | 301 177 .1079 1453 | 310 177 .1238 1454 | 311 177 .1264 1455 | 321 177 .2044 1456 | 201 178 .0171 1457 | 182 179 .0678 1458 | 189 179 .0448 1459 | 217 179 .0556 1460 | 218 179 .0302 1461 | 221 179 .0299 1462 | 230 179 .0401 1463 | 232 179 .0753 1464 | 255 179 .0624 1465 | 260 179 .0589 1466 | 261 179 .1167 1467 | 292 179 .1082 1468 | 293 179 .126 1469 | 299 179 .1019 1470 | 301 179 .1077 1471 | 306 179 .1173 1472 | 307 179 .1201 1473 | 310 179 .1232 1474 | 311 179 .1259 1475 | 182 180 .0095 1476 | 182 181 .073 1477 | 188 181 .005 1478 | 219 181 .0593 1479 | 258 181 .0634 1480 | 261 181 .0687 1481 | 189 182 .0241 1482 | 191 182 .0159 1483 | 192 182 .0301 1484 | 195 182 .0118 1485 | 196 182 .0133 1486 | 198 182 .0375 1487 | 201 182 .1634 1488 | 203 182 .1613 1489 | 206 182 .0211 1490 | 207 182 .0151 1491 | 209 182 .0261 1492 | 211 182 .0181 1493 | 212 182 .0743 1494 | 216 182 .0373 1495 | 217 182 .0313 1496 | 219 182 .1288 1497 | 220 182 .0327 1498 | 221 182 .065 1499 | 222 182 .0429 1500 | 225 182 .0489 1501 | 226 182 .0385 1502 | 230 182 .0586 1503 | 232 182 .0355 1504 | 233 182 .0901 1505 | 239 182 .0396 1506 | 246 182 .1458 1507 | 248 182 .15 1508 | 250 182 .1413 1509 | 253 182 .1483 1510 | 255 182 .0574 1511 | 256 182 .0529 1512 | 258 182 .1214 1513 | 261 182 .0655 1514 | 263 182 .1482 1515 | 274 182 .1225 1516 | 284 182 .0904 1517 | 292 182 .0837 1518 | 293 182 .0876 1519 | 296 182 .0905 1520 | 297 182 .0973 1521 | 299 182 .1088 1522 | 301 182 .1105 1523 | 305 182 .1155 1524 | 306 182 .1266 1525 | 307 182 .1246 1526 | 310 182 .1319 1527 | 311 182 .134 1528 | 313 182 .3823 1529 | 321 182 .2303 1530 | 197 183 .0099 1531 | 201 183 .011 1532 | 203 183 .0129 1533 | 213 183 .0208 1534 | 219 183 .0412 1535 | 224 183 .0338 1536 | 245 183 .0461 1537 | 246 183 .0489 1538 | 248 183 .0483 1539 | 253 183 .0516 1540 | 258 183 .0701 1541 | 261 183 .1369 1542 | 263 183 .0613 1543 | 293 183 .1575 1544 | 255 184 .1237 1545 | 261 184 .0738 1546 | 293 184 .1008 1547 | 201 185 .0088 1548 | 192 186 .0156 1549 | 230 186 .0304 1550 | 192 189 .0059 1551 | 217 189 .0202 1552 | 230 189 .0374 1553 | 232 189 .0369 1554 | 255 189 .0438 1555 | 256 189 .0444 1556 | 261 189 .0765 1557 | 202 190 .0083 1558 | 230 190 .0307 1559 | 204 192 .0246 1560 | 217 192 .0211 1561 | 222 192 .0215 1562 | 255 192 .042 1563 | 201 193 .0063 1564 | 221 194 .0207 1565 | 207 195 .0058 1566 | 201 197 .0011 1567 | 213 197 .0156 1568 | 219 197 .0392 1569 | 245 197 .0389 1570 | 246 197 .0422 1571 | 248 197 .041 1572 | 250 197 .0472 1573 | 251 197 .0427 1574 | 253 197 .0445 1575 | 258 197 .0662 1576 | 261 197 .1363 1577 | 263 197 .0541 1578 | 206 198 .021 1579 | 216 198 .0159 1580 | 219 198 .0915 1581 | 232 198 .0453 1582 | 258 198 .0845 1583 | 261 198 .0455 1584 | 201 199 .0072 1585 | 213 199 .0103 1586 | 213 200 .0093 1587 | 203 201 .0033 1588 | 205 201 .0099 1589 | 213 201 .0157 1590 | 214 201 .0103 1591 | 217 201 .1823 1592 | 219 201 .0395 1593 | 224 201 .0268 1594 | 228 201 .0244 1595 | 230 201 .2122 1596 | 232 201 .1667 1597 | 233 201 .0839 1598 | 236 201 .0278 1599 | 242 201 .0331 1600 | 245 201 .0385 1601 | 246 201 .0419 1602 | 248 201 .0406 1603 | 250 201 .047 1604 | 251 201 .0423 1605 | 253 201 .0441 1606 | 255 201 .1969 1607 | 258 201 .0662 1608 | 261 201 .1367 1609 | 263 201 .0536 1610 | 276 201 .0986 1611 | 292 201 .179 1612 | 293 201 .1558 1613 | 299 201 .2268 1614 | 311 201 .2423 1615 | 313 201 .2389 1616 | 316 201 .2359 1617 | 318 201 .2421 1618 | 204 202 .0136 1619 | 212 202 .0082 1620 | 217 202 .0495 1621 | 221 202 .0173 1622 | 230 202 .0287 1623 | 255 202 .0517 1624 | 261 202 .1097 1625 | 219 203 .0366 1626 | 245 203 .0353 1627 | 246 203 .0386 1628 | 248 203 .0373 1629 | 250 203 .0436 1630 | 251 203 .039 1631 | 253 203 .0408 1632 | 258 203 .0629 1633 | 261 203 .1338 1634 | 263 203 .0504 1635 | 288 203 .1407 1636 | 293 203 .1527 1637 | 311 203 .2392 1638 | 218 204 .0118 1639 | 230 204 .0208 1640 | 213 205 .0064 1641 | 209 206 .0058 1642 | 239 206 .0247 1643 | 261 206 .0456 1644 | 258 208 .0416 1645 | 217 212 .0539 1646 | 221 212 .0164 1647 | 230 212 .029 1648 | 255 212 .0521 1649 | 261 212 .1128 1650 | 245 213 .0255 1651 | 246 213 .0282 1652 | 248 213 .0279 1653 | 253 213 .0311 1654 | 261 213 .1213 1655 | 263 213 .0407 1656 | 248 214 .0308 1657 | 230 215 .0141 1658 | 255 215 .0289 1659 | 225 216 .0116 1660 | 232 216 .032 1661 | 239 216 .0234 1662 | 255 216 .0633 1663 | 258 216 .0862 1664 | 261 216 .0321 1665 | 262 216 .0324 1666 | 293 216 .0595 1667 | 296 216 .0627 1668 | 218 217 .0344 1669 | 219 217 .145 1670 | 221 217 .0403 1671 | 222 217 .014 1672 | 230 217 .0305 1673 | 232 217 .0197 1674 | 239 217 .0312 1675 | 240 217 .0142 1676 | 248 217 .1628 1677 | 255 217 .0263 1678 | 256 217 .0245 1679 | 258 217 .1315 1680 | 261 217 .0611 1681 | 263 217 .1586 1682 | 271 217 .0404 1683 | 281 217 .0525 1684 | 283 217 .058 1685 | 284 217 .0595 1686 | 286 217 .0541 1687 | 292 217 .0614 1688 | 293 217 .0734 1689 | 296 217 .0758 1690 | 297 217 .087 1691 | 299 217 .0784 1692 | 301 217 .0807 1693 | 306 217 .0963 1694 | 307 217 .095 1695 | 310 217 .1017 1696 | 311 217 .1039 1697 | 321 217 .1991 1698 | 221 218 .0062 1699 | 230 218 .0099 1700 | 235 218 .0118 1701 | 249 218 .0215 1702 | 255 218 .0327 1703 | 261 218 .0923 1704 | 299 218 .0735 1705 | 230 219 .1744 1706 | 232 219 .1286 1707 | 233 219 .0446 1708 | 245 219 .0243 1709 | 246 219 .023 1710 | 248 219 .0263 1711 | 250 219 .0227 1712 | 253 219 .0268 1713 | 255 219 .1582 1714 | 258 219 .0299 1715 | 261 219 .0972 1716 | 263 219 .0336 1717 | 274 219 .0436 1718 | 276 219 .0606 1719 | 293 219 .1165 1720 | 297 219 .1056 1721 | 299 219 .1873 1722 | 301 219 .1834 1723 | 313 219 .2595 1724 | 226 220 .0064 1725 | 261 220 .0331 1726 | 230 221 .0127 1727 | 237 221 .0197 1728 | 243 221 .0159 1729 | 249 221 .0219 1730 | 252 221 .021 1731 | 255 221 .0357 1732 | 260 221 .0291 1733 | 261 221 .0972 1734 | 273 221 .0379 1735 | 284 221 .0535 1736 | 298 221 .065 1737 | 299 221 .0722 1738 | 300 221 .0749 1739 | 301 221 .0779 1740 | 305 221 .0832 1741 | 306 221 .088 1742 | 307 221 .0905 1743 | 310 221 .094 1744 | 311 221 .0966 1745 | 321 221 .1789 1746 | 322 221 .1818 1747 | 325 221 .1877 1748 | 230 222 .0165 1749 | 232 222 .0314 1750 | 255 222 .0208 1751 | 261 222 .072 1752 | 230 223 .0084 1753 | 255 223 .0196 1754 | 242 224 .0104 1755 | 261 224 .1148 1756 | 232 225 .0389 1757 | 255 225 .0691 1758 | 258 225 .0757 1759 | 261 225 .0241 1760 | 262 225 .0246 1761 | 293 225 .053 1762 | 296 225 .0562 1763 | 261 226 .0271 1764 | 230 227 .017 1765 | 236 228 .0033 1766 | 233 229 .0251 1767 | 254 229 .0148 1768 | 261 229 .0325 1769 | 262 229 .0335 1770 | 274 229 .0555 1771 | 232 230 .046 1772 | 234 230 .0217 1773 | 235 230 .0108 1774 | 237 230 .0071 1775 | 238 230 .0174 1776 | 239 230 .0576 1777 | 240 230 .0301 1778 | 243 230 .0179 1779 | 248 230 .1911 1780 | 249 230 .0122 1781 | 252 230 .0179 1782 | 255 230 .0232 1783 | 256 230 .0335 1784 | 258 230 .1591 1785 | 259 230 .0184 1786 | 260 230 .0226 1787 | 261 230 .0849 1788 | 263 230 .186 1789 | 272 230 .0393 1790 | 273 230 .0295 1791 | 281 230 .0569 1792 | 284 230 .0453 1793 | 286 230 .0553 1794 | 287 230 .0491 1795 | 288 230 .0974 1796 | 292 230 .0688 1797 | 293 230 .0887 1798 | 295 230 .0532 1799 | 297 230 .1044 1800 | 298 230 .0577 1801 | 299 230 .0648 1802 | 300 230 .068 1803 | 301 230 .0697 1804 | 305 230 .0751 1805 | 306 230 .0816 1806 | 307 230 .083 1807 | 310 230 .0875 1808 | 311 230 .0901 1809 | 321 230 .1774 1810 | 258 231 .0164 1811 | 239 232 .0118 1812 | 240 232 .0169 1813 | 248 232 .1451 1814 | 255 232 .0313 1815 | 256 232 .0218 1816 | 258 232 .1133 1817 | 261 232 .0414 1818 | 283 232 .0435 1819 | 285 232 .047 1820 | 288 232 .06 1821 | 292 232 .0483 1822 | 296 232 .0582 1823 | 297 232 .0681 1824 | 299 232 .0771 1825 | 301 232 .0775 1826 | 307 232 .0915 1827 | 310 232 .0994 1828 | 311 232 .1013 1829 | 248 233 .0609 1830 | 254 233 .0277 1831 | 255 233 .1138 1832 | 258 233 .0315 1833 | 261 233 .0529 1834 | 262 233 .0539 1835 | 263 233 .0581 1836 | 274 233 .0355 1837 | 275 233 .0372 1838 | 276 233 .0309 1839 | 293 233 .075 1840 | 296 233 .0773 1841 | 255 234 .0139 1842 | 255 237 .0165 1843 | 255 239 .041 1844 | 261 239 .0302 1845 | 262 239 .0297 1846 | 255 240 .0152 1847 | 256 240 .0103 1848 | 261 240 .0549 1849 | 261 241 .0175 1850 | 245 242 .0078 1851 | 248 242 .0087 1852 | 251 242 .0104 1853 | 261 242 .117 1854 | 248 244 .0048 1855 | 246 245 .0041 1856 | 248 245 .0025 1857 | 253 245 .0056 1858 | 258 245 .0331 1859 | 261 245 .1093 1860 | 263 245 .0152 1861 | 250 246 .0059 1862 | 253 246 .0039 1863 | 255 246 .1689 1864 | 258 246 .029 1865 | 261 246 .1053 1866 | 263 246 .0128 1867 | 293 246 .1198 1868 | 313 246 .2387 1869 | 261 247 .0128 1870 | 250 248 .0097 1871 | 253 248 .0037 1872 | 255 248 .173 1873 | 258 248 .0329 1874 | 261 248 .1093 1875 | 263 248 .0131 1876 | 274 248 .0418 1877 | 276 248 .0645 1878 | 292 248 .1482 1879 | 293 248 .1233 1880 | 299 248 .1961 1881 | 301 248 .1914 1882 | 311 248 .2091 1883 | 313 248 .2347 1884 | 316 248 .2303 1885 | 318 248 .2349 1886 | 331 248 .5326 1887 | 255 249 .0171 1888 | 259 249 .007 1889 | 260 249 .0113 1890 | 258 250 .0232 1891 | 261 250 .0996 1892 | 263 250 .0109 1893 | 274 250 .0327 1894 | 258 251 .0314 1895 | 261 251 .1079 1896 | 255 253 .1703 1897 | 258 253 .0299 1898 | 261 253 .1064 1899 | 263 253 .0097 1900 | 293 253 .12 1901 | 261 254 .0255 1902 | 262 254 .0265 1903 | 276 254 .0292 1904 | 288 254 .0384 1905 | 256 255 .0118 1906 | 258 255 .1405 1907 | 259 255 .0128 1908 | 260 255 .0234 1909 | 261 255 .0646 1910 | 263 255 .167 1911 | 265 255 .0111 1912 | 266 255 .0403 1913 | 267 255 .0493 1914 | 271 255 .0315 1915 | 272 255 .0163 1916 | 273 255 .0219 1917 | 276 255 .1132 1918 | 281 255 .0342 1919 | 283 255 .0453 1920 | 284 255 .0332 1921 | 286 255 .0334 1922 | 287 255 .031 1923 | 288 255 .0748 1924 | 289 255 .0334 1925 | 292 255 .0458 1926 | 293 255 .0657 1927 | 295 255 .0394 1928 | 296 255 .0672 1929 | 297 255 .0816 1930 | 298 255 .046 1931 | 299 255 .0523 1932 | 300 255 .0564 1933 | 301 255 .0551 1934 | 305 255 .0605 1935 | 306 255 .0701 1936 | 307 255 .0692 1937 | 310 255 .0756 1938 | 311 255 .0779 1939 | 313 255 .3924 1940 | 321 255 .1729 1941 | 322 255 .1767 1942 | 261 256 .0528 1943 | 267 256 .0377 1944 | 271 256 .0208 1945 | 272 256 .0122 1946 | 292 256 .0385 1947 | 296 256 .0574 1948 | 299 256 .0563 1949 | 261 257 .0164 1950 | 261 258 .0765 1951 | 263 258 .0271 1952 | 264 258 .0151 1953 | 274 258 .0137 1954 | 276 258 .0327 1955 | 288 258 .0788 1956 | 293 258 .0911 1957 | 296 258 .0926 1958 | 297 258 .0784 1959 | 299 258 .1635 1960 | 313 258 .2609 1961 | 273 259 .0125 1962 | 273 260 .0095 1963 | 284 260 .0245 1964 | 263 261 .1026 1965 | 266 261 .0257 1966 | 267 261 .0169 1967 | 268 261 .0143 1968 | 269 261 .0127 1969 | 270 261 .0097 1970 | 271 261 .0358 1971 | 273 261 .0809 1972 | 274 261 .0712 1973 | 275 261 .0278 1974 | 276 261 .0487 1975 | 277 261 .0122 1976 | 278 261 .028 1977 | 279 261 .0229 1978 | 280 261 .0176 1979 | 281 261 .0495 1980 | 282 261 .0223 1981 | 283 261 .0375 1982 | 284 261 .0815 1983 | 286 261 .0553 1984 | 288 261 .025 1985 | 290 261 .0363 1986 | 291 261 .0328 1987 | 292 261 .0442 1988 | 293 261 .0291 1989 | 294 261 .032 1990 | 296 261 .0323 1991 | 297 261 .0329 1992 | 299 261 .0907 1993 | 301 261 .0876 1994 | 303 261 .049 1995 | 304 261 .0526 1996 | 305 261 .0905 1997 | 306 261 .1047 1998 | 307 261 .0987 1999 | 308 261 .0638 2000 | 309 261 .0645 2001 | 310 261 .1079 2002 | 311 261 .109 2003 | 313 261 .3292 2004 | 321 261 .2098 2005 | 275 262 .0286 2006 | 276 262 .0495 2007 | 288 262 .0248 2008 | 292 262 .0433 2009 | 293 262 .0285 2010 | 296 262 .0316 2011 | 297 262 .0327 2012 | 276 263 .0557 2013 | 293 263 .1143 2014 | 311 263 .1993 2015 | 313 263 .2344 2016 | 267 266 .0091 2017 | 271 266 .0102 2018 | 272 266 .0289 2019 | 281 266 .0261 2020 | 286 266 .0315 2021 | 271 267 .0191 2022 | 281 267 .033 2023 | 283 267 .0228 2024 | 270 269 .0035 2025 | 281 271 .018 2026 | 283 271 .0177 2027 | 283 272 .0295 2028 | 284 273 .0158 2029 | 299 273 .0353 2030 | 322 273 .1556 2031 | 325 273 .1612 2032 | 293 274 .082 2033 | 276 275 .0213 2034 | 288 275 .0278 2035 | 296 275 .0415 2036 | 288 276 .0465 2037 | 293 276 .0588 2038 | 296 276 .0601 2039 | 297 276 .0458 2040 | 283 281 .0149 2041 | 286 281 .0058 2042 | 293 281 .0368 2043 | 293 283 .022 2044 | 286 284 .028 2045 | 293 284 .0698 2046 | 298 284 .013 2047 | 299 284 .0198 2048 | 300 284 .0235 2049 | 306 284 .0373 2050 | 311 284 .0455 2051 | 293 286 .0419 2052 | 299 286 .0357 2053 | 301 286 .0337 2054 | 299 287 .0244 2055 | 301 287 .025 2056 | 311 287 .0486 2057 | 293 288 .0124 2058 | 296 288 .0138 2059 | 297 288 .0083 2060 | 308 288 .0389 2061 | 293 290 .0172 2062 | 293 292 .0259 2063 | 296 292 .0258 2064 | 297 292 .042 2065 | 299 292 .048 2066 | 301 292 .0439 2067 | 310 292 .0637 2068 | 311 292 .0647 2069 | 296 293 .0032 2070 | 297 293 .0165 2071 | 299 293 .073 2072 | 301 293 .0681 2073 | 303 293 .0238 2074 | 305 293 .0696 2075 | 306 293 .0838 2076 | 307 293 .0766 2077 | 308 293 .0378 2078 | 309 293 .0393 2079 | 310 293 .086 2080 | 311 293 .0865 2081 | 297 296 .0162 2082 | 303 296 .0212 2083 | 308 296 .0349 2084 | 311 296 .0848 2085 | 308 297 .0319 2086 | 299 298 .0072 2087 | 300 298 .0105 2088 | 301 299 .0076 2089 | 305 299 .0118 2090 | 306 299 .0178 2091 | 307 299 .0182 2092 | 310 299 .0234 2093 | 311 299 .0257 2094 | 321 299 .1232 2095 | 322 299 .1274 2096 | 325 299 .1326 2097 | 301 300 .0097 2098 | 307 300 .016 2099 | 305 301 .0055 2100 | 306 301 .0175 2101 | 307 301 .0142 2102 | 310 301 .0219 2103 | 311 301 .0237 2104 | 321 301 .124 2105 | 310 302 .0218 2106 | 307 305 .0092 2107 | 311 305 .0192 2108 | 307 306 .0086 2109 | 310 306 .0059 2110 | 311 306 .0086 2111 | 310 307 .0094 2112 | 311 307 .0103 2113 | 311 310 .0027 2114 | 320 311 .0968 2115 | 321 311 .1011 2116 | 322 311 .1057 2117 | 324 311 .1021 2118 | 325 311 .1105 2119 | 313 312 .0096 2120 | 316 312 .018 2121 | 318 312 .0272 2122 | 314 313 .0045 2123 | 315 313 .0072 2124 | 316 313 .0086 2125 | 317 313 .0071 2126 | 318 313 .0179 2127 | 319 313 .0212 2128 | 326 313 .0736 2129 | 329 313 .3011 2130 | 331 313 .3468 2131 | 316 314 .0042 2132 | 317 314 .0036 2133 | 316 315 .0013 2134 | 317 315 .0022 2135 | 317 316 .0028 2136 | 318 316 .0113 2137 | 319 316 .0133 2138 | 319 318 .0051 2139 | 324 320 .0054 2140 | 322 321 .0054 2141 | 323 321 .0061 2142 | 324 321 .005 2143 | 325 321 .0094 2144 | 325 322 .0061 2145 | 328 327 .0093 2146 | 329 327 .0163 2147 | 330 327 .0935 2148 | 332 327 .0013 2149 | 329 328 .007 2150 | 330 329 .0784 2151 | -------------------------------------------------------------------------------- /netgan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class Generator(nn.Module): 7 | 8 | def __init__(self, H_inputs, H, z_dim, N, rw_len, temp): 9 | ''' 10 | H_inputs: input dimension 11 | H: hidden dimension 12 | z_dim: latent dimension 13 | N: number of nodes (needed for the up and down projection) 14 | rw_len: number of LSTM cells 15 | temp: temperature for the gumbel softmax 16 | ''' 17 | super(Generator, self).__init__() 18 | self.intermediate = nn.Linear(z_dim, H).type(torch.float64) 19 | torch.nn.init.xavier_uniform_(self.intermediate.weight) 20 | torch.nn.init.zeros_(self.intermediate.bias) 21 | self.c_up = nn.Linear(H, H).type(torch.float64) 22 | torch.nn.init.xavier_uniform_(self.c_up.weight) 23 | torch.nn.init.zeros_(self.c_up.bias) 24 | self.h_up = nn.Linear(H, H).type(torch.float64) 25 | torch.nn.init.xavier_uniform_(self.h_up.weight) 26 | torch.nn.init.zeros_(self.h_up.bias) 27 | self.lstmcell = LSTMCell(H_inputs, H).type(torch.float64) 28 | 29 | 30 | self.W_up = nn.Linear(H, N).type(torch.float64) 31 | self.W_down = nn.Linear(N, H_inputs, bias=False).type(torch.float64) 32 | self.rw_len = rw_len 33 | self.temp = temp 34 | self.H = H 35 | self.latent_dim = z_dim 36 | self.N = N 37 | self.H_inputs = H_inputs 38 | 39 | def forward(self, latent, inputs, device='cuda'): # h_down = input_zeros 40 | intermediate = torch.tanh(self.intermediate(latent)) 41 | hc = (torch.tanh(self.h_up(intermediate)), torch.tanh(self.c_up(intermediate))) 42 | out = [] # gumbel_noise = uniform noise [0, 1] 43 | for i in range(self.rw_len): 44 | hh, cc = self.lstmcell(inputs, hc) 45 | hc = (hh, cc) 46 | h_up = self.W_up(hh) # blow up to dimension N using W_up 47 | h_sample = self.gumbel_softmax_sample(h_up, self.temp, device) 48 | inputs = self.W_down(h_sample) # back to dimension H (in netgan they reduce the dimension to d) 49 | out.append(h_sample) 50 | return torch.stack(out, dim=1) 51 | 52 | def sample_latent(self, num_samples, device): 53 | return torch.randn((num_samples, self.latent_dim)).type(torch.float64).to(device) 54 | 55 | 56 | def sample(self, num_samples, device): 57 | noise = self.sample_latent(num_samples, device) 58 | input_zeros = self.init_hidden(num_samples).contiguous().type(torch.float64).to(device) 59 | generated_data = self(noise, input_zeros, device) 60 | return generated_data 61 | 62 | def sample_discrete(self, num_samples, device): 63 | with torch.no_grad(): 64 | proba = self.sample(num_samples, device) 65 | return np.argmax(proba.cpu().numpy(), axis=2) 66 | 67 | def sample_gumbel(self, logits, eps=1e-20): 68 | U = torch.rand(logits.shape, dtype=torch.float64) 69 | return -torch.log(-torch.log(U + eps) + eps) 70 | 71 | def gumbel_softmax_sample(self, logits, temperature, device, hard=True): 72 | """ Draw a sample from the Gumbel-Softmax distribution""" 73 | gumbel = self.sample_gumbel(logits).type(torch.float64).to(device) 74 | y = logits + gumbel 75 | y = torch.nn.functional.softmax(y / temperature, dim=1) 76 | if hard: 77 | y_hard = torch.max(y, 1, keepdim=True)[0].eq(y).type(torch.float64).to(device) 78 | y = (y_hard - y).detach() + y 79 | return y 80 | 81 | def init_hidden(self, batch_size): 82 | weight = next(self.parameters()).data 83 | return weight.new(batch_size, self.H_inputs).zero_().type(torch.float64) 84 | 85 | #def reset_weights(self): 86 | # import h5py 87 | # weights = h5py.File(r'C:\Users\Data Miner\PycharmProjects\Master_Projekt4\weights.h5', 'r') 88 | # self.intermediate.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('intermediate')).T).type(torch.float64)) 89 | # self.intermediate.bias = torch.nn.Parameter(torch.tensor(np.array(weights.get('intermediate_bias'))).type(torch.float64)) 90 | # self.c_up.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('c')).T).type(torch.float64)) 91 | # self.c_up.bias = torch.nn.Parameter(torch.tensor(np.array(weights.get('c_bias'))).type(torch.float64)) 92 | # self.h_up.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('h')).T).type(torch.float64)) 93 | # self.h_up.bias = torch.nn.Parameter(torch.tensor(np.array(weights.get('h_bias'))).type(torch.float64)) 94 | # self.lstmcell.cell.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('generator_lstm')).T).type(torch.float64)) 95 | # self.lstmcell.cell.bias = torch.nn.Parameter(torch.tensor(np.array(weights.get('generator_lstm_bias'))).type(torch.float64)) 96 | # self.W_up.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('W_up_generator')).T).type(torch.float64)) 97 | # self.W_up.bias = torch.nn.Parameter(torch.tensor(np.array(weights.get('W_up_generator_bias'))).type(torch.float64)) 98 | # self.W_down.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('W_down_generator')).T).type(torch.float64)) 99 | 100 | 101 | class Discriminator(nn.Module): 102 | def __init__(self, H_inputs, H, N, rw_len): 103 | ''' 104 | H_inputs: input dimension 105 | H: hidden dimension 106 | N: number of nodes (needed for the up and down projection) 107 | rw_len: number of LSTM cells 108 | ''' 109 | super(Discriminator, self).__init__() 110 | self.W_down = nn.Linear(N, H_inputs, bias=False).type(torch.float64) 111 | torch.nn.init.xavier_uniform_(self.W_down.weight) 112 | self.lstmcell = LSTMCell(H_inputs, H).type(torch.float64) 113 | self.lin_out = nn.Linear(H, 1, bias=True).type(torch.float64) 114 | torch.nn.init.xavier_uniform_(self.lin_out.weight) 115 | torch.nn.init.zeros_(self.lin_out.bias) 116 | self.H = H 117 | self.N = N 118 | self.rw_len = rw_len 119 | self.H_inputs = H_inputs 120 | 121 | def forward(self, x): 122 | x = x.view(-1, self.N) 123 | xa = self.W_down(x) 124 | xa = xa.view(-1, self.rw_len, self.H_inputs) 125 | hc = self.init_hidden(xa.size(0)) 126 | for i in range(self.rw_len): 127 | hc = self.lstmcell(xa[:, i, :], hc) 128 | out = hc[0] 129 | pred = self.lin_out(out) 130 | return pred 131 | 132 | def init_inputs(self, num_samples): 133 | weight = next(self.parameters()).data 134 | return weight.new(num_samples, self.H_inputs).zero_().type(torch.float64) 135 | 136 | def init_hidden(self, num_samples): 137 | weight = next(self.parameters()).data 138 | return (weight.new(num_samples, self.H).zero_().contiguous().type(torch.float64), weight.new(num_samples, self.H).zero_().contiguous().type(torch.float64)) 139 | 140 | #def reset_weights(self): 141 | # import h5py 142 | # weights = h5py.File(r'C:\Users\Data Miner\PycharmProjects\Master_Projekt4\weights.h5', 'r') 143 | # self.W_down.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('W_down_discriminator')).T).type(torch.float64)) 144 | # self.lin_out.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('discriminator_out')).T).type(torch.float64)) 145 | # self.lin_out.bias = torch.nn.Parameter(torch.tensor(np.array(weights.get('discriminator_out_bias'))).type(torch.float64)) 146 | # self.lstmcell.cell.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('discriminator_lstm')).T).type(torch.float64)) 147 | # self.lstmcell.cell.bias = torch.nn.Parameter(torch.tensor(np.array(weights.get('discriminator_lstm_bias'))).type(torch.float64)) 148 | 149 | class LSTMCell(nn.Module): 150 | def __init__(self, input_size, hidden_size): 151 | super(LSTMCell, self).__init__() 152 | self.input_size = input_size 153 | self.hidden_size = hidden_size 154 | 155 | self.cell = nn.Linear(input_size+hidden_size, 4 * hidden_size, bias=True) 156 | torch.nn.init.xavier_uniform_(self.cell.weight) 157 | torch.nn.init.zeros_(self.cell.bias) 158 | 159 | def forward(self, x, hidden): 160 | hx, cx = hidden 161 | gates = torch.cat((x, hx), dim=1) 162 | gates = self.cell(gates) 163 | 164 | ingate, cellgate, forgetgate, outgate = gates.chunk(4, 1) 165 | 166 | ingate = torch.sigmoid(ingate) 167 | forgetgate = torch.sigmoid(torch.add(forgetgate, 1.0)) 168 | cellgate = torch.tanh(cellgate) 169 | outgate = torch.sigmoid(outgate) 170 | cy = torch.mul(cx, forgetgate) + torch.mul(ingate, cellgate) 171 | hy = torch.mul(outgate, torch.tanh(cy)) 172 | return (hy, cy) -------------------------------------------------------------------------------- /netgan/training.py: -------------------------------------------------------------------------------- 1 | from models import Generator, Discriminator 2 | import utils 3 | 4 | import numpy as np 5 | import scipy.sparse as sp 6 | from sklearn.metrics import roc_auc_score, average_precision_score 7 | import math 8 | 9 | import torch 10 | import torch.optim as optim 11 | from torch.nn.functional import one_hot 12 | from torch.autograd import grad 13 | #from torch.utils.tensorboard import SummaryWriter 14 | import time 15 | from joblib import Parallel, delayed 16 | import pdb 17 | from matplotlib import pyplot as plt 18 | 19 | class Trainer(): 20 | def __init__(self, graph, N, max_iterations=20000, rw_len=16, batch_size=128, H_gen=40, H_disc=30, H_inp=128, z_dim=16, lr=0.0003, n_critic=3, gp_weight=10.0, betas=(.5, .9), 21 | l2_penalty_disc=5e-5, l2_penalty_gen=1e-7, temp_start=5.0, temp_decay=1-5e-5, min_temp=0.5, val_share=0.1, test_share=0.05, seed=498164, set_ops=False): 22 | """ 23 | Initialize NetGAN. 24 | Parameters 25 | ---------- 26 | graph: scipy_sparse_matrix 27 | Graph 28 | N: int 29 | Number of nodes in the graph to generate. 30 | max_iterations: int, default: 40,000 31 | Maximal iterations if the stopping_criterion is not fulfilled. 32 | rw_len: int 33 | Length of random walks to generate. 34 | batch_size: int, default: 128 35 | The batch size. 36 | H_gen: int, default: 40 37 | The hidden_size of the generator. 38 | H_disc: int, default: 30 39 | The hidden_size of the discriminator 40 | H_inp: int, 128 41 | Inputsize of the LSTM-Cells 42 | z_dim: int, 16 43 | The dimension of the random noise that is used as input to the generator. 44 | lr: float, default: 0.0003 45 | The Learning rate will be used for the generator as well as for the discriminator. 46 | n_critic: int, default: 3 47 | The number of discriminator iterations per generator training iteration. 48 | gp_weight: float, default: 10 49 | Gradient penalty weight for the Wasserstein GAN. See the paper 'Improved Training of Wasserstein GANs' for more details. 50 | betas: tuple, default: (.5, .9) 51 | Decay rates of the Adam Optimizers. 52 | l2_penalty_gen: float, default: 1e-7 53 | L2 penalty on the generator weights. 54 | l2_penalty_disc: float, default: 5e-5 55 | L2 penalty on the discriminator weights. 56 | temp_start: float, default: 5.0 57 | The initial temperature for the Gumbel softmax. 58 | temp_decay: float, default: 1-5e-5 59 | After each evaluation, the current temperature is updated as 60 | current_temp := max(temperature_decay*current_temp, min_temperature) 61 | min_temp: float, default: 0.5 62 | The minimal temperature for the Gumbel softmax. 63 | val_share: float, default: 0.1 64 | Percentage of validation edges. 65 | test_share: float, default: 0.1 66 | Percentage of test edges. 67 | seed: int, default: 498164 68 | Seed for numpy.random. It is used for splitting the graph in train, validation and test sets. 69 | 70 | 71 | """ 72 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 73 | self.max_iterations = max_iterations 74 | self.rw_len = rw_len 75 | self.batch_size = batch_size 76 | self.N = N 77 | self.generator = Generator(H_inputs=H_inp, H=H_gen, N=N, rw_len=rw_len, z_dim=z_dim, temp=temp_start).to(self.device) 78 | self.discriminator = Discriminator(H_inputs=H_inp, H=H_disc, N=N, rw_len=rw_len).to(self.device) 79 | self.G_optimizer = optim.Adam(self.generator.parameters(), lr=lr, betas=betas) 80 | self.D_optimizer = optim.Adam(self.discriminator.parameters(), lr=lr, betas=betas) 81 | self.n_critic = n_critic 82 | self.gp_weight = gp_weight 83 | self.l2_penalty_disc = l2_penalty_disc 84 | self.l2_penalty_gen =l2_penalty_gen 85 | self.temp_start = temp_start 86 | self.temp_decay = temp_decay 87 | self.min_temp = min_temp 88 | 89 | self.graph = graph 90 | #self.train_ones, self.val_ones, self.val_zeros, self.test_ones, self.test_zeros = stuff[0], stuff[1], stuff[2], stuff[3], stuff[4] 91 | #self.train_graph = stuff[5] 92 | #self.walker = stuff[6] 93 | self.train_ones, self.val_ones, self.val_zeros, self.test_ones, self.test_zeros = utils.train_val_test_split_adjacency(graph, val_share, test_share, seed, undirected=True, connected=True, asserts=True, set_ops=set_ops) 94 | self.train_graph = sp.coo_matrix((np.ones(len(self.train_ones)), (self.train_ones[:, 0], self.train_ones[:, 1]))).tocsr() 95 | assert (self.train_graph.toarray() == self.train_graph.toarray().T).all() 96 | self.walker = utils.RandomWalker(self.train_graph, rw_len, p=1, q=1, batch_size=batch_size) 97 | self.eo = [] 98 | self.critic_loss = [] 99 | self.generator_loss = [] 100 | self.avp = [] 101 | self.roc_auc = [] 102 | self.best_performance = 0.0 103 | self.running = True 104 | 105 | 106 | def l2_regularization_G(self, G): 107 | # regularizaation for the generator. W_down will not be regularized. 108 | l2_1 = torch.sum(torch.cat([x.view(-1) for x in G.W_down.weight]) ** 2 / 2) 109 | l2_2 = torch.sum(torch.cat([x.view(-1) for x in G.W_up.weight]) ** 2 / 2) 110 | l2_3 = torch.sum(torch.cat([x.view(-1) for x in G.W_up.bias]) ** 2 / 2) 111 | l2_4 = torch.sum(torch.cat([x.view(-1) for x in G.intermediate.weight]) ** 2 / 2) 112 | l2_5 = torch.sum(torch.cat([x.view(-1) for x in G.intermediate.bias]) ** 2 / 2) 113 | l2_6 = torch.sum(torch.cat([x.view(-1) for x in G.h_up.weight]) ** 2 / 2) 114 | l2_7 = torch.sum(torch.cat([x.view(-1) for x in G.h_up.bias]) ** 2 / 2) 115 | l2_8 = torch.sum(torch.cat([x.view(-1) for x in G.c_up.weight]) ** 2 / 2) 116 | l2_9 = torch.sum(torch.cat([x.view(-1) for x in G.c_up.bias]) ** 2 / 2) 117 | l2_10 = torch.sum(torch.cat([x.view(-1) for x in G.lstmcell.cell.weight]) ** 2 / 2) 118 | l2_11 = torch.sum(torch.cat([x.view(-1) for x in G.lstmcell.cell.bias]) ** 2 / 2) 119 | l2 = self.l2_penalty_gen * (l2_1 + l2_2 + l2_3 + l2_4 + l2_5 + l2_6 + l2_7 + l2_8 + l2_9 + l2_10 + l2_11) 120 | return l2 121 | 122 | def l2_regularization_D(self, D): 123 | # regularizaation for the discriminator. W_down will not be regularized. 124 | l2_1 = torch.sum(torch.cat([x.view(-1) for x in D.W_down.weight]) ** 2 / 2) 125 | l2_2 = torch.sum(torch.cat([x.view(-1) for x in D.lstmcell.cell.weight]) ** 2 / 2) 126 | l2_3 = torch.sum(torch.cat([x.view(-1) for x in D.lstmcell.cell.bias]) ** 2 / 2) 127 | l2_4 = torch.sum(torch.cat([x.view(-1) for x in D.lin_out.weight]) ** 2 / 2) 128 | l2_5 = torch.sum(torch.cat([x.view(-1) for x in D.lin_out.bias]) ** 2 / 2) 129 | l2 = self.l2_penalty_disc * (l2_1 + l2_2 + l2_3 + l2_4 + l2_5) 130 | return l2 131 | 132 | def calc_gp(self, fake_inputs, real_inputs): 133 | # calculate the gradient penalty. For more details see the paper 'Improved Training of Wasserstein GANs'. 134 | alpha = torch.rand((self.batch_size, 1, 1), dtype=torch.float64).to(self.device) 135 | differences = fake_inputs - real_inputs 136 | interpolates = real_inputs + alpha * differences 137 | 138 | y_pred_interpolates = self.discriminator(interpolates) 139 | gradients = grad(outputs=y_pred_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(y_pred_interpolates), create_graph=True, retain_graph=True)[0] 140 | slopes = torch.sqrt(torch.sum(gradients ** 2, dim=[1, 2])) 141 | gradient_penalty = torch.mean((slopes - 1) ** 2) 142 | gradient_penalty = gradient_penalty * self.gp_weight 143 | return gradient_penalty 144 | 145 | def critic_train_iteration(self): 146 | self.D_optimizer.zero_grad() 147 | # create fake and real inputs 148 | fake_inputs = self.generator.sample(self.batch_size, self.device) 149 | real_inputs = one_hot(torch.tensor(next(self.walker.walk())), num_classes=self.N).type(torch.float64).to(self.device) 150 | 151 | y_pred_fake = self.discriminator(fake_inputs) 152 | y_pred_real = self.discriminator(real_inputs) 153 | gp = self.calc_gp(fake_inputs, real_inputs) # gradient penalty 154 | 155 | disc_cost = torch.mean(y_pred_fake) - torch.mean(y_pred_real) + gp + self.l2_regularization_D(self.discriminator) 156 | disc_cost.backward() 157 | self.D_optimizer.step() 158 | return disc_cost.item() 159 | 160 | def generator_train_iteration(self): 161 | self.generator.train() 162 | self.G_optimizer.zero_grad() 163 | fake_inputs = self.generator.sample(self.batch_size, self.device) 164 | 165 | y_pred_fake = self.discriminator(fake_inputs) 166 | gen_cost = -torch.mean(y_pred_fake) + self.l2_regularization_G(self.generator) 167 | 168 | gen_cost.backward() 169 | self.G_optimizer.step() 170 | return gen_cost.item() 171 | 172 | def create_graph(self, num_samples, i, batch_size=1000, reset_weights=False): 173 | if reset_weights: 174 | self.generator.reset_weights() 175 | self.generator.eval() 176 | 177 | self.generator.temp = 0.5 178 | samples = [] 179 | num_iterations = int(num_samples/batch_size) 180 | print("Number iterations: " + str(num_iterations)) 181 | for j in range(num_iterations): 182 | if(j%10 == 1): print(j) 183 | samples.append(self.generator.sample_discrete(batch_size, self.device)) 184 | samples = np.vstack(samples) 185 | gr = utils.score_matrix_from_random_walks(samples, self.N) 186 | gr = gr.tocsr() 187 | 188 | # Assemble a graph from the score matrix 189 | _graph = utils.graph_from_scores(gr, self.graph.sum()) 190 | # Compute edge overlap 191 | edge_overlap = utils.edge_overlap(self.graph.toarray(), _graph) 192 | edge_scores = np.append(gr[tuple(self.val_ones.T)].A1, gr[tuple(self.val_zeros.T)].A1) 193 | actual_labels_val = np.append(np.ones(len(self.val_ones)), np.zeros(len(self.val_zeros))) 194 | # Compute Validation ROC-AUC and average precision scores. 195 | self.roc_auc.append(roc_auc_score(actual_labels_val, edge_scores)) 196 | self.avp.append(average_precision_score(actual_labels_val, edge_scores)) 197 | self.eo.append(edge_overlap/self.graph.sum()) 198 | 199 | print('roc: {:.4f} avp: {:.4f} eo: {:.4f}'.format(self.roc_auc[-1], self.avp[-1], self.eo[-1])) 200 | self.generator.temp = np.maximum(self.temp_start * np.exp(-(1 - self.temp_decay) * i), self.min_temp) 201 | 202 | def create_transition_matrix(self, num_samples): # should be multiples of 1000 203 | self.generator.eval() 204 | samples = [] 205 | num_iterations = int(num_samples/1000)+1 206 | for j in range(num_iterations): 207 | if(j%10 == 1): print(j) 208 | samples.append(self.generator.sample_discrete(int(num_samples/1000), self.device)) 209 | samples = np.vstack(samples) 210 | gr = utils.score_matrix_from_random_walks(samples, self.N) 211 | gr = gr.tocsr() 212 | return gr 213 | 214 | def check_running(self, i): 215 | if (self.stopping_criterion == 'val'): 216 | if (self.roc_auc[-1] + self.avp[-1] > self.best_performance): 217 | self.best_performance = self.roc_auc[-1] + self.avp[-1] 218 | self.patience = self.max_patience 219 | else: 220 | self.patience -= 1 221 | if self.patience == 0: 222 | print('finished after {} iterations'.format(i)) 223 | self.running = False 224 | else: 225 | if (self.stopping_eo < self.eo[-1]): 226 | print('finished after {} iterations'.format(i)) 227 | self.running = False 228 | 229 | def initialize_validation_settings(self, stopping_criterion, stopping_eo, max_patience): 230 | self.stopping_criterion = stopping_criterion 231 | self.stopping_eo = stopping_eo # needed for 'eo' criterion # 232 | self.max_patience = max_patience # needed for 'val' criterion 233 | self.patience = max_patience # 234 | if (self.stopping_criterion == 'val'): 235 | print("**** Using VAL criterion for early stopping with max patience of: {}****".format(self.max_patience)) 236 | else: 237 | assert self.stopping_eo is not None, "stopping_eo is not a float" 238 | print("**** Using EO criterion of {} for early stopping".format(self.stopping_eo)) 239 | 240 | def plot_graph(self): 241 | if len(self.critic_loss) > 10: 242 | plt.plot(self.critic_loss[9::], label="Critic loss") 243 | plt.plot(self.generator_loss[9::], label="Generator loss") 244 | else: 245 | plt.plot(self.critic_loss, label="Critic loss") 246 | plt.plot(self.generator_loss, label="Generator loss") 247 | plt.legend() 248 | plt.show() 249 | 250 | def train(self, create_graph_every = 2000, plot_graph_every=500, num_samples_graph = 100000, stopping_criterion='val', max_patience=5, stopping_eo=None): 251 | """ 252 | create_graph_every: int, default: 2000 253 | Creates every nth iteration a graph from randomwalks. 254 | plot_graph_every: int, default: 2000 255 | Plots the lost functions of the generator and discriminator. 256 | num_samples_graph: int, default 10000 257 | Number of random walks that will be created for the graphs. Higher values mean more precise evaluations but also more computational time. 258 | stopping_criterion: str, default: 'val' 259 | The stopping_criterion can be either 'val' or 'eo': 260 | 'val': Stops the optimization if there are no improvements after several iterations. --> defined by max_patience 261 | 'eo': Stops if the edge overlap exceeds a certain treshold. --> defined by stopping_eo 262 | max_patience: int, default: 5 263 | Maximum evaluation steps without improvement of the validation accuracy to tolerate. Only 264 | applies to the VAL criterion. 265 | stopping_eo: float in (0,1], default: 0.5 266 | Stops when the edge overlap exceeds this threshold. Will be used when stopping_criterion is 'eo'. 267 | """ 268 | self.initialize_validation_settings(stopping_criterion, stopping_eo, max_patience) 269 | starting_time = time.time() 270 | # Start Training 271 | for i in range(self.max_iterations): 272 | if(self.running): 273 | self.critic_loss.append(np.mean([self.critic_train_iteration() for _ in range(self.n_critic)])) 274 | self.generator_loss.append(self.generator_train_iteration()) 275 | if(i%10 ==1): print('iteration: {} critic: {:.6f} gen {:.6f}'.format(i, self.critic_loss[-1], self.generator_loss[-1])) 276 | if (i % create_graph_every == create_graph_every-1): 277 | self.create_graph(num_samples_graph, i) 278 | self.check_running(i) 279 | print('Took {} minutes so far..'.format((time.time() - starting_time)/60)) 280 | if plot_graph_every > 0 and (i + 1) % plot_graph_every == 0: self.plot_graph() 281 | -------------------------------------------------------------------------------- /netgan/utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import scipy.sparse as sp 3 | import numpy as np 4 | from scipy.sparse.csgraph import connected_components, minimum_spanning_tree 5 | import warnings 6 | import pandas as pd 7 | from matplotlib import pyplot as plt 8 | import pdb 9 | from numba import jit 10 | 11 | def load_npz(file_name): 12 | """Load a SparseGraph from a Numpy binary file. 13 | 14 | Parameters 15 | ---------- 16 | file_name : str 17 | Name of the file to load. 18 | 19 | Returns 20 | ------- 21 | sparse_graph : gust.SparseGraph 22 | Graph in sparse matrix format. 23 | 24 | """ 25 | if not file_name.endswith('.npz'): 26 | file_name += '.npz' 27 | with np.load(file_name, allow_pickle=True) as loader: 28 | #pdb.set_trace() 29 | loader = dict(loader)['arr_0'].item() 30 | adj_matrix = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], 31 | loader['adj_indptr']), shape=loader['adj_shape']) 32 | 33 | if 'attr_data' in loader: 34 | attr_matrix = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], 35 | loader['attr_indptr']), shape=loader['attr_shape']) 36 | else: 37 | attr_matrix = None 38 | 39 | labels = loader.get('labels') 40 | 41 | return adj_matrix, attr_matrix, labels 42 | 43 | 44 | def largest_connected_components(adj, n_components=1): 45 | """Select the largest connected components in the graph. 46 | 47 | Parameters 48 | ---------- 49 | sparse_graph : gust.SparseGraph 50 | Input graph. 51 | n_components : int, default 1 52 | Number of largest connected components to keep. 53 | 54 | Returns 55 | ------- 56 | sparse_graph : gust.SparseGraph 57 | Subgraph of the input graph where only the nodes in largest n_components are kept. 58 | 59 | """ 60 | _, component_indices = connected_components(adj) 61 | component_sizes = np.bincount(component_indices) 62 | components_to_keep = np.argsort(component_sizes)[::-1][:n_components] # reverse order to sort descending 63 | nodes_to_keep = [ 64 | idx for (idx, component) in enumerate(component_indices) if component in components_to_keep 65 | 66 | 67 | ] 68 | print("Selecting {0} largest connected components".format(n_components)) 69 | return nodes_to_keep 70 | 71 | 72 | def edges_to_sparse(edges, N, values=None): 73 | """ 74 | Create a sparse adjacency matrix from an array of edge indices and (optionally) values. 75 | 76 | Parameters 77 | ---------- 78 | edges : array-like, shape [n_edges, 2] 79 | Edge indices 80 | N : int 81 | Number of nodes 82 | values : array_like, shape [n_edges] 83 | The values to put at the specified edge indices. Optional, default: np.ones(.) 84 | 85 | Returns 86 | ------- 87 | A : scipy.sparse.csr.csr_matrix 88 | Sparse adjacency matrix 89 | 90 | """ 91 | if values is None: 92 | values = np.ones(edges.shape[0]) 93 | 94 | return sp.coo_matrix((values, (edges[:, 0], edges[:, 1])), shape=(N, N)).tocsr() 95 | 96 | 97 | def train_val_test_split_adjacency(A, p_val=0.10, p_test=0.05, seed=0, neg_mul=1, 98 | every_node=True, connected=False, undirected=False, 99 | use_edge_cover=True, set_ops=True, asserts=False): 100 | """ 101 | Split the edges of the adjacency matrix into train, validation and test edges 102 | and randomly samples equal amount of validation and test non-edges. 103 | 104 | Parameters 105 | ---------- 106 | A : scipy.sparse.spmatrix 107 | Sparse unweighted adjacency matrix 108 | p_val : float 109 | Percentage of validation edges. Default p_val=0.10 110 | p_test : float 111 | Percentage of test edges. Default p_test=0.05 112 | seed : int 113 | Seed for numpy.random. Default seed=0 114 | neg_mul : int 115 | What multiplicity of negative samples (non-edges) to have in the test/validation set 116 | w.r.t the number of edges, i.e. len(non-edges) = L * len(edges). Default neg_mul=1 117 | every_node : bool 118 | Make sure each node appears at least once in the train set. Default every_node=True 119 | connected : bool 120 | Make sure the training graph is still connected after the split 121 | undirected : bool 122 | Whether to make the split undirected, that is if (i, j) is in val/test set then (j, i) is there as well. 123 | Default undirected=False 124 | use_edge_cover: bool 125 | Whether to use (approximate) edge_cover to find the minimum set of edges that cover every node. 126 | Only active when every_node=True. Default use_edge_cover=True 127 | set_ops : bool 128 | Whether to use set operations to construction the test zeros. Default setwise_zeros=True 129 | Otherwise use a while loop. 130 | asserts : bool 131 | Unit test like checks. Default asserts=False 132 | 133 | Returns 134 | ------- 135 | train_ones : array-like, shape [n_train, 2] 136 | Indices of the train edges 137 | val_ones : array-like, shape [n_val, 2] 138 | Indices of the validation edges 139 | val_zeros : array-like, shape [n_val, 2] 140 | Indices of the validation non-edges 141 | test_ones : array-like, shape [n_test, 2] 142 | Indices of the test edges 143 | test_zeros : array-like, shape [n_test, 2] 144 | Indices of the test non-edges 145 | 146 | """ 147 | 148 | assert p_val + p_test > 0 149 | assert A.max() == 1 # no weights 150 | assert A.min() == 0 # no negative edges 151 | assert A.diagonal().sum() == 0 # no self-loops 152 | assert not np.any(A.sum(0).A1 + A.sum(1).A1 == 0) # no dangling nodes 153 | 154 | is_undirected = (A != A.T).nnz == 0 155 | 156 | if undirected: 157 | assert is_undirected # make sure is directed 158 | A = sp.tril(A).tocsr() # consider only upper triangular 159 | A.eliminate_zeros() 160 | else: 161 | if is_undirected: 162 | warnings.warn('Graph appears to be undirected. Did you forgot to set undirected=True?') 163 | 164 | np.random.seed(seed) 165 | 166 | E = A.nnz 167 | N = A.shape[0] 168 | 169 | s_train = int(E * (1 - p_val - p_test)) 170 | 171 | idx = np.arange(N) 172 | 173 | # hold some edges so each node appears at least once 174 | if every_node: 175 | if connected: 176 | assert connected_components(A)[0] == 1 # make sure original graph is connected 177 | A_hold = minimum_spanning_tree(A) 178 | else: 179 | A.eliminate_zeros() # makes sure A.tolil().rows contains only indices of non-zero elements 180 | d = A.sum(1).A1 181 | 182 | if use_edge_cover: 183 | hold_edges = np.array(list(nx.maximal_matching(nx.DiGraph(A)))) 184 | not_in_cover = np.array(list(set(range(N)).difference(hold_edges.flatten()))) 185 | 186 | # makes sure the training percentage is not smaller than N/E when every_node is set to True 187 | min_size = hold_edges.shape[0] + len(not_in_cover) 188 | if min_size > s_train: 189 | raise ValueError('Training percentage too low to guarantee every node. Min train size needed {:.2f}' 190 | .format(min_size / E)) 191 | 192 | d_nic = d[not_in_cover] 193 | 194 | hold_edges_d1 = np.column_stack((not_in_cover[d_nic > 0], 195 | np.row_stack(map(np.random.choice, 196 | A[not_in_cover[d_nic > 0]].tolil().rows)))) 197 | 198 | if np.any(d_nic == 0): 199 | hold_edges_d0 = np.column_stack((np.row_stack(map(np.random.choice, A[:, not_in_cover[d_nic == 0]].T.tolil().rows)), 200 | not_in_cover[d_nic == 0])) 201 | hold_edges = np.row_stack((hold_edges, hold_edges_d0, hold_edges_d1)) 202 | else: 203 | hold_edges = np.row_stack((hold_edges, hold_edges_d1)) 204 | 205 | else: 206 | # makes sure the training percentage is not smaller than N/E when every_node is set to True 207 | if N > s_train: 208 | raise ValueError('Training percentage too low to guarantee every node. Min train size needed {:.2f}' 209 | .format(N / E)) 210 | 211 | hold_edges_d1 = np.column_stack( 212 | (idx[d > 0], np.row_stack(map(np.random.choice, A[d > 0].tolil().rows)))) 213 | 214 | if np.any(d == 0): 215 | hold_edges_d0 = np.column_stack((np.row_stack(map(np.random.choice, A[:, d == 0].T.tolil().rows)), 216 | idx[d == 0])) 217 | hold_edges = np.row_stack((hold_edges_d0, hold_edges_d1)) 218 | else: 219 | hold_edges = hold_edges_d1 220 | 221 | if asserts: 222 | assert np.all(A[hold_edges[:, 0], hold_edges[:, 1]]) 223 | assert len(np.unique(hold_edges.flatten())) == N 224 | 225 | A_hold = edges_to_sparse(hold_edges, N) 226 | 227 | A_hold[A_hold > 1] = 1 228 | A_hold.eliminate_zeros() 229 | A_sample = A - A_hold 230 | 231 | s_train = s_train - A_hold.nnz 232 | else: 233 | A_sample = A 234 | 235 | idx_ones = np.random.permutation(A_sample.nnz) 236 | ones = np.column_stack(A_sample.nonzero()) 237 | train_ones = ones[idx_ones[:s_train]] 238 | test_ones = ones[idx_ones[s_train:]] 239 | 240 | # return back the held edges 241 | if every_node: 242 | train_ones = np.row_stack((train_ones, np.column_stack(A_hold.nonzero()))) 243 | 244 | n_test = len(test_ones) * neg_mul 245 | if set_ops: 246 | # generate slightly more completely random non-edge indices than needed and discard any that hit an edge 247 | # much faster compared a while loop 248 | # in the future: estimate the multiplicity (currently fixed 1.3/2.3) based on A_obs.nnz 249 | if undirected: 250 | random_sample = np.random.randint(0, N, [int(2.3 * n_test), 2]) 251 | random_sample = random_sample[random_sample[:, 0] > random_sample[:, 1]] 252 | else: 253 | random_sample = np.random.randint(0, N, [int(1.3 * n_test), 2]) 254 | random_sample = random_sample[random_sample[:, 0] != random_sample[:, 1]] 255 | 256 | test_zeros = random_sample[A[random_sample[:, 0], random_sample[:, 1]].A1 == 0] 257 | test_zeros = np.row_stack(test_zeros)[:n_test] 258 | assert test_zeros.shape[0] == n_test 259 | else: 260 | test_zeros = [] 261 | while len(test_zeros) < n_test: 262 | i, j = np.random.randint(0, N, 2) 263 | if A[i, j] == 0 and (not undirected or i > j) and (i, j) not in test_zeros: 264 | test_zeros.append((i, j)) 265 | test_zeros = np.array(test_zeros) 266 | 267 | # split the test set into validation and test set 268 | s_val_ones = int(len(test_ones) * p_val / (p_val + p_test)) 269 | s_val_zeros = int(len(test_zeros) * p_val / (p_val + p_test)) 270 | 271 | val_ones = test_ones[:s_val_ones] 272 | test_ones = test_ones[s_val_ones:] 273 | 274 | val_zeros = test_zeros[:s_val_zeros] 275 | test_zeros = test_zeros[s_val_zeros:] 276 | 277 | if undirected: 278 | # put (j, i) edges for every (i, j) edge in the respective sets and form back original A 279 | symmetrize = lambda x: np.row_stack((x, np.column_stack((x[:, 1], x[:, 0])))) 280 | train_ones = symmetrize(train_ones) 281 | val_ones = symmetrize(val_ones) 282 | val_zeros = symmetrize(val_zeros) 283 | test_ones = symmetrize(test_ones) 284 | test_zeros = symmetrize(test_zeros) 285 | A = A.maximum(A.T) 286 | 287 | if asserts: 288 | set_of_train_ones = set(map(tuple, train_ones)) 289 | assert train_ones.shape[0] + test_ones.shape[0] + val_ones.shape[0] == A.nnz 290 | assert (edges_to_sparse(np.row_stack((train_ones, test_ones, val_ones)), N) != A).nnz == 0 291 | assert set_of_train_ones.intersection(set(map(tuple, test_ones))) == set() 292 | assert set_of_train_ones.intersection(set(map(tuple, val_ones))) == set() 293 | assert set_of_train_ones.intersection(set(map(tuple, test_zeros))) == set() 294 | assert set_of_train_ones.intersection(set(map(tuple, val_zeros))) == set() 295 | assert len(set(map(tuple, test_zeros))) == len(test_ones) * neg_mul 296 | assert len(set(map(tuple, val_zeros))) == len(val_ones) * neg_mul 297 | assert not connected or connected_components(A_hold)[0] == 1 298 | assert not every_node or ((A_hold - A) > 0).sum() == 0 299 | 300 | 301 | return train_ones, val_ones, val_zeros, test_ones, test_zeros 302 | 303 | 304 | def score_matrix_from_random_walks(random_walks, N, symmetric=True): 305 | """ 306 | Compute the transition scores, i.e. how often a transition occurs, for all node pairs from 307 | the random walks provided. 308 | Parameters 309 | ---------- 310 | random_walks: np.array of shape (n_walks, rw_len, N) 311 | The input random walks to count the transitions in. 312 | N: int 313 | The number of nodes 314 | symmetric: bool, default: True 315 | Whether to symmetrize the resulting scores matrix. 316 | 317 | Returns 318 | ------- 319 | scores_matrix: sparse matrix, shape (N, N) 320 | Matrix whose entries (i,j) correspond to the number of times a transition from node i to j was 321 | observed in the input random walks. 322 | 323 | """ 324 | random_walks = np.array(random_walks) 325 | bigrams = np.array(list(zip(random_walks[:, :-1], random_walks[:, 1:]))) 326 | bigrams = np.transpose(bigrams, [0, 2, 1]) 327 | bigrams = bigrams.reshape([-1, 2]) 328 | if symmetric: 329 | bigrams = np.row_stack((bigrams, bigrams[:, ::-1])) 330 | 331 | mat = sp.coo_matrix((np.ones(bigrams.shape[0]), (bigrams[:, 0], bigrams[:, 1])), 332 | shape=[N, N]) 333 | return mat 334 | 335 | @jit(nopython=True) 336 | def random_walk(edges, node_ixs, rwlen, p=1, q=1, n_walks=1): 337 | N=len(node_ixs) 338 | 339 | walk = [] 340 | prev_nbs = None 341 | for w in range(n_walks): 342 | source_node = np.random.choice(N) 343 | walk.append(source_node) 344 | for it in range(rwlen-1): 345 | 346 | if walk[-1] == N-1: 347 | nbs = edges[node_ixs[walk[-1]]::,1] 348 | else: 349 | nbs = edges[node_ixs[walk[-1]]:node_ixs[walk[-1]+1],1] 350 | 351 | if it == 0: 352 | walk.append(np.random.choice(nbs)) 353 | prev_nbs = set(nbs) 354 | continue 355 | 356 | is_dist_1 = [] 357 | for n in nbs: 358 | is_dist_1.append(int(n in set(prev_nbs))) 359 | 360 | is_dist_1_np = np.array(is_dist_1) 361 | is_dist_0 = nbs == walk[-2] 362 | is_dist_2 = 1 - is_dist_1_np - is_dist_0 363 | 364 | alpha_pq = is_dist_0 / p + is_dist_1_np + is_dist_2/q 365 | alpha_pq_norm = alpha_pq/np.sum(alpha_pq) 366 | rdm_num = np.random.rand() 367 | cumsum = np.cumsum(alpha_pq_norm) 368 | nxt = nbs[np.sum(1-(cumsum > rdm_num))] 369 | walk.append(nxt) 370 | prev_nbs = set(nbs) 371 | return np.array(walk) 372 | 373 | class RandomWalker: 374 | """ 375 | Helper class to generate random walks on the input adjacency matrix. 376 | """ 377 | def __init__(self, adj, rw_len, p=1, q=1, batch_size=128): 378 | self.adj = adj 379 | #if not "lil" in str(type(adj)): 380 | # warnings.warn("Input adjacency matrix not in lil format. Converting it to lil.") 381 | # self.adj = self.adj.tolil() 382 | 383 | self.rw_len = rw_len 384 | self.p = p 385 | self.q = q 386 | self.edges = np.array(self.adj.nonzero()).T 387 | self.node_ixs = np.unique(self.edges[:, 0], return_index=True)[1] 388 | self.batch_size = batch_size 389 | 390 | def walk(self): 391 | while True: 392 | yield random_walk(self.edges, self.node_ixs, self.rw_len, self.p, self.q, self.batch_size).reshape([-1, self.rw_len]) 393 | 394 | 395 | 396 | def edge_overlap(A, B): 397 | """ 398 | Compute edge overlap between input graphs A and B, i.e. how many edges in A are also present in graph B. Assumes 399 | that both graphs contain the same number of edges. 400 | 401 | Parameters 402 | ---------- 403 | A: sparse matrix or np.array of shape (N,N). 404 | First input adjacency matrix. 405 | B: sparse matrix or np.array of shape (N,N). 406 | Second input adjacency matrix. 407 | 408 | Returns 409 | ------- 410 | float, the edge overlap. 411 | """ 412 | 413 | return ((A == B) & (A == 1)).sum() 414 | 415 | 416 | def graph_from_scores(scores, n_edges): 417 | """ 418 | Assemble a symmetric binary graph from the input score matrix. Ensures that there will be no singleton nodes. 419 | See the paper for details. 420 | 421 | Parameters 422 | ---------- 423 | scores: np.array of shape (N,N) 424 | The input transition scores. 425 | n_edges: int 426 | The desired number of edges in the target graph. 427 | 428 | Returns 429 | ------- 430 | target_g: symmettic binary sparse matrix of shape (N,N) 431 | The assembled graph. 432 | 433 | """ 434 | 435 | if len(scores.nonzero()[0]) < n_edges: 436 | return symmetric(scores) > 0 437 | target_g = np.zeros(scores.shape) # initialize target graph 438 | scores_int = scores.toarray().copy() # internal copy of the scores matrix 439 | scores_int[np.diag_indices_from(scores_int)] = 0 # set diagonal to zero 440 | degrees_int = scores_int.sum(0) # The row sum over the scores. 441 | 442 | N = scores.shape[0] 443 | 444 | for n in np.random.choice(N, replace=False, size=N): # Iterate the nodes in random order 445 | 446 | row = scores_int[n,:].copy() 447 | if row.sum() == 0: 448 | continue 449 | 450 | probs = row / row.sum() 451 | 452 | target = np.random.choice(N, p=probs) 453 | target_g[n, target] = 1 454 | target_g[target, n] = 1 455 | 456 | diff = np.round((n_edges - target_g.sum())/2) 457 | if diff > 0: 458 | triu = np.triu(scores_int) 459 | triu[target_g > 0] = 0 460 | triu[np.diag_indices_from(scores_int)] = 0 461 | triu = triu / triu.sum() 462 | 463 | triu_ixs = np.triu_indices_from(scores_int) 464 | extra_edges = np.random.choice(triu_ixs[0].shape[0], replace=False, p=triu[triu_ixs], size=int(diff)) 465 | 466 | target_g[(triu_ixs[0][extra_edges], triu_ixs[1][extra_edges])] = 1 467 | target_g[(triu_ixs[1][extra_edges], triu_ixs[0][extra_edges])] = 1 468 | 469 | target_g = symmetric(target_g) 470 | return target_g 471 | 472 | 473 | def symmetric(directed_adjacency, clip_to_one=True): 474 | """ 475 | Symmetrize the input adjacency matrix. 476 | Parameters 477 | ---------- 478 | directed_adjacency: sparse matrix or np.array of shape (N,N) 479 | Input adjacency matrix. 480 | clip_to_one: bool, default: True 481 | Whether the output should be binarized (i.e. clipped to 1) 482 | 483 | Returns 484 | ------- 485 | A_symmetric: sparse matrix or np.array of the same shape as the input 486 | Symmetrized adjacency matrix. 487 | 488 | """ 489 | 490 | A_symmetric = directed_adjacency + directed_adjacency.T 491 | if clip_to_one: 492 | A_symmetric[A_symmetric > 1] = 1 493 | return A_symmetric 494 | 495 | def squares(g): 496 | """ 497 | Count the number of squares for each node 498 | Parameters 499 | ---------- 500 | g: igraph Graph object 501 | The input graph. 502 | 503 | Returns 504 | ------- 505 | List with N entries (N is number of nodes) that give the number of squares a node is part of. 506 | """ 507 | 508 | cliques = g.cliques(min=4, max=4) 509 | result = [0] * g.vcount() 510 | for i, j, k, l in cliques: 511 | result[i] += 1 512 | result[j] += 1 513 | result[k] += 1 514 | result[l] += 1 515 | return result 516 | 517 | 518 | def statistics_degrees(A_in): 519 | """ 520 | Compute min, max, mean degree 521 | 522 | Parameters 523 | ---------- 524 | A_in: sparse matrix or np.array 525 | The input adjacency matrix. 526 | Returns 527 | ------- 528 | d_max. d_min, d_mean 529 | """ 530 | 531 | degrees = A_in.sum(axis=0) 532 | return np.max(degrees), np.min(degrees), np.mean(degrees) 533 | 534 | 535 | def statistics_LCC(A_in): 536 | """ 537 | Compute the size of the largest connected component (LCC) 538 | 539 | Parameters 540 | ---------- 541 | A_in: sparse matrix or np.array 542 | The input adjacency matrix. 543 | Returns 544 | ------- 545 | Size of LCC 546 | 547 | """ 548 | 549 | unique, counts = np.unique(connected_components(A_in)[1], return_counts=True) 550 | LCC = np.where(connected_components(A_in)[1] == np.argmax(counts))[0] 551 | return LCC 552 | 553 | 554 | def statistics_wedge_count(A_in): 555 | """ 556 | Compute the wedge count of the input graph 557 | 558 | Parameters 559 | ---------- 560 | A_in: sparse matrix or np.array 561 | The input adjacency matrix. 562 | 563 | Returns 564 | ------- 565 | The wedge count. 566 | """ 567 | 568 | degrees = A_in.sum(axis=0) 569 | return float(np.sum(np.array([0.5 * x * (x - 1) for x in degrees]))) 570 | 571 | 572 | def statistics_claw_count(A_in): 573 | """ 574 | Compute the claw count of the input graph 575 | 576 | Parameters 577 | ---------- 578 | A_in: sparse matrix or np.array 579 | The input adjacency matrix. 580 | 581 | Returns 582 | ------- 583 | Claw count 584 | """ 585 | 586 | degrees = A_in.sum(axis=0) 587 | return float(np.sum(np.array([1 / 6. * x * (x - 1) * (x - 2) for x in degrees]))) 588 | 589 | 590 | def statistics_triangle_count(A_in): 591 | """ 592 | Compute the triangle count of the input graph 593 | 594 | Parameters 595 | ---------- 596 | A_in: sparse matrix or np.array 597 | The input adjacency matrix. 598 | Returns 599 | ------- 600 | Triangle count 601 | """ 602 | 603 | A_graph = nx.from_numpy_matrix(A_in) 604 | triangles = nx.triangles(A_graph) 605 | t = np.sum(list(triangles.values())) / 3 606 | return int(t) 607 | 608 | 609 | #def statistics_square_count(A_in): 610 | # """ 611 | # Compute the square count of the input graph 612 | # 613 | # Parameters 614 | # ---------- 615 | # A_in: sparse matrix or np.array 616 | # The input adjacency matrix. 617 | # Returns 618 | # ------- 619 | # Square count 620 | # """ 621 | # 622 | # A_igraph = igraph.Graph.Adjacency((A_in > 0).tolist()).as_undirected() 623 | # return int(np.sum(squares(A_igraph)) / 4) 624 | 625 | 626 | def statistics_power_law_alpha(A_in): 627 | """ 628 | Compute the power law coefficient of the degree distribution of the input graph 629 | 630 | Parameters 631 | ---------- 632 | A_in: sparse matrix or np.array 633 | The input adjacency matrix. 634 | 635 | Returns 636 | ------- 637 | Power law coefficient 638 | """ 639 | 640 | degrees = A_in.sum(axis=0) 641 | return powerlaw.Fit(degrees, xmin=max(np.min(degrees),1)).power_law.alpha 642 | 643 | 644 | def statistics_gini(A_in): 645 | """ 646 | Compute the Gini coefficient of the degree distribution of the input graph 647 | 648 | Parameters 649 | ---------- 650 | A_in: sparse matrix or np.array 651 | The input adjacency matrix. 652 | 653 | Returns 654 | ------- 655 | Gini coefficient 656 | """ 657 | 658 | n = A_in.shape[0] 659 | degrees = A_in.sum(axis=0) 660 | degrees_sorted = np.sort(degrees) 661 | G = (2 * np.sum(np.array([i * degrees_sorted[i] for i in range(len(degrees))]))) / (n * np.sum(degrees)) - ( 662 | n + 1) / n 663 | return float(G) 664 | 665 | 666 | def statistics_edge_distribution_entropy(A_in): 667 | """ 668 | Compute the relative edge distribution entropy of the input graph. 669 | 670 | Parameters 671 | ---------- 672 | A_in: sparse matrix or np.array 673 | The input adjacency matrix. 674 | 675 | Returns 676 | ------- 677 | Rel. edge distribution entropy 678 | """ 679 | 680 | degrees = A_in.sum(axis=0) 681 | m = 0.5 * np.sum(np.square(A_in)) 682 | n = A_in.shape[0] 683 | 684 | H_er = 1 / np.log(n) * np.sum(-degrees / (2 * float(m)) * np.log((degrees+.0001) / (2 * float(m)))) 685 | return H_er 686 | 687 | def statistics_cluster_props(A, Z_obs): 688 | def get_blocks(A_in, Z_obs, normalize=True): 689 | block = Z_obs.T.dot(A_in.dot(Z_obs)) 690 | counts = np.sum(Z_obs, axis=0) 691 | blocks_outer = counts[:,None].dot(counts[None,:]) 692 | if normalize: 693 | blocks_outer = np.multiply(block, 1/blocks_outer) 694 | return blocks_outer 695 | 696 | in_blocks = get_blocks(A, Z_obs) 697 | diag_mean = np.multiply(in_blocks, np.eye(in_blocks.shape[0])).mean() 698 | offdiag_mean = np.multiply(in_blocks, 1-np.eye(in_blocks.shape[0])).mean() 699 | return diag_mean, offdiag_mean 700 | 701 | def statistics_compute_cpl(A): 702 | """Compute characteristic path length.""" 703 | P = sp.csgraph.shortest_path(sp.csr_matrix(A)) 704 | return P[((1 - np.isinf(P)) * (1 - np.eye(P.shape[0]))).astype(np.bool)].mean() 705 | 706 | 707 | def compute_graph_statistics(A_in, Z_obs=None): 708 | """ 709 | 710 | Parameters 711 | ---------- 712 | A_in: sparse matrix 713 | The input adjacency matrix. 714 | Z_obs: np.matrix [N, K], where K is the number of classes. 715 | Matrix whose rows are one-hot vectors indicating the class membership of the respective node. 716 | 717 | Returns 718 | ------- 719 | Dictionary containing the following statistics: 720 | * Maximum, minimum, mean degree of nodes 721 | * Size of the largest connected component (LCC) 722 | * Wedge count 723 | * Claw count 724 | * Triangle count 725 | * Square count 726 | * Power law exponent 727 | * Gini coefficient 728 | * Relative edge distribution entropy 729 | * Assortativity 730 | * Clustering coefficient 731 | * Number of connected components 732 | * Intra- and inter-community density (if Z_obs is passed) 733 | * Characteristic path length 734 | """ 735 | 736 | A = A_in.copy() 737 | 738 | assert ((A == A.T).all()) 739 | A_graph = nx.from_numpy_matrix(A).to_undirected() 740 | 741 | statistics = {} 742 | 743 | d_max, d_min, d_mean = statistics_degrees(A) 744 | 745 | # Degree statistics 746 | statistics['d_max'] = d_max 747 | statistics['d_min'] = d_min 748 | statistics['d'] = d_mean 749 | 750 | # largest connected component 751 | LCC = statistics_LCC(A) 752 | 753 | statistics['LCC'] = LCC.shape[0] 754 | # wedge count 755 | statistics['wedge_count'] = statistics_wedge_count(A) 756 | 757 | # claw count 758 | statistics['claw_count'] = statistics_claw_count(A) 759 | 760 | # triangle count 761 | statistics['triangle_count'] = statistics_triangle_count(A) 762 | 763 | # Square count 764 | statistics['square_count'] = statistics_square_count(A) 765 | 766 | # power law exponent 767 | statistics['power_law_exp'] = statistics_power_law_alpha(A) 768 | 769 | # gini coefficient 770 | statistics['gini'] = statistics_gini(A) 771 | 772 | # Relative edge distribution entropy 773 | statistics['rel_edge_distr_entropy'] = statistics_edge_distribution_entropy(A) 774 | 775 | # Assortativity 776 | statistics['assortativity'] = nx.degree_assortativity_coefficient(A_graph) 777 | 778 | # Clustering coefficient 779 | statistics['clustering_coefficient'] = 3 * statistics['triangle_count'] / statistics['claw_count'] 780 | 781 | # Number of connected components 782 | statistics['n_components'] = connected_components(A)[0] 783 | 784 | if Z_obs is not None: 785 | # inter- and intra-community density 786 | intra, inter = statistics_cluster_props(A, Z_obs) 787 | statistics['intra_community_density'] = intra 788 | statistics['inter_community_density'] = inter 789 | 790 | statistics['cpl'] = statistics_compute_cpl(A) 791 | 792 | return statistics 793 | 794 | def get_graph(path_folder, tag): 795 | #data = pd.read_csv(path, index_col=0) 796 | #data = data[data['Network'] == tag] 797 | #keys = np.unique(data[['fbus', 'tbus']].values) 798 | #dicto = dict(zip(keys, np.arange(len(keys)))) 799 | #data = data[['fbus', 'tbus']].values.reshape(-1) 800 | #data = np.array([dicto[key] for key in data]) 801 | #data = data.reshape(-1, 2) 802 | #edges = [list(edge) for edge in data] 803 | #nodes = np.sort(np.unique(data)) 804 | #G = nx.Graph() 805 | #G.add_nodes_from(nodes) 806 | #G.add_edges_from(edges) 807 | #G = nx.to_scipy_sparse_matrix(G) 808 | 809 | path_branch = path_folder + r'\branch.csv' 810 | path_bus = path_folder + r'\bus.csv' 811 | data_branch = pd.read_csv(path_branch) 812 | data_bus = pd.read_csv(path_bus) 813 | 814 | data_branch = data_branch[data_branch['Network'] == tag] 815 | data_bus = data_bus[data_bus['Network'] == tag] 816 | 817 | data_branch = data_branch.drop_duplicates(subset=['fbus', 'tbus']).sort_values(by=['fbus', 'tbus']) 818 | 819 | edges = data_branch[['fbus', 'tbus']].values 820 | nodes = np.arange(data_bus.shape[0]) 821 | 822 | G = nx.Graph() 823 | G.add_nodes_from(nodes) 824 | G.add_edges_from(edges) 825 | G = nx.to_scipy_sparse_matrix(G) 826 | 827 | return G 828 | 829 | 830 | -------------------------------------------------------------------------------- /netgan_modified/17000_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmiller96/netgan_pytorch/4511f7de6fb87000435c1fd498d720391f7ccdc5/netgan_modified/17000_model.pt -------------------------------------------------------------------------------- /netgan_modified/20000_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmiller96/netgan_pytorch/4511f7de6fb87000435c1fd498d720391f7ccdc5/netgan_modified/20000_model.pt -------------------------------------------------------------------------------- /netgan_modified/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class Generator(nn.Module): 7 | 8 | def __init__(self, H_inputs, H, z_dim, N, rw_len, temp, state='structure'): 9 | ''' 10 | H_inputs: input dimension 11 | H: hidden dimension 12 | z_dim: latent dimension 13 | N: number of nodes (needed for the up and down projection) 14 | rw_len: number of LSTM cells 15 | temp: temperature for the gumbel softmax 16 | ''' 17 | super(Generator, self).__init__() 18 | self.intermediate = nn.Linear(z_dim, H).type(torch.float64) 19 | torch.nn.init.xavier_uniform_(self.intermediate.weight) 20 | torch.nn.init.zeros_(self.intermediate.bias) 21 | self.intermediate_lines = nn.Linear(z_dim, H).type(torch.float64) 22 | torch.nn.init.xavier_uniform_(self.intermediate_lines.weight) 23 | torch.nn.init.zeros_(self.intermediate_lines.bias) 24 | 25 | self.c_up = nn.Linear(H, H).type(torch.float64) 26 | torch.nn.init.xavier_uniform_(self.c_up.weight) 27 | torch.nn.init.zeros_(self.c_up.bias) 28 | self.h_up = nn.Linear(H, H).type(torch.float64) 29 | torch.nn.init.xavier_uniform_(self.h_up.weight) 30 | torch.nn.init.zeros_(self.h_up.bias) 31 | self.c_up_lines = nn.Linear(H, H).type(torch.float64) 32 | torch.nn.init.xavier_uniform_(self.c_up_lines.weight) 33 | torch.nn.init.zeros_(self.c_up_lines.bias) 34 | self.h_up_lines = nn.Linear(H, H).type(torch.float64) 35 | torch.nn.init.xavier_uniform_(self.h_up_lines.weight) 36 | torch.nn.init.zeros_(self.h_up_lines.bias) 37 | 38 | self.lstmcell = LSTMCell(H_inputs, H).type(torch.float64) 39 | self.lstmcell_lines = LSTMCell(H_inputs, H).type(torch.float64) 40 | 41 | 42 | self.W_up = nn.Linear(H, N).type(torch.float64) 43 | self.W_down = nn.Linear(N, H_inputs, bias=False).type(torch.float64) 44 | self.W_down_lines = nn.Linear(N, H_inputs, bias=False).type(torch.float64) 45 | self.W_out_lines = nn.Linear(H, 1).type(torch.float64) 46 | self.rw_len = rw_len 47 | self.temp = temp 48 | self.H = H 49 | self.latent_dim = z_dim 50 | self.N = N 51 | self.H_inputs = H_inputs 52 | self.freeze_params(state) 53 | 54 | def forward(self, latent, inputs, device='cuda'): # h_down = input_zeros 55 | intermediate = torch.tanh(self.intermediate(latent)) 56 | intermediate_lines = torch.tanh(self.intermediate_lines(latent)) 57 | hc = (torch.tanh(self.h_up(intermediate)), torch.tanh(self.c_up(intermediate))) 58 | hc_lines = (torch.tanh(self.h_up_lines(intermediate_lines)), torch.tanh(self.c_up_lines(intermediate_lines))) 59 | out, out_lines = [], [] # gumbel_noise = uniform noise [0, 1] 60 | for i in range(self.rw_len): 61 | hh, cc = self.lstmcell(inputs, hc) 62 | hc = (hh, cc) 63 | h_up = self.W_up(hh) # blow up to dimension N using W_up 64 | h_sample = self.gumbel_softmax_sample(h_up, self.temp, device) 65 | inputs = self.W_down(h_sample) # back to dimension H (in netgan they reduce the dimension to d) 66 | out.append(h_sample) 67 | for j in range(self.rw_len): 68 | inputs_lines = self.W_down_lines(out[j]) 69 | hh_lines, cc_lines = self.lstmcell(inputs_lines, hc_lines) 70 | hc_lines = (hh_lines, cc_lines) 71 | hh_out = self.W_out_lines(hh_lines) 72 | out_lines.append(hh_out) 73 | return torch.stack(out, dim=1), torch.stack(out_lines, dim=1) 74 | 75 | def sample_latent(self, num_samples, device): 76 | return torch.randn((num_samples, self.latent_dim)).type(torch.float64).to(device) 77 | 78 | 79 | def sample(self, num_samples, device): 80 | noise = self.sample_latent(num_samples, device) 81 | input_zeros = self.init_hidden(num_samples).contiguous().type(torch.float64).to(device) 82 | generated_data, generated_weights = self(noise, input_zeros, device) 83 | return generated_data, generated_weights 84 | 85 | def sample_discrete(self, num_samples, device): 86 | with torch.no_grad(): 87 | proba, proba_weights = self.sample(num_samples, device) 88 | return np.argmax(proba.cpu().numpy(), axis=2), proba_weights.cpu().numpy() 89 | 90 | def sample_gumbel(self, logits, eps=1e-20): 91 | U = torch.rand(logits.shape, dtype=torch.float64) 92 | return -torch.log(-torch.log(U + eps) + eps) 93 | 94 | def gumbel_softmax_sample(self, logits, temperature, device, hard=True): 95 | """ Draw a sample from the Gumbel-Softmax distribution""" 96 | gumbel = self.sample_gumbel(logits).type(torch.float64).to(device) 97 | y = logits + gumbel 98 | y = torch.nn.functional.softmax(y / temperature, dim=1) 99 | if hard: 100 | y_hard = torch.max(y, 1, keepdim=True)[0].eq(y).type(torch.float64).to(device) 101 | y = (y_hard - y).detach() + y 102 | return y 103 | 104 | def init_hidden(self, batch_size): 105 | weight = next(self.parameters()).data 106 | return weight.new(batch_size, self.H_inputs).zero_().type(torch.float64) 107 | 108 | def freeze_params(self, state): 109 | if(state=='structure'): 110 | self.intermediate_lines.weight.requires_grad_(False) 111 | self.intermediate_lines.bias.requires_grad_(False) 112 | self.h_up_lines.weight.requires_grad_(False) 113 | self.h_up_lines.bias.requires_grad_(False) 114 | self.c_up_lines.weight.requires_grad_(False) 115 | self.c_up_lines.bias.requires_grad_(False) 116 | self.lstmcell_lines.cell.weight.requires_grad_(False) 117 | self.lstmcell_lines.cell.bias.requires_grad_(False) 118 | self.W_down_lines.weight.requires_grad_(False) 119 | self.W_out_lines.weight.requires_grad_(False) 120 | self.W_out_lines.bias.requires_grad_(False) 121 | 122 | self.intermediate.weight.requires_grad_(True) 123 | self.intermediate.bias.requires_grad_(True) 124 | self.h_up.weight.requires_grad_(True) 125 | self.h_up.bias.requires_grad_(True) 126 | self.c_up.weight.requires_grad_(True) 127 | self.c_up.bias.requires_grad_(True) 128 | self.lstmcell.cell.weight.requires_grad_(True) 129 | self.lstmcell.cell.bias.requires_grad_(True) 130 | self.W_down.weight.requires_grad_(True) 131 | self.W_up.weight.requires_grad_(True) 132 | self.W_up.bias.requires_grad_(True) 133 | else: # state = 'lines' 134 | self.intermediate_lines.weight.requires_grad_(True) 135 | self.intermediate_lines.bias.requires_grad_(True) 136 | self.h_up_lines.weight.requires_grad_(True) 137 | self.h_up_lines.bias.requires_grad_(True) 138 | self.c_up_lines.weight.requires_grad_(True) 139 | self.c_up_lines.bias.requires_grad_(True) 140 | self.lstmcell_lines.cell.weight.requires_grad_(True) 141 | self.lstmcell_lines.cell.bias.requires_grad_(True) 142 | self.W_down_lines.weight.requires_grad_(True) 143 | self.W_out_lines.weight.requires_grad_(True) 144 | self.W_out_lines.bias.requires_grad_(True) 145 | 146 | self.intermediate.weight.requires_grad_(False) 147 | self.intermediate.bias.requires_grad_(False) 148 | self.h_up.weight.requires_grad_(False) 149 | self.h_up.bias.requires_grad_(False) 150 | self.c_up.weight.requires_grad_(False) 151 | self.c_up.bias.requires_grad_(False) 152 | self.lstmcell.cell.weight.requires_grad_(False) 153 | self.lstmcell.cell.bias.requires_grad_(False) 154 | self.W_down.weight.requires_grad_(False) 155 | self.W_up.weight.requires_grad_(False) 156 | self.W_up.bias.requires_grad_(False) 157 | 158 | 159 | 160 | 161 | 162 | class Discriminator(nn.Module): 163 | def __init__(self, H_inputs, H, N, rw_len): 164 | ''' 165 | H_inputs: input dimension 166 | H: hidden dimension 167 | N: number of nodes (needed for the up and down projection) 168 | rw_len: number of LSTM cells 169 | ''' 170 | super(Discriminator, self).__init__() 171 | self.W_down = nn.Linear(N, H_inputs, bias=False).type(torch.float64) 172 | torch.nn.init.xavier_uniform_(self.W_down.weight) 173 | self.lstmcell = LSTMCell(H_inputs+1, H).type(torch.float64) 174 | self.lin_out = nn.Linear(H, 1, bias=True).type(torch.float64) 175 | torch.nn.init.xavier_uniform_(self.lin_out.weight) 176 | torch.nn.init.zeros_(self.lin_out.bias) 177 | self.H = H 178 | self.N = N 179 | self.rw_len = rw_len 180 | self.H_inputs = H_inputs 181 | 182 | #def forward(self, x_rw, x_weights): 183 | def forward(self, x): 184 | x_rw = x[:, :, :self.N] 185 | x_weights = x[:, :, -1:] 186 | x_rw = x_rw.view(-1, self.N) 187 | xa = self.W_down(x_rw) 188 | xa = xa.view(-1, self.rw_len, self.H_inputs) 189 | xc = torch.cat((xa, x_weights), dim=2) 190 | hc = self.init_hidden(xc.size(0)) 191 | for i in range(self.rw_len): 192 | hc = self.lstmcell(xc[:, i, :], hc) 193 | out = hc[0] 194 | pred = self.lin_out(out) 195 | return pred 196 | 197 | def init_inputs(self, num_samples): 198 | weight = next(self.parameters()).data 199 | return weight.new(num_samples, self.H_inputs).zero_().type(torch.float64) 200 | 201 | def init_hidden(self, num_samples): 202 | weight = next(self.parameters()).data 203 | return (weight.new(num_samples, self.H).zero_().contiguous().type(torch.float64), weight.new(num_samples, self.H).zero_().contiguous().type(torch.float64)) 204 | 205 | #def reset_weights(self): 206 | # import h5py 207 | # weights = h5py.File(r'C:\Users\Data Miner\PycharmProjects\Master_Projekt4\weights.h5', 'r') 208 | # self.W_down.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('W_down_discriminator')).T).type(torch.float64)) 209 | # self.lin_out.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('discriminator_out')).T).type(torch.float64)) 210 | # self.lin_out.bias = torch.nn.Parameter(torch.tensor(np.array(weights.get('discriminator_out_bias'))).type(torch.float64)) 211 | # self.lstmcell.cell.weight = torch.nn.Parameter(torch.tensor(np.array(weights.get('discriminator_lstm')).T).type(torch.float64)) 212 | # self.lstmcell.cell.bias = torch.nn.Parameter(torch.tensor(np.array(weights.get('discriminator_lstm_bias'))).type(torch.float64)) 213 | 214 | class LSTMCell(nn.Module): 215 | def __init__(self, input_size, hidden_size): 216 | super(LSTMCell, self).__init__() 217 | self.input_size = input_size 218 | self.hidden_size = hidden_size 219 | 220 | self.cell = nn.Linear(input_size+hidden_size, 4 * hidden_size, bias=True) 221 | torch.nn.init.xavier_uniform_(self.cell.weight) 222 | torch.nn.init.zeros_(self.cell.bias) 223 | 224 | def forward(self, x, hidden): 225 | hx, cx = hidden 226 | gates = torch.cat((x, hx), dim=1) 227 | gates = self.cell(gates) 228 | 229 | ingate, cellgate, forgetgate, outgate = gates.chunk(4, 1) 230 | 231 | ingate = torch.sigmoid(ingate) 232 | forgetgate = torch.sigmoid(torch.add(forgetgate, 1.0)) 233 | cellgate = torch.tanh(cellgate) 234 | outgate = torch.sigmoid(outgate) 235 | cy = torch.mul(cx, forgetgate) + torch.mul(ingate, cellgate) 236 | hy = torch.mul(outgate, torch.tanh(cy)) 237 | return (hy, cy) -------------------------------------------------------------------------------- /netgan_modified/training.py: -------------------------------------------------------------------------------- 1 | from models import Generator, Discriminator 2 | import utils 3 | 4 | import numpy as np 5 | import scipy.sparse as sp 6 | from sklearn.metrics import roc_auc_score, average_precision_score 7 | import math 8 | 9 | import torch 10 | import torch.optim as optim 11 | from torch.nn.functional import one_hot 12 | from torch.autograd import grad 13 | #from torch.utils.tensorboard import SummaryWriter 14 | import time 15 | from joblib import Parallel, delayed 16 | import pdb 17 | from matplotlib import pyplot as plt 18 | 19 | class Trainer(): 20 | def __init__(self, graph, graph_weighted, scaler, N, max_iterations=20000, rw_len=16, batch_size=128, H_gen=40, H_disc=30, H_inp=128, z_dim=16, lr=0.0003, n_critic=3, gp_weight=10.0, betas=(.5, .9), 21 | l2_penalty_disc=5e-5, l2_penalty_gen=1e-7, temp_start=5.0, temp_decay=1-5e-5, min_temp=0.5, val_share=0.1, test_share=0.05, seed=498164, delta=0.02): 22 | """ 23 | Initialize NetGAN. 24 | Parameters 25 | ---------- 26 | graph: scipy_sparse_matrix 27 | Graph 28 | N: int 29 | Number of nodes in the graph to generate. 30 | max_iterations: int, default: 40,000 31 | Maximal iterations if the stopping_criterion is not fulfilled. 32 | rw_len: int 33 | Length of random walks to generate. 34 | batch_size: int, default: 128 35 | The batch size. 36 | H_gen: int, default: 40 37 | The hidden_size of the generator. 38 | H_disc: int, default: 30 39 | The hidden_size of the discriminator 40 | H_inp: int, 128 41 | Inputsize of the LSTM-Cells 42 | z_dim: int, 16 43 | The dimension of the random noise that is used as input to the generator. 44 | lr: float, default: 0.0003 45 | The Learning rate will be used for the generator as well as for the discriminator. 46 | n_critic: int, default: 3 47 | The number of discriminator iterations per generator training iteration. 48 | gp_weight: float, default: 10 49 | Gradient penalty weight for the Wasserstein GAN. See the paper 'Improved Training of Wasserstein GANs' for more details. 50 | betas: tuple, default: (.5, .9) 51 | Decay rates of the Adam Optimizers. 52 | l2_penalty_gen: float, default: 1e-7 53 | L2 penalty on the generator weights. 54 | l2_penalty_disc: float, default: 5e-5 55 | L2 penalty on the discriminator weights. 56 | temp_start: float, default: 5.0 57 | The initial temperature for the Gumbel softmax. 58 | temp_decay: float, default: 1-5e-5 59 | After each evaluation, the current temperature is updated as 60 | current_temp := max(temperature_decay*current_temp, min_temperature) 61 | min_temp: float, default: 0.5 62 | The minimal temperature for the Gumbel softmax. 63 | val_share: float, default: 0.1 64 | Percentage of validation edges. 65 | test_share: float, default: 0.1 66 | Percentage of test edges. 67 | seed: int, default: 498164 68 | Seed for numpy.random. It is used for splitting the graph in train, validation and test sets. 69 | 70 | 71 | """ 72 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 73 | self.max_iterations = max_iterations 74 | self.rw_len = rw_len 75 | self.batch_size = batch_size 76 | self.N = N 77 | self.state = 'structure' 78 | self.generator = Generator(H_inputs=H_inp, H=H_gen, N=N, rw_len=rw_len, z_dim=z_dim, temp=temp_start).to(self.device) 79 | self.generator.freeze_params('structure') 80 | self.discriminator = Discriminator(H_inputs=H_inp, H=H_disc, N=N, rw_len=rw_len).to(self.device) 81 | self.G_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.generator.parameters()), lr=lr, betas=betas) 82 | self.D_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.discriminator.parameters()), lr=lr, betas=betas) 83 | self.n_critic = n_critic 84 | self.gp_weight = gp_weight 85 | self.l2_penalty_disc = l2_penalty_disc 86 | self.l2_penalty_gen =l2_penalty_gen 87 | self.temp_start = temp_start 88 | self.temp_decay = temp_decay 89 | self.min_temp = min_temp 90 | 91 | self.graph = graph 92 | self.graph_weighted = graph_weighted 93 | self.scaler = scaler 94 | self.train_ones, self.val_ones, self.val_zeros, self.test_ones, self.test_zeros = utils.train_val_test_split_adjacency(graph, val_share, test_share, seed, undirected=True, connected=True, asserts=True) 95 | self.train_graph = sp.coo_matrix((np.ones(len(self.train_ones)), (self.train_ones[:, 0], self.train_ones[:, 1]))).tocsr() 96 | assert (self.train_graph.toarray() == self.train_graph.toarray().T).all() 97 | self.walker = utils.RandomWalker(self.train_graph, self.graph_weighted, self.rw_len, p=1, q=1, batch_size=batch_size) 98 | self.eo = [] 99 | self.critic_loss = [] 100 | self.generator_loss = [] 101 | self.avp = [] 102 | self.roc_auc = [] 103 | self.best_performance = 0.0 104 | self.running = True 105 | self.delta = delta 106 | self.lr = lr 107 | self.betas = betas 108 | self.loss_val = [] 109 | self.loss_train = [] 110 | 111 | def l2_regularization_G(self, G): 112 | # regularizaation for the generator. W_down will not be regularized. 113 | l2_1 = torch.sum(torch.cat([x.view(-1) for x in G.W_down.weight]) ** 2 / 2) 114 | l2_2 = torch.sum(torch.cat([x.view(-1) for x in G.W_up.weight]) ** 2 / 2) 115 | l2_3 = torch.sum(torch.cat([x.view(-1) for x in G.W_up.bias]) ** 2 / 2) 116 | l2_4 = torch.sum(torch.cat([x.view(-1) for x in G.intermediate.weight]) ** 2 / 2) 117 | l2_5 = torch.sum(torch.cat([x.view(-1) for x in G.intermediate.bias]) ** 2 / 2) 118 | l2_6 = torch.sum(torch.cat([x.view(-1) for x in G.h_up.weight]) ** 2 / 2) 119 | l2_7 = torch.sum(torch.cat([x.view(-1) for x in G.h_up.bias]) ** 2 / 2) 120 | l2_8 = torch.sum(torch.cat([x.view(-1) for x in G.c_up.weight]) ** 2 / 2) 121 | l2_9 = torch.sum(torch.cat([x.view(-1) for x in G.c_up.bias]) ** 2 / 2) 122 | l2_10 = torch.sum(torch.cat([x.view(-1) for x in G.lstmcell.cell.weight]) ** 2 / 2) 123 | l2_11 = torch.sum(torch.cat([x.view(-1) for x in G.lstmcell.cell.bias]) ** 2 / 2) 124 | l2 = self.l2_penalty_gen * (l2_1 + l2_2 + l2_3 + l2_4 + l2_5 + l2_6 + l2_7 + l2_8 + l2_9 + l2_10 + l2_11) 125 | return l2 126 | 127 | def l2_regularization_G_lines(self, G): 128 | # regularizaation for the generator. W_down will not be regularized. 129 | l2_1 = torch.sum(torch.cat([x.view(-1) for x in G.W_down_lines.weight]) ** 2 / 2) 130 | l2_2 = torch.sum(torch.cat([x.view(-1) for x in G.h_up_lines.weight]) ** 2 / 2) 131 | l2_3 = torch.sum(torch.cat([x.view(-1) for x in G.h_up_lines.bias]) ** 2 / 2) 132 | l2_4 = torch.sum(torch.cat([x.view(-1) for x in G.c_up_lines.weight]) ** 2 / 2) 133 | l2_5 = torch.sum(torch.cat([x.view(-1) for x in G.c_up_lines.bias]) ** 2 / 2) 134 | l2_6 = torch.sum(torch.cat([x.view(-1) for x in G.lstmcell_lines.cell.weight]) ** 2 / 2) 135 | l2_7 = torch.sum(torch.cat([x.view(-1) for x in G.lstmcell_lines.cell.bias]) ** 2 / 2) 136 | l2_8 = torch.sum(torch.cat([x.view(-1) for x in G.W_out_lines.weight]) ** 2 / 2) 137 | l2_9 = torch.sum(torch.cat([x.view(-1) for x in G.W_out_lines.bias]) ** 2 / 2) 138 | l2_10 = torch.sum(torch.cat([x.view(-1) for x in G.intermediate_lines.weight]) ** 2 / 2) 139 | l2_11 = torch.sum(torch.cat([x.view(-1) for x in G.intermediate_lines.bias]) ** 2 / 2) 140 | l2 = self.l2_penalty_gen * (l2_1 + l2_2 + l2_3 + l2_4 + l2_5 + l2_6 + l2_7 + l2_8 + l2_9 + l2_10 + l2_11) 141 | return l2 142 | 143 | def l2_regularization_D(self, D): 144 | # regularizaation for the discriminator. W_down will not be regularized. 145 | l2_1 = torch.sum(torch.cat([x.view(-1) for x in D.W_down.weight]) ** 2 / 2) 146 | l2_2 = torch.sum(torch.cat([x.view(-1) for x in D.lstmcell.cell.weight]) ** 2 / 2) 147 | l2_3 = torch.sum(torch.cat([x.view(-1) for x in D.lstmcell.cell.bias]) ** 2 / 2) 148 | l2_4 = torch.sum(torch.cat([x.view(-1) for x in D.lin_out.weight]) ** 2 / 2) 149 | l2_5 = torch.sum(torch.cat([x.view(-1) for x in D.lin_out.bias]) ** 2 / 2) 150 | l2 = self.l2_penalty_disc * (l2_1 + l2_2 + l2_3 + l2_4 + l2_5) 151 | return l2 152 | 153 | 154 | #def calc_gp(self, fake_inputs, fake_inputs_weights, real_inputs, real_inputs_weights): 155 | def calc_gp(self, fake_inputs, real_inputs): 156 | # calculate the gradient penalty. For more details see the paper 'Improved Training of Wasserstein GANs'. 157 | alpha = torch.rand((self.batch_size, 1, 1), dtype=torch.float64).to(self.device) 158 | differences = fake_inputs - real_inputs 159 | interpolates = real_inputs + alpha * differences 160 | y_pred_interpolates = self.discriminator(interpolates) 161 | gradients = grad(outputs=y_pred_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(y_pred_interpolates), create_graph=True, retain_graph=True)[0] 162 | slopes = torch.sqrt(torch.sum(gradients ** 2, dim=[1, 2])) 163 | gradient_penalty = torch.mean((slopes - 1) ** 2) 164 | gradient_penalty = gradient_penalty * self.gp_weight 165 | return gradient_penalty 166 | 167 | def critic_train_iteration(self): 168 | self.D_optimizer.zero_grad() 169 | # create fake and real inputs 170 | fake_inputs_rw, fake_inputs_weights = self.generator.sample(self.batch_size, self.device) 171 | random_walks, weights = self.walker.walk() 172 | 173 | real_inputs = one_hot(torch.tensor(random_walks, dtype=torch.int64), num_classes=self.N).type(torch.float64).to(self.device) 174 | real_inputs_weights = torch.tensor(weights, dtype=torch.float64).to(self.device) 175 | if(self.state=='structure'): 176 | real_inputs_weights = torch.zeros_like(real_inputs_weights) 177 | fake_inputs_weights = torch.zeros_like(fake_inputs_weights) 178 | real_inputs = torch.cat((real_inputs, real_inputs_weights), dim=2) 179 | fake_inputs = torch.cat((fake_inputs_rw, fake_inputs_weights), dim=2) 180 | 181 | y_pred_fake = self.discriminator(fake_inputs) 182 | y_pred_real = self.discriminator(real_inputs) 183 | 184 | gp = self.calc_gp(fake_inputs, real_inputs) # gradient penalty 185 | disc_cost = torch.mean(y_pred_fake) - torch.mean(y_pred_real) + gp + self.l2_regularization_D(self.discriminator) 186 | 187 | disc_cost.backward() 188 | self.D_optimizer.step() 189 | return disc_cost.item() 190 | 191 | def generator_train_iteration(self): 192 | self.generator.train() 193 | self.G_optimizer.zero_grad() 194 | fake_inputs_rw, fake_weights = self.generator.sample(self.batch_size, self.device) 195 | if(self.state=='structure'): fake_weights = torch.zeros_like(fake_weights) 196 | fake_inputs = torch.cat((fake_inputs_rw, fake_weights), dim=2) 197 | y_pred_fake = self.discriminator(fake_inputs) 198 | 199 | if(self.state=='structure'): gen_cost = -torch.mean(y_pred_fake) + self.l2_regularization_G(self.generator) 200 | else: gen_cost = -torch.mean(y_pred_fake) + self.l2_regularization_G_lines(self.generator) 201 | gen_cost.backward() 202 | self.G_optimizer.step() 203 | return gen_cost.item() 204 | 205 | def create_graph(self, num_samples, i, batch_size=1000, reset_weights=False): 206 | if reset_weights: 207 | self.generator.reset_weights() 208 | self.generator.eval() 209 | 210 | self.generator.temp = 0.5 211 | samples, samples_lines = [], [] 212 | num_iterations = int(num_samples/batch_size) 213 | print("Number iterations: " + str(num_iterations)) 214 | for j in range(num_iterations): 215 | if(j%10 == 1): print('{}/{}'.format(j, num_iterations)) 216 | rw_smpls, lines_smpls = self.generator.sample_discrete(batch_size, self.device) 217 | samples.append(rw_smpls) 218 | samples_lines.append(lines_smpls) 219 | samples = np.vstack(samples) 220 | gr, gr_weights = utils.score_matrix_from_random_walks(samples, self.N, samples_lines) 221 | gr = gr.tocsr() 222 | 223 | # Assemble a graph from the score matrix 224 | _graph = utils.graph_from_scores(gr, self.graph.sum()) 225 | 226 | # Compute edge overlap 227 | edge_overlap = utils.edge_overlap(self.graph.toarray(), _graph) 228 | edge_scores = np.append(gr[tuple(self.val_ones.T)].A1, gr[tuple(self.val_zeros.T)].A1) 229 | actual_labels_val = np.append(np.ones(len(self.val_ones)), np.zeros(len(self.val_zeros))) 230 | # Compute Validation ROC-AUC and average precision scores. 231 | self.roc_auc.append(roc_auc_score(actual_labels_val, edge_scores)) 232 | self.avp.append(average_precision_score(actual_labels_val, edge_scores)) 233 | self.eo.append(edge_overlap/self.graph.sum()) 234 | 235 | loss_lines = utils.calc_lines_mse(self.train_ones, self.graph_weighted, gr_weights) 236 | loss_lines_val = utils.calc_lines_mse(self.val_ones, self.graph_weighted, gr_weights) 237 | print('roc: {:.4f} avp: {:.4f} eo: {:.4f}'.format(self.roc_auc[-1], self.avp[-1], self.eo[-1])) 238 | print('loss_lines Train: {}'.format(loss_lines)) 239 | print('loss_lines Val: {}'.format(loss_lines_val)) 240 | self.loss_train.append(loss_lines) 241 | self.loss_val.append(loss_lines_val) 242 | self.generator.temp = np.maximum(self.temp_start * np.exp(-(1 - self.temp_decay) * i), self.min_temp) 243 | 244 | 245 | def check_running(self, i): 246 | torch.save(self, str(i+1)+ '_model.pt') 247 | if(self.state=='structure'): 248 | if (self.stopping_criterion == 'val'): 249 | if (self.roc_auc[-1] + self.avp[-1] > self.best_performance + self.delta): 250 | self.best_performance = self.roc_auc[-1] + self.avp[-1] 251 | self.patience = self.max_patience 252 | else: 253 | self.patience -= 1 254 | 255 | if self.patience == 0: 256 | print('Strucutre is finished after {} iterations, start training line length'.format(i)) 257 | self.state = 'lines' 258 | self.generator.freeze_params('lines') 259 | self.G_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.generator.parameters()), lr=self.lr, betas=self.betas) 260 | 261 | else: 262 | if (self.stopping_eo < self.eo[-1]): 263 | print('Strucutre is finished after {} iterations, start training line length'.format(i)) 264 | self.state = 'lines' 265 | self.generator.freeze_params('lines') 266 | self.G_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.generator.parameters()), lr=self.lr, betas=self.betas) 267 | 268 | def initialize_validation_settings(self, stopping_criterion, stopping_eo, max_patience): 269 | self.stopping_criterion = stopping_criterion 270 | self.stopping_eo = stopping_eo # needed for 'eo' criterion # 271 | self.max_patience = max_patience # needed for 'val' criterion 272 | self.patience = max_patience # 273 | if (self.stopping_criterion == 'val'): 274 | print("**** Using VAL criterion for early stopping with max patience of: {}****".format(self.max_patience)) 275 | else: 276 | assert self.stopping_eo is not None, "stopping_eo is not a float" 277 | print("**** Using EO criterion of {} for early stopping".format(self.stopping_eo)) 278 | 279 | def plot_graph(self): 280 | if len(self.critic_loss) > 10: 281 | plt.plot(self.critic_loss[9::], label="Critic loss") 282 | plt.plot(self.generator_loss[9::], label="Generator loss") 283 | else: 284 | plt.plot(self.critic_loss, label="Critic loss") 285 | plt.plot(self.generator_loss, label="Generator loss") 286 | plt.legend() 287 | plt.show() 288 | 289 | def train(self, create_graph_every = 2000, plot_graph_every=500, num_samples_graph = 100000, stopping_criterion='val', max_patience=5, stopping_eo=None, i_start=0): 290 | """ 291 | create_graph_every: int, default: 2000 292 | Creates every nth iteration a graph from randomwalks. 293 | plot_graph_every: int, default: 2000 294 | Plots the lost functions of the generator and discriminator. 295 | num_samples_graph: int, default 10000 296 | Number of random walks that will be created for the graphs. Higher values mean more precise evaluations but also more computational time. 297 | stopping_criterion: str, default: 'val' 298 | The stopping_criterion can be either 'val' or 'eo': 299 | 'val': Stops the optimization if there are no improvements after several iterations. --> defined by max_patience 300 | 'eo': Stops if the edge overlap exceeds a certain treshold. --> defined by stopping_eo 301 | max_patience: int, default: 5 302 | Maximum evaluation steps without improvement of the validation accuracy to tolerate. Only 303 | applies to the VAL criterion. 304 | stopping_eo: float in (0,1], default: 0.5 305 | Stops when the edge overlap exceeds this threshold. Will be used when stopping_criterion is 'eo'. 306 | """ 307 | self.initialize_validation_settings(stopping_criterion, stopping_eo, max_patience) 308 | starting_time = time.time() 309 | # Start Training 310 | for i in np.arange(i_start, self.max_iterations): 311 | if(self.running): 312 | self.critic_loss.append(np.mean([self.critic_train_iteration() for _ in range(self.n_critic)])) 313 | self.generator_loss.append(self.generator_train_iteration()) 314 | if (i % create_graph_every == create_graph_every-1): 315 | self.create_graph(num_samples_graph, i) 316 | self.check_running(i) 317 | print('Took {} minutes so far..'.format((time.time() - starting_time)/60)) 318 | if plot_graph_every > 0 and (i + 1) % plot_graph_every == 0: self.plot_graph() 319 | -------------------------------------------------------------------------------- /netgan_modified/utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import scipy.sparse as sp 3 | import numpy as np 4 | from scipy.sparse.csgraph import connected_components, minimum_spanning_tree 5 | import warnings 6 | import pandas as pd 7 | from matplotlib import pyplot as plt 8 | #import igraph 9 | #import powerlaw 10 | from numba import jit 11 | import pdb 12 | from sklearn.preprocessing import StandardScaler 13 | 14 | def load_npz(file_name): 15 | """Load a SparseGraph from a Numpy binary file. 16 | 17 | Parameters 18 | ---------- 19 | file_name : str 20 | Name of the file to load. 21 | 22 | Returns 23 | ------- 24 | sparse_graph : gust.SparseGraph 25 | Graph in sparse matrix format. 26 | 27 | """ 28 | if not file_name.endswith('.npz'): 29 | file_name += '.npz' 30 | with np.load(file_name, allow_pickle=True) as loader: 31 | #pdb.set_trace() 32 | loader = dict(loader)['arr_0'].item() 33 | adj_matrix = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], 34 | loader['adj_indptr']), shape=loader['adj_shape']) 35 | 36 | if 'attr_data' in loader: 37 | attr_matrix = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], 38 | loader['attr_indptr']), shape=loader['attr_shape']) 39 | else: 40 | attr_matrix = None 41 | 42 | labels = loader.get('labels') 43 | 44 | return adj_matrix, attr_matrix, labels 45 | 46 | 47 | def largest_connected_components(adj, n_components=1): 48 | """Select the largest connected components in the graph. 49 | 50 | Parameters 51 | ---------- 52 | sparse_graph : gust.SparseGraph 53 | Input graph. 54 | n_components : int, default 1 55 | Number of largest connected components to keep. 56 | 57 | Returns 58 | ------- 59 | sparse_graph : gust.SparseGraph 60 | Subgraph of the input graph where only the nodes in largest n_components are kept. 61 | 62 | """ 63 | _, component_indices = connected_components(adj) 64 | component_sizes = np.bincount(component_indices) 65 | components_to_keep = np.argsort(component_sizes)[::-1][:n_components] # reverse order to sort descending 66 | nodes_to_keep = [ 67 | idx for (idx, component) in enumerate(component_indices) if component in components_to_keep 68 | 69 | 70 | ] 71 | print("Selecting {0} largest connected components".format(n_components)) 72 | return nodes_to_keep 73 | 74 | 75 | def edges_to_sparse(edges, N, values=None): 76 | """ 77 | Create a sparse adjacency matrix from an array of edge indices and (optionally) values. 78 | 79 | Parameters 80 | ---------- 81 | edges : array-like, shape [n_edges, 2] 82 | Edge indices 83 | N : int 84 | Number of nodes 85 | values : array_like, shape [n_edges] 86 | The values to put at the specified edge indices. Optional, default: np.ones(.) 87 | 88 | Returns 89 | ------- 90 | A : scipy.sparse.csr.csr_matrix 91 | Sparse adjacency matrix 92 | 93 | """ 94 | if values is None: 95 | values = np.ones(edges.shape[0]) 96 | 97 | return sp.coo_matrix((values, (edges[:, 0], edges[:, 1])), shape=(N, N)).tocsr() 98 | 99 | 100 | def train_val_test_split_adjacency(A, p_val=0.10, p_test=0.05, seed=0, neg_mul=1, 101 | every_node=True, connected=False, undirected=False, 102 | use_edge_cover=True, set_ops=True, asserts=False): 103 | """ 104 | Split the edges of the adjacency matrix into train, validation and test edges 105 | and randomly samples equal amount of validation and test non-edges. 106 | 107 | Parameters 108 | ---------- 109 | A : scipy.sparse.spmatrix 110 | Sparse unweighted adjacency matrix 111 | p_val : float 112 | Percentage of validation edges. Default p_val=0.10 113 | p_test : float 114 | Percentage of test edges. Default p_test=0.05 115 | seed : int 116 | Seed for numpy.random. Default seed=0 117 | neg_mul : int 118 | What multiplicity of negative samples (non-edges) to have in the test/validation set 119 | w.r.t the number of edges, i.e. len(non-edges) = L * len(edges). Default neg_mul=1 120 | every_node : bool 121 | Make sure each node appears at least once in the train set. Default every_node=True 122 | connected : bool 123 | Make sure the training graph is still connected after the split 124 | undirected : bool 125 | Whether to make the split undirected, that is if (i, j) is in val/test set then (j, i) is there as well. 126 | Default undirected=False 127 | use_edge_cover: bool 128 | Whether to use (approximate) edge_cover to find the minimum set of edges that cover every node. 129 | Only active when every_node=True. Default use_edge_cover=True 130 | set_ops : bool 131 | Whether to use set operations to construction the test zeros. Default setwise_zeros=True 132 | Otherwise use a while loop. 133 | asserts : bool 134 | Unit test like checks. Default asserts=False 135 | 136 | Returns 137 | ------- 138 | train_ones : array-like, shape [n_train, 2] 139 | Indices of the train edges 140 | val_ones : array-like, shape [n_val, 2] 141 | Indices of the validation edges 142 | val_zeros : array-like, shape [n_val, 2] 143 | Indices of the validation non-edges 144 | test_ones : array-like, shape [n_test, 2] 145 | Indices of the test edges 146 | test_zeros : array-like, shape [n_test, 2] 147 | Indices of the test non-edges 148 | 149 | """ 150 | 151 | assert p_val + p_test > 0 152 | assert A.max() == 1 # no weights 153 | assert A.min() == 0 # no negative edges 154 | assert A.diagonal().sum() == 0 # no self-loops 155 | assert not np.any(A.sum(0).A1 + A.sum(1).A1 == 0) # no dangling nodes 156 | 157 | is_undirected = (A != A.T).nnz == 0 158 | 159 | if undirected: 160 | assert is_undirected # make sure is directed 161 | A = sp.tril(A).tocsr() # consider only upper triangular 162 | A.eliminate_zeros() 163 | else: 164 | if is_undirected: 165 | warnings.warn('Graph appears to be undirected. Did you forgot to set undirected=True?') 166 | 167 | np.random.seed(seed) 168 | 169 | E = A.nnz 170 | N = A.shape[0] 171 | 172 | s_train = int(E * (1 - p_val - p_test)) 173 | 174 | idx = np.arange(N) 175 | 176 | # hold some edges so each node appears at least once 177 | if every_node: 178 | if connected: 179 | assert connected_components(A)[0] == 1 # make sure original graph is connected 180 | A_hold = minimum_spanning_tree(A) 181 | else: 182 | A.eliminate_zeros() # makes sure A.tolil().rows contains only indices of non-zero elements 183 | d = A.sum(1).A1 184 | 185 | if use_edge_cover: 186 | hold_edges = np.array(list(nx.maximal_matching(nx.DiGraph(A)))) 187 | not_in_cover = np.array(list(set(range(N)).difference(hold_edges.flatten()))) 188 | 189 | # makes sure the training percentage is not smaller than N/E when every_node is set to True 190 | min_size = hold_edges.shape[0] + len(not_in_cover) 191 | if min_size > s_train: 192 | raise ValueError('Training percentage too low to guarantee every node. Min train size needed {:.2f}' 193 | .format(min_size / E)) 194 | 195 | d_nic = d[not_in_cover] 196 | 197 | hold_edges_d1 = np.column_stack((not_in_cover[d_nic > 0], 198 | np.row_stack(map(np.random.choice, 199 | A[not_in_cover[d_nic > 0]].tolil().rows)))) 200 | 201 | if np.any(d_nic == 0): 202 | hold_edges_d0 = np.column_stack((np.row_stack(map(np.random.choice, A[:, not_in_cover[d_nic == 0]].T.tolil().rows)), 203 | not_in_cover[d_nic == 0])) 204 | hold_edges = np.row_stack((hold_edges, hold_edges_d0, hold_edges_d1)) 205 | else: 206 | hold_edges = np.row_stack((hold_edges, hold_edges_d1)) 207 | 208 | else: 209 | # makes sure the training percentage is not smaller than N/E when every_node is set to True 210 | if N > s_train: 211 | raise ValueError('Training percentage too low to guarantee every node. Min train size needed {:.2f}' 212 | .format(N / E)) 213 | 214 | hold_edges_d1 = np.column_stack( 215 | (idx[d > 0], np.row_stack(map(np.random.choice, A[d > 0].tolil().rows)))) 216 | 217 | if np.any(d == 0): 218 | hold_edges_d0 = np.column_stack((np.row_stack(map(np.random.choice, A[:, d == 0].T.tolil().rows)), 219 | idx[d == 0])) 220 | hold_edges = np.row_stack((hold_edges_d0, hold_edges_d1)) 221 | else: 222 | hold_edges = hold_edges_d1 223 | 224 | if asserts: 225 | assert np.all(A[hold_edges[:, 0], hold_edges[:, 1]]) 226 | assert len(np.unique(hold_edges.flatten())) == N 227 | 228 | A_hold = edges_to_sparse(hold_edges, N) 229 | 230 | A_hold[A_hold > 1] = 1 231 | A_hold.eliminate_zeros() 232 | A_sample = A - A_hold 233 | 234 | s_train = s_train - A_hold.nnz 235 | else: 236 | A_sample = A 237 | 238 | idx_ones = np.random.permutation(A_sample.nnz) 239 | ones = np.column_stack(A_sample.nonzero()) 240 | train_ones = ones[idx_ones[:s_train]] 241 | test_ones = ones[idx_ones[s_train:]] 242 | 243 | # return back the held edges 244 | if every_node: 245 | train_ones = np.row_stack((train_ones, np.column_stack(A_hold.nonzero()))) 246 | 247 | n_test = len(test_ones) * neg_mul 248 | if set_ops: 249 | # generate slightly more completely random non-edge indices than needed and discard any that hit an edge 250 | # much faster compared a while loop 251 | # in the future: estimate the multiplicity (currently fixed 1.3/2.3) based on A_obs.nnz 252 | if undirected: 253 | random_sample = np.random.randint(0, N, [int(2.3 * n_test), 2]) 254 | random_sample = random_sample[random_sample[:, 0] > random_sample[:, 1]] 255 | else: 256 | random_sample = np.random.randint(0, N, [int(1.3 * n_test), 2]) 257 | random_sample = random_sample[random_sample[:, 0] != random_sample[:, 1]] 258 | 259 | test_zeros = random_sample[A[random_sample[:, 0], random_sample[:, 1]].A1 == 0] 260 | test_zeros = np.row_stack(test_zeros)[:n_test] 261 | assert test_zeros.shape[0] == n_test 262 | else: 263 | test_zeros = [] 264 | while len(test_zeros) < n_test: 265 | i, j = np.random.randint(0, N, 2) 266 | if A[i, j] == 0 and (not undirected or i > j) and (i, j) not in test_zeros: 267 | test_zeros.append((i, j)) 268 | test_zeros = np.array(test_zeros) 269 | 270 | # split the test set into validation and test set 271 | s_val_ones = int(len(test_ones) * p_val / (p_val + p_test)) 272 | s_val_zeros = int(len(test_zeros) * p_val / (p_val + p_test)) 273 | 274 | val_ones = test_ones[:s_val_ones] 275 | test_ones = test_ones[s_val_ones:] 276 | 277 | val_zeros = test_zeros[:s_val_zeros] 278 | test_zeros = test_zeros[s_val_zeros:] 279 | 280 | if undirected: 281 | # put (j, i) edges for every (i, j) edge in the respective sets and form back original A 282 | symmetrize = lambda x: np.row_stack((x, np.column_stack((x[:, 1], x[:, 0])))) 283 | train_ones = symmetrize(train_ones) 284 | val_ones = symmetrize(val_ones) 285 | val_zeros = symmetrize(val_zeros) 286 | test_ones = symmetrize(test_ones) 287 | test_zeros = symmetrize(test_zeros) 288 | A = A.maximum(A.T) 289 | 290 | if asserts: 291 | set_of_train_ones = set(map(tuple, train_ones)) 292 | assert train_ones.shape[0] + test_ones.shape[0] + val_ones.shape[0] == A.nnz 293 | assert (edges_to_sparse(np.row_stack((train_ones, test_ones, val_ones)), N) != A).nnz == 0 294 | assert set_of_train_ones.intersection(set(map(tuple, test_ones))) == set() 295 | assert set_of_train_ones.intersection(set(map(tuple, val_ones))) == set() 296 | assert set_of_train_ones.intersection(set(map(tuple, test_zeros))) == set() 297 | assert set_of_train_ones.intersection(set(map(tuple, val_zeros))) == set() 298 | assert len(set(map(tuple, test_zeros))) == len(test_ones) * neg_mul 299 | assert len(set(map(tuple, val_zeros))) == len(val_ones) * neg_mul 300 | assert not connected or connected_components(A_hold)[0] == 1 301 | assert not every_node or ((A_hold - A) > 0).sum() == 0 302 | 303 | 304 | return train_ones, val_ones, val_zeros, test_ones, test_zeros 305 | 306 | 307 | def score_matrix_from_random_walks(random_walks, N, weights_walks, symmetric=True): 308 | """ 309 | Compute the transition scores, i.e. how often a transition occurs, for all node pairs from 310 | the random walks provided. 311 | Parameters 312 | ---------- 313 | random_walks: np.array of shape (n_walks, rw_len, N) 314 | The input random walks to count the transitions in. 315 | N: int 316 | The number of nodes 317 | symmetric: bool, default: True 318 | Whether to symmetrize the resulting scores matrix. 319 | 320 | Returns 321 | ------- 322 | scores_matrix: sparse matrix, shape (N, N) 323 | Matrix whose entries (i,j) correspond to the number of times a transition from node i to j was 324 | observed in the input random walks. 325 | 326 | """ 327 | random_walks = np.array(random_walks) 328 | weights_walks = np.vstack(weights_walks)[:, 1:, :].flatten() 329 | bigrams = np.array(list(zip(random_walks[:, :-1], random_walks[:, 1:]))) 330 | bigrams = np.transpose(bigrams, [0, 2, 1]) 331 | bigrams = bigrams.reshape([-1, 2]) 332 | weight_matrix = np.zeros((N, N)) 333 | for i, (x, y) in enumerate(bigrams): 334 | weight_matrix[x, y] += weights_walks[i] 335 | weight_matrix = weight_matrix + weight_matrix.T 336 | if symmetric: 337 | bigrams = np.row_stack((bigrams, bigrams[:, ::-1])) 338 | 339 | mat = sp.coo_matrix((np.ones(bigrams.shape[0]), (bigrams[:, 0], bigrams[:, 1])), 340 | shape=[N, N]) 341 | weight_matrix = np.divide(weight_matrix, mat.toarray(), where=mat.toarray()!=0) 342 | return mat, weight_matrix 343 | 344 | #@jit(nopython=True) 345 | def random_walk(edges, graph_weighted, node_ixs, rwlen, p=1, q=1, n_walks=1): 346 | N=len(node_ixs) 347 | 348 | walk = [] 349 | prev_nbs = None 350 | for w in range(n_walks): 351 | source_node = np.random.choice(N) 352 | walk.append(source_node) 353 | for it in range(rwlen-1): 354 | 355 | if walk[-1] == N-1: 356 | nbs = edges[node_ixs[walk[-1]]::,1] 357 | else: 358 | nbs = edges[node_ixs[walk[-1]]:node_ixs[walk[-1]+1],1] 359 | 360 | if it == 0: 361 | walk.append(np.random.choice(nbs)) 362 | prev_nbs = set(nbs) 363 | continue 364 | 365 | is_dist_1 = [] 366 | for n in nbs: 367 | is_dist_1.append(int(n in set(prev_nbs))) 368 | 369 | is_dist_1_np = np.array(is_dist_1) 370 | is_dist_0 = nbs == walk[-2] 371 | is_dist_2 = 1 - is_dist_1_np - is_dist_0 372 | 373 | alpha_pq = is_dist_0 / p + is_dist_1_np + is_dist_2/q 374 | alpha_pq_norm = alpha_pq/np.sum(alpha_pq) 375 | rdm_num = np.random.rand() 376 | cumsum = np.cumsum(alpha_pq_norm) 377 | nxt = nbs[np.sum(1-(cumsum > rdm_num))] 378 | walk.append(nxt) 379 | prev_nbs = set(nbs) 380 | rw_complete = np.array(walk).reshape([-1, rwlen]) 381 | weights = np.zeros((rw_complete.shape[0], rw_complete.shape[1], 1)) 382 | weights[:, 1:, :] = np.expand_dims(long_rw_to_single_rw(rw_complete, graph_weighted), axis=2) 383 | return rw_complete, weights 384 | 385 | class RandomWalker: 386 | """ 387 | Helper class to generate random walks on the input adjacency matrix. 388 | """ 389 | def __init__(self, adj, adj_weighted, rw_len, p=1, q=1, batch_size=128): 390 | self.adj = adj 391 | #if not "lil" in str(type(adj)): 392 | # warnings.warn("Input adjacency matrix not in lil format. Converting it to lil.") 393 | # self.adj = self.adj.tolil() 394 | 395 | self.rw_len = rw_len 396 | self.p = p 397 | self.q = q 398 | self.edges = np.array(self.adj.nonzero()).T 399 | self.node_ixs = np.unique(self.edges[:, 0], return_index=True)[1] 400 | self.batch_size = batch_size 401 | self.adj_weighted = adj_weighted 402 | 403 | def walk(self): 404 | return random_walk(self.edges, self.adj_weighted, self.node_ixs, self.rw_len, self.p, self.q, self.batch_size) 405 | #while True: 406 | # yield random_walk(self.edges, self.adj_weighted, self.node_ixs, self.rw_len, self.p, self.q, self.batch_size).reshape([-1, self.rw_len]) 407 | 408 | 409 | 410 | def edge_overlap(A, B): 411 | """ 412 | Compute edge overlap between input graphs A and B, i.e. how many edges in A are also present in graph B. Assumes 413 | that both graphs contain the same number of edges. 414 | 415 | Parameters 416 | ---------- 417 | A: sparse matrix or np.array of shape (N,N). 418 | First input adjacency matrix. 419 | B: sparse matrix or np.array of shape (N,N). 420 | Second input adjacency matrix. 421 | 422 | Returns 423 | ------- 424 | float, the edge overlap. 425 | """ 426 | 427 | return ((A == B) & (A == 1)).sum() 428 | 429 | 430 | def graph_from_scores(scores, n_edges): 431 | """ 432 | Assemble a symmetric binary graph from the input score matrix. Ensures that there will be no singleton nodes. 433 | See the paper for details. 434 | 435 | Parameters 436 | ---------- 437 | scores: np.array of shape (N,N) 438 | The input transition scores. 439 | n_edges: int 440 | The desired number of edges in the target graph. 441 | 442 | Returns 443 | ------- 444 | target_g: symmettic binary sparse matrix of shape (N,N) 445 | The assembled graph. 446 | 447 | """ 448 | 449 | if len(scores.nonzero()[0]) < n_edges: 450 | return symmetric(scores) > 0 451 | target_g = np.zeros(scores.shape) # initialize target graph 452 | scores_int = scores.toarray().copy() # internal copy of the scores matrix 453 | scores_int[np.diag_indices_from(scores_int)] = 0 # set diagonal to zero 454 | degrees_int = scores_int.sum(0) # The row sum over the scores. 455 | 456 | N = scores.shape[0] 457 | 458 | for n in np.random.choice(N, replace=False, size=N): # Iterate the nodes in random order 459 | 460 | row = scores_int[n,:].copy() 461 | if row.sum() == 0: 462 | continue 463 | 464 | probs = row / row.sum() 465 | 466 | target = np.random.choice(N, p=probs) 467 | target_g[n, target] = 1 468 | target_g[target, n] = 1 469 | 470 | diff = np.round((n_edges - target_g.sum())/2) 471 | if diff > 0: 472 | triu = np.triu(scores_int) 473 | triu[target_g > 0] = 0 474 | triu[np.diag_indices_from(scores_int)] = 0 475 | triu = triu / triu.sum() 476 | 477 | triu_ixs = np.triu_indices_from(scores_int) 478 | extra_edges = np.random.choice(triu_ixs[0].shape[0], replace=False, p=triu[triu_ixs], size=int(diff)) 479 | 480 | target_g[(triu_ixs[0][extra_edges], triu_ixs[1][extra_edges])] = 1 481 | target_g[(triu_ixs[1][extra_edges], triu_ixs[0][extra_edges])] = 1 482 | 483 | target_g = symmetric(target_g) 484 | return target_g 485 | 486 | 487 | def symmetric(directed_adjacency, clip_to_one=True): 488 | """ 489 | Symmetrize the input adjacency matrix. 490 | Parameters 491 | ---------- 492 | directed_adjacency: sparse matrix or np.array of shape (N,N) 493 | Input adjacency matrix. 494 | clip_to_one: bool, default: True 495 | Whether the output should be binarized (i.e. clipped to 1) 496 | 497 | Returns 498 | ------- 499 | A_symmetric: sparse matrix or np.array of the same shape as the input 500 | Symmetrized adjacency matrix. 501 | 502 | """ 503 | 504 | A_symmetric = directed_adjacency + directed_adjacency.T 505 | if clip_to_one: 506 | A_symmetric[A_symmetric > 1] = 1 507 | return A_symmetric 508 | 509 | def squares(g): 510 | """ 511 | Count the number of squares for each node 512 | Parameters 513 | ---------- 514 | g: igraph Graph object 515 | The input graph. 516 | 517 | Returns 518 | ------- 519 | List with N entries (N is number of nodes) that give the number of squares a node is part of. 520 | """ 521 | 522 | cliques = g.cliques(min=4, max=4) 523 | result = [0] * g.vcount() 524 | for i, j, k, l in cliques: 525 | result[i] += 1 526 | result[j] += 1 527 | result[k] += 1 528 | result[l] += 1 529 | return result 530 | 531 | 532 | def statistics_degrees(A_in): 533 | """ 534 | Compute min, max, mean degree 535 | 536 | Parameters 537 | ---------- 538 | A_in: sparse matrix or np.array 539 | The input adjacency matrix. 540 | Returns 541 | ------- 542 | d_max. d_min, d_mean 543 | """ 544 | 545 | degrees = A_in.sum(axis=0) 546 | return np.max(degrees), np.min(degrees), np.mean(degrees) 547 | 548 | 549 | def statistics_LCC(A_in): 550 | """ 551 | Compute the size of the largest connected component (LCC) 552 | 553 | Parameters 554 | ---------- 555 | A_in: sparse matrix or np.array 556 | The input adjacency matrix. 557 | Returns 558 | ------- 559 | Size of LCC 560 | 561 | """ 562 | 563 | unique, counts = np.unique(connected_components(A_in)[1], return_counts=True) 564 | LCC = np.where(connected_components(A_in)[1] == np.argmax(counts))[0] 565 | return LCC 566 | 567 | 568 | def statistics_wedge_count(A_in): 569 | """ 570 | Compute the wedge count of the input graph 571 | 572 | Parameters 573 | ---------- 574 | A_in: sparse matrix or np.array 575 | The input adjacency matrix. 576 | 577 | Returns 578 | ------- 579 | The wedge count. 580 | """ 581 | 582 | degrees = A_in.sum(axis=0) 583 | return float(np.sum(np.array([0.5 * x * (x - 1) for x in degrees]))) 584 | 585 | 586 | def statistics_claw_count(A_in): 587 | """ 588 | Compute the claw count of the input graph 589 | 590 | Parameters 591 | ---------- 592 | A_in: sparse matrix or np.array 593 | The input adjacency matrix. 594 | 595 | Returns 596 | ------- 597 | Claw count 598 | """ 599 | 600 | degrees = A_in.sum(axis=0) 601 | return float(np.sum(np.array([1 / 6. * x * (x - 1) * (x - 2) for x in degrees]))) 602 | 603 | 604 | def statistics_triangle_count(A_in): 605 | """ 606 | Compute the triangle count of the input graph 607 | 608 | Parameters 609 | ---------- 610 | A_in: sparse matrix or np.array 611 | The input adjacency matrix. 612 | Returns 613 | ------- 614 | Triangle count 615 | """ 616 | 617 | A_graph = nx.from_numpy_matrix(A_in) 618 | triangles = nx.triangles(A_graph) 619 | t = np.sum(list(triangles.values())) / 3 620 | return int(t) 621 | 622 | 623 | #def statistics_square_count(A_in): 624 | # """ 625 | # Compute the square count of the input graph 626 | # 627 | # Parameters 628 | # ---------- 629 | # A_in: sparse matrix or np.array 630 | # The input adjacency matrix. 631 | # Returns 632 | # ------- 633 | # Square count 634 | # """ 635 | # 636 | # A_igraph = igraph.Graph.Adjacency((A_in > 0).tolist()).as_undirected() 637 | # return int(np.sum(squares(A_igraph)) / 4) 638 | 639 | 640 | def statistics_power_law_alpha(A_in): 641 | """ 642 | Compute the power law coefficient of the degree distribution of the input graph 643 | 644 | Parameters 645 | ---------- 646 | A_in: sparse matrix or np.array 647 | The input adjacency matrix. 648 | 649 | Returns 650 | ------- 651 | Power law coefficient 652 | """ 653 | 654 | degrees = A_in.sum(axis=0) 655 | return powerlaw.Fit(degrees, xmin=max(np.min(degrees),1)).power_law.alpha 656 | 657 | 658 | def statistics_gini(A_in): 659 | """ 660 | Compute the Gini coefficient of the degree distribution of the input graph 661 | 662 | Parameters 663 | ---------- 664 | A_in: sparse matrix or np.array 665 | The input adjacency matrix. 666 | 667 | Returns 668 | ------- 669 | Gini coefficient 670 | """ 671 | 672 | n = A_in.shape[0] 673 | degrees = A_in.sum(axis=0) 674 | degrees_sorted = np.sort(degrees) 675 | G = (2 * np.sum(np.array([i * degrees_sorted[i] for i in range(len(degrees))]))) / (n * np.sum(degrees)) - ( 676 | n + 1) / n 677 | return float(G) 678 | 679 | 680 | def statistics_edge_distribution_entropy(A_in): 681 | """ 682 | Compute the relative edge distribution entropy of the input graph. 683 | 684 | Parameters 685 | ---------- 686 | A_in: sparse matrix or np.array 687 | The input adjacency matrix. 688 | 689 | Returns 690 | ------- 691 | Rel. edge distribution entropy 692 | """ 693 | 694 | degrees = A_in.sum(axis=0) 695 | m = 0.5 * np.sum(np.square(A_in)) 696 | n = A_in.shape[0] 697 | 698 | H_er = 1 / np.log(n) * np.sum(-degrees / (2 * float(m)) * np.log((degrees+.0001) / (2 * float(m)))) 699 | return H_er 700 | 701 | def statistics_cluster_props(A, Z_obs): 702 | def get_blocks(A_in, Z_obs, normalize=True): 703 | block = Z_obs.T.dot(A_in.dot(Z_obs)) 704 | counts = np.sum(Z_obs, axis=0) 705 | blocks_outer = counts[:,None].dot(counts[None,:]) 706 | if normalize: 707 | blocks_outer = np.multiply(block, 1/blocks_outer) 708 | return blocks_outer 709 | 710 | in_blocks = get_blocks(A, Z_obs) 711 | diag_mean = np.multiply(in_blocks, np.eye(in_blocks.shape[0])).mean() 712 | offdiag_mean = np.multiply(in_blocks, 1-np.eye(in_blocks.shape[0])).mean() 713 | return diag_mean, offdiag_mean 714 | 715 | def statistics_compute_cpl(A): 716 | """Compute characteristic path length.""" 717 | P = sp.csgraph.shortest_path(sp.csr_matrix(A)) 718 | return P[((1 - np.isinf(P)) * (1 - np.eye(P.shape[0]))).astype(np.bool)].mean() 719 | 720 | 721 | def compute_graph_statistics(A_in, Z_obs=None): 722 | """ 723 | 724 | Parameters 725 | ---------- 726 | A_in: sparse matrix 727 | The input adjacency matrix. 728 | Z_obs: np.matrix [N, K], where K is the number of classes. 729 | Matrix whose rows are one-hot vectors indicating the class membership of the respective node. 730 | 731 | Returns 732 | ------- 733 | Dictionary containing the following statistics: 734 | * Maximum, minimum, mean degree of nodes 735 | * Size of the largest connected component (LCC) 736 | * Wedge count 737 | * Claw count 738 | * Triangle count 739 | * Square count 740 | * Power law exponent 741 | * Gini coefficient 742 | * Relative edge distribution entropy 743 | * Assortativity 744 | * Clustering coefficient 745 | * Number of connected components 746 | * Intra- and inter-community density (if Z_obs is passed) 747 | * Characteristic path length 748 | """ 749 | 750 | A = A_in.copy() 751 | 752 | assert ((A == A.T).all()) 753 | A_graph = nx.from_numpy_matrix(A).to_undirected() 754 | 755 | statistics = {} 756 | 757 | d_max, d_min, d_mean = statistics_degrees(A) 758 | 759 | # Degree statistics 760 | statistics['d_max'] = d_max 761 | statistics['d_min'] = d_min 762 | statistics['d'] = d_mean 763 | 764 | # largest connected component 765 | LCC = statistics_LCC(A) 766 | 767 | statistics['LCC'] = LCC.shape[0] 768 | # wedge count 769 | statistics['wedge_count'] = statistics_wedge_count(A) 770 | 771 | # claw count 772 | statistics['claw_count'] = statistics_claw_count(A) 773 | 774 | # triangle count 775 | statistics['triangle_count'] = statistics_triangle_count(A) 776 | 777 | # Square count 778 | statistics['square_count'] = statistics_square_count(A) 779 | 780 | # power law exponent 781 | statistics['power_law_exp'] = statistics_power_law_alpha(A) 782 | 783 | # gini coefficient 784 | statistics['gini'] = statistics_gini(A) 785 | 786 | # Relative edge distribution entropy 787 | statistics['rel_edge_distr_entropy'] = statistics_edge_distribution_entropy(A) 788 | 789 | # Assortativity 790 | statistics['assortativity'] = nx.degree_assortativity_coefficient(A_graph) 791 | 792 | # Clustering coefficient 793 | statistics['clustering_coefficient'] = 3 * statistics['triangle_count'] / statistics['claw_count'] 794 | 795 | # Number of connected components 796 | statistics['n_components'] = connected_components(A)[0] 797 | 798 | if Z_obs is not None: 799 | # inter- and intra-community density 800 | intra, inter = statistics_cluster_props(A, Z_obs) 801 | statistics['intra_community_density'] = intra 802 | statistics['inter_community_density'] = inter 803 | 804 | statistics['cpl'] = statistics_compute_cpl(A) 805 | 806 | return statistics 807 | 808 | def get_graph(path, tag): 809 | data = pd.read_csv(path, index_col=0) 810 | data = data[data['Network'] == tag] 811 | keys = np.unique(data[['fbus', 'tbus']].values) 812 | dicto = dict(zip(keys, np.arange(len(keys)))) 813 | data = data[['fbus', 'tbus']].values.reshape(-1) 814 | data = np.array([dicto[key] for key in data]) 815 | data = data.reshape(-1, 2) 816 | edges = [list(edge) for edge in data] 817 | nodes = np.sort(np.unique(data)) 818 | G = nx.Graph() 819 | G.add_nodes_from(nodes) 820 | G.add_edges_from(edges) 821 | G = nx.to_scipy_sparse_matrix(G) 822 | return G 823 | 824 | def get_graph_weighted(path_folder, tag): 825 | path_branch = path_folder + r'\branch.csv' 826 | path_bus = path_folder + r'\bus.csv' 827 | data_branch = pd.read_csv(path_branch) 828 | data_bus = pd.read_csv(path_bus) 829 | 830 | data_branch = data_branch[data_branch['Network'] == tag] 831 | data_bus = data_bus[data_bus['Network'] == tag] 832 | 833 | data_branch = data_branch.drop_duplicates(subset=['fbus', 'tbus']).sort_values(by=['fbus', 'tbus']) 834 | scaler = StandardScaler() 835 | data_branch['l_scaled'] = scaler.fit_transform(data_branch['l'].values.reshape(-1, 1)) 836 | 837 | edges_weighted = data_branch[['fbus', 'tbus', 'l_scaled']].values 838 | edges = data_branch[['fbus', 'tbus']].values 839 | nodes = np.arange(data_bus.shape[0]) 840 | 841 | G = nx.Graph() 842 | G.add_nodes_from(nodes) 843 | G.add_edges_from(edges) 844 | G = nx.to_scipy_sparse_matrix(G) 845 | 846 | G_weighted = nx.Graph() 847 | G_weighted.add_nodes_from(nodes) 848 | G_weighted.add_weighted_edges_from(edges_weighted) 849 | G_weighted = nx.to_scipy_sparse_matrix(G_weighted) 850 | G_weighted = G_weighted.todense() 851 | return G, G_weighted, scaler 852 | 853 | def calc_lines_mse(rw_list, graph_weighted_real, graph_weighted_fake): 854 | weights_diff = np.zeros(len(rw_list)) 855 | for i, rw_single in enumerate(rw_list): 856 | weights_diff[i] = graph_weighted_real[rw_single[0], rw_single[1]] - graph_weighted_fake[rw_single[0], rw_single[1]] 857 | loss = (weights_diff**2).sum()/len(rw_list) 858 | return loss 859 | 860 | def get_weights_from_rw_list(rw_list, graph_weighted): 861 | weights = np.zeros(len(rw_list)) 862 | for i, (x, y) in enumerate(rw_list): 863 | weights[i] = graph_weighted[x, y] 864 | return weights 865 | 866 | def long_rw_to_single_rw(rw_complete, graph_weighted): 867 | rw_list = [] 868 | for i in range(rw_complete.shape[0]): 869 | for j in range(rw_complete.shape[1]-1): 870 | rw_list.append(np.sort(rw_complete[i, j:j+2])) 871 | rw_list = np.vstack(rw_list) 872 | weight_list = get_weights_from_rw_list(rw_list, graph_weighted) 873 | weight_matrix = weight_list.reshape(-1, rw_complete.shape[1] - 1) 874 | return weight_matrix 875 | 876 | def create_train_graph(path_folder, tag, trainer_best): 877 | path_branch = path_folder + r'\branch.csv' 878 | path_bus = path_folder + r'\bus.csv' 879 | 880 | data_branch = pd.read_csv(path_branch) 881 | data_bus = pd.read_csv(path_bus) 882 | 883 | data_branch = data_branch[data_branch['Network'] == tag] 884 | data_bus = data_bus[data_bus['Network'] == tag] 885 | data_branch = data_branch.drop_duplicates(subset=['fbus', 'tbus']).sort_values(by=['fbus', 'tbus']) 886 | 887 | data_branch.loc[data_branch['l']==0.0, 'l'] = 0.1 888 | 889 | data_branch['l_inverse'] = 1/data_branch['l'] 890 | 891 | edges_weighted = data_branch[['fbus', 'tbus', 'l_inverse']].values 892 | nodes = np.arange(data_bus.shape[0]) 893 | 894 | G_weighted = nx.Graph() 895 | G_weighted.add_nodes_from(nodes) 896 | G_weighted.add_weighted_edges_from(edges_weighted) 897 | 898 | train_graph = trainer_best.train_graph.toarray()*nx.convert_matrix.to_numpy_array(G_weighted) 899 | train_graph = nx.convert_matrix.from_numpy_array(train_graph) 900 | 901 | nx.write_gexf(train_graph, "originals.gexf") 902 | 903 | 904 | def clean_weighted_graph(graph, graph_weighted, scaler): 905 | if type(graph) is np.ndarray or type(graph) is np.matrix: 906 | graph = sp.csr_matrix(graph) 907 | if type(graph_weighted) is np.ndarray or type(graph_weighted) is np.matrix: 908 | graph_weighted = sp.csr_matrix(graph_weighted) 909 | lcc = largest_connected_components(graph) 910 | graph = graph[lcc, :][:, lcc] 911 | graph_weighted = graph_weighted[lcc, :][:, lcc] 912 | graph_weighted = scaler.inverse_transform(graph_weighted.toarray()) * graph.toarray() 913 | graph_weighted[(graph_weighted <= 0.0)] = 0.0 914 | lcc = largest_connected_components(graph_weighted) 915 | graph_weighted = graph_weighted[lcc, :][:, lcc] 916 | graph_weighted_inverse = graph_weighted.copy() 917 | graph_weighted_inverse[graph_weighted_inverse > 0.0] = 1 / graph_weighted_inverse[graph_weighted_inverse > 0.0] 918 | 919 | graph = nx.convert_matrix.from_scipy_sparse_matrix(graph) 920 | graph_weighted = nx.convert_matrix.from_numpy_array(graph_weighted) 921 | graph_weighted_inverse = nx.convert_matrix.from_numpy_array(graph_weighted_inverse) 922 | return graph, graph_weighted, graph_weighted_inverse --------------------------------------------------------------------------------