├── MAXP 2021初赛数据探索和处理-1.ipynb ├── MAXP 2021初赛数据探索和处理-2.ipynb ├── MAXP 2021初赛数据探索和处理-3.ipynb ├── MAXP 2021初赛数据探索和处理-4.ipynb ├── README.md └── gnn ├── __init__.py ├── model_train.py ├── model_utils.py ├── models.py └── utils.py /MAXP 2021初赛数据探索和处理-1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MAXP 2021初赛数据探索和处理-1\n", 8 | "\n", 9 | "由于节点的Feature维度比较高,所以先处理节点的ID,并保存。Feature的处理放到第2部分。" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pandas as pd\n", 19 | "import numpy as np\n", 20 | "import os\n", 21 | "import gc" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# path\n", 31 | "base_path = '/Users/jamezhan/PycharmProjects/MAXP/final_dataset'\n", 32 | "publish_path = 'publish'\n", 33 | "\n", 34 | "link_p1_path = os.path.join(base_path, publish_path, 'link_phase1.csv')\n", 35 | "train_nodes_path = os.path.join(base_path, publish_path, 'train_nodes.csv')\n", 36 | "val_nodes_path = os.path.join(base_path, publish_path, 'validation_nodes.csv')" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "### 读取边列表并统计节点数量" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "(29168650, 3)\n" 56 | ] 57 | }, 58 | { 59 | "data": { 60 | "text/html": [ 61 | "
\n", 62 | "\n", 75 | "\n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | "
paper_idreference_paper_idphase
0f10da75ad1eaf16eb2ffe0d85b76b332711ef25bdb2c2421c0131af77b3ede1dphase1
19ac5a4327bd4f3dcb424c93ca9b840872d91c73304c5e8a94a0e5b4956093f71phase1
29d91bfd4703e55dd814dfffb3d63fc3333d4fdfe3967a1ffde9311bfe6827ef9phase1
3e1bdbce05528952ed6579795373782d44bda690abec912b3b7b228b01fb6819aphase1
4eb623ac4b10df96835921edabbde2951c1a05bdfc88a73bf2830e705b2f39dbbphase1
\n", 117 | "
" 118 | ], 119 | "text/plain": [ 120 | " paper_id reference_paper_id phase\n", 121 | "0 f10da75ad1eaf16eb2ffe0d85b76b332 711ef25bdb2c2421c0131af77b3ede1d phase1\n", 122 | "1 9ac5a4327bd4f3dcb424c93ca9b84087 2d91c73304c5e8a94a0e5b4956093f71 phase1\n", 123 | "2 9d91bfd4703e55dd814dfffb3d63fc33 33d4fdfe3967a1ffde9311bfe6827ef9 phase1\n", 124 | "3 e1bdbce05528952ed6579795373782d4 4bda690abec912b3b7b228b01fb6819a phase1\n", 125 | "4 eb623ac4b10df96835921edabbde2951 c1a05bdfc88a73bf2830e705b2f39dbb phase1" 126 | ] 127 | }, 128 | "execution_count": 3, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "edge_df = pd.read_csv(link_p1_path)\n", 135 | "print(edge_df.shape)\n", 136 | "edge_df.head()" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 4, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "data": { 146 | "text/plain": [ 147 | "count 29168650\n", 148 | "unique 1\n", 149 | "top phase1\n", 150 | "freq 29168650\n", 151 | "Name: phase, dtype: object" 152 | ] 153 | }, 154 | "execution_count": 4, 155 | "metadata": {}, 156 | "output_type": "execute_result" 157 | } 158 | ], 159 | "source": [ 160 | "edge_df.phase.describe()" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 5, 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | "(3031367, 1)\n" 173 | ] 174 | }, 175 | { 176 | "data": { 177 | "text/html": [ 178 | "
\n", 179 | "\n", 192 | "\n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | "
paper_id
0f10da75ad1eaf16eb2ffe0d85b76b332
19ac5a4327bd4f3dcb424c93ca9b84087
29d91bfd4703e55dd814dfffb3d63fc33
3e1bdbce05528952ed6579795373782d4
\n", 218 | "
" 219 | ], 220 | "text/plain": [ 221 | " paper_id\n", 222 | "0 f10da75ad1eaf16eb2ffe0d85b76b332\n", 223 | "1 9ac5a4327bd4f3dcb424c93ca9b84087\n", 224 | "2 9d91bfd4703e55dd814dfffb3d63fc33\n", 225 | "3 e1bdbce05528952ed6579795373782d4" 226 | ] 227 | }, 228 | "execution_count": 5, 229 | "metadata": {}, 230 | "output_type": "execute_result" 231 | } 232 | ], 233 | "source": [ 234 | "nodes = pd.concat([edge_df['paper_id'], edge_df['reference_paper_id']])\n", 235 | "nodes = pd.DataFrame(nodes.drop_duplicates())\n", 236 | "nodes.rename(columns={0:'paper_id'}, inplace=True)\n", 237 | "\n", 238 | "print(nodes.shape)\n", 239 | "nodes.head(4)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "#### 在边列表,一共出现了3,031,367个节点(paper_id)\n", 247 | "\n", 248 | "## 读取并查看train_nodes和validation_nodes里面的节点" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 6, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "def process_node(line):\n", 258 | " nid, feat_json, label = line.strip().split('\\\"')\n", 259 | " \n", 260 | " feat_list = [float(feat[1:-1]) for feat in feat_json[1:-1].split(', ')]\n", 261 | " \n", 262 | " if len(feat_list) != 300:\n", 263 | " print('此行数据有问题 {}'.format(line))\n", 264 | " \n", 265 | " return nid[:-1], feat_list, label[1:]" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 7, 271 | "metadata": { 272 | "scrolled": true 273 | }, 274 | "outputs": [ 275 | { 276 | "name": "stdout", 277 | "output_type": "stream", 278 | "text": [ 279 | "Processed 100000 train rows\n", 280 | "Processed 200000 train rows\n", 281 | "Processed 300000 train rows\n", 282 | "Processed 400000 train rows\n", 283 | "Processed 500000 train rows\n", 284 | "Processed 600000 train rows\n", 285 | "Processed 700000 train rows\n", 286 | "Processed 800000 train rows\n", 287 | "Processed 900000 train rows\n", 288 | "Processed 1000000 train rows\n", 289 | "Processed 1100000 train rows\n", 290 | "Processed 1200000 train rows\n", 291 | "Processed 1300000 train rows\n", 292 | "Processed 1400000 train rows\n", 293 | "Processed 1500000 train rows\n", 294 | "Processed 1600000 train rows\n", 295 | "Processed 1700000 train rows\n", 296 | "Processed 1800000 train rows\n", 297 | "Processed 1900000 train rows\n", 298 | "Processed 2000000 train rows\n", 299 | "Processed 2100000 train rows\n", 300 | "Processed 2200000 train rows\n", 301 | "Processed 2300000 train rows\n", 302 | "Processed 2400000 train rows\n", 303 | "Processed 2500000 train rows\n", 304 | "Processed 2600000 train rows\n", 305 | "Processed 2700000 train rows\n", 306 | "Processed 2800000 train rows\n", 307 | "Processed 2900000 train rows\n", 308 | "Processed 3000000 train rows\n", 309 | "Processed 100000 validation rows\n", 310 | "Processed 200000 validation rows\n", 311 | "Processed 300000 validation rows\n", 312 | "Processed 400000 validation rows\n", 313 | "Processed 500000 validation rows\n" 314 | ] 315 | } 316 | ], 317 | "source": [ 318 | "# 先构建ID和Label的关系,保证ID的顺序和Feature的顺序一致即可\n", 319 | "nid_list = []\n", 320 | "label_list = []\n", 321 | "tr_val_list = []\n", 322 | "\n", 323 | "with open(train_nodes_path, 'r') as f:\n", 324 | " i = 0\n", 325 | " \n", 326 | " for line in f:\n", 327 | " if i > 0:\n", 328 | " nid, _, label = process_node(line)\n", 329 | " nid_list.append(nid)\n", 330 | " label_list.append(label)\n", 331 | " tr_val_list.append(0) # 0表示train的点\n", 332 | " i += 1\n", 333 | " if i % 100000 == 0:\n", 334 | " print('Processed {} train rows'.format(i))\n", 335 | "\n", 336 | "with open(val_nodes_path, 'r') as f:\n", 337 | " i = 0\n", 338 | " \n", 339 | " for line in f:\n", 340 | " if i > 0:\n", 341 | " nid, _, label = process_node(line)\n", 342 | " nid_list.append(nid)\n", 343 | " label_list.append(label)\n", 344 | " tr_val_list.append(1) # 1表示validation的点\n", 345 | " i += 1\n", 346 | " if i % 100000 == 0:\n", 347 | " print('Processed {} validation rows'.format(i))\n", 348 | " \n", 349 | "nid_arr = np.array(nid_list)\n", 350 | "label_arr = np.array(label_list)\n", 351 | "tr_val_arr = np.array(tr_val_list)\n", 352 | " \n", 353 | "nid_label_df = pd.DataFrame({'paper_id':nid_arr, 'Label': label_arr, 'Split_ID':tr_val_arr})" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 8, 359 | "metadata": {}, 360 | "outputs": [ 361 | { 362 | "name": "stdout", 363 | "output_type": "stream", 364 | "text": [ 365 | "(3655033, 4)\n" 366 | ] 367 | }, 368 | { 369 | "data": { 370 | "text/html": [ 371 | "
\n", 372 | "\n", 385 | "\n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | "
node_idxpaper_idLabelSplit_ID
00bfdee5ab86ef5e68da974d48a138c28eS0
1178f43b8b62f040347fec0be44e5f08bd0
22a971601a0286d2701aa5cde46e63a9fdG0
33ac4b88a72146bae66cedfd1c13e1552d0
\n", 426 | "
" 427 | ], 428 | "text/plain": [ 429 | " node_idx paper_id Label Split_ID\n", 430 | "0 0 bfdee5ab86ef5e68da974d48a138c28e S 0\n", 431 | "1 1 78f43b8b62f040347fec0be44e5f08bd 0\n", 432 | "2 2 a971601a0286d2701aa5cde46e63a9fd G 0\n", 433 | "3 3 ac4b88a72146bae66cedfd1c13e1552d 0" 434 | ] 435 | }, 436 | "execution_count": 8, 437 | "metadata": {}, 438 | "output_type": "execute_result" 439 | } 440 | ], 441 | "source": [ 442 | "nid_label_df.reset_index(inplace=True)\n", 443 | "nid_label_df.rename(columns={'index':'node_idx'}, inplace=True)\n", 444 | "print(nid_label_df.shape)\n", 445 | "nid_label_df.head(4)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 9, 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "data": { 455 | "text/plain": [ 456 | "(3655033,)" 457 | ] 458 | }, 459 | "execution_count": 9, 460 | "metadata": {}, 461 | "output_type": "execute_result" 462 | } 463 | ], 464 | "source": [ 465 | "# 检查ID在Train和Validation是否有重复\n", 466 | "ids = nid_label_df.paper_id.drop_duplicates()\n", 467 | "ids.shape" 468 | ] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "metadata": {}, 473 | "source": [ 474 | "#### train和validation一共有3,655,033个节点" 475 | ] 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "metadata": {}, 480 | "source": [ 481 | "#### 下面交叉比对边列表里的paper id和节点列表里的ID,检查是否有匹配不上的节点" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": 10, 487 | "metadata": {}, 488 | "outputs": [ 489 | { 490 | "name": "stdout", 491 | "output_type": "stream", 492 | "text": [ 493 | "(3030948, 4)\n" 494 | ] 495 | } 496 | ], 497 | "source": [ 498 | "inboth = nid_label_df.merge(nodes, on='paper_id', how='inner')\n", 499 | "print(inboth.shape)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": 11, 505 | "metadata": {}, 506 | "outputs": [ 507 | { 508 | "name": "stdout", 509 | "output_type": "stream", 510 | "text": [ 511 | "(3031367, 4)\n", 512 | "共有419边列表的节点在给出的节点列表里没有对应,缺乏特征\n" 513 | ] 514 | }, 515 | { 516 | "data": { 517 | "text/html": [ 518 | "
\n", 519 | "\n", 532 | "\n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | "
paper_idnode_idxLabelSplit_ID
1124cc388eaec8838ce383d8a8792014fedbNaNNaNNaN
11845d899f41e52f751fef843cf7b1d05b4aNaNNaNNaN
143422b2004ec3c99a44b5cb6045ca547453eNaNNaNNaN
15803d657c4451a9617f4eec96d3b2e6092c7NaNNaNNaN
\n", 573 | "
" 574 | ], 575 | "text/plain": [ 576 | " paper_id node_idx Label Split_ID\n", 577 | "1124 cc388eaec8838ce383d8a8792014fedb NaN NaN NaN\n", 578 | "1184 5d899f41e52f751fef843cf7b1d05b4a NaN NaN NaN\n", 579 | "14342 2b2004ec3c99a44b5cb6045ca547453e NaN NaN NaN\n", 580 | "15803 d657c4451a9617f4eec96d3b2e6092c7 NaN NaN NaN" 581 | ] 582 | }, 583 | "execution_count": 11, 584 | "metadata": {}, 585 | "output_type": "execute_result" 586 | } 587 | ], 588 | "source": [ 589 | "edge_node = nodes.merge(nid_label_df, on='paper_id', how='left')\n", 590 | "print(edge_node.shape)\n", 591 | "print('共有{}边列表的节点在给出的节点列表里没有对应,缺乏特征'.format(edge_node[edge_node.node_idx.isna()].shape[0]))\n", 592 | "edge_node[edge_node.node_idx.isna()].head(4)" 593 | ] 594 | }, 595 | { 596 | "cell_type": "markdown", 597 | "metadata": {}, 598 | "source": [ 599 | "#### 合并边列表里独特的节点和train和validation的节点到一起,构成全部节点列表" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": 12, 605 | "metadata": {}, 606 | "outputs": [ 607 | { 608 | "name": "stderr", 609 | "output_type": "stream", 610 | "text": [ 611 | "/Users/jamezhan/anaconda3/envs/dgl/lib/python3.6/site-packages/ipykernel_launcher.py:3: UserWarning: Pandas doesn't allow columns to be created via a new attribute name - see https://pandas.pydata.org/pandas-docs/stable/indexing.html#attribute-access\n", 612 | " This is separate from the ipykernel package so we can avoid doing imports until\n", 613 | "/Users/jamezhan/anaconda3/envs/dgl/lib/python3.6/site-packages/pandas/core/generic.py:5170: SettingWithCopyWarning: \n", 614 | "A value is trying to be set on a copy of a slice from a DataFrame.\n", 615 | "Try using .loc[row_indexer,col_indexer] = value instead\n", 616 | "\n", 617 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", 618 | " self[name] = value\n", 619 | "/Users/jamezhan/anaconda3/envs/dgl/lib/python3.6/site-packages/pandas/core/frame.py:4174: SettingWithCopyWarning: \n", 620 | "A value is trying to be set on a copy of a slice from a DataFrame\n", 621 | "\n", 622 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", 623 | " errors=errors,\n" 624 | ] 625 | }, 626 | { 627 | "data": { 628 | "text/html": [ 629 | "
\n", 630 | "\n", 643 | "\n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | "
node_idxpaper_idLabelSplit_ID
03655033cc388eaec8838ce383d8a8792014fedbNaN1
136550345d899f41e52f751fef843cf7b1d05b4aNaN1
236550352b2004ec3c99a44b5cb6045ca547453eNaN1
33655036d657c4451a9617f4eec96d3b2e6092c7NaN1
\n", 684 | "
" 685 | ], 686 | "text/plain": [ 687 | " node_idx paper_id Label Split_ID\n", 688 | "0 3655033 cc388eaec8838ce383d8a8792014fedb NaN 1\n", 689 | "1 3655034 5d899f41e52f751fef843cf7b1d05b4a NaN 1\n", 690 | "2 3655035 2b2004ec3c99a44b5cb6045ca547453e NaN 1\n", 691 | "3 3655036 d657c4451a9617f4eec96d3b2e6092c7 NaN 1" 692 | ] 693 | }, 694 | "execution_count": 12, 695 | "metadata": {}, 696 | "output_type": "execute_result" 697 | } 698 | ], 699 | "source": [ 700 | "# 获取未能匹配上的节点,并构建新的节点DataFrame,然后和原有的Train/Validation节点Concat起来\n", 701 | "diff_nodes = edge_node[edge_node.node_idx.isna()]\n", 702 | "diff_nodes.ID = diff_nodes.paper_id\n", 703 | "diff_nodes.Split_ID = 1\n", 704 | "diff_nodes.node_idx = 0\n", 705 | "diff_nodes.reset_index(inplace=True)\n", 706 | "diff_nodes.drop(['index'], axis=1, inplace=True)\n", 707 | "diff_nodes.node_idx = diff_nodes.node_idx + diff_nodes.index + 3655033\n", 708 | "diff_nodes = diff_nodes[['node_idx', 'paper_id', 'Label', 'Split_ID']]\n", 709 | "diff_nodes.head(4)" 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": 13, 715 | "metadata": {}, 716 | "outputs": [ 717 | { 718 | "data": { 719 | "text/html": [ 720 | "
\n", 721 | "\n", 734 | "\n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | "
node_idxpaper_idLabelSplit_ID
4153655448caed47d55d1e193ecb1fa97a415c13ddNaN1
4163655449c82eb6be79a245392fb626b9a7e1f246NaN1
4173655450926a31f6b378575204aae30b5dfa6dd3NaN1
4183655451bbace2419c3f827158ea4602f3eb35faNaN1
\n", 775 | "
" 776 | ], 777 | "text/plain": [ 778 | " node_idx paper_id Label Split_ID\n", 779 | "415 3655448 caed47d55d1e193ecb1fa97a415c13dd NaN 1\n", 780 | "416 3655449 c82eb6be79a245392fb626b9a7e1f246 NaN 1\n", 781 | "417 3655450 926a31f6b378575204aae30b5dfa6dd3 NaN 1\n", 782 | "418 3655451 bbace2419c3f827158ea4602f3eb35fa NaN 1" 783 | ] 784 | }, 785 | "execution_count": 13, 786 | "metadata": {}, 787 | "output_type": "execute_result" 788 | } 789 | ], 790 | "source": [ 791 | "# Concatenate这419个未匹配到的节点到总的node的最后,从而让nid能接上\n", 792 | "nid_label_df = pd.concat([nid_label_df, diff_nodes])\n", 793 | "nid_label_df.tail(4)" 794 | ] 795 | }, 796 | { 797 | "cell_type": "code", 798 | "execution_count": 14, 799 | "metadata": {}, 800 | "outputs": [], 801 | "source": [ 802 | "# 保存ID和Label到本地文件\n", 803 | "nid_label_df.to_csv(os.path.join(base_path, publish_path, './IDandLabels.csv'), index=False)\n", 804 | "# 保存未匹配上的节点用于feature的处理\n", 805 | "diff_nodes.to_csv(os.path.join(base_path, publish_path, './diff_nodes.csv'), index=False)\n" 806 | ] 807 | } 808 | ], 809 | "metadata": { 810 | "kernelspec": { 811 | "display_name": "Python [conda env:dgl]", 812 | "language": "python", 813 | "name": "conda-env-dgl-py" 814 | }, 815 | "language_info": { 816 | "codemirror_mode": { 817 | "name": "ipython", 818 | "version": 3 819 | }, 820 | "file_extension": ".py", 821 | "mimetype": "text/x-python", 822 | "name": "python", 823 | "nbconvert_exporter": "python", 824 | "pygments_lexer": "ipython3", 825 | "version": "3.6.10" 826 | } 827 | }, 828 | "nbformat": 4, 829 | "nbformat_minor": 2 830 | } 831 | -------------------------------------------------------------------------------- /MAXP 2021初赛数据探索和处理-2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MAXP 2021初赛数据探索和处理-2\n", 8 | "\n", 9 | "处理Feature,并结合第1部分里得到的不匹配的节点,生成新的Feature,并保存。" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 4, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pandas as pd\n", 19 | "import numpy as np\n", 20 | "import os\n", 21 | "import gc" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 5, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# path\n", 31 | "base_path = '/Users/jamezhan/PycharmProjects/MAXP/final_dataset'\n", 32 | "publish_path = 'publish'\n", 33 | "\n", 34 | "diff_node_path = os.path.join(base_path, publish_path, 'diff_nodes.csv')\n", 35 | "train_nodes_path = os.path.join(base_path, publish_path, 'train_nodes.csv')\n", 36 | "val_nodes_path = os.path.join(base_path, publish_path, 'validation_nodes.csv')" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "### 处理节点的特征" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 6, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "def process_node(line):\n", 53 | " nid, feat_json, label = line.strip().split('\\\"')\n", 54 | " \n", 55 | " feat_list = [float(feat[1:-1]) for feat in feat_json[1:-1].split(', ')]\n", 56 | " \n", 57 | " if len(feat_list) != 300:\n", 58 | " print('此行数据有问题 {}'.format(line))\n", 59 | " \n", 60 | " return nid[:-1], feat_list, label[1:]" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 7, 66 | "metadata": { 67 | "scrolled": true 68 | }, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "Processed 100000 train rows\n", 75 | "Processed 200000 train rows\n", 76 | "Processed 300000 train rows\n", 77 | "Processed 400000 train rows\n", 78 | "Processed 500000 train rows\n", 79 | "Processed 600000 train rows\n", 80 | "Processed 700000 train rows\n", 81 | "Processed 800000 train rows\n", 82 | "Processed 900000 train rows\n", 83 | "Processed 1000000 train rows\n", 84 | "Processed 1100000 train rows\n", 85 | "Processed 1200000 train rows\n", 86 | "Processed 1300000 train rows\n", 87 | "Processed 1400000 train rows\n", 88 | "Processed 1500000 train rows\n", 89 | "Processed 1600000 train rows\n", 90 | "Processed 1700000 train rows\n", 91 | "Processed 1800000 train rows\n", 92 | "Processed 1900000 train rows\n", 93 | "Processed 2000000 train rows\n", 94 | "Processed 2100000 train rows\n", 95 | "Processed 2200000 train rows\n", 96 | "Processed 2300000 train rows\n", 97 | "Processed 2400000 train rows\n", 98 | "Processed 2500000 train rows\n", 99 | "Processed 2600000 train rows\n", 100 | "Processed 2700000 train rows\n", 101 | "Processed 2800000 train rows\n", 102 | "Processed 2900000 train rows\n", 103 | "Processed 3000000 train rows\n", 104 | "Processed 100000 validation rows\n", 105 | "Processed 200000 validation rows\n", 106 | "Processed 300000 validation rows\n", 107 | "Processed 400000 validation rows\n", 108 | "Processed 500000 validation rows\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "# 下面处理特征Feature并存储亿备后用\n", 114 | "# nid_list = []\n", 115 | "feature_list = []\n", 116 | "\n", 117 | "with open(train_nodes_path, 'r') as f:\n", 118 | " i = 0\n", 119 | " \n", 120 | " for line in f:\n", 121 | " if i > 0:\n", 122 | " _, features, _ = process_node(line)\n", 123 | "# nid_list.append(nid)\n", 124 | " feature_list.append(features)\n", 125 | " i += 1\n", 126 | " if i % 100000 == 0:\n", 127 | " print('Processed {} train rows'.format(i))\n", 128 | " \n", 129 | "with open(val_nodes_path, 'r') as f:\n", 130 | " i = 0\n", 131 | " \n", 132 | " for line in f:\n", 133 | " if i > 0:\n", 134 | " _, features, _ = process_node(line)\n", 135 | "# nid_list.append(nid)\n", 136 | " feature_list.append(features)\n", 137 | " i += 1\n", 138 | " if i % 100000 == 0:\n", 139 | " print('Processed {} validation rows'.format(i))\n", 140 | " \n", 141 | "# nid_arr = np.array(nid_list)\n", 142 | "feat_arr = np.array(feature_list)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 8, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "text/plain": [ 153 | "0" 154 | ] 155 | }, 156 | "execution_count": 8, 157 | "metadata": {}, 158 | "output_type": "execute_result" 159 | } 160 | ], 161 | "source": [ 162 | "# 删除list以节省内存\n", 163 | "del feature_list\n", 164 | "gc.collect()" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 9, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "# 给未匹配上的419个节点造300维的特征,这里用其他所有节点的300维的平均值来作为他们的特征\n", 174 | "# 更好的方法是用每个节点的所有邻居的特征的平均,这里就不搞这么复杂了。\n", 175 | "diff_node_feat_arr = np.tile(np.mean(feat_arr, axis=0),(419, 1))" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 11, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "text/plain": [ 186 | "(3655452, 300)" 187 | ] 188 | }, 189 | "execution_count": 11, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "feat_arr = np.concatenate((feat_arr, diff_node_feat_arr), axis=0)\n", 196 | "feat_arr.shape" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "# 使用Numpy保存特征为.npy格式,以节省存储空间和提高读写速度\n", 206 | "with open(os.path.join(base_path, publish_path, './features.npy'), 'wb') as f:\n", 207 | " np.save(f, feat_arr)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [] 216 | } 217 | ], 218 | "metadata": { 219 | "kernelspec": { 220 | "display_name": "Python [conda env:dgl]", 221 | "language": "python", 222 | "name": "conda-env-dgl-py" 223 | }, 224 | "language_info": { 225 | "codemirror_mode": { 226 | "name": "ipython", 227 | "version": 3 228 | }, 229 | "file_extension": ".py", 230 | "mimetype": "text/x-python", 231 | "name": "python", 232 | "nbconvert_exporter": "python", 233 | "pygments_lexer": "ipython3", 234 | "version": "3.6.10" 235 | } 236 | }, 237 | "nbformat": 4, 238 | "nbformat_minor": 2 239 | } 240 | -------------------------------------------------------------------------------- /MAXP 2021初赛数据探索和处理-3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MAXP 2021初赛数据探索和处理-3\n", 8 | "\n", 9 | "使用步骤1里处理好的节点的ID,来构建DGL的graph所需要的边列表。" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "Using backend: pytorch\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "import pandas as pd\n", 27 | "import numpy as np\n", 28 | "import os\n", 29 | "\n", 30 | "import dgl" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "# path\n", 40 | "base_path = '/Users/jamezhan/PycharmProjects/MAXP/final_dataset'\n", 41 | "publish_path = 'publish'\n", 42 | "\n", 43 | "link_p1_path = os.path.join(base_path, publish_path, 'link_phase1.csv')\n", 44 | "nodes_path = os.path.join(base_path, publish_path, 'IDandLabels.csv')" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "### 读取节点列表" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "(3655452, 4)\n" 64 | ] 65 | }, 66 | { 67 | "data": { 68 | "text/html": [ 69 | "
\n", 70 | "\n", 83 | "\n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | "
node_idxpaper_idLabelSplit_ID
36554483655448caed47d55d1e193ecb1fa97a415c13ddNaN1
36554493655449c82eb6be79a245392fb626b9a7e1f246NaN1
36554503655450926a31f6b378575204aae30b5dfa6dd3NaN1
36554513655451bbace2419c3f827158ea4602f3eb35faNaN1
\n", 124 | "
" 125 | ], 126 | "text/plain": [ 127 | " node_idx paper_id Label Split_ID\n", 128 | "3655448 3655448 caed47d55d1e193ecb1fa97a415c13dd NaN 1\n", 129 | "3655449 3655449 c82eb6be79a245392fb626b9a7e1f246 NaN 1\n", 130 | "3655450 3655450 926a31f6b378575204aae30b5dfa6dd3 NaN 1\n", 131 | "3655451 3655451 bbace2419c3f827158ea4602f3eb35fa NaN 1" 132 | ] 133 | }, 134 | "execution_count": 3, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | } 138 | ], 139 | "source": [ 140 | "nodes_df = pd.read_csv(nodes_path, dtype={'Label':str})\n", 141 | "print(nodes_df.shape)\n", 142 | "nodes_df.tail(4)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "### 读取边列表" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 4, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "name": "stdout", 159 | "output_type": "stream", 160 | "text": [ 161 | "(29168650, 3)\n" 162 | ] 163 | }, 164 | { 165 | "data": { 166 | "text/html": [ 167 | "
\n", 168 | "\n", 181 | "\n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | "
paper_idreference_paper_idphase
0f10da75ad1eaf16eb2ffe0d85b76b332711ef25bdb2c2421c0131af77b3ede1dphase1
19ac5a4327bd4f3dcb424c93ca9b840872d91c73304c5e8a94a0e5b4956093f71phase1
29d91bfd4703e55dd814dfffb3d63fc3333d4fdfe3967a1ffde9311bfe6827ef9phase1
3e1bdbce05528952ed6579795373782d44bda690abec912b3b7b228b01fb6819aphase1
4eb623ac4b10df96835921edabbde2951c1a05bdfc88a73bf2830e705b2f39dbbphase1
\n", 223 | "
" 224 | ], 225 | "text/plain": [ 226 | " paper_id reference_paper_id phase\n", 227 | "0 f10da75ad1eaf16eb2ffe0d85b76b332 711ef25bdb2c2421c0131af77b3ede1d phase1\n", 228 | "1 9ac5a4327bd4f3dcb424c93ca9b84087 2d91c73304c5e8a94a0e5b4956093f71 phase1\n", 229 | "2 9d91bfd4703e55dd814dfffb3d63fc33 33d4fdfe3967a1ffde9311bfe6827ef9 phase1\n", 230 | "3 e1bdbce05528952ed6579795373782d4 4bda690abec912b3b7b228b01fb6819a phase1\n", 231 | "4 eb623ac4b10df96835921edabbde2951 c1a05bdfc88a73bf2830e705b2f39dbb phase1" 232 | ] 233 | }, 234 | "execution_count": 4, 235 | "metadata": {}, 236 | "output_type": "execute_result" 237 | } 238 | ], 239 | "source": [ 240 | "edges_df = pd.read_csv(link_p1_path)\n", 241 | "print(edges_df.shape)\n", 242 | "edges_df.head()" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "## Join点列表和边列表以生成从0开始的边列表\n", 250 | "\n", 251 | "DGL默认节点是从0开始,并以最大的ID为容量构建Graph,因此这里我们先构建从0开始的边列表。" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 5, 257 | "metadata": {}, 258 | "outputs": [ 259 | { 260 | "name": "stdout", 261 | "output_type": "stream", 262 | "text": [ 263 | "(29168650, 10)\n" 264 | ] 265 | }, 266 | { 267 | "data": { 268 | "text/html": [ 269 | "
\n", 270 | "\n", 283 | "\n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | "
paper_id_xreference_paper_idphasenode_idx_xLabel_xSplit_ID_xnode_idx_ypaper_id_yLabel_ySplit_ID_y
0f10da75ad1eaf16eb2ffe0d85b76b332711ef25bdb2c2421c0131af77b3ede1dphase1529879NaN02364950711ef25bdb2c2421c0131af77b3ede1dNaN0
19ac5a4327bd4f3dcb424c93ca9b840872d91c73304c5e8a94a0e5b4956093f71phase1410481D03840232d91c73304c5e8a94a0e5b4956093f71K0
29d91bfd4703e55dd814dfffb3d63fc3333d4fdfe3967a1ffde9311bfe6827ef9phase12196044D0189561933d4fdfe3967a1ffde9311bfe6827ef9N0
3e1bdbce05528952ed6579795373782d44bda690abec912b3b7b228b01fb6819aphase12545623NaN021759774bda690abec912b3b7b228b01fb6819aNaN0
\n", 354 | "
" 355 | ], 356 | "text/plain": [ 357 | " paper_id_x reference_paper_id phase \\\n", 358 | "0 f10da75ad1eaf16eb2ffe0d85b76b332 711ef25bdb2c2421c0131af77b3ede1d phase1 \n", 359 | "1 9ac5a4327bd4f3dcb424c93ca9b84087 2d91c73304c5e8a94a0e5b4956093f71 phase1 \n", 360 | "2 9d91bfd4703e55dd814dfffb3d63fc33 33d4fdfe3967a1ffde9311bfe6827ef9 phase1 \n", 361 | "3 e1bdbce05528952ed6579795373782d4 4bda690abec912b3b7b228b01fb6819a phase1 \n", 362 | "\n", 363 | " node_idx_x Label_x Split_ID_x node_idx_y \\\n", 364 | "0 529879 NaN 0 2364950 \n", 365 | "1 410481 D 0 384023 \n", 366 | "2 2196044 D 0 1895619 \n", 367 | "3 2545623 NaN 0 2175977 \n", 368 | "\n", 369 | " paper_id_y Label_y Split_ID_y \n", 370 | "0 711ef25bdb2c2421c0131af77b3ede1d NaN 0 \n", 371 | "1 2d91c73304c5e8a94a0e5b4956093f71 K 0 \n", 372 | "2 33d4fdfe3967a1ffde9311bfe6827ef9 N 0 \n", 373 | "3 4bda690abec912b3b7b228b01fb6819a NaN 0 " 374 | ] 375 | }, 376 | "execution_count": 5, 377 | "metadata": {}, 378 | "output_type": "execute_result" 379 | } 380 | ], 381 | "source": [ 382 | "# Merge paper_id列\n", 383 | "edges = edges_df.merge(nodes_df, on='paper_id', how='left')\n", 384 | "# Merge reference_paper_id列\n", 385 | "edges = edges.merge(nodes_df, left_on='reference_paper_id', right_on='paper_id', how='left')\n", 386 | "print(edges.shape)\n", 387 | "edges.head(4)" 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "metadata": {}, 393 | "source": [ 394 | "#### 修改node_idx_* 列的名称作为新的node id,并只保留需要的列" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 6, 400 | "metadata": {}, 401 | "outputs": [ 402 | { 403 | "data": { 404 | "text/html": [ 405 | "
\n", 406 | "\n", 419 | "\n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | "
src_niddst_nidpaper_idreference_paper_id
05298792364950f10da75ad1eaf16eb2ffe0d85b76b332711ef25bdb2c2421c0131af77b3ede1d
14104813840239ac5a4327bd4f3dcb424c93ca9b840872d91c73304c5e8a94a0e5b4956093f71
2219604418956199d91bfd4703e55dd814dfffb3d63fc3333d4fdfe3967a1ffde9311bfe6827ef9
325456232175977e1bdbce05528952ed6579795373782d44bda690abec912b3b7b228b01fb6819a
\n", 460 | "
" 461 | ], 462 | "text/plain": [ 463 | " src_nid dst_nid paper_id \\\n", 464 | "0 529879 2364950 f10da75ad1eaf16eb2ffe0d85b76b332 \n", 465 | "1 410481 384023 9ac5a4327bd4f3dcb424c93ca9b84087 \n", 466 | "2 2196044 1895619 9d91bfd4703e55dd814dfffb3d63fc33 \n", 467 | "3 2545623 2175977 e1bdbce05528952ed6579795373782d4 \n", 468 | "\n", 469 | " reference_paper_id \n", 470 | "0 711ef25bdb2c2421c0131af77b3ede1d \n", 471 | "1 2d91c73304c5e8a94a0e5b4956093f71 \n", 472 | "2 33d4fdfe3967a1ffde9311bfe6827ef9 \n", 473 | "3 4bda690abec912b3b7b228b01fb6819a " 474 | ] 475 | }, 476 | "execution_count": 6, 477 | "metadata": {}, 478 | "output_type": "execute_result" 479 | } 480 | ], 481 | "source": [ 482 | "edges.rename(columns={'paper_id_x': 'paper_id', 'node_idx_x':'src_nid', 'node_idx_y':'dst_nid'}, inplace=True)\n", 483 | "edges = edges[['src_nid', 'dst_nid', 'paper_id', 'reference_paper_id']]\n", 484 | "edges.head(4)" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "metadata": {}, 490 | "source": [ 491 | "## 构建DGL的Graph" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 7, 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [ 500 | "# 讲源节点和目标节点转换成Numpy的NDArray\n", 501 | "src_nid = edges.src_nid.to_numpy()\n", 502 | "dst_nid = edges.dst_nid.to_numpy()" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 8, 508 | "metadata": {}, 509 | "outputs": [ 510 | { 511 | "name": "stdout", 512 | "output_type": "stream", 513 | "text": [ 514 | "Graph(num_nodes=3655452, num_edges=29168650,\n", 515 | " ndata_schemes={}\n", 516 | " edata_schemes={})\n" 517 | ] 518 | } 519 | ], 520 | "source": [ 521 | "# 构建一个DGL的graph\n", 522 | "graph = dgl.graph((src_nid, dst_nid))\n", 523 | "print(graph)" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": 10, 529 | "metadata": {}, 530 | "outputs": [], 531 | "source": [ 532 | "# 保存Graph为二进制格式方便后面建模时的快速读取\n", 533 | "graph_path = os.path.join(base_path, publish_path, 'graph.bin')\n", 534 | "dgl.data.utils.save_graphs(graph_path, [graph])" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": null, 540 | "metadata": {}, 541 | "outputs": [], 542 | "source": [] 543 | } 544 | ], 545 | "metadata": { 546 | "kernelspec": { 547 | "display_name": "Python [conda env:dgl]", 548 | "language": "python", 549 | "name": "conda-env-dgl-py" 550 | }, 551 | "language_info": { 552 | "codemirror_mode": { 553 | "name": "ipython", 554 | "version": 3 555 | }, 556 | "file_extension": ".py", 557 | "mimetype": "text/x-python", 558 | "name": "python", 559 | "nbconvert_exporter": "python", 560 | "pygments_lexer": "ipython3", 561 | "version": "3.6.10" 562 | } 563 | }, 564 | "nbformat": 4, 565 | "nbformat_minor": 2 566 | } 567 | -------------------------------------------------------------------------------- /MAXP 2021初赛数据探索和处理-4.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MAXP 2021初赛数据探索和处理-4\n", 8 | "\n", 9 | "把原始数据的标签转换成数字形式,并完成Train/Validation/Test的分割。这里的划分是用于比赛模型训练和模型选择用的,并不是原始的文件名。" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 28, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pandas as pd\n", 19 | "import numpy as np\n", 20 | "import os\n", 21 | "import pickle\n", 22 | "\n", 23 | "import dgl" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# path\n", 33 | "base_path = '/Users/jamezhan/PycharmProjects/MAXP/final_dataset'\n", 34 | "publish_path = 'publish'\n", 35 | "\n", 36 | "nodes_path = os.path.join(base_path, publish_path, 'IDandLabels.csv')" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "### 读取节点列表" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "(3655452, 4)\n" 56 | ] 57 | }, 58 | { 59 | "data": { 60 | "text/html": [ 61 | "
\n", 62 | "\n", 75 | "\n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | "
node_idxpaper_idLabelSplit_ID
36554483655448caed47d55d1e193ecb1fa97a415c13ddNaN1
36554493655449c82eb6be79a245392fb626b9a7e1f246NaN1
36554503655450926a31f6b378575204aae30b5dfa6dd3NaN1
36554513655451bbace2419c3f827158ea4602f3eb35faNaN1
\n", 116 | "
" 117 | ], 118 | "text/plain": [ 119 | " node_idx paper_id Label Split_ID\n", 120 | "3655448 3655448 caed47d55d1e193ecb1fa97a415c13dd NaN 1\n", 121 | "3655449 3655449 c82eb6be79a245392fb626b9a7e1f246 NaN 1\n", 122 | "3655450 3655450 926a31f6b378575204aae30b5dfa6dd3 NaN 1\n", 123 | "3655451 3655451 bbace2419c3f827158ea4602f3eb35fa NaN 1" 124 | ] 125 | }, 126 | "execution_count": 3, 127 | "metadata": {}, 128 | "output_type": "execute_result" 129 | } 130 | ], 131 | "source": [ 132 | "nodes_df = pd.read_csv(nodes_path, dtype={'Label':str})\n", 133 | "print(nodes_df.shape)\n", 134 | "nodes_df.tail(4)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "### 转换标签为数字" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 6, 147 | "metadata": { 148 | "scrolled": true 149 | }, 150 | "outputs": [ 151 | { 152 | "name": "stdout", 153 | "output_type": "stream", 154 | "text": [ 155 | "(23, 3)\n" 156 | ] 157 | }, 158 | { 159 | "data": { 160 | "text/html": [ 161 | "
\n", 162 | "\n", 175 | "\n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | "
node_idxpaper_idSplit_ID
Label
A267026702670
B653036530365303
C111502111502111502
D104005104005104005
E450144501445014
F328763287632876
G434524345243452
H718247182471824
I239942399423994
J252412524125241
K327623276232762
L533915339153391
M839718397183971
N103472103472103472
O175931759317593
P521665216652166
Q196761967619676
R326103261032610
S246092460924609
T208782087820878
U247402474024740
V395573955739557
W131111311113111
\n", 331 | "
" 332 | ], 333 | "text/plain": [ 334 | " node_idx paper_id Split_ID\n", 335 | "Label \n", 336 | "A 2670 2670 2670\n", 337 | "B 65303 65303 65303\n", 338 | "C 111502 111502 111502\n", 339 | "D 104005 104005 104005\n", 340 | "E 45014 45014 45014\n", 341 | "F 32876 32876 32876\n", 342 | "G 43452 43452 43452\n", 343 | "H 71824 71824 71824\n", 344 | "I 23994 23994 23994\n", 345 | "J 25241 25241 25241\n", 346 | "K 32762 32762 32762\n", 347 | "L 53391 53391 53391\n", 348 | "M 83971 83971 83971\n", 349 | "N 103472 103472 103472\n", 350 | "O 17593 17593 17593\n", 351 | "P 52166 52166 52166\n", 352 | "Q 19676 19676 19676\n", 353 | "R 32610 32610 32610\n", 354 | "S 24609 24609 24609\n", 355 | "T 20878 20878 20878\n", 356 | "U 24740 24740 24740\n", 357 | "V 39557 39557 39557\n", 358 | "W 13111 13111 13111" 359 | ] 360 | }, 361 | "execution_count": 6, 362 | "metadata": {}, 363 | "output_type": "execute_result" 364 | } 365 | ], 366 | "source": [ 367 | "# 先检查一下标签的分布\n", 368 | "label_dist = nodes_df.groupby(by='Label').count()\n", 369 | "print(label_dist.shape)\n", 370 | "label_dist" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "metadata": {}, 376 | "source": [ 377 | "#### 可以看到一共有23个标签,A类最少,C类最多,基本每类都有几万个。下面从0开始,重够标签\n" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 16, 383 | "metadata": {}, 384 | "outputs": [ 385 | { 386 | "data": { 387 | "text/html": [ 388 | "
\n", 389 | "\n", 402 | "\n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | "
node_idxpaper_idLabelSplit_IDlabel
00bfdee5ab86ef5e68da974d48a138c28eS018
1178f43b8b62f040347fec0be44e5f08bdNaN0-1
22a971601a0286d2701aa5cde46e63a9fdG06
33ac4b88a72146bae66cedfd1c13e1552dNaN0-1
\n", 448 | "
" 449 | ], 450 | "text/plain": [ 451 | " node_idx paper_id Label Split_ID label\n", 452 | "0 0 bfdee5ab86ef5e68da974d48a138c28e S 0 18\n", 453 | "1 1 78f43b8b62f040347fec0be44e5f08bd NaN 0 -1\n", 454 | "2 2 a971601a0286d2701aa5cde46e63a9fd G 0 6\n", 455 | "3 3 ac4b88a72146bae66cedfd1c13e1552d NaN 0 -1" 456 | ] 457 | }, 458 | "execution_count": 16, 459 | "metadata": {}, 460 | "output_type": "execute_result" 461 | } 462 | ], 463 | "source": [ 464 | "# 按A-W的顺序,从0开始转换\n", 465 | "for i, l in enumerate(label_dist.index.to_list()):\n", 466 | " nodes_df.loc[(nodes_df.Label==l), 'label'] = i\n", 467 | "\n", 468 | "nodes_df.label.fillna(-1, inplace=True)\n", 469 | "nodes_df.label = nodes_df.label.astype('int')\n", 470 | "nodes_df.head(4)" 471 | ] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "metadata": {}, 476 | "source": [ 477 | "#### 只保留新的node index、标签和原始的分割标签" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": 19, 483 | "metadata": {}, 484 | "outputs": [ 485 | { 486 | "data": { 487 | "text/html": [ 488 | "
\n", 489 | "\n", 502 | "\n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | "
node_idxlabelSplit_ID
36554483655448-11
36554493655449-11
36554503655450-11
36554513655451-11
\n", 538 | "
" 539 | ], 540 | "text/plain": [ 541 | " node_idx label Split_ID\n", 542 | "3655448 3655448 -1 1\n", 543 | "3655449 3655449 -1 1\n", 544 | "3655450 3655450 -1 1\n", 545 | "3655451 3655451 -1 1" 546 | ] 547 | }, 548 | "execution_count": 19, 549 | "metadata": {}, 550 | "output_type": "execute_result" 551 | } 552 | ], 553 | "source": [ 554 | "nodes = nodes_df[['node_idx', 'label', 'Split_ID']]\n", 555 | "nodes.tail(4)" 556 | ] 557 | }, 558 | { 559 | "cell_type": "markdown", 560 | "metadata": {}, 561 | "source": [ 562 | "## 划分Train/Validation/Test\n", 563 | "\n", 564 | "由于只有原始的Train_nodes文件里面包括了标签,所以这里的Train/Validation是对原始的分割。\n", 565 | "\n", 566 | "这里按照9:1的比例划分Train/Validation。Test就是原来的validation_nodes里面的index。" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 30, 572 | "metadata": {}, 573 | "outputs": [], 574 | "source": [ 575 | "# 获取所有的标签\n", 576 | "tr_val_labels_df = nodes[(nodes.Split_ID == 0) & (nodes.label >= 0)]\n", 577 | "test_label_df = nodes[nodes.Split_ID == 1]\n", 578 | "\n", 579 | "# 按照0~22每个标签划分train/validation\n", 580 | "tr_labels_idx = np.array([0])\n", 581 | "val_labels_idx = np.array([0])\n", 582 | "split_ratio = 0.9\n", 583 | "\n", 584 | "for label in range(23):\n", 585 | " label_idx = tr_val_labels_df[tr_val_labels_df.label == label].node_idx.to_numpy()\n", 586 | " split_point = int(label_idx.shape[0] * split_ratio)\n", 587 | " \n", 588 | " # 把每个标签的train和validation的index添加到整个列表\n", 589 | " tr_labels_idx = np.append(tr_labels_idx, label_idx[: split_point])\n", 590 | " val_labels_idx = np.append(val_labels_idx, label_idx[split_point: ])" 591 | ] 592 | }, 593 | { 594 | "cell_type": "code", 595 | "execution_count": 31, 596 | "metadata": {}, 597 | "outputs": [], 598 | "source": [ 599 | "# 获取Train/Validation/Test标签index\n", 600 | "tr_labels_idx = tr_labels_idx[1: ]\n", 601 | "val_labels_idx = val_labels_idx[1: ]\n", 602 | "\n", 603 | "test_labels_idx = test_label_df.node_idx.to_numpy()" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": 32, 609 | "metadata": {}, 610 | "outputs": [], 611 | "source": [ 612 | "# 获取完整的标签列表\n", 613 | "labels = nodes.label.to_numpy()" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": 33, 619 | "metadata": {}, 620 | "outputs": [], 621 | "source": [ 622 | "# 保存标签以及Train/Validation/Test的index为二进制格式方便后面建模时的快速读取\n", 623 | "label_path = os.path.join(base_path, publish_path, 'labels.pkl')\n", 624 | "\n", 625 | "with open(label_path, 'wb') as f:\n", 626 | " pickle.dump({'tr_label_idx': tr_labels_idx, \n", 627 | " 'val_label_idx': val_labels_idx, \n", 628 | " 'test_label_idx': test_labels_idx,\n", 629 | " 'label': labels}, f)" 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": null, 635 | "metadata": {}, 636 | "outputs": [], 637 | "source": [] 638 | } 639 | ], 640 | "metadata": { 641 | "kernelspec": { 642 | "display_name": "Python [conda env:dgl]", 643 | "language": "python", 644 | "name": "conda-env-dgl-py" 645 | }, 646 | "language_info": { 647 | "codemirror_mode": { 648 | "name": "ipython", 649 | "version": 3 650 | }, 651 | "file_extension": ".py", 652 | "mimetype": "text/x-python", 653 | "name": "python", 654 | "nbconvert_exporter": "python", 655 | "pygments_lexer": "ipython3", 656 | "version": "3.6.10" 657 | } 658 | }, 659 | "nbformat": 4, 660 | "nbformat_minor": 2 661 | } 662 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAXP竞赛——DGL图数据Baseline模型 2 | 3 | 本代码库是为2021 MAXP竞赛的DGL图数据所准备的Baseline模型,供参赛选手参考学习使用DGL来构建GNN模型。 4 | 5 | 代码库包括2个部分: 6 | --------------- 7 | 1. 用于数据预处理的4个Jupyter Notebook 8 | 2. 用DGL构建的3个GNN模型(GCN,GraphSage和GAT),以及训练模型所用的代码和辅助函数。 9 | 10 | 依赖包: 11 | ------ 12 | - dgl==0.7.1 13 | - pytorch==1.7.0 14 | - pandas 15 | - numpy 16 | - datetime 17 | 18 | 如何运行: 19 | ------- 20 | 对于4个Jupyter Notebook文件,请使用Jupyter环境运行,并注意把其中的竞赛数据文件所在的文件夹替换为你自己保存数据文件的文件夹。 21 | 并记录下你处理完成后的数据文件所在的位置,供下面模型训练使用。 22 | 23 | **注意:** 在运行*MAXP 2021初赛数据探索和处理-2*时,内存的使用量会比较高。这个在Mac上运行没有出现问题,但是尚未在Windows和Linux环境测试。 24 | 如果在这两种环境下遇到内存问题,建议找一个内存大一些的机器处理,或者修改代码,一部分一部分的处理。 25 | 26 | --------- 27 | 对于GNN的模型,需要先cd到gnn目录,然后运行: 28 | 29 | ```bash 30 | python model_train.py --data_path path/to/processed_data --gnn_model graphsage --hidden_dim 64 --n_layers 2 --fanout 20,20 --batch_size 4096 --GPU -1 --out_path ./ 31 | ``` 32 | 33 | **注意**:请把--data_path的路径替换成用Jupyter Notebook文件处理后数据所在的位置路径。其余的参数,请参考model_train.py里面的入参说明修改。 34 | 35 | 如果希望使用单GPU进行模型训练,则需要修改入参 `--GPU`的输入值为单个GPU的编号,如: 36 | ```bash 37 | --GPU 0 38 | ``` 39 | 40 | 如果希望使用单机多GPU进行模型训练,则需要修改入参 `--GPU`的输入值为多个可用的GPU的编号,并用空格分割,如: 41 | ```bash 42 | --GPU 0 1 2 3 43 | ``` 44 | -------------------------------------------------------------------------------- /gnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dglai/maxp_baseline_model/00ccc6aa2cdd7fce83e6468380bc33d45a870431/gnn/__init__.py -------------------------------------------------------------------------------- /gnn/model_train.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | 3 | # Author:james Zhang 4 | """ 5 | Minibatch training with node neighbor sampling in multiple GPUs 6 | """ 7 | 8 | import os 9 | import argparse 10 | import datetime as dt 11 | import numpy as np 12 | import torch as th 13 | import torch.nn as thnn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import torch.distributed as dist 17 | 18 | import dgl 19 | from dgl.dataloading.neighbor import MultiLayerNeighborSampler 20 | from dgl.dataloading.pytorch import NodeDataLoader 21 | import dgl.multiprocessing as mp 22 | 23 | from models import GraphSageModel, GraphConvModel, GraphAttnModel 24 | from utils import load_dgl_graph, time_diff 25 | from model_utils import early_stopper, thread_wrapped_func 26 | 27 | 28 | def load_subtensor(node_feats, labels, seeds, input_nodes, device): 29 | """ 30 | Copys features and labels of a set of nodes onto GPU. 31 | """ 32 | batch_inputs = node_feats[input_nodes].to(device) 33 | batch_labels = labels[seeds].to(device) 34 | return batch_inputs, batch_labels 35 | 36 | 37 | def cleanup(): 38 | dist.destroy_process_group() 39 | 40 | 41 | def cpu_train(graph_data, 42 | gnn_model, 43 | hidden_dim, 44 | n_layers, 45 | n_classes, 46 | fanouts, 47 | batch_size, 48 | device, 49 | num_workers, 50 | epochs, 51 | out_path): 52 | """ 53 | 运行在CPU设备上的训练代码。 54 | 由于比赛数据量比较大,因此这个部分的代码建议仅用于代码调试。 55 | 有GPU的,请使用下面的GPU设备训练的代码来提高训练速度。 56 | """ 57 | graph, labels, train_nid, val_nid, test_nid, node_feat = graph_data 58 | 59 | sampler = MultiLayerNeighborSampler(fanouts) 60 | train_dataloader = NodeDataLoader(graph, 61 | train_nid, 62 | sampler, 63 | batch_size=batch_size, 64 | shuffle=True, 65 | drop_last=False, 66 | num_workers=num_workers) 67 | 68 | # 2 initialize GNN model 69 | in_feat = node_feat.shape[1] 70 | 71 | if gnn_model == 'graphsage': 72 | model = GraphSageModel(in_feat, hidden_dim, n_layers, n_classes) 73 | elif gnn_model == 'graphconv': 74 | model = GraphConvModel(in_feat, hidden_dim, n_layers, n_classes, 75 | norm='both', activation=F.relu, dropout=0) 76 | elif gnn_model == 'graphattn': 77 | model = GraphAttnModel(in_feat, hidden_dim, n_layers, n_classes, 78 | heads=([5] * n_layers), activation=F.relu, feat_drop=0, attn_drop=0) 79 | else: 80 | raise NotImplementedError('So far, only support three algorithms: GraphSage, GraphConv, and GraphAttn') 81 | 82 | model = model.to(device) 83 | 84 | # 3 define loss function and optimizer 85 | loss_fn = thnn.CrossEntropyLoss().to(device) 86 | optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) 87 | 88 | # 4 train epoch 89 | avg = 0 90 | iter_tput = [] 91 | start_t = dt.datetime.now() 92 | 93 | print('Start training at: {}-{} {}:{}:{}'.format(start_t.month, 94 | start_t.day, 95 | start_t.hour, 96 | start_t.minute, 97 | start_t.second)) 98 | 99 | for epoch in range(epochs): 100 | 101 | for step, (input_nodes, seeds, mfgs) in enumerate(train_dataloader): 102 | 103 | start_t = dt.datetime.now() 104 | 105 | batch_inputs, batch_labels = load_subtensor(node_feat, labels, seeds, input_nodes, device) 106 | mfgs = [mfg.to(device) for mfg in mfgs] 107 | 108 | batch_logit = model(mfgs, batch_inputs) 109 | loss = loss_fn(batch_logit, batch_labels) 110 | pred = th.sum(th.argmax(batch_logit, dim=1) == batch_labels) / th.tensor(batch_labels.shape[0]) 111 | 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | 116 | e_t1 = dt.datetime.now() 117 | h, m, s = time_diff(e_t1, start_t) 118 | 119 | print('In epoch:{:03d}|batch:{}, loss:{:4f}, acc:{:4f}, time:{}h{}m{}s'.format(epoch, 120 | step, 121 | loss, 122 | pred.detach(), 123 | h, m, s)) 124 | 125 | # 5 保存模型 126 | # 此处就省略了 127 | 128 | 129 | def gpu_train(proc_id, n_gpus, GPUS, 130 | graph_data, gnn_model, 131 | hidden_dim, n_layers, n_classes, fanouts, 132 | batch_size=32, num_workers=4, epochs=100, message_queue=None, 133 | output_folder='./output'): 134 | 135 | device_id = GPUS[proc_id] 136 | device = th.device('cuda:{}'.format(device_id)) 137 | 138 | print('Use GPU {} for training ......'.format(device_id)) 139 | 140 | if n_gpus > 1: 141 | dist_init_method = 'tcp://{}:{}'.format('127.0.0.1', '23456') 142 | world_size = n_gpus 143 | dist.init_process_group(backend='nccl', 144 | init_method=dist_init_method, 145 | world_size=world_size, 146 | rank=proc_id) 147 | 148 | th.cuda.set_device(device_id) 149 | 150 | # ------------------- 1. Prepare data and split for multiple GPUs ------------------- # 151 | start_t = dt.datetime.now() 152 | print('Start graph building at: {}-{} {}:{}:{}'.format(start_t.month, 153 | start_t.day, 154 | start_t.hour, 155 | start_t.minute, 156 | start_t.second)) 157 | 158 | graph, labels, train_nid, val_nid, test_nid, node_feat = graph_data 159 | 160 | train_div, _ = divmod(train_nid.shape[0], n_gpus) 161 | val_div, _ = divmod(val_nid.shape[0], n_gpus) 162 | 163 | # just use one GPU, give all training/validation index to the one GPU 164 | if proc_id == (n_gpus - 1): 165 | train_nid_per_gpu = train_nid[proc_id * train_div: ] 166 | val_nid_per_gpu = val_nid[proc_id * val_div: ] 167 | # in case of multiple GPUs, split training/validation index to different GPUs 168 | else: 169 | train_nid_per_gpu = train_nid[proc_id * train_div: (proc_id + 1) * train_div] 170 | val_nid_per_gpu = val_nid[proc_id * val_div: (proc_id + 1) * val_div] 171 | 172 | train_sampler = MultiLayerNeighborSampler(fanouts) 173 | train_dataloader = NodeDataLoader(graph, 174 | train_nid_per_gpu, 175 | train_sampler, 176 | device=device, 177 | use_ddp=n_gpus > 1, 178 | batch_size=batch_size, 179 | shuffle=True, 180 | drop_last=False, 181 | num_workers=num_workers, 182 | ) 183 | val_sampler = MultiLayerNeighborSampler(fanouts) 184 | val_dataloader = NodeDataLoader(graph, 185 | val_nid_per_gpu, 186 | val_sampler, 187 | use_ddp=n_gpus > 1, 188 | device=device, 189 | batch_size=batch_size, 190 | shuffle=True, 191 | drop_last=False, 192 | num_workers=num_workers, 193 | ) 194 | e_t1 = dt.datetime.now() 195 | h, m, s = time_diff(e_t1, start_t) 196 | print('Model built used: {:02d}h {:02d}m {:02}s'.format(h, m, s)) 197 | 198 | # ------------------- 2. Build model for multiple GPUs ------------------------------ # 199 | start_t = dt.datetime.now() 200 | print('Start Model building at: {}-{} {}:{}:{}'.format(start_t.month, 201 | start_t.day, 202 | start_t.hour, 203 | start_t.minute, 204 | start_t.second)) 205 | 206 | in_feat = node_feat.shape[1] 207 | if gnn_model == 'graphsage': 208 | model = GraphSageModel(in_feat, hidden_dim, n_layers, n_classes) 209 | elif gnn_model == 'graphconv': 210 | model = GraphConvModel(in_feat, hidden_dim, n_layers, n_classes, 211 | norm='both', activation=F.relu, dropout=0) 212 | elif gnn_model == 'graphattn': 213 | model = GraphAttnModel(in_feat, hidden_dim, n_layers, n_classes, 214 | heads=([5] * n_layers), activation=F.relu, feat_drop=0, attn_drop=0) 215 | else: 216 | raise NotImplementedError('So far, only support three algorithms: GraphSage, GraphConv, and GraphAttn') 217 | 218 | model = model.to(device_id) 219 | 220 | if n_gpus > 1: 221 | model = thnn.parallel.DistributedDataParallel(model, 222 | device_ids=[device_id], 223 | output_device=device_id) 224 | e_t1 = dt.datetime.now() 225 | h, m, s = time_diff(e_t1, start_t) 226 | print('Model built used: {:02d}h {:02d}m {:02}s'.format(h, m, s)) 227 | 228 | # ------------------- 3. Build loss function and optimizer -------------------------- # 229 | loss_fn = thnn.CrossEntropyLoss().to(device_id) 230 | optimizer = optim.Adam(model.parameters(), lr=0.004, weight_decay=5e-4) 231 | 232 | earlystoper = early_stopper(patience=2, verbose=False) 233 | 234 | # ------------------- 4. Train model ----------------------------------------------- # 235 | print('Plan to train {} epoches \n'.format(epochs)) 236 | 237 | for epoch in range(epochs): 238 | if n_gpus > 1: 239 | train_dataloader.set_epoch(epoch) 240 | val_dataloader.set_epoch(epoch) 241 | 242 | # mini-batch for training 243 | train_loss_list = [] 244 | # train_acc_list = [] 245 | model.train() 246 | for step, (input_nodes, seeds, blocks) in enumerate(train_dataloader): 247 | # forward 248 | batch_inputs, batch_labels = load_subtensor(node_feat, labels, seeds, input_nodes, device_id) 249 | blocks = [block.to(device_id) for block in blocks] 250 | # metric and loss 251 | train_batch_logits = model(blocks, batch_inputs) 252 | train_loss = loss_fn(train_batch_logits, batch_labels) 253 | # backward 254 | optimizer.zero_grad() 255 | train_loss.backward() 256 | optimizer.step() 257 | 258 | train_loss_list.append(train_loss.cpu().detach().numpy()) 259 | tr_batch_pred = th.sum(th.argmax(train_batch_logits, dim=1) == batch_labels) / th.tensor(batch_labels.shape[0]) 260 | 261 | if step % 10 == 0: 262 | print('In epoch:{:03d}|batch:{:04d}, train_loss:{:4f}, train_acc:{:.4f}'.format(epoch, 263 | step, 264 | np.mean(train_loss_list), 265 | tr_batch_pred.detach())) 266 | 267 | # mini-batch for validation 268 | val_loss_list = [] 269 | val_acc_list = [] 270 | model.eval() 271 | for step, (input_nodes, seeds, blocks) in enumerate(val_dataloader): 272 | # forward 273 | batch_inputs, batch_labels = load_subtensor(node_feat, labels, seeds, input_nodes, device_id) 274 | blocks = [block.to(device_id) for block in blocks] 275 | # metric and loss 276 | val_batch_logits = model(blocks, batch_inputs) 277 | val_loss = loss_fn(val_batch_logits, batch_labels) 278 | 279 | val_loss_list.append(val_loss.detach().cpu().numpy()) 280 | val_batch_pred = th.sum(th.argmax(val_batch_logits, dim=1) == batch_labels) / th.tensor(batch_labels.shape[0]) 281 | 282 | if step % 10 == 0: 283 | print('In epoch:{:03d}|batch:{:04d}, val_loss:{:4f}, val_acc:{:.4f}'.format(epoch, 284 | step, 285 | np.mean(val_loss_list), 286 | val_batch_pred.detach())) 287 | 288 | # put validation results into message queue and aggregate at device 0 289 | if n_gpus > 1 and message_queue != None: 290 | message_queue.put(val_loss_list) 291 | 292 | if proc_id == 0: 293 | for i in range(n_gpus): 294 | loss = message_queue.get() 295 | print(loss) 296 | del loss 297 | else: 298 | print(val_loss_list) 299 | 300 | # -------------------------5. Collect stats ------------------------------------# 301 | # best_preds = earlystoper.val_preds 302 | # best_logits = earlystoper.val_logits 303 | # 304 | # best_precision, best_recall, best_f1 = get_f1_score(val_y.cpu().numpy(), best_preds) 305 | # best_auc = get_auc_score(val_y.cpu().numpy(), best_logits[:, 1]) 306 | # best_recall_at_99precision = recall_at_perc_precision(val_y.cpu().numpy(), best_logits[:, 1], threshold=0.99) 307 | # best_recall_at_90precision = recall_at_perc_precision(val_y.cpu().numpy(), best_logits[:, 1], threshold=0.9) 308 | 309 | # plot_roc(val_y.cpu().numpy(), best_logits[:, 1]) 310 | # plot_p_r_curve(val_y.cpu().numpy(), best_logits[:, 1]) 311 | 312 | # -------------------------6. Save models --------------------------------------# 313 | model_path = os.path.join(output_folder, 'dgl_model-' + '{:06d}'.format(np.random.randint(100000)) + '.pth') 314 | 315 | if n_gpus > 1: 316 | if proc_id == 0: 317 | model_para_dict = model.state_dict() 318 | th.save(model_para_dict, model_path) 319 | # after trainning, remember to cleanup and release resouces 320 | cleanup() 321 | else: 322 | model_para_dict = model.state_dict() 323 | th.save(model_para_dict, model_path) 324 | 325 | 326 | if __name__ == '__main__': 327 | parser = argparse.ArgumentParser(description='DGL_SamplingTrain') 328 | parser.add_argument('--data_path', type=str, help="Path of saved processed data files.") 329 | parser.add_argument('--gnn_model', type=str, choices=['graphsage', 'graphconv', 'graphattn'], 330 | required=True, default='graphsage') 331 | parser.add_argument('--hidden_dim', type=int, required=True) 332 | parser.add_argument('--n_layers', type=int, default=2) 333 | parser.add_argument("--fanout", type=str, required=True, help="fanout numbers", default='20,20') 334 | parser.add_argument('--batch_size', type=int, required=True, default=1) 335 | parser.add_argument('--GPU', nargs='+', type=int, required=True) 336 | parser.add_argument('--num_workers_per_gpu', type=int, default=4) 337 | parser.add_argument('--epochs', type=int, default=100) 338 | parser.add_argument('--out_path', type=str, required=True, help="Absolute path for saving model parameters") 339 | args = parser.parse_args() 340 | 341 | # parse arguments 342 | BASE_PATH = args.data_path 343 | MODEL_CHOICE = args.gnn_model 344 | HID_DIM = args.hidden_dim 345 | N_LAYERS = args.n_layers 346 | FANOUTS = [int(i) for i in args.fanout.split(',')] 347 | BATCH_SIZE = args.batch_size 348 | GPUS = args.GPU 349 | WORKERS = args.num_workers_per_gpu 350 | EPOCHS = args.epochs 351 | OUT_PATH = args.out_path 352 | 353 | # output arguments for logging 354 | print('Data path: {}'.format(BASE_PATH)) 355 | print('Used algorithm: {}'.format(MODEL_CHOICE)) 356 | print('Hidden dimensions: {}'.format(HID_DIM)) 357 | print('number of hidden layers: {}'.format(N_LAYERS)) 358 | print('Fanout list: {}'.format(FANOUTS)) 359 | print('Batch size: {}'.format(BATCH_SIZE)) 360 | print('GPU list: {}'.format(GPUS)) 361 | print('Number of workers per GPU: {}'.format(WORKERS)) 362 | print('Max number of epochs: {}'.format(EPOCHS)) 363 | print('Output path: {}'.format(OUT_PATH)) 364 | 365 | # Retrieve preprocessed data and add reverse edge and self-loop 366 | graph, labels, train_nid, val_nid, test_nid, node_feat = load_dgl_graph(BASE_PATH) 367 | graph = dgl.to_bidirected(graph, copy_ndata=True) 368 | graph = dgl.add_self_loop(graph) 369 | 370 | graph.create_formats_() 371 | 372 | # call train with CPU, one GPU, or multiple GPUs 373 | if GPUS[0] < 0: 374 | cpu_device = th.device('cpu') 375 | cpu_train(graph_data=(graph, labels, train_nid, val_nid, test_nid, node_feat), 376 | gnn_model=MODEL_CHOICE, 377 | n_layers=N_LAYERS, 378 | hidden_dim=HID_DIM, 379 | n_classes=23, 380 | fanouts=FANOUTS, 381 | batch_size=BATCH_SIZE, 382 | num_workers=WORKERS, 383 | device=cpu_device, 384 | epochs=EPOCHS, 385 | out_path=OUT_PATH) 386 | else: 387 | n_gpus = len(GPUS) 388 | 389 | if n_gpus == 1: 390 | gpu_train(0, n_gpus, GPUS, 391 | graph_data=(graph, labels, train_nid, val_nid, test_nid, node_feat), 392 | gnn_model=MODEL_CHOICE, hidden_dim=HID_DIM, n_layers=N_LAYERS, n_classes=23, 393 | fanouts=FANOUTS, batch_size=BATCH_SIZE, num_workers=WORKERS, epochs=EPOCHS, 394 | message_queue=None, output_folder=OUT_PATH) 395 | else: 396 | message_queue = mp.Queue() 397 | procs = [] 398 | for proc_id in range(n_gpus): 399 | p = mp.Process(target=gpu_train, 400 | args=(proc_id, n_gpus, GPUS, 401 | (graph, labels, train_nid, val_nid, test_nid, node_feat), 402 | MODEL_CHOICE, HID_DIM, N_LAYERS, 23, 403 | FANOUTS, BATCH_SIZE, WORKERS, EPOCHS, 404 | message_queue, OUT_PATH)) 405 | p.start() 406 | procs.append(p) 407 | for p in procs: 408 | p.join() -------------------------------------------------------------------------------- /gnn/model_utils.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | 3 | # Author:james Zhang 4 | """ 5 | utilities file for Pytorch models 6 | """ 7 | 8 | from functools import wraps 9 | import traceback 10 | from _thread import start_new_thread 11 | import torch.multiprocessing as mp 12 | 13 | 14 | class early_stopper(object): 15 | 16 | def __init__(self, patience=10, verbose=False, delta=0): 17 | 18 | self.patience = patience 19 | self.verbose = verbose 20 | self.delta = delta 21 | 22 | self.best_value = None 23 | self.is_earlystop = False 24 | self.count = 0 25 | self.val_preds = [] 26 | self.val_logits = [] 27 | 28 | def earlystop(self, loss, preds, logits): 29 | 30 | value = -loss 31 | 32 | if self.best_value is None: 33 | self.best_value = value 34 | self.val_preds = preds 35 | self.val_logits = logits 36 | elif value < self.best_value + self.delta: 37 | self.count += 1 38 | if self.verbose: 39 | print('EarlyStoper count: {:02d}'.format(self.count)) 40 | if self.count >= self.patience: 41 | self.is_earlystop = True 42 | else: 43 | self.best_value = value 44 | self.val_preds = preds 45 | self.val_logits = logits 46 | self.count = 0 47 | 48 | 49 | # According to https://github.com/pytorch/pytorch/issues/17199, this decorator 50 | # is necessary to make fork() and openmp work together. 51 | def thread_wrapped_func(func): 52 | """ 53 | 用于Pytorch的OpenMP的包装方法。Wraps a process entry point to make it work with OpenMP. 54 | """ 55 | @wraps(func) 56 | def decorated_function(*args, **kwargs): 57 | queue = mp.Queue() 58 | def _queue_result(): 59 | exception, trace, res = None, None, None 60 | try: 61 | res = func(*args, **kwargs) 62 | except Exception as e: 63 | exception = e 64 | trace = traceback.format_exc() 65 | queue.put((res, exception, trace)) 66 | 67 | start_new_thread(_queue_result, ()) 68 | result, exception, trace = queue.get() 69 | if exception is None: 70 | return result 71 | else: 72 | assert isinstance(exception, Exception) 73 | raise exception.__class__(trace) 74 | return decorated_function -------------------------------------------------------------------------------- /gnn/models.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | 3 | # Author:james Zhang 4 | 5 | """ 6 | Three common GNN models. 7 | """ 8 | 9 | import torch.nn as thnn 10 | import torch.nn.functional as F 11 | import dgl.nn as dglnn 12 | 13 | 14 | class GraphSageModel(thnn.Module): 15 | 16 | def __init__(self, 17 | in_feats, 18 | hidden_dim, 19 | n_layers, 20 | n_classes, 21 | activation=F.relu, 22 | dropout=0): 23 | super(GraphSageModel, self).__init__() 24 | self.in_feats = in_feats 25 | self.n_layers = n_layers 26 | self.hidden_dim = hidden_dim 27 | self.n_classes = n_classes 28 | self.activation = activation 29 | self.dropout = thnn.Dropout(dropout) 30 | 31 | self.layers = thnn.ModuleList() 32 | 33 | # build multiple layers 34 | self.layers.append(dglnn.SAGEConv(in_feats=self.in_feats, 35 | out_feats=self.hidden_dim, 36 | aggregator_type='mean')) 37 | # aggregator_type = 'pool')) 38 | for l in range(1, (self.n_layers - 1)): 39 | self.layers.append(dglnn.SAGEConv(in_feats=self.hidden_dim, 40 | out_feats=self.hidden_dim, 41 | aggregator_type='mean')) 42 | # aggregator_type='pool')) 43 | self.layers.append(dglnn.SAGEConv(in_feats=self.hidden_dim, 44 | out_feats=self.n_classes, 45 | aggregator_type='mean')) 46 | # aggregator_type = 'pool')) 47 | 48 | def forward(self, blocks, features): 49 | h = features 50 | 51 | for l, (layer, block) in enumerate(zip(self.layers, blocks)): 52 | h = layer(block, h) 53 | if l != len(self.layers) - 1: 54 | h = self.activation(h) 55 | h = self.dropout(h) 56 | 57 | return h 58 | 59 | 60 | class GraphConvModel(thnn.Module): 61 | 62 | def __init__(self, 63 | in_feats, 64 | hidden_dim, 65 | n_layers, 66 | n_classes, 67 | norm, 68 | activation, 69 | dropout): 70 | super(GraphConvModel, self).__init__() 71 | self.in_feats = in_feats 72 | self.n_layers = n_layers 73 | self.hidden_dim = hidden_dim 74 | self.n_classes = n_classes 75 | self.norm = norm 76 | self.activation = activation 77 | self.dropout = thnn.Dropout(dropout) 78 | 79 | self.layers = thnn.ModuleList() 80 | 81 | # build multiple layers 82 | self.layers.append(dglnn.GraphConv(in_feats=self.in_feats, 83 | out_feats=self.hidden_dim, 84 | norm=self.norm, 85 | activation=self.activation,)) 86 | for l in range(1, (self.n_layers - 1)): 87 | self.layers.append(dglnn.GraphConv(in_feats=self.hidden_dim, 88 | out_feats=self.hidden_dim, 89 | norm=self.norm, 90 | activation=self.activation)) 91 | self.layers.append(dglnn.GraphConv(in_feats=self.hidden_dim, 92 | out_feats=self.n_classes, 93 | norm=self.norm, 94 | activation=self.activation)) 95 | 96 | def forward(self, blocks, features): 97 | h = features 98 | 99 | for l, (layer, block) in enumerate(zip(self.layers, blocks)): 100 | h = layer(block, h) 101 | if l != len(self.layers) - 1: 102 | h = self.dropout(h) 103 | 104 | return h 105 | 106 | 107 | class GraphAttnModel(thnn.Module): 108 | 109 | def __init__(self, 110 | in_feats, 111 | hidden_dim, 112 | n_layers, 113 | n_classes, 114 | heads, 115 | activation, 116 | feat_drop, 117 | attn_drop 118 | ): 119 | super(GraphAttnModel, self).__init__() 120 | self.in_feats = in_feats 121 | self.hidden_dim = hidden_dim 122 | self.n_layers = n_layers 123 | self.n_classes = n_classes 124 | self.heads = heads 125 | self.feat_dropout = feat_drop 126 | self.attn_dropout = attn_drop 127 | self.activation = activation 128 | 129 | self.layers = thnn.ModuleList() 130 | 131 | # build multiple layers 132 | self.layers.append(dglnn.GATConv(in_feats=self.in_feats, 133 | out_feats=self.hidden_dim, 134 | num_heads=self.heads[0], 135 | feat_drop=self.feat_dropout, 136 | attn_drop=self.attn_dropout, 137 | activation=self.activation)) 138 | 139 | for l in range(1, (self.n_layers - 1)): 140 | # due to multi-head, the in_dim = num_hidden * num_heads 141 | self.layers.append(dglnn.GATConv(in_feats=self.hidden_dim * self.heads[l - 1], 142 | out_feats=self.hidden_dim, 143 | num_heads=self.heads[l], 144 | feat_drop=self.feat_dropout, 145 | attn_drop=self.attn_dropout, 146 | activation=self.activation)) 147 | 148 | self.layers.append(dglnn.GATConv(in_feats=self.hidden_dim * self.heads[-2], 149 | out_feats=self.n_classes, 150 | num_heads=self.heads[-1], 151 | feat_drop=self.feat_dropout, 152 | attn_drop=self.attn_dropout, 153 | activation=None)) 154 | 155 | def forward(self, blocks, features): 156 | h = features 157 | 158 | for l in range(self.n_layers - 1): 159 | h = self.layers[l](blocks[l], h).flatten(1) 160 | 161 | logits = self.layers[-1](blocks[-1],h).mean(1) 162 | 163 | return logits 164 | 165 | -------------------------------------------------------------------------------- /gnn/utils.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | 3 | """ 4 | Utilities to handel graph data 5 | """ 6 | 7 | import os 8 | import dgl 9 | import pickle 10 | import numpy as np 11 | import torch as th 12 | 13 | 14 | def load_dgl_graph(base_path): 15 | """ 16 | 读取预处理的Graph,Feature和Label文件,并构建相应的数据供训练代码使用。 17 | 18 | :param base_path: 19 | :return: 20 | """ 21 | graphs, _ = dgl.load_graphs(os.path.join(base_path, 'graph.bin')) 22 | graph = graphs[0] 23 | print('################ Graph info: ###############') 24 | print(graph) 25 | 26 | with open(os.path.join(base_path, 'labels.pkl'), 'rb') as f: 27 | label_data = pickle.load(f) 28 | 29 | labels = th.from_numpy(label_data['label']) 30 | tr_label_idx = th.from_numpy(label_data['tr_label_idx']).long() 31 | val_label_idx = th.from_numpy(label_data['val_label_idx']).long() 32 | test_label_idx = th.from_numpy(label_data['test_label_idx']).long() 33 | print('################ Label info: ################') 34 | print('Total labels (including not labeled): {}'.format(labels.shape[0])) 35 | print(' Training label number: {}'.format(tr_label_idx.shape[0])) 36 | print(' Validation label number: {}'.format(val_label_idx.shape[0])) 37 | print(' Test label number: {}'.format(test_label_idx.shape[0])) 38 | 39 | # get node features 40 | features = np.load(os.path.join(base_path, 'features.npy')) 41 | node_feat = th.from_numpy(features).float() 42 | print('################ Feature info: ###############') 43 | print('Node\'s feature shape:{}'.format(node_feat.shape)) 44 | 45 | return graph, labels, tr_label_idx, val_label_idx, test_label_idx, node_feat 46 | 47 | 48 | def time_diff(t_end, t_start): 49 | """ 50 | 计算时间差。t_end, t_start are datetime format, so use deltatime 51 | Parameters 52 | ---------- 53 | t_end 54 | t_start 55 | 56 | Returns 57 | ------- 58 | """ 59 | diff_sec = (t_end - t_start).seconds 60 | diff_min, rest_sec = divmod(diff_sec, 60) 61 | diff_hrs, rest_min = divmod(diff_min, 60) 62 | return (diff_hrs, rest_min, rest_sec) 63 | --------------------------------------------------------------------------------