├── MAXP 2021初赛数据探索和处理-1.ipynb
├── MAXP 2021初赛数据探索和处理-2.ipynb
├── MAXP 2021初赛数据探索和处理-3.ipynb
├── MAXP 2021初赛数据探索和处理-4.ipynb
├── README.md
└── gnn
├── __init__.py
├── model_train.py
├── model_utils.py
├── models.py
└── utils.py
/MAXP 2021初赛数据探索和处理-1.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# MAXP 2021初赛数据探索和处理-1\n",
8 | "\n",
9 | "由于节点的Feature维度比较高,所以先处理节点的ID,并保存。Feature的处理放到第2部分。"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import pandas as pd\n",
19 | "import numpy as np\n",
20 | "import os\n",
21 | "import gc"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "# path\n",
31 | "base_path = '/Users/jamezhan/PycharmProjects/MAXP/final_dataset'\n",
32 | "publish_path = 'publish'\n",
33 | "\n",
34 | "link_p1_path = os.path.join(base_path, publish_path, 'link_phase1.csv')\n",
35 | "train_nodes_path = os.path.join(base_path, publish_path, 'train_nodes.csv')\n",
36 | "val_nodes_path = os.path.join(base_path, publish_path, 'validation_nodes.csv')"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "### 读取边列表并统计节点数量"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 3,
49 | "metadata": {},
50 | "outputs": [
51 | {
52 | "name": "stdout",
53 | "output_type": "stream",
54 | "text": [
55 | "(29168650, 3)\n"
56 | ]
57 | },
58 | {
59 | "data": {
60 | "text/html": [
61 | "
\n",
62 | "\n",
75 | "
\n",
76 | " \n",
77 | " \n",
78 | " | \n",
79 | " paper_id | \n",
80 | " reference_paper_id | \n",
81 | " phase | \n",
82 | "
\n",
83 | " \n",
84 | " \n",
85 | " \n",
86 | " 0 | \n",
87 | " f10da75ad1eaf16eb2ffe0d85b76b332 | \n",
88 | " 711ef25bdb2c2421c0131af77b3ede1d | \n",
89 | " phase1 | \n",
90 | "
\n",
91 | " \n",
92 | " 1 | \n",
93 | " 9ac5a4327bd4f3dcb424c93ca9b84087 | \n",
94 | " 2d91c73304c5e8a94a0e5b4956093f71 | \n",
95 | " phase1 | \n",
96 | "
\n",
97 | " \n",
98 | " 2 | \n",
99 | " 9d91bfd4703e55dd814dfffb3d63fc33 | \n",
100 | " 33d4fdfe3967a1ffde9311bfe6827ef9 | \n",
101 | " phase1 | \n",
102 | "
\n",
103 | " \n",
104 | " 3 | \n",
105 | " e1bdbce05528952ed6579795373782d4 | \n",
106 | " 4bda690abec912b3b7b228b01fb6819a | \n",
107 | " phase1 | \n",
108 | "
\n",
109 | " \n",
110 | " 4 | \n",
111 | " eb623ac4b10df96835921edabbde2951 | \n",
112 | " c1a05bdfc88a73bf2830e705b2f39dbb | \n",
113 | " phase1 | \n",
114 | "
\n",
115 | " \n",
116 | "
\n",
117 | "
"
118 | ],
119 | "text/plain": [
120 | " paper_id reference_paper_id phase\n",
121 | "0 f10da75ad1eaf16eb2ffe0d85b76b332 711ef25bdb2c2421c0131af77b3ede1d phase1\n",
122 | "1 9ac5a4327bd4f3dcb424c93ca9b84087 2d91c73304c5e8a94a0e5b4956093f71 phase1\n",
123 | "2 9d91bfd4703e55dd814dfffb3d63fc33 33d4fdfe3967a1ffde9311bfe6827ef9 phase1\n",
124 | "3 e1bdbce05528952ed6579795373782d4 4bda690abec912b3b7b228b01fb6819a phase1\n",
125 | "4 eb623ac4b10df96835921edabbde2951 c1a05bdfc88a73bf2830e705b2f39dbb phase1"
126 | ]
127 | },
128 | "execution_count": 3,
129 | "metadata": {},
130 | "output_type": "execute_result"
131 | }
132 | ],
133 | "source": [
134 | "edge_df = pd.read_csv(link_p1_path)\n",
135 | "print(edge_df.shape)\n",
136 | "edge_df.head()"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 4,
142 | "metadata": {},
143 | "outputs": [
144 | {
145 | "data": {
146 | "text/plain": [
147 | "count 29168650\n",
148 | "unique 1\n",
149 | "top phase1\n",
150 | "freq 29168650\n",
151 | "Name: phase, dtype: object"
152 | ]
153 | },
154 | "execution_count": 4,
155 | "metadata": {},
156 | "output_type": "execute_result"
157 | }
158 | ],
159 | "source": [
160 | "edge_df.phase.describe()"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 5,
166 | "metadata": {},
167 | "outputs": [
168 | {
169 | "name": "stdout",
170 | "output_type": "stream",
171 | "text": [
172 | "(3031367, 1)\n"
173 | ]
174 | },
175 | {
176 | "data": {
177 | "text/html": [
178 | "\n",
179 | "\n",
192 | "
\n",
193 | " \n",
194 | " \n",
195 | " | \n",
196 | " paper_id | \n",
197 | "
\n",
198 | " \n",
199 | " \n",
200 | " \n",
201 | " 0 | \n",
202 | " f10da75ad1eaf16eb2ffe0d85b76b332 | \n",
203 | "
\n",
204 | " \n",
205 | " 1 | \n",
206 | " 9ac5a4327bd4f3dcb424c93ca9b84087 | \n",
207 | "
\n",
208 | " \n",
209 | " 2 | \n",
210 | " 9d91bfd4703e55dd814dfffb3d63fc33 | \n",
211 | "
\n",
212 | " \n",
213 | " 3 | \n",
214 | " e1bdbce05528952ed6579795373782d4 | \n",
215 | "
\n",
216 | " \n",
217 | "
\n",
218 | "
"
219 | ],
220 | "text/plain": [
221 | " paper_id\n",
222 | "0 f10da75ad1eaf16eb2ffe0d85b76b332\n",
223 | "1 9ac5a4327bd4f3dcb424c93ca9b84087\n",
224 | "2 9d91bfd4703e55dd814dfffb3d63fc33\n",
225 | "3 e1bdbce05528952ed6579795373782d4"
226 | ]
227 | },
228 | "execution_count": 5,
229 | "metadata": {},
230 | "output_type": "execute_result"
231 | }
232 | ],
233 | "source": [
234 | "nodes = pd.concat([edge_df['paper_id'], edge_df['reference_paper_id']])\n",
235 | "nodes = pd.DataFrame(nodes.drop_duplicates())\n",
236 | "nodes.rename(columns={0:'paper_id'}, inplace=True)\n",
237 | "\n",
238 | "print(nodes.shape)\n",
239 | "nodes.head(4)"
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "metadata": {},
245 | "source": [
246 | "#### 在边列表,一共出现了3,031,367个节点(paper_id)\n",
247 | "\n",
248 | "## 读取并查看train_nodes和validation_nodes里面的节点"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 6,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "def process_node(line):\n",
258 | " nid, feat_json, label = line.strip().split('\\\"')\n",
259 | " \n",
260 | " feat_list = [float(feat[1:-1]) for feat in feat_json[1:-1].split(', ')]\n",
261 | " \n",
262 | " if len(feat_list) != 300:\n",
263 | " print('此行数据有问题 {}'.format(line))\n",
264 | " \n",
265 | " return nid[:-1], feat_list, label[1:]"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": 7,
271 | "metadata": {
272 | "scrolled": true
273 | },
274 | "outputs": [
275 | {
276 | "name": "stdout",
277 | "output_type": "stream",
278 | "text": [
279 | "Processed 100000 train rows\n",
280 | "Processed 200000 train rows\n",
281 | "Processed 300000 train rows\n",
282 | "Processed 400000 train rows\n",
283 | "Processed 500000 train rows\n",
284 | "Processed 600000 train rows\n",
285 | "Processed 700000 train rows\n",
286 | "Processed 800000 train rows\n",
287 | "Processed 900000 train rows\n",
288 | "Processed 1000000 train rows\n",
289 | "Processed 1100000 train rows\n",
290 | "Processed 1200000 train rows\n",
291 | "Processed 1300000 train rows\n",
292 | "Processed 1400000 train rows\n",
293 | "Processed 1500000 train rows\n",
294 | "Processed 1600000 train rows\n",
295 | "Processed 1700000 train rows\n",
296 | "Processed 1800000 train rows\n",
297 | "Processed 1900000 train rows\n",
298 | "Processed 2000000 train rows\n",
299 | "Processed 2100000 train rows\n",
300 | "Processed 2200000 train rows\n",
301 | "Processed 2300000 train rows\n",
302 | "Processed 2400000 train rows\n",
303 | "Processed 2500000 train rows\n",
304 | "Processed 2600000 train rows\n",
305 | "Processed 2700000 train rows\n",
306 | "Processed 2800000 train rows\n",
307 | "Processed 2900000 train rows\n",
308 | "Processed 3000000 train rows\n",
309 | "Processed 100000 validation rows\n",
310 | "Processed 200000 validation rows\n",
311 | "Processed 300000 validation rows\n",
312 | "Processed 400000 validation rows\n",
313 | "Processed 500000 validation rows\n"
314 | ]
315 | }
316 | ],
317 | "source": [
318 | "# 先构建ID和Label的关系,保证ID的顺序和Feature的顺序一致即可\n",
319 | "nid_list = []\n",
320 | "label_list = []\n",
321 | "tr_val_list = []\n",
322 | "\n",
323 | "with open(train_nodes_path, 'r') as f:\n",
324 | " i = 0\n",
325 | " \n",
326 | " for line in f:\n",
327 | " if i > 0:\n",
328 | " nid, _, label = process_node(line)\n",
329 | " nid_list.append(nid)\n",
330 | " label_list.append(label)\n",
331 | " tr_val_list.append(0) # 0表示train的点\n",
332 | " i += 1\n",
333 | " if i % 100000 == 0:\n",
334 | " print('Processed {} train rows'.format(i))\n",
335 | "\n",
336 | "with open(val_nodes_path, 'r') as f:\n",
337 | " i = 0\n",
338 | " \n",
339 | " for line in f:\n",
340 | " if i > 0:\n",
341 | " nid, _, label = process_node(line)\n",
342 | " nid_list.append(nid)\n",
343 | " label_list.append(label)\n",
344 | " tr_val_list.append(1) # 1表示validation的点\n",
345 | " i += 1\n",
346 | " if i % 100000 == 0:\n",
347 | " print('Processed {} validation rows'.format(i))\n",
348 | " \n",
349 | "nid_arr = np.array(nid_list)\n",
350 | "label_arr = np.array(label_list)\n",
351 | "tr_val_arr = np.array(tr_val_list)\n",
352 | " \n",
353 | "nid_label_df = pd.DataFrame({'paper_id':nid_arr, 'Label': label_arr, 'Split_ID':tr_val_arr})"
354 | ]
355 | },
356 | {
357 | "cell_type": "code",
358 | "execution_count": 8,
359 | "metadata": {},
360 | "outputs": [
361 | {
362 | "name": "stdout",
363 | "output_type": "stream",
364 | "text": [
365 | "(3655033, 4)\n"
366 | ]
367 | },
368 | {
369 | "data": {
370 | "text/html": [
371 | "\n",
372 | "\n",
385 | "
\n",
386 | " \n",
387 | " \n",
388 | " | \n",
389 | " node_idx | \n",
390 | " paper_id | \n",
391 | " Label | \n",
392 | " Split_ID | \n",
393 | "
\n",
394 | " \n",
395 | " \n",
396 | " \n",
397 | " 0 | \n",
398 | " 0 | \n",
399 | " bfdee5ab86ef5e68da974d48a138c28e | \n",
400 | " S | \n",
401 | " 0 | \n",
402 | "
\n",
403 | " \n",
404 | " 1 | \n",
405 | " 1 | \n",
406 | " 78f43b8b62f040347fec0be44e5f08bd | \n",
407 | " | \n",
408 | " 0 | \n",
409 | "
\n",
410 | " \n",
411 | " 2 | \n",
412 | " 2 | \n",
413 | " a971601a0286d2701aa5cde46e63a9fd | \n",
414 | " G | \n",
415 | " 0 | \n",
416 | "
\n",
417 | " \n",
418 | " 3 | \n",
419 | " 3 | \n",
420 | " ac4b88a72146bae66cedfd1c13e1552d | \n",
421 | " | \n",
422 | " 0 | \n",
423 | "
\n",
424 | " \n",
425 | "
\n",
426 | "
"
427 | ],
428 | "text/plain": [
429 | " node_idx paper_id Label Split_ID\n",
430 | "0 0 bfdee5ab86ef5e68da974d48a138c28e S 0\n",
431 | "1 1 78f43b8b62f040347fec0be44e5f08bd 0\n",
432 | "2 2 a971601a0286d2701aa5cde46e63a9fd G 0\n",
433 | "3 3 ac4b88a72146bae66cedfd1c13e1552d 0"
434 | ]
435 | },
436 | "execution_count": 8,
437 | "metadata": {},
438 | "output_type": "execute_result"
439 | }
440 | ],
441 | "source": [
442 | "nid_label_df.reset_index(inplace=True)\n",
443 | "nid_label_df.rename(columns={'index':'node_idx'}, inplace=True)\n",
444 | "print(nid_label_df.shape)\n",
445 | "nid_label_df.head(4)"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": 9,
451 | "metadata": {},
452 | "outputs": [
453 | {
454 | "data": {
455 | "text/plain": [
456 | "(3655033,)"
457 | ]
458 | },
459 | "execution_count": 9,
460 | "metadata": {},
461 | "output_type": "execute_result"
462 | }
463 | ],
464 | "source": [
465 | "# 检查ID在Train和Validation是否有重复\n",
466 | "ids = nid_label_df.paper_id.drop_duplicates()\n",
467 | "ids.shape"
468 | ]
469 | },
470 | {
471 | "cell_type": "markdown",
472 | "metadata": {},
473 | "source": [
474 | "#### train和validation一共有3,655,033个节点"
475 | ]
476 | },
477 | {
478 | "cell_type": "markdown",
479 | "metadata": {},
480 | "source": [
481 | "#### 下面交叉比对边列表里的paper id和节点列表里的ID,检查是否有匹配不上的节点"
482 | ]
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": 10,
487 | "metadata": {},
488 | "outputs": [
489 | {
490 | "name": "stdout",
491 | "output_type": "stream",
492 | "text": [
493 | "(3030948, 4)\n"
494 | ]
495 | }
496 | ],
497 | "source": [
498 | "inboth = nid_label_df.merge(nodes, on='paper_id', how='inner')\n",
499 | "print(inboth.shape)"
500 | ]
501 | },
502 | {
503 | "cell_type": "code",
504 | "execution_count": 11,
505 | "metadata": {},
506 | "outputs": [
507 | {
508 | "name": "stdout",
509 | "output_type": "stream",
510 | "text": [
511 | "(3031367, 4)\n",
512 | "共有419边列表的节点在给出的节点列表里没有对应,缺乏特征\n"
513 | ]
514 | },
515 | {
516 | "data": {
517 | "text/html": [
518 | "\n",
519 | "\n",
532 | "
\n",
533 | " \n",
534 | " \n",
535 | " | \n",
536 | " paper_id | \n",
537 | " node_idx | \n",
538 | " Label | \n",
539 | " Split_ID | \n",
540 | "
\n",
541 | " \n",
542 | " \n",
543 | " \n",
544 | " 1124 | \n",
545 | " cc388eaec8838ce383d8a8792014fedb | \n",
546 | " NaN | \n",
547 | " NaN | \n",
548 | " NaN | \n",
549 | "
\n",
550 | " \n",
551 | " 1184 | \n",
552 | " 5d899f41e52f751fef843cf7b1d05b4a | \n",
553 | " NaN | \n",
554 | " NaN | \n",
555 | " NaN | \n",
556 | "
\n",
557 | " \n",
558 | " 14342 | \n",
559 | " 2b2004ec3c99a44b5cb6045ca547453e | \n",
560 | " NaN | \n",
561 | " NaN | \n",
562 | " NaN | \n",
563 | "
\n",
564 | " \n",
565 | " 15803 | \n",
566 | " d657c4451a9617f4eec96d3b2e6092c7 | \n",
567 | " NaN | \n",
568 | " NaN | \n",
569 | " NaN | \n",
570 | "
\n",
571 | " \n",
572 | "
\n",
573 | "
"
574 | ],
575 | "text/plain": [
576 | " paper_id node_idx Label Split_ID\n",
577 | "1124 cc388eaec8838ce383d8a8792014fedb NaN NaN NaN\n",
578 | "1184 5d899f41e52f751fef843cf7b1d05b4a NaN NaN NaN\n",
579 | "14342 2b2004ec3c99a44b5cb6045ca547453e NaN NaN NaN\n",
580 | "15803 d657c4451a9617f4eec96d3b2e6092c7 NaN NaN NaN"
581 | ]
582 | },
583 | "execution_count": 11,
584 | "metadata": {},
585 | "output_type": "execute_result"
586 | }
587 | ],
588 | "source": [
589 | "edge_node = nodes.merge(nid_label_df, on='paper_id', how='left')\n",
590 | "print(edge_node.shape)\n",
591 | "print('共有{}边列表的节点在给出的节点列表里没有对应,缺乏特征'.format(edge_node[edge_node.node_idx.isna()].shape[0]))\n",
592 | "edge_node[edge_node.node_idx.isna()].head(4)"
593 | ]
594 | },
595 | {
596 | "cell_type": "markdown",
597 | "metadata": {},
598 | "source": [
599 | "#### 合并边列表里独特的节点和train和validation的节点到一起,构成全部节点列表"
600 | ]
601 | },
602 | {
603 | "cell_type": "code",
604 | "execution_count": 12,
605 | "metadata": {},
606 | "outputs": [
607 | {
608 | "name": "stderr",
609 | "output_type": "stream",
610 | "text": [
611 | "/Users/jamezhan/anaconda3/envs/dgl/lib/python3.6/site-packages/ipykernel_launcher.py:3: UserWarning: Pandas doesn't allow columns to be created via a new attribute name - see https://pandas.pydata.org/pandas-docs/stable/indexing.html#attribute-access\n",
612 | " This is separate from the ipykernel package so we can avoid doing imports until\n",
613 | "/Users/jamezhan/anaconda3/envs/dgl/lib/python3.6/site-packages/pandas/core/generic.py:5170: SettingWithCopyWarning: \n",
614 | "A value is trying to be set on a copy of a slice from a DataFrame.\n",
615 | "Try using .loc[row_indexer,col_indexer] = value instead\n",
616 | "\n",
617 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
618 | " self[name] = value\n",
619 | "/Users/jamezhan/anaconda3/envs/dgl/lib/python3.6/site-packages/pandas/core/frame.py:4174: SettingWithCopyWarning: \n",
620 | "A value is trying to be set on a copy of a slice from a DataFrame\n",
621 | "\n",
622 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
623 | " errors=errors,\n"
624 | ]
625 | },
626 | {
627 | "data": {
628 | "text/html": [
629 | "\n",
630 | "\n",
643 | "
\n",
644 | " \n",
645 | " \n",
646 | " | \n",
647 | " node_idx | \n",
648 | " paper_id | \n",
649 | " Label | \n",
650 | " Split_ID | \n",
651 | "
\n",
652 | " \n",
653 | " \n",
654 | " \n",
655 | " 0 | \n",
656 | " 3655033 | \n",
657 | " cc388eaec8838ce383d8a8792014fedb | \n",
658 | " NaN | \n",
659 | " 1 | \n",
660 | "
\n",
661 | " \n",
662 | " 1 | \n",
663 | " 3655034 | \n",
664 | " 5d899f41e52f751fef843cf7b1d05b4a | \n",
665 | " NaN | \n",
666 | " 1 | \n",
667 | "
\n",
668 | " \n",
669 | " 2 | \n",
670 | " 3655035 | \n",
671 | " 2b2004ec3c99a44b5cb6045ca547453e | \n",
672 | " NaN | \n",
673 | " 1 | \n",
674 | "
\n",
675 | " \n",
676 | " 3 | \n",
677 | " 3655036 | \n",
678 | " d657c4451a9617f4eec96d3b2e6092c7 | \n",
679 | " NaN | \n",
680 | " 1 | \n",
681 | "
\n",
682 | " \n",
683 | "
\n",
684 | "
"
685 | ],
686 | "text/plain": [
687 | " node_idx paper_id Label Split_ID\n",
688 | "0 3655033 cc388eaec8838ce383d8a8792014fedb NaN 1\n",
689 | "1 3655034 5d899f41e52f751fef843cf7b1d05b4a NaN 1\n",
690 | "2 3655035 2b2004ec3c99a44b5cb6045ca547453e NaN 1\n",
691 | "3 3655036 d657c4451a9617f4eec96d3b2e6092c7 NaN 1"
692 | ]
693 | },
694 | "execution_count": 12,
695 | "metadata": {},
696 | "output_type": "execute_result"
697 | }
698 | ],
699 | "source": [
700 | "# 获取未能匹配上的节点,并构建新的节点DataFrame,然后和原有的Train/Validation节点Concat起来\n",
701 | "diff_nodes = edge_node[edge_node.node_idx.isna()]\n",
702 | "diff_nodes.ID = diff_nodes.paper_id\n",
703 | "diff_nodes.Split_ID = 1\n",
704 | "diff_nodes.node_idx = 0\n",
705 | "diff_nodes.reset_index(inplace=True)\n",
706 | "diff_nodes.drop(['index'], axis=1, inplace=True)\n",
707 | "diff_nodes.node_idx = diff_nodes.node_idx + diff_nodes.index + 3655033\n",
708 | "diff_nodes = diff_nodes[['node_idx', 'paper_id', 'Label', 'Split_ID']]\n",
709 | "diff_nodes.head(4)"
710 | ]
711 | },
712 | {
713 | "cell_type": "code",
714 | "execution_count": 13,
715 | "metadata": {},
716 | "outputs": [
717 | {
718 | "data": {
719 | "text/html": [
720 | "\n",
721 | "\n",
734 | "
\n",
735 | " \n",
736 | " \n",
737 | " | \n",
738 | " node_idx | \n",
739 | " paper_id | \n",
740 | " Label | \n",
741 | " Split_ID | \n",
742 | "
\n",
743 | " \n",
744 | " \n",
745 | " \n",
746 | " 415 | \n",
747 | " 3655448 | \n",
748 | " caed47d55d1e193ecb1fa97a415c13dd | \n",
749 | " NaN | \n",
750 | " 1 | \n",
751 | "
\n",
752 | " \n",
753 | " 416 | \n",
754 | " 3655449 | \n",
755 | " c82eb6be79a245392fb626b9a7e1f246 | \n",
756 | " NaN | \n",
757 | " 1 | \n",
758 | "
\n",
759 | " \n",
760 | " 417 | \n",
761 | " 3655450 | \n",
762 | " 926a31f6b378575204aae30b5dfa6dd3 | \n",
763 | " NaN | \n",
764 | " 1 | \n",
765 | "
\n",
766 | " \n",
767 | " 418 | \n",
768 | " 3655451 | \n",
769 | " bbace2419c3f827158ea4602f3eb35fa | \n",
770 | " NaN | \n",
771 | " 1 | \n",
772 | "
\n",
773 | " \n",
774 | "
\n",
775 | "
"
776 | ],
777 | "text/plain": [
778 | " node_idx paper_id Label Split_ID\n",
779 | "415 3655448 caed47d55d1e193ecb1fa97a415c13dd NaN 1\n",
780 | "416 3655449 c82eb6be79a245392fb626b9a7e1f246 NaN 1\n",
781 | "417 3655450 926a31f6b378575204aae30b5dfa6dd3 NaN 1\n",
782 | "418 3655451 bbace2419c3f827158ea4602f3eb35fa NaN 1"
783 | ]
784 | },
785 | "execution_count": 13,
786 | "metadata": {},
787 | "output_type": "execute_result"
788 | }
789 | ],
790 | "source": [
791 | "# Concatenate这419个未匹配到的节点到总的node的最后,从而让nid能接上\n",
792 | "nid_label_df = pd.concat([nid_label_df, diff_nodes])\n",
793 | "nid_label_df.tail(4)"
794 | ]
795 | },
796 | {
797 | "cell_type": "code",
798 | "execution_count": 14,
799 | "metadata": {},
800 | "outputs": [],
801 | "source": [
802 | "# 保存ID和Label到本地文件\n",
803 | "nid_label_df.to_csv(os.path.join(base_path, publish_path, './IDandLabels.csv'), index=False)\n",
804 | "# 保存未匹配上的节点用于feature的处理\n",
805 | "diff_nodes.to_csv(os.path.join(base_path, publish_path, './diff_nodes.csv'), index=False)\n"
806 | ]
807 | }
808 | ],
809 | "metadata": {
810 | "kernelspec": {
811 | "display_name": "Python [conda env:dgl]",
812 | "language": "python",
813 | "name": "conda-env-dgl-py"
814 | },
815 | "language_info": {
816 | "codemirror_mode": {
817 | "name": "ipython",
818 | "version": 3
819 | },
820 | "file_extension": ".py",
821 | "mimetype": "text/x-python",
822 | "name": "python",
823 | "nbconvert_exporter": "python",
824 | "pygments_lexer": "ipython3",
825 | "version": "3.6.10"
826 | }
827 | },
828 | "nbformat": 4,
829 | "nbformat_minor": 2
830 | }
831 |
--------------------------------------------------------------------------------
/MAXP 2021初赛数据探索和处理-2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# MAXP 2021初赛数据探索和处理-2\n",
8 | "\n",
9 | "处理Feature,并结合第1部分里得到的不匹配的节点,生成新的Feature,并保存。"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 4,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import pandas as pd\n",
19 | "import numpy as np\n",
20 | "import os\n",
21 | "import gc"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 5,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "# path\n",
31 | "base_path = '/Users/jamezhan/PycharmProjects/MAXP/final_dataset'\n",
32 | "publish_path = 'publish'\n",
33 | "\n",
34 | "diff_node_path = os.path.join(base_path, publish_path, 'diff_nodes.csv')\n",
35 | "train_nodes_path = os.path.join(base_path, publish_path, 'train_nodes.csv')\n",
36 | "val_nodes_path = os.path.join(base_path, publish_path, 'validation_nodes.csv')"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "### 处理节点的特征"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 6,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "def process_node(line):\n",
53 | " nid, feat_json, label = line.strip().split('\\\"')\n",
54 | " \n",
55 | " feat_list = [float(feat[1:-1]) for feat in feat_json[1:-1].split(', ')]\n",
56 | " \n",
57 | " if len(feat_list) != 300:\n",
58 | " print('此行数据有问题 {}'.format(line))\n",
59 | " \n",
60 | " return nid[:-1], feat_list, label[1:]"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 7,
66 | "metadata": {
67 | "scrolled": true
68 | },
69 | "outputs": [
70 | {
71 | "name": "stdout",
72 | "output_type": "stream",
73 | "text": [
74 | "Processed 100000 train rows\n",
75 | "Processed 200000 train rows\n",
76 | "Processed 300000 train rows\n",
77 | "Processed 400000 train rows\n",
78 | "Processed 500000 train rows\n",
79 | "Processed 600000 train rows\n",
80 | "Processed 700000 train rows\n",
81 | "Processed 800000 train rows\n",
82 | "Processed 900000 train rows\n",
83 | "Processed 1000000 train rows\n",
84 | "Processed 1100000 train rows\n",
85 | "Processed 1200000 train rows\n",
86 | "Processed 1300000 train rows\n",
87 | "Processed 1400000 train rows\n",
88 | "Processed 1500000 train rows\n",
89 | "Processed 1600000 train rows\n",
90 | "Processed 1700000 train rows\n",
91 | "Processed 1800000 train rows\n",
92 | "Processed 1900000 train rows\n",
93 | "Processed 2000000 train rows\n",
94 | "Processed 2100000 train rows\n",
95 | "Processed 2200000 train rows\n",
96 | "Processed 2300000 train rows\n",
97 | "Processed 2400000 train rows\n",
98 | "Processed 2500000 train rows\n",
99 | "Processed 2600000 train rows\n",
100 | "Processed 2700000 train rows\n",
101 | "Processed 2800000 train rows\n",
102 | "Processed 2900000 train rows\n",
103 | "Processed 3000000 train rows\n",
104 | "Processed 100000 validation rows\n",
105 | "Processed 200000 validation rows\n",
106 | "Processed 300000 validation rows\n",
107 | "Processed 400000 validation rows\n",
108 | "Processed 500000 validation rows\n"
109 | ]
110 | }
111 | ],
112 | "source": [
113 | "# 下面处理特征Feature并存储亿备后用\n",
114 | "# nid_list = []\n",
115 | "feature_list = []\n",
116 | "\n",
117 | "with open(train_nodes_path, 'r') as f:\n",
118 | " i = 0\n",
119 | " \n",
120 | " for line in f:\n",
121 | " if i > 0:\n",
122 | " _, features, _ = process_node(line)\n",
123 | "# nid_list.append(nid)\n",
124 | " feature_list.append(features)\n",
125 | " i += 1\n",
126 | " if i % 100000 == 0:\n",
127 | " print('Processed {} train rows'.format(i))\n",
128 | " \n",
129 | "with open(val_nodes_path, 'r') as f:\n",
130 | " i = 0\n",
131 | " \n",
132 | " for line in f:\n",
133 | " if i > 0:\n",
134 | " _, features, _ = process_node(line)\n",
135 | "# nid_list.append(nid)\n",
136 | " feature_list.append(features)\n",
137 | " i += 1\n",
138 | " if i % 100000 == 0:\n",
139 | " print('Processed {} validation rows'.format(i))\n",
140 | " \n",
141 | "# nid_arr = np.array(nid_list)\n",
142 | "feat_arr = np.array(feature_list)"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": 8,
148 | "metadata": {},
149 | "outputs": [
150 | {
151 | "data": {
152 | "text/plain": [
153 | "0"
154 | ]
155 | },
156 | "execution_count": 8,
157 | "metadata": {},
158 | "output_type": "execute_result"
159 | }
160 | ],
161 | "source": [
162 | "# 删除list以节省内存\n",
163 | "del feature_list\n",
164 | "gc.collect()"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": 9,
170 | "metadata": {},
171 | "outputs": [],
172 | "source": [
173 | "# 给未匹配上的419个节点造300维的特征,这里用其他所有节点的300维的平均值来作为他们的特征\n",
174 | "# 更好的方法是用每个节点的所有邻居的特征的平均,这里就不搞这么复杂了。\n",
175 | "diff_node_feat_arr = np.tile(np.mean(feat_arr, axis=0),(419, 1))"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 11,
181 | "metadata": {},
182 | "outputs": [
183 | {
184 | "data": {
185 | "text/plain": [
186 | "(3655452, 300)"
187 | ]
188 | },
189 | "execution_count": 11,
190 | "metadata": {},
191 | "output_type": "execute_result"
192 | }
193 | ],
194 | "source": [
195 | "feat_arr = np.concatenate((feat_arr, diff_node_feat_arr), axis=0)\n",
196 | "feat_arr.shape"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "# 使用Numpy保存特征为.npy格式,以节省存储空间和提高读写速度\n",
206 | "with open(os.path.join(base_path, publish_path, './features.npy'), 'wb') as f:\n",
207 | " np.save(f, feat_arr)"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": null,
213 | "metadata": {},
214 | "outputs": [],
215 | "source": []
216 | }
217 | ],
218 | "metadata": {
219 | "kernelspec": {
220 | "display_name": "Python [conda env:dgl]",
221 | "language": "python",
222 | "name": "conda-env-dgl-py"
223 | },
224 | "language_info": {
225 | "codemirror_mode": {
226 | "name": "ipython",
227 | "version": 3
228 | },
229 | "file_extension": ".py",
230 | "mimetype": "text/x-python",
231 | "name": "python",
232 | "nbconvert_exporter": "python",
233 | "pygments_lexer": "ipython3",
234 | "version": "3.6.10"
235 | }
236 | },
237 | "nbformat": 4,
238 | "nbformat_minor": 2
239 | }
240 |
--------------------------------------------------------------------------------
/MAXP 2021初赛数据探索和处理-3.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# MAXP 2021初赛数据探索和处理-3\n",
8 | "\n",
9 | "使用步骤1里处理好的节点的ID,来构建DGL的graph所需要的边列表。"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stderr",
19 | "output_type": "stream",
20 | "text": [
21 | "Using backend: pytorch\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "import pandas as pd\n",
27 | "import numpy as np\n",
28 | "import os\n",
29 | "\n",
30 | "import dgl"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 2,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "# path\n",
40 | "base_path = '/Users/jamezhan/PycharmProjects/MAXP/final_dataset'\n",
41 | "publish_path = 'publish'\n",
42 | "\n",
43 | "link_p1_path = os.path.join(base_path, publish_path, 'link_phase1.csv')\n",
44 | "nodes_path = os.path.join(base_path, publish_path, 'IDandLabels.csv')"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "### 读取节点列表"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 3,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "(3655452, 4)\n"
64 | ]
65 | },
66 | {
67 | "data": {
68 | "text/html": [
69 | "\n",
70 | "\n",
83 | "
\n",
84 | " \n",
85 | " \n",
86 | " | \n",
87 | " node_idx | \n",
88 | " paper_id | \n",
89 | " Label | \n",
90 | " Split_ID | \n",
91 | "
\n",
92 | " \n",
93 | " \n",
94 | " \n",
95 | " 3655448 | \n",
96 | " 3655448 | \n",
97 | " caed47d55d1e193ecb1fa97a415c13dd | \n",
98 | " NaN | \n",
99 | " 1 | \n",
100 | "
\n",
101 | " \n",
102 | " 3655449 | \n",
103 | " 3655449 | \n",
104 | " c82eb6be79a245392fb626b9a7e1f246 | \n",
105 | " NaN | \n",
106 | " 1 | \n",
107 | "
\n",
108 | " \n",
109 | " 3655450 | \n",
110 | " 3655450 | \n",
111 | " 926a31f6b378575204aae30b5dfa6dd3 | \n",
112 | " NaN | \n",
113 | " 1 | \n",
114 | "
\n",
115 | " \n",
116 | " 3655451 | \n",
117 | " 3655451 | \n",
118 | " bbace2419c3f827158ea4602f3eb35fa | \n",
119 | " NaN | \n",
120 | " 1 | \n",
121 | "
\n",
122 | " \n",
123 | "
\n",
124 | "
"
125 | ],
126 | "text/plain": [
127 | " node_idx paper_id Label Split_ID\n",
128 | "3655448 3655448 caed47d55d1e193ecb1fa97a415c13dd NaN 1\n",
129 | "3655449 3655449 c82eb6be79a245392fb626b9a7e1f246 NaN 1\n",
130 | "3655450 3655450 926a31f6b378575204aae30b5dfa6dd3 NaN 1\n",
131 | "3655451 3655451 bbace2419c3f827158ea4602f3eb35fa NaN 1"
132 | ]
133 | },
134 | "execution_count": 3,
135 | "metadata": {},
136 | "output_type": "execute_result"
137 | }
138 | ],
139 | "source": [
140 | "nodes_df = pd.read_csv(nodes_path, dtype={'Label':str})\n",
141 | "print(nodes_df.shape)\n",
142 | "nodes_df.tail(4)"
143 | ]
144 | },
145 | {
146 | "cell_type": "markdown",
147 | "metadata": {},
148 | "source": [
149 | "### 读取边列表"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": 4,
155 | "metadata": {},
156 | "outputs": [
157 | {
158 | "name": "stdout",
159 | "output_type": "stream",
160 | "text": [
161 | "(29168650, 3)\n"
162 | ]
163 | },
164 | {
165 | "data": {
166 | "text/html": [
167 | "\n",
168 | "\n",
181 | "
\n",
182 | " \n",
183 | " \n",
184 | " | \n",
185 | " paper_id | \n",
186 | " reference_paper_id | \n",
187 | " phase | \n",
188 | "
\n",
189 | " \n",
190 | " \n",
191 | " \n",
192 | " 0 | \n",
193 | " f10da75ad1eaf16eb2ffe0d85b76b332 | \n",
194 | " 711ef25bdb2c2421c0131af77b3ede1d | \n",
195 | " phase1 | \n",
196 | "
\n",
197 | " \n",
198 | " 1 | \n",
199 | " 9ac5a4327bd4f3dcb424c93ca9b84087 | \n",
200 | " 2d91c73304c5e8a94a0e5b4956093f71 | \n",
201 | " phase1 | \n",
202 | "
\n",
203 | " \n",
204 | " 2 | \n",
205 | " 9d91bfd4703e55dd814dfffb3d63fc33 | \n",
206 | " 33d4fdfe3967a1ffde9311bfe6827ef9 | \n",
207 | " phase1 | \n",
208 | "
\n",
209 | " \n",
210 | " 3 | \n",
211 | " e1bdbce05528952ed6579795373782d4 | \n",
212 | " 4bda690abec912b3b7b228b01fb6819a | \n",
213 | " phase1 | \n",
214 | "
\n",
215 | " \n",
216 | " 4 | \n",
217 | " eb623ac4b10df96835921edabbde2951 | \n",
218 | " c1a05bdfc88a73bf2830e705b2f39dbb | \n",
219 | " phase1 | \n",
220 | "
\n",
221 | " \n",
222 | "
\n",
223 | "
"
224 | ],
225 | "text/plain": [
226 | " paper_id reference_paper_id phase\n",
227 | "0 f10da75ad1eaf16eb2ffe0d85b76b332 711ef25bdb2c2421c0131af77b3ede1d phase1\n",
228 | "1 9ac5a4327bd4f3dcb424c93ca9b84087 2d91c73304c5e8a94a0e5b4956093f71 phase1\n",
229 | "2 9d91bfd4703e55dd814dfffb3d63fc33 33d4fdfe3967a1ffde9311bfe6827ef9 phase1\n",
230 | "3 e1bdbce05528952ed6579795373782d4 4bda690abec912b3b7b228b01fb6819a phase1\n",
231 | "4 eb623ac4b10df96835921edabbde2951 c1a05bdfc88a73bf2830e705b2f39dbb phase1"
232 | ]
233 | },
234 | "execution_count": 4,
235 | "metadata": {},
236 | "output_type": "execute_result"
237 | }
238 | ],
239 | "source": [
240 | "edges_df = pd.read_csv(link_p1_path)\n",
241 | "print(edges_df.shape)\n",
242 | "edges_df.head()"
243 | ]
244 | },
245 | {
246 | "cell_type": "markdown",
247 | "metadata": {},
248 | "source": [
249 | "## Join点列表和边列表以生成从0开始的边列表\n",
250 | "\n",
251 | "DGL默认节点是从0开始,并以最大的ID为容量构建Graph,因此这里我们先构建从0开始的边列表。"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": 5,
257 | "metadata": {},
258 | "outputs": [
259 | {
260 | "name": "stdout",
261 | "output_type": "stream",
262 | "text": [
263 | "(29168650, 10)\n"
264 | ]
265 | },
266 | {
267 | "data": {
268 | "text/html": [
269 | "\n",
270 | "\n",
283 | "
\n",
284 | " \n",
285 | " \n",
286 | " | \n",
287 | " paper_id_x | \n",
288 | " reference_paper_id | \n",
289 | " phase | \n",
290 | " node_idx_x | \n",
291 | " Label_x | \n",
292 | " Split_ID_x | \n",
293 | " node_idx_y | \n",
294 | " paper_id_y | \n",
295 | " Label_y | \n",
296 | " Split_ID_y | \n",
297 | "
\n",
298 | " \n",
299 | " \n",
300 | " \n",
301 | " 0 | \n",
302 | " f10da75ad1eaf16eb2ffe0d85b76b332 | \n",
303 | " 711ef25bdb2c2421c0131af77b3ede1d | \n",
304 | " phase1 | \n",
305 | " 529879 | \n",
306 | " NaN | \n",
307 | " 0 | \n",
308 | " 2364950 | \n",
309 | " 711ef25bdb2c2421c0131af77b3ede1d | \n",
310 | " NaN | \n",
311 | " 0 | \n",
312 | "
\n",
313 | " \n",
314 | " 1 | \n",
315 | " 9ac5a4327bd4f3dcb424c93ca9b84087 | \n",
316 | " 2d91c73304c5e8a94a0e5b4956093f71 | \n",
317 | " phase1 | \n",
318 | " 410481 | \n",
319 | " D | \n",
320 | " 0 | \n",
321 | " 384023 | \n",
322 | " 2d91c73304c5e8a94a0e5b4956093f71 | \n",
323 | " K | \n",
324 | " 0 | \n",
325 | "
\n",
326 | " \n",
327 | " 2 | \n",
328 | " 9d91bfd4703e55dd814dfffb3d63fc33 | \n",
329 | " 33d4fdfe3967a1ffde9311bfe6827ef9 | \n",
330 | " phase1 | \n",
331 | " 2196044 | \n",
332 | " D | \n",
333 | " 0 | \n",
334 | " 1895619 | \n",
335 | " 33d4fdfe3967a1ffde9311bfe6827ef9 | \n",
336 | " N | \n",
337 | " 0 | \n",
338 | "
\n",
339 | " \n",
340 | " 3 | \n",
341 | " e1bdbce05528952ed6579795373782d4 | \n",
342 | " 4bda690abec912b3b7b228b01fb6819a | \n",
343 | " phase1 | \n",
344 | " 2545623 | \n",
345 | " NaN | \n",
346 | " 0 | \n",
347 | " 2175977 | \n",
348 | " 4bda690abec912b3b7b228b01fb6819a | \n",
349 | " NaN | \n",
350 | " 0 | \n",
351 | "
\n",
352 | " \n",
353 | "
\n",
354 | "
"
355 | ],
356 | "text/plain": [
357 | " paper_id_x reference_paper_id phase \\\n",
358 | "0 f10da75ad1eaf16eb2ffe0d85b76b332 711ef25bdb2c2421c0131af77b3ede1d phase1 \n",
359 | "1 9ac5a4327bd4f3dcb424c93ca9b84087 2d91c73304c5e8a94a0e5b4956093f71 phase1 \n",
360 | "2 9d91bfd4703e55dd814dfffb3d63fc33 33d4fdfe3967a1ffde9311bfe6827ef9 phase1 \n",
361 | "3 e1bdbce05528952ed6579795373782d4 4bda690abec912b3b7b228b01fb6819a phase1 \n",
362 | "\n",
363 | " node_idx_x Label_x Split_ID_x node_idx_y \\\n",
364 | "0 529879 NaN 0 2364950 \n",
365 | "1 410481 D 0 384023 \n",
366 | "2 2196044 D 0 1895619 \n",
367 | "3 2545623 NaN 0 2175977 \n",
368 | "\n",
369 | " paper_id_y Label_y Split_ID_y \n",
370 | "0 711ef25bdb2c2421c0131af77b3ede1d NaN 0 \n",
371 | "1 2d91c73304c5e8a94a0e5b4956093f71 K 0 \n",
372 | "2 33d4fdfe3967a1ffde9311bfe6827ef9 N 0 \n",
373 | "3 4bda690abec912b3b7b228b01fb6819a NaN 0 "
374 | ]
375 | },
376 | "execution_count": 5,
377 | "metadata": {},
378 | "output_type": "execute_result"
379 | }
380 | ],
381 | "source": [
382 | "# Merge paper_id列\n",
383 | "edges = edges_df.merge(nodes_df, on='paper_id', how='left')\n",
384 | "# Merge reference_paper_id列\n",
385 | "edges = edges.merge(nodes_df, left_on='reference_paper_id', right_on='paper_id', how='left')\n",
386 | "print(edges.shape)\n",
387 | "edges.head(4)"
388 | ]
389 | },
390 | {
391 | "cell_type": "markdown",
392 | "metadata": {},
393 | "source": [
394 | "#### 修改node_idx_* 列的名称作为新的node id,并只保留需要的列"
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "execution_count": 6,
400 | "metadata": {},
401 | "outputs": [
402 | {
403 | "data": {
404 | "text/html": [
405 | "\n",
406 | "\n",
419 | "
\n",
420 | " \n",
421 | " \n",
422 | " | \n",
423 | " src_nid | \n",
424 | " dst_nid | \n",
425 | " paper_id | \n",
426 | " reference_paper_id | \n",
427 | "
\n",
428 | " \n",
429 | " \n",
430 | " \n",
431 | " 0 | \n",
432 | " 529879 | \n",
433 | " 2364950 | \n",
434 | " f10da75ad1eaf16eb2ffe0d85b76b332 | \n",
435 | " 711ef25bdb2c2421c0131af77b3ede1d | \n",
436 | "
\n",
437 | " \n",
438 | " 1 | \n",
439 | " 410481 | \n",
440 | " 384023 | \n",
441 | " 9ac5a4327bd4f3dcb424c93ca9b84087 | \n",
442 | " 2d91c73304c5e8a94a0e5b4956093f71 | \n",
443 | "
\n",
444 | " \n",
445 | " 2 | \n",
446 | " 2196044 | \n",
447 | " 1895619 | \n",
448 | " 9d91bfd4703e55dd814dfffb3d63fc33 | \n",
449 | " 33d4fdfe3967a1ffde9311bfe6827ef9 | \n",
450 | "
\n",
451 | " \n",
452 | " 3 | \n",
453 | " 2545623 | \n",
454 | " 2175977 | \n",
455 | " e1bdbce05528952ed6579795373782d4 | \n",
456 | " 4bda690abec912b3b7b228b01fb6819a | \n",
457 | "
\n",
458 | " \n",
459 | "
\n",
460 | "
"
461 | ],
462 | "text/plain": [
463 | " src_nid dst_nid paper_id \\\n",
464 | "0 529879 2364950 f10da75ad1eaf16eb2ffe0d85b76b332 \n",
465 | "1 410481 384023 9ac5a4327bd4f3dcb424c93ca9b84087 \n",
466 | "2 2196044 1895619 9d91bfd4703e55dd814dfffb3d63fc33 \n",
467 | "3 2545623 2175977 e1bdbce05528952ed6579795373782d4 \n",
468 | "\n",
469 | " reference_paper_id \n",
470 | "0 711ef25bdb2c2421c0131af77b3ede1d \n",
471 | "1 2d91c73304c5e8a94a0e5b4956093f71 \n",
472 | "2 33d4fdfe3967a1ffde9311bfe6827ef9 \n",
473 | "3 4bda690abec912b3b7b228b01fb6819a "
474 | ]
475 | },
476 | "execution_count": 6,
477 | "metadata": {},
478 | "output_type": "execute_result"
479 | }
480 | ],
481 | "source": [
482 | "edges.rename(columns={'paper_id_x': 'paper_id', 'node_idx_x':'src_nid', 'node_idx_y':'dst_nid'}, inplace=True)\n",
483 | "edges = edges[['src_nid', 'dst_nid', 'paper_id', 'reference_paper_id']]\n",
484 | "edges.head(4)"
485 | ]
486 | },
487 | {
488 | "cell_type": "markdown",
489 | "metadata": {},
490 | "source": [
491 | "## 构建DGL的Graph"
492 | ]
493 | },
494 | {
495 | "cell_type": "code",
496 | "execution_count": 7,
497 | "metadata": {},
498 | "outputs": [],
499 | "source": [
500 | "# 讲源节点和目标节点转换成Numpy的NDArray\n",
501 | "src_nid = edges.src_nid.to_numpy()\n",
502 | "dst_nid = edges.dst_nid.to_numpy()"
503 | ]
504 | },
505 | {
506 | "cell_type": "code",
507 | "execution_count": 8,
508 | "metadata": {},
509 | "outputs": [
510 | {
511 | "name": "stdout",
512 | "output_type": "stream",
513 | "text": [
514 | "Graph(num_nodes=3655452, num_edges=29168650,\n",
515 | " ndata_schemes={}\n",
516 | " edata_schemes={})\n"
517 | ]
518 | }
519 | ],
520 | "source": [
521 | "# 构建一个DGL的graph\n",
522 | "graph = dgl.graph((src_nid, dst_nid))\n",
523 | "print(graph)"
524 | ]
525 | },
526 | {
527 | "cell_type": "code",
528 | "execution_count": 10,
529 | "metadata": {},
530 | "outputs": [],
531 | "source": [
532 | "# 保存Graph为二进制格式方便后面建模时的快速读取\n",
533 | "graph_path = os.path.join(base_path, publish_path, 'graph.bin')\n",
534 | "dgl.data.utils.save_graphs(graph_path, [graph])"
535 | ]
536 | },
537 | {
538 | "cell_type": "code",
539 | "execution_count": null,
540 | "metadata": {},
541 | "outputs": [],
542 | "source": []
543 | }
544 | ],
545 | "metadata": {
546 | "kernelspec": {
547 | "display_name": "Python [conda env:dgl]",
548 | "language": "python",
549 | "name": "conda-env-dgl-py"
550 | },
551 | "language_info": {
552 | "codemirror_mode": {
553 | "name": "ipython",
554 | "version": 3
555 | },
556 | "file_extension": ".py",
557 | "mimetype": "text/x-python",
558 | "name": "python",
559 | "nbconvert_exporter": "python",
560 | "pygments_lexer": "ipython3",
561 | "version": "3.6.10"
562 | }
563 | },
564 | "nbformat": 4,
565 | "nbformat_minor": 2
566 | }
567 |
--------------------------------------------------------------------------------
/MAXP 2021初赛数据探索和处理-4.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# MAXP 2021初赛数据探索和处理-4\n",
8 | "\n",
9 | "把原始数据的标签转换成数字形式,并完成Train/Validation/Test的分割。这里的划分是用于比赛模型训练和模型选择用的,并不是原始的文件名。"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 28,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import pandas as pd\n",
19 | "import numpy as np\n",
20 | "import os\n",
21 | "import pickle\n",
22 | "\n",
23 | "import dgl"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# path\n",
33 | "base_path = '/Users/jamezhan/PycharmProjects/MAXP/final_dataset'\n",
34 | "publish_path = 'publish'\n",
35 | "\n",
36 | "nodes_path = os.path.join(base_path, publish_path, 'IDandLabels.csv')"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "### 读取节点列表"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 3,
49 | "metadata": {},
50 | "outputs": [
51 | {
52 | "name": "stdout",
53 | "output_type": "stream",
54 | "text": [
55 | "(3655452, 4)\n"
56 | ]
57 | },
58 | {
59 | "data": {
60 | "text/html": [
61 | "\n",
62 | "\n",
75 | "
\n",
76 | " \n",
77 | " \n",
78 | " | \n",
79 | " node_idx | \n",
80 | " paper_id | \n",
81 | " Label | \n",
82 | " Split_ID | \n",
83 | "
\n",
84 | " \n",
85 | " \n",
86 | " \n",
87 | " 3655448 | \n",
88 | " 3655448 | \n",
89 | " caed47d55d1e193ecb1fa97a415c13dd | \n",
90 | " NaN | \n",
91 | " 1 | \n",
92 | "
\n",
93 | " \n",
94 | " 3655449 | \n",
95 | " 3655449 | \n",
96 | " c82eb6be79a245392fb626b9a7e1f246 | \n",
97 | " NaN | \n",
98 | " 1 | \n",
99 | "
\n",
100 | " \n",
101 | " 3655450 | \n",
102 | " 3655450 | \n",
103 | " 926a31f6b378575204aae30b5dfa6dd3 | \n",
104 | " NaN | \n",
105 | " 1 | \n",
106 | "
\n",
107 | " \n",
108 | " 3655451 | \n",
109 | " 3655451 | \n",
110 | " bbace2419c3f827158ea4602f3eb35fa | \n",
111 | " NaN | \n",
112 | " 1 | \n",
113 | "
\n",
114 | " \n",
115 | "
\n",
116 | "
"
117 | ],
118 | "text/plain": [
119 | " node_idx paper_id Label Split_ID\n",
120 | "3655448 3655448 caed47d55d1e193ecb1fa97a415c13dd NaN 1\n",
121 | "3655449 3655449 c82eb6be79a245392fb626b9a7e1f246 NaN 1\n",
122 | "3655450 3655450 926a31f6b378575204aae30b5dfa6dd3 NaN 1\n",
123 | "3655451 3655451 bbace2419c3f827158ea4602f3eb35fa NaN 1"
124 | ]
125 | },
126 | "execution_count": 3,
127 | "metadata": {},
128 | "output_type": "execute_result"
129 | }
130 | ],
131 | "source": [
132 | "nodes_df = pd.read_csv(nodes_path, dtype={'Label':str})\n",
133 | "print(nodes_df.shape)\n",
134 | "nodes_df.tail(4)"
135 | ]
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "metadata": {},
140 | "source": [
141 | "### 转换标签为数字"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 6,
147 | "metadata": {
148 | "scrolled": true
149 | },
150 | "outputs": [
151 | {
152 | "name": "stdout",
153 | "output_type": "stream",
154 | "text": [
155 | "(23, 3)\n"
156 | ]
157 | },
158 | {
159 | "data": {
160 | "text/html": [
161 | "\n",
162 | "\n",
175 | "
\n",
176 | " \n",
177 | " \n",
178 | " | \n",
179 | " node_idx | \n",
180 | " paper_id | \n",
181 | " Split_ID | \n",
182 | "
\n",
183 | " \n",
184 | " Label | \n",
185 | " | \n",
186 | " | \n",
187 | " | \n",
188 | "
\n",
189 | " \n",
190 | " \n",
191 | " \n",
192 | " A | \n",
193 | " 2670 | \n",
194 | " 2670 | \n",
195 | " 2670 | \n",
196 | "
\n",
197 | " \n",
198 | " B | \n",
199 | " 65303 | \n",
200 | " 65303 | \n",
201 | " 65303 | \n",
202 | "
\n",
203 | " \n",
204 | " C | \n",
205 | " 111502 | \n",
206 | " 111502 | \n",
207 | " 111502 | \n",
208 | "
\n",
209 | " \n",
210 | " D | \n",
211 | " 104005 | \n",
212 | " 104005 | \n",
213 | " 104005 | \n",
214 | "
\n",
215 | " \n",
216 | " E | \n",
217 | " 45014 | \n",
218 | " 45014 | \n",
219 | " 45014 | \n",
220 | "
\n",
221 | " \n",
222 | " F | \n",
223 | " 32876 | \n",
224 | " 32876 | \n",
225 | " 32876 | \n",
226 | "
\n",
227 | " \n",
228 | " G | \n",
229 | " 43452 | \n",
230 | " 43452 | \n",
231 | " 43452 | \n",
232 | "
\n",
233 | " \n",
234 | " H | \n",
235 | " 71824 | \n",
236 | " 71824 | \n",
237 | " 71824 | \n",
238 | "
\n",
239 | " \n",
240 | " I | \n",
241 | " 23994 | \n",
242 | " 23994 | \n",
243 | " 23994 | \n",
244 | "
\n",
245 | " \n",
246 | " J | \n",
247 | " 25241 | \n",
248 | " 25241 | \n",
249 | " 25241 | \n",
250 | "
\n",
251 | " \n",
252 | " K | \n",
253 | " 32762 | \n",
254 | " 32762 | \n",
255 | " 32762 | \n",
256 | "
\n",
257 | " \n",
258 | " L | \n",
259 | " 53391 | \n",
260 | " 53391 | \n",
261 | " 53391 | \n",
262 | "
\n",
263 | " \n",
264 | " M | \n",
265 | " 83971 | \n",
266 | " 83971 | \n",
267 | " 83971 | \n",
268 | "
\n",
269 | " \n",
270 | " N | \n",
271 | " 103472 | \n",
272 | " 103472 | \n",
273 | " 103472 | \n",
274 | "
\n",
275 | " \n",
276 | " O | \n",
277 | " 17593 | \n",
278 | " 17593 | \n",
279 | " 17593 | \n",
280 | "
\n",
281 | " \n",
282 | " P | \n",
283 | " 52166 | \n",
284 | " 52166 | \n",
285 | " 52166 | \n",
286 | "
\n",
287 | " \n",
288 | " Q | \n",
289 | " 19676 | \n",
290 | " 19676 | \n",
291 | " 19676 | \n",
292 | "
\n",
293 | " \n",
294 | " R | \n",
295 | " 32610 | \n",
296 | " 32610 | \n",
297 | " 32610 | \n",
298 | "
\n",
299 | " \n",
300 | " S | \n",
301 | " 24609 | \n",
302 | " 24609 | \n",
303 | " 24609 | \n",
304 | "
\n",
305 | " \n",
306 | " T | \n",
307 | " 20878 | \n",
308 | " 20878 | \n",
309 | " 20878 | \n",
310 | "
\n",
311 | " \n",
312 | " U | \n",
313 | " 24740 | \n",
314 | " 24740 | \n",
315 | " 24740 | \n",
316 | "
\n",
317 | " \n",
318 | " V | \n",
319 | " 39557 | \n",
320 | " 39557 | \n",
321 | " 39557 | \n",
322 | "
\n",
323 | " \n",
324 | " W | \n",
325 | " 13111 | \n",
326 | " 13111 | \n",
327 | " 13111 | \n",
328 | "
\n",
329 | " \n",
330 | "
\n",
331 | "
"
332 | ],
333 | "text/plain": [
334 | " node_idx paper_id Split_ID\n",
335 | "Label \n",
336 | "A 2670 2670 2670\n",
337 | "B 65303 65303 65303\n",
338 | "C 111502 111502 111502\n",
339 | "D 104005 104005 104005\n",
340 | "E 45014 45014 45014\n",
341 | "F 32876 32876 32876\n",
342 | "G 43452 43452 43452\n",
343 | "H 71824 71824 71824\n",
344 | "I 23994 23994 23994\n",
345 | "J 25241 25241 25241\n",
346 | "K 32762 32762 32762\n",
347 | "L 53391 53391 53391\n",
348 | "M 83971 83971 83971\n",
349 | "N 103472 103472 103472\n",
350 | "O 17593 17593 17593\n",
351 | "P 52166 52166 52166\n",
352 | "Q 19676 19676 19676\n",
353 | "R 32610 32610 32610\n",
354 | "S 24609 24609 24609\n",
355 | "T 20878 20878 20878\n",
356 | "U 24740 24740 24740\n",
357 | "V 39557 39557 39557\n",
358 | "W 13111 13111 13111"
359 | ]
360 | },
361 | "execution_count": 6,
362 | "metadata": {},
363 | "output_type": "execute_result"
364 | }
365 | ],
366 | "source": [
367 | "# 先检查一下标签的分布\n",
368 | "label_dist = nodes_df.groupby(by='Label').count()\n",
369 | "print(label_dist.shape)\n",
370 | "label_dist"
371 | ]
372 | },
373 | {
374 | "cell_type": "markdown",
375 | "metadata": {},
376 | "source": [
377 | "#### 可以看到一共有23个标签,A类最少,C类最多,基本每类都有几万个。下面从0开始,重够标签\n"
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": 16,
383 | "metadata": {},
384 | "outputs": [
385 | {
386 | "data": {
387 | "text/html": [
388 | "\n",
389 | "\n",
402 | "
\n",
403 | " \n",
404 | " \n",
405 | " | \n",
406 | " node_idx | \n",
407 | " paper_id | \n",
408 | " Label | \n",
409 | " Split_ID | \n",
410 | " label | \n",
411 | "
\n",
412 | " \n",
413 | " \n",
414 | " \n",
415 | " 0 | \n",
416 | " 0 | \n",
417 | " bfdee5ab86ef5e68da974d48a138c28e | \n",
418 | " S | \n",
419 | " 0 | \n",
420 | " 18 | \n",
421 | "
\n",
422 | " \n",
423 | " 1 | \n",
424 | " 1 | \n",
425 | " 78f43b8b62f040347fec0be44e5f08bd | \n",
426 | " NaN | \n",
427 | " 0 | \n",
428 | " -1 | \n",
429 | "
\n",
430 | " \n",
431 | " 2 | \n",
432 | " 2 | \n",
433 | " a971601a0286d2701aa5cde46e63a9fd | \n",
434 | " G | \n",
435 | " 0 | \n",
436 | " 6 | \n",
437 | "
\n",
438 | " \n",
439 | " 3 | \n",
440 | " 3 | \n",
441 | " ac4b88a72146bae66cedfd1c13e1552d | \n",
442 | " NaN | \n",
443 | " 0 | \n",
444 | " -1 | \n",
445 | "
\n",
446 | " \n",
447 | "
\n",
448 | "
"
449 | ],
450 | "text/plain": [
451 | " node_idx paper_id Label Split_ID label\n",
452 | "0 0 bfdee5ab86ef5e68da974d48a138c28e S 0 18\n",
453 | "1 1 78f43b8b62f040347fec0be44e5f08bd NaN 0 -1\n",
454 | "2 2 a971601a0286d2701aa5cde46e63a9fd G 0 6\n",
455 | "3 3 ac4b88a72146bae66cedfd1c13e1552d NaN 0 -1"
456 | ]
457 | },
458 | "execution_count": 16,
459 | "metadata": {},
460 | "output_type": "execute_result"
461 | }
462 | ],
463 | "source": [
464 | "# 按A-W的顺序,从0开始转换\n",
465 | "for i, l in enumerate(label_dist.index.to_list()):\n",
466 | " nodes_df.loc[(nodes_df.Label==l), 'label'] = i\n",
467 | "\n",
468 | "nodes_df.label.fillna(-1, inplace=True)\n",
469 | "nodes_df.label = nodes_df.label.astype('int')\n",
470 | "nodes_df.head(4)"
471 | ]
472 | },
473 | {
474 | "cell_type": "markdown",
475 | "metadata": {},
476 | "source": [
477 | "#### 只保留新的node index、标签和原始的分割标签"
478 | ]
479 | },
480 | {
481 | "cell_type": "code",
482 | "execution_count": 19,
483 | "metadata": {},
484 | "outputs": [
485 | {
486 | "data": {
487 | "text/html": [
488 | "\n",
489 | "\n",
502 | "
\n",
503 | " \n",
504 | " \n",
505 | " | \n",
506 | " node_idx | \n",
507 | " label | \n",
508 | " Split_ID | \n",
509 | "
\n",
510 | " \n",
511 | " \n",
512 | " \n",
513 | " 3655448 | \n",
514 | " 3655448 | \n",
515 | " -1 | \n",
516 | " 1 | \n",
517 | "
\n",
518 | " \n",
519 | " 3655449 | \n",
520 | " 3655449 | \n",
521 | " -1 | \n",
522 | " 1 | \n",
523 | "
\n",
524 | " \n",
525 | " 3655450 | \n",
526 | " 3655450 | \n",
527 | " -1 | \n",
528 | " 1 | \n",
529 | "
\n",
530 | " \n",
531 | " 3655451 | \n",
532 | " 3655451 | \n",
533 | " -1 | \n",
534 | " 1 | \n",
535 | "
\n",
536 | " \n",
537 | "
\n",
538 | "
"
539 | ],
540 | "text/plain": [
541 | " node_idx label Split_ID\n",
542 | "3655448 3655448 -1 1\n",
543 | "3655449 3655449 -1 1\n",
544 | "3655450 3655450 -1 1\n",
545 | "3655451 3655451 -1 1"
546 | ]
547 | },
548 | "execution_count": 19,
549 | "metadata": {},
550 | "output_type": "execute_result"
551 | }
552 | ],
553 | "source": [
554 | "nodes = nodes_df[['node_idx', 'label', 'Split_ID']]\n",
555 | "nodes.tail(4)"
556 | ]
557 | },
558 | {
559 | "cell_type": "markdown",
560 | "metadata": {},
561 | "source": [
562 | "## 划分Train/Validation/Test\n",
563 | "\n",
564 | "由于只有原始的Train_nodes文件里面包括了标签,所以这里的Train/Validation是对原始的分割。\n",
565 | "\n",
566 | "这里按照9:1的比例划分Train/Validation。Test就是原来的validation_nodes里面的index。"
567 | ]
568 | },
569 | {
570 | "cell_type": "code",
571 | "execution_count": 30,
572 | "metadata": {},
573 | "outputs": [],
574 | "source": [
575 | "# 获取所有的标签\n",
576 | "tr_val_labels_df = nodes[(nodes.Split_ID == 0) & (nodes.label >= 0)]\n",
577 | "test_label_df = nodes[nodes.Split_ID == 1]\n",
578 | "\n",
579 | "# 按照0~22每个标签划分train/validation\n",
580 | "tr_labels_idx = np.array([0])\n",
581 | "val_labels_idx = np.array([0])\n",
582 | "split_ratio = 0.9\n",
583 | "\n",
584 | "for label in range(23):\n",
585 | " label_idx = tr_val_labels_df[tr_val_labels_df.label == label].node_idx.to_numpy()\n",
586 | " split_point = int(label_idx.shape[0] * split_ratio)\n",
587 | " \n",
588 | " # 把每个标签的train和validation的index添加到整个列表\n",
589 | " tr_labels_idx = np.append(tr_labels_idx, label_idx[: split_point])\n",
590 | " val_labels_idx = np.append(val_labels_idx, label_idx[split_point: ])"
591 | ]
592 | },
593 | {
594 | "cell_type": "code",
595 | "execution_count": 31,
596 | "metadata": {},
597 | "outputs": [],
598 | "source": [
599 | "# 获取Train/Validation/Test标签index\n",
600 | "tr_labels_idx = tr_labels_idx[1: ]\n",
601 | "val_labels_idx = val_labels_idx[1: ]\n",
602 | "\n",
603 | "test_labels_idx = test_label_df.node_idx.to_numpy()"
604 | ]
605 | },
606 | {
607 | "cell_type": "code",
608 | "execution_count": 32,
609 | "metadata": {},
610 | "outputs": [],
611 | "source": [
612 | "# 获取完整的标签列表\n",
613 | "labels = nodes.label.to_numpy()"
614 | ]
615 | },
616 | {
617 | "cell_type": "code",
618 | "execution_count": 33,
619 | "metadata": {},
620 | "outputs": [],
621 | "source": [
622 | "# 保存标签以及Train/Validation/Test的index为二进制格式方便后面建模时的快速读取\n",
623 | "label_path = os.path.join(base_path, publish_path, 'labels.pkl')\n",
624 | "\n",
625 | "with open(label_path, 'wb') as f:\n",
626 | " pickle.dump({'tr_label_idx': tr_labels_idx, \n",
627 | " 'val_label_idx': val_labels_idx, \n",
628 | " 'test_label_idx': test_labels_idx,\n",
629 | " 'label': labels}, f)"
630 | ]
631 | },
632 | {
633 | "cell_type": "code",
634 | "execution_count": null,
635 | "metadata": {},
636 | "outputs": [],
637 | "source": []
638 | }
639 | ],
640 | "metadata": {
641 | "kernelspec": {
642 | "display_name": "Python [conda env:dgl]",
643 | "language": "python",
644 | "name": "conda-env-dgl-py"
645 | },
646 | "language_info": {
647 | "codemirror_mode": {
648 | "name": "ipython",
649 | "version": 3
650 | },
651 | "file_extension": ".py",
652 | "mimetype": "text/x-python",
653 | "name": "python",
654 | "nbconvert_exporter": "python",
655 | "pygments_lexer": "ipython3",
656 | "version": "3.6.10"
657 | }
658 | },
659 | "nbformat": 4,
660 | "nbformat_minor": 2
661 | }
662 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MAXP竞赛——DGL图数据Baseline模型
2 |
3 | 本代码库是为2021 MAXP竞赛的DGL图数据所准备的Baseline模型,供参赛选手参考学习使用DGL来构建GNN模型。
4 |
5 | 代码库包括2个部分:
6 | ---------------
7 | 1. 用于数据预处理的4个Jupyter Notebook
8 | 2. 用DGL构建的3个GNN模型(GCN,GraphSage和GAT),以及训练模型所用的代码和辅助函数。
9 |
10 | 依赖包:
11 | ------
12 | - dgl==0.7.1
13 | - pytorch==1.7.0
14 | - pandas
15 | - numpy
16 | - datetime
17 |
18 | 如何运行:
19 | -------
20 | 对于4个Jupyter Notebook文件,请使用Jupyter环境运行,并注意把其中的竞赛数据文件所在的文件夹替换为你自己保存数据文件的文件夹。
21 | 并记录下你处理完成后的数据文件所在的位置,供下面模型训练使用。
22 |
23 | **注意:** 在运行*MAXP 2021初赛数据探索和处理-2*时,内存的使用量会比较高。这个在Mac上运行没有出现问题,但是尚未在Windows和Linux环境测试。
24 | 如果在这两种环境下遇到内存问题,建议找一个内存大一些的机器处理,或者修改代码,一部分一部分的处理。
25 |
26 | ---------
27 | 对于GNN的模型,需要先cd到gnn目录,然后运行:
28 |
29 | ```bash
30 | python model_train.py --data_path path/to/processed_data --gnn_model graphsage --hidden_dim 64 --n_layers 2 --fanout 20,20 --batch_size 4096 --GPU -1 --out_path ./
31 | ```
32 |
33 | **注意**:请把--data_path的路径替换成用Jupyter Notebook文件处理后数据所在的位置路径。其余的参数,请参考model_train.py里面的入参说明修改。
34 |
35 | 如果希望使用单GPU进行模型训练,则需要修改入参 `--GPU`的输入值为单个GPU的编号,如:
36 | ```bash
37 | --GPU 0
38 | ```
39 |
40 | 如果希望使用单机多GPU进行模型训练,则需要修改入参 `--GPU`的输入值为多个可用的GPU的编号,并用空格分割,如:
41 | ```bash
42 | --GPU 0 1 2 3
43 | ```
44 |
--------------------------------------------------------------------------------
/gnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dglai/maxp_baseline_model/00ccc6aa2cdd7fce83e6468380bc33d45a870431/gnn/__init__.py
--------------------------------------------------------------------------------
/gnn/model_train.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 |
3 | # Author:james Zhang
4 | """
5 | Minibatch training with node neighbor sampling in multiple GPUs
6 | """
7 |
8 | import os
9 | import argparse
10 | import datetime as dt
11 | import numpy as np
12 | import torch as th
13 | import torch.nn as thnn
14 | import torch.nn.functional as F
15 | import torch.optim as optim
16 | import torch.distributed as dist
17 |
18 | import dgl
19 | from dgl.dataloading.neighbor import MultiLayerNeighborSampler
20 | from dgl.dataloading.pytorch import NodeDataLoader
21 | import dgl.multiprocessing as mp
22 |
23 | from models import GraphSageModel, GraphConvModel, GraphAttnModel
24 | from utils import load_dgl_graph, time_diff
25 | from model_utils import early_stopper, thread_wrapped_func
26 |
27 |
28 | def load_subtensor(node_feats, labels, seeds, input_nodes, device):
29 | """
30 | Copys features and labels of a set of nodes onto GPU.
31 | """
32 | batch_inputs = node_feats[input_nodes].to(device)
33 | batch_labels = labels[seeds].to(device)
34 | return batch_inputs, batch_labels
35 |
36 |
37 | def cleanup():
38 | dist.destroy_process_group()
39 |
40 |
41 | def cpu_train(graph_data,
42 | gnn_model,
43 | hidden_dim,
44 | n_layers,
45 | n_classes,
46 | fanouts,
47 | batch_size,
48 | device,
49 | num_workers,
50 | epochs,
51 | out_path):
52 | """
53 | 运行在CPU设备上的训练代码。
54 | 由于比赛数据量比较大,因此这个部分的代码建议仅用于代码调试。
55 | 有GPU的,请使用下面的GPU设备训练的代码来提高训练速度。
56 | """
57 | graph, labels, train_nid, val_nid, test_nid, node_feat = graph_data
58 |
59 | sampler = MultiLayerNeighborSampler(fanouts)
60 | train_dataloader = NodeDataLoader(graph,
61 | train_nid,
62 | sampler,
63 | batch_size=batch_size,
64 | shuffle=True,
65 | drop_last=False,
66 | num_workers=num_workers)
67 |
68 | # 2 initialize GNN model
69 | in_feat = node_feat.shape[1]
70 |
71 | if gnn_model == 'graphsage':
72 | model = GraphSageModel(in_feat, hidden_dim, n_layers, n_classes)
73 | elif gnn_model == 'graphconv':
74 | model = GraphConvModel(in_feat, hidden_dim, n_layers, n_classes,
75 | norm='both', activation=F.relu, dropout=0)
76 | elif gnn_model == 'graphattn':
77 | model = GraphAttnModel(in_feat, hidden_dim, n_layers, n_classes,
78 | heads=([5] * n_layers), activation=F.relu, feat_drop=0, attn_drop=0)
79 | else:
80 | raise NotImplementedError('So far, only support three algorithms: GraphSage, GraphConv, and GraphAttn')
81 |
82 | model = model.to(device)
83 |
84 | # 3 define loss function and optimizer
85 | loss_fn = thnn.CrossEntropyLoss().to(device)
86 | optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
87 |
88 | # 4 train epoch
89 | avg = 0
90 | iter_tput = []
91 | start_t = dt.datetime.now()
92 |
93 | print('Start training at: {}-{} {}:{}:{}'.format(start_t.month,
94 | start_t.day,
95 | start_t.hour,
96 | start_t.minute,
97 | start_t.second))
98 |
99 | for epoch in range(epochs):
100 |
101 | for step, (input_nodes, seeds, mfgs) in enumerate(train_dataloader):
102 |
103 | start_t = dt.datetime.now()
104 |
105 | batch_inputs, batch_labels = load_subtensor(node_feat, labels, seeds, input_nodes, device)
106 | mfgs = [mfg.to(device) for mfg in mfgs]
107 |
108 | batch_logit = model(mfgs, batch_inputs)
109 | loss = loss_fn(batch_logit, batch_labels)
110 | pred = th.sum(th.argmax(batch_logit, dim=1) == batch_labels) / th.tensor(batch_labels.shape[0])
111 |
112 | optimizer.zero_grad()
113 | loss.backward()
114 | optimizer.step()
115 |
116 | e_t1 = dt.datetime.now()
117 | h, m, s = time_diff(e_t1, start_t)
118 |
119 | print('In epoch:{:03d}|batch:{}, loss:{:4f}, acc:{:4f}, time:{}h{}m{}s'.format(epoch,
120 | step,
121 | loss,
122 | pred.detach(),
123 | h, m, s))
124 |
125 | # 5 保存模型
126 | # 此处就省略了
127 |
128 |
129 | def gpu_train(proc_id, n_gpus, GPUS,
130 | graph_data, gnn_model,
131 | hidden_dim, n_layers, n_classes, fanouts,
132 | batch_size=32, num_workers=4, epochs=100, message_queue=None,
133 | output_folder='./output'):
134 |
135 | device_id = GPUS[proc_id]
136 | device = th.device('cuda:{}'.format(device_id))
137 |
138 | print('Use GPU {} for training ......'.format(device_id))
139 |
140 | if n_gpus > 1:
141 | dist_init_method = 'tcp://{}:{}'.format('127.0.0.1', '23456')
142 | world_size = n_gpus
143 | dist.init_process_group(backend='nccl',
144 | init_method=dist_init_method,
145 | world_size=world_size,
146 | rank=proc_id)
147 |
148 | th.cuda.set_device(device_id)
149 |
150 | # ------------------- 1. Prepare data and split for multiple GPUs ------------------- #
151 | start_t = dt.datetime.now()
152 | print('Start graph building at: {}-{} {}:{}:{}'.format(start_t.month,
153 | start_t.day,
154 | start_t.hour,
155 | start_t.minute,
156 | start_t.second))
157 |
158 | graph, labels, train_nid, val_nid, test_nid, node_feat = graph_data
159 |
160 | train_div, _ = divmod(train_nid.shape[0], n_gpus)
161 | val_div, _ = divmod(val_nid.shape[0], n_gpus)
162 |
163 | # just use one GPU, give all training/validation index to the one GPU
164 | if proc_id == (n_gpus - 1):
165 | train_nid_per_gpu = train_nid[proc_id * train_div: ]
166 | val_nid_per_gpu = val_nid[proc_id * val_div: ]
167 | # in case of multiple GPUs, split training/validation index to different GPUs
168 | else:
169 | train_nid_per_gpu = train_nid[proc_id * train_div: (proc_id + 1) * train_div]
170 | val_nid_per_gpu = val_nid[proc_id * val_div: (proc_id + 1) * val_div]
171 |
172 | train_sampler = MultiLayerNeighborSampler(fanouts)
173 | train_dataloader = NodeDataLoader(graph,
174 | train_nid_per_gpu,
175 | train_sampler,
176 | device=device,
177 | use_ddp=n_gpus > 1,
178 | batch_size=batch_size,
179 | shuffle=True,
180 | drop_last=False,
181 | num_workers=num_workers,
182 | )
183 | val_sampler = MultiLayerNeighborSampler(fanouts)
184 | val_dataloader = NodeDataLoader(graph,
185 | val_nid_per_gpu,
186 | val_sampler,
187 | use_ddp=n_gpus > 1,
188 | device=device,
189 | batch_size=batch_size,
190 | shuffle=True,
191 | drop_last=False,
192 | num_workers=num_workers,
193 | )
194 | e_t1 = dt.datetime.now()
195 | h, m, s = time_diff(e_t1, start_t)
196 | print('Model built used: {:02d}h {:02d}m {:02}s'.format(h, m, s))
197 |
198 | # ------------------- 2. Build model for multiple GPUs ------------------------------ #
199 | start_t = dt.datetime.now()
200 | print('Start Model building at: {}-{} {}:{}:{}'.format(start_t.month,
201 | start_t.day,
202 | start_t.hour,
203 | start_t.minute,
204 | start_t.second))
205 |
206 | in_feat = node_feat.shape[1]
207 | if gnn_model == 'graphsage':
208 | model = GraphSageModel(in_feat, hidden_dim, n_layers, n_classes)
209 | elif gnn_model == 'graphconv':
210 | model = GraphConvModel(in_feat, hidden_dim, n_layers, n_classes,
211 | norm='both', activation=F.relu, dropout=0)
212 | elif gnn_model == 'graphattn':
213 | model = GraphAttnModel(in_feat, hidden_dim, n_layers, n_classes,
214 | heads=([5] * n_layers), activation=F.relu, feat_drop=0, attn_drop=0)
215 | else:
216 | raise NotImplementedError('So far, only support three algorithms: GraphSage, GraphConv, and GraphAttn')
217 |
218 | model = model.to(device_id)
219 |
220 | if n_gpus > 1:
221 | model = thnn.parallel.DistributedDataParallel(model,
222 | device_ids=[device_id],
223 | output_device=device_id)
224 | e_t1 = dt.datetime.now()
225 | h, m, s = time_diff(e_t1, start_t)
226 | print('Model built used: {:02d}h {:02d}m {:02}s'.format(h, m, s))
227 |
228 | # ------------------- 3. Build loss function and optimizer -------------------------- #
229 | loss_fn = thnn.CrossEntropyLoss().to(device_id)
230 | optimizer = optim.Adam(model.parameters(), lr=0.004, weight_decay=5e-4)
231 |
232 | earlystoper = early_stopper(patience=2, verbose=False)
233 |
234 | # ------------------- 4. Train model ----------------------------------------------- #
235 | print('Plan to train {} epoches \n'.format(epochs))
236 |
237 | for epoch in range(epochs):
238 | if n_gpus > 1:
239 | train_dataloader.set_epoch(epoch)
240 | val_dataloader.set_epoch(epoch)
241 |
242 | # mini-batch for training
243 | train_loss_list = []
244 | # train_acc_list = []
245 | model.train()
246 | for step, (input_nodes, seeds, blocks) in enumerate(train_dataloader):
247 | # forward
248 | batch_inputs, batch_labels = load_subtensor(node_feat, labels, seeds, input_nodes, device_id)
249 | blocks = [block.to(device_id) for block in blocks]
250 | # metric and loss
251 | train_batch_logits = model(blocks, batch_inputs)
252 | train_loss = loss_fn(train_batch_logits, batch_labels)
253 | # backward
254 | optimizer.zero_grad()
255 | train_loss.backward()
256 | optimizer.step()
257 |
258 | train_loss_list.append(train_loss.cpu().detach().numpy())
259 | tr_batch_pred = th.sum(th.argmax(train_batch_logits, dim=1) == batch_labels) / th.tensor(batch_labels.shape[0])
260 |
261 | if step % 10 == 0:
262 | print('In epoch:{:03d}|batch:{:04d}, train_loss:{:4f}, train_acc:{:.4f}'.format(epoch,
263 | step,
264 | np.mean(train_loss_list),
265 | tr_batch_pred.detach()))
266 |
267 | # mini-batch for validation
268 | val_loss_list = []
269 | val_acc_list = []
270 | model.eval()
271 | for step, (input_nodes, seeds, blocks) in enumerate(val_dataloader):
272 | # forward
273 | batch_inputs, batch_labels = load_subtensor(node_feat, labels, seeds, input_nodes, device_id)
274 | blocks = [block.to(device_id) for block in blocks]
275 | # metric and loss
276 | val_batch_logits = model(blocks, batch_inputs)
277 | val_loss = loss_fn(val_batch_logits, batch_labels)
278 |
279 | val_loss_list.append(val_loss.detach().cpu().numpy())
280 | val_batch_pred = th.sum(th.argmax(val_batch_logits, dim=1) == batch_labels) / th.tensor(batch_labels.shape[0])
281 |
282 | if step % 10 == 0:
283 | print('In epoch:{:03d}|batch:{:04d}, val_loss:{:4f}, val_acc:{:.4f}'.format(epoch,
284 | step,
285 | np.mean(val_loss_list),
286 | val_batch_pred.detach()))
287 |
288 | # put validation results into message queue and aggregate at device 0
289 | if n_gpus > 1 and message_queue != None:
290 | message_queue.put(val_loss_list)
291 |
292 | if proc_id == 0:
293 | for i in range(n_gpus):
294 | loss = message_queue.get()
295 | print(loss)
296 | del loss
297 | else:
298 | print(val_loss_list)
299 |
300 | # -------------------------5. Collect stats ------------------------------------#
301 | # best_preds = earlystoper.val_preds
302 | # best_logits = earlystoper.val_logits
303 | #
304 | # best_precision, best_recall, best_f1 = get_f1_score(val_y.cpu().numpy(), best_preds)
305 | # best_auc = get_auc_score(val_y.cpu().numpy(), best_logits[:, 1])
306 | # best_recall_at_99precision = recall_at_perc_precision(val_y.cpu().numpy(), best_logits[:, 1], threshold=0.99)
307 | # best_recall_at_90precision = recall_at_perc_precision(val_y.cpu().numpy(), best_logits[:, 1], threshold=0.9)
308 |
309 | # plot_roc(val_y.cpu().numpy(), best_logits[:, 1])
310 | # plot_p_r_curve(val_y.cpu().numpy(), best_logits[:, 1])
311 |
312 | # -------------------------6. Save models --------------------------------------#
313 | model_path = os.path.join(output_folder, 'dgl_model-' + '{:06d}'.format(np.random.randint(100000)) + '.pth')
314 |
315 | if n_gpus > 1:
316 | if proc_id == 0:
317 | model_para_dict = model.state_dict()
318 | th.save(model_para_dict, model_path)
319 | # after trainning, remember to cleanup and release resouces
320 | cleanup()
321 | else:
322 | model_para_dict = model.state_dict()
323 | th.save(model_para_dict, model_path)
324 |
325 |
326 | if __name__ == '__main__':
327 | parser = argparse.ArgumentParser(description='DGL_SamplingTrain')
328 | parser.add_argument('--data_path', type=str, help="Path of saved processed data files.")
329 | parser.add_argument('--gnn_model', type=str, choices=['graphsage', 'graphconv', 'graphattn'],
330 | required=True, default='graphsage')
331 | parser.add_argument('--hidden_dim', type=int, required=True)
332 | parser.add_argument('--n_layers', type=int, default=2)
333 | parser.add_argument("--fanout", type=str, required=True, help="fanout numbers", default='20,20')
334 | parser.add_argument('--batch_size', type=int, required=True, default=1)
335 | parser.add_argument('--GPU', nargs='+', type=int, required=True)
336 | parser.add_argument('--num_workers_per_gpu', type=int, default=4)
337 | parser.add_argument('--epochs', type=int, default=100)
338 | parser.add_argument('--out_path', type=str, required=True, help="Absolute path for saving model parameters")
339 | args = parser.parse_args()
340 |
341 | # parse arguments
342 | BASE_PATH = args.data_path
343 | MODEL_CHOICE = args.gnn_model
344 | HID_DIM = args.hidden_dim
345 | N_LAYERS = args.n_layers
346 | FANOUTS = [int(i) for i in args.fanout.split(',')]
347 | BATCH_SIZE = args.batch_size
348 | GPUS = args.GPU
349 | WORKERS = args.num_workers_per_gpu
350 | EPOCHS = args.epochs
351 | OUT_PATH = args.out_path
352 |
353 | # output arguments for logging
354 | print('Data path: {}'.format(BASE_PATH))
355 | print('Used algorithm: {}'.format(MODEL_CHOICE))
356 | print('Hidden dimensions: {}'.format(HID_DIM))
357 | print('number of hidden layers: {}'.format(N_LAYERS))
358 | print('Fanout list: {}'.format(FANOUTS))
359 | print('Batch size: {}'.format(BATCH_SIZE))
360 | print('GPU list: {}'.format(GPUS))
361 | print('Number of workers per GPU: {}'.format(WORKERS))
362 | print('Max number of epochs: {}'.format(EPOCHS))
363 | print('Output path: {}'.format(OUT_PATH))
364 |
365 | # Retrieve preprocessed data and add reverse edge and self-loop
366 | graph, labels, train_nid, val_nid, test_nid, node_feat = load_dgl_graph(BASE_PATH)
367 | graph = dgl.to_bidirected(graph, copy_ndata=True)
368 | graph = dgl.add_self_loop(graph)
369 |
370 | graph.create_formats_()
371 |
372 | # call train with CPU, one GPU, or multiple GPUs
373 | if GPUS[0] < 0:
374 | cpu_device = th.device('cpu')
375 | cpu_train(graph_data=(graph, labels, train_nid, val_nid, test_nid, node_feat),
376 | gnn_model=MODEL_CHOICE,
377 | n_layers=N_LAYERS,
378 | hidden_dim=HID_DIM,
379 | n_classes=23,
380 | fanouts=FANOUTS,
381 | batch_size=BATCH_SIZE,
382 | num_workers=WORKERS,
383 | device=cpu_device,
384 | epochs=EPOCHS,
385 | out_path=OUT_PATH)
386 | else:
387 | n_gpus = len(GPUS)
388 |
389 | if n_gpus == 1:
390 | gpu_train(0, n_gpus, GPUS,
391 | graph_data=(graph, labels, train_nid, val_nid, test_nid, node_feat),
392 | gnn_model=MODEL_CHOICE, hidden_dim=HID_DIM, n_layers=N_LAYERS, n_classes=23,
393 | fanouts=FANOUTS, batch_size=BATCH_SIZE, num_workers=WORKERS, epochs=EPOCHS,
394 | message_queue=None, output_folder=OUT_PATH)
395 | else:
396 | message_queue = mp.Queue()
397 | procs = []
398 | for proc_id in range(n_gpus):
399 | p = mp.Process(target=gpu_train,
400 | args=(proc_id, n_gpus, GPUS,
401 | (graph, labels, train_nid, val_nid, test_nid, node_feat),
402 | MODEL_CHOICE, HID_DIM, N_LAYERS, 23,
403 | FANOUTS, BATCH_SIZE, WORKERS, EPOCHS,
404 | message_queue, OUT_PATH))
405 | p.start()
406 | procs.append(p)
407 | for p in procs:
408 | p.join()
--------------------------------------------------------------------------------
/gnn/model_utils.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 |
3 | # Author:james Zhang
4 | """
5 | utilities file for Pytorch models
6 | """
7 |
8 | from functools import wraps
9 | import traceback
10 | from _thread import start_new_thread
11 | import torch.multiprocessing as mp
12 |
13 |
14 | class early_stopper(object):
15 |
16 | def __init__(self, patience=10, verbose=False, delta=0):
17 |
18 | self.patience = patience
19 | self.verbose = verbose
20 | self.delta = delta
21 |
22 | self.best_value = None
23 | self.is_earlystop = False
24 | self.count = 0
25 | self.val_preds = []
26 | self.val_logits = []
27 |
28 | def earlystop(self, loss, preds, logits):
29 |
30 | value = -loss
31 |
32 | if self.best_value is None:
33 | self.best_value = value
34 | self.val_preds = preds
35 | self.val_logits = logits
36 | elif value < self.best_value + self.delta:
37 | self.count += 1
38 | if self.verbose:
39 | print('EarlyStoper count: {:02d}'.format(self.count))
40 | if self.count >= self.patience:
41 | self.is_earlystop = True
42 | else:
43 | self.best_value = value
44 | self.val_preds = preds
45 | self.val_logits = logits
46 | self.count = 0
47 |
48 |
49 | # According to https://github.com/pytorch/pytorch/issues/17199, this decorator
50 | # is necessary to make fork() and openmp work together.
51 | def thread_wrapped_func(func):
52 | """
53 | 用于Pytorch的OpenMP的包装方法。Wraps a process entry point to make it work with OpenMP.
54 | """
55 | @wraps(func)
56 | def decorated_function(*args, **kwargs):
57 | queue = mp.Queue()
58 | def _queue_result():
59 | exception, trace, res = None, None, None
60 | try:
61 | res = func(*args, **kwargs)
62 | except Exception as e:
63 | exception = e
64 | trace = traceback.format_exc()
65 | queue.put((res, exception, trace))
66 |
67 | start_new_thread(_queue_result, ())
68 | result, exception, trace = queue.get()
69 | if exception is None:
70 | return result
71 | else:
72 | assert isinstance(exception, Exception)
73 | raise exception.__class__(trace)
74 | return decorated_function
--------------------------------------------------------------------------------
/gnn/models.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 |
3 | # Author:james Zhang
4 |
5 | """
6 | Three common GNN models.
7 | """
8 |
9 | import torch.nn as thnn
10 | import torch.nn.functional as F
11 | import dgl.nn as dglnn
12 |
13 |
14 | class GraphSageModel(thnn.Module):
15 |
16 | def __init__(self,
17 | in_feats,
18 | hidden_dim,
19 | n_layers,
20 | n_classes,
21 | activation=F.relu,
22 | dropout=0):
23 | super(GraphSageModel, self).__init__()
24 | self.in_feats = in_feats
25 | self.n_layers = n_layers
26 | self.hidden_dim = hidden_dim
27 | self.n_classes = n_classes
28 | self.activation = activation
29 | self.dropout = thnn.Dropout(dropout)
30 |
31 | self.layers = thnn.ModuleList()
32 |
33 | # build multiple layers
34 | self.layers.append(dglnn.SAGEConv(in_feats=self.in_feats,
35 | out_feats=self.hidden_dim,
36 | aggregator_type='mean'))
37 | # aggregator_type = 'pool'))
38 | for l in range(1, (self.n_layers - 1)):
39 | self.layers.append(dglnn.SAGEConv(in_feats=self.hidden_dim,
40 | out_feats=self.hidden_dim,
41 | aggregator_type='mean'))
42 | # aggregator_type='pool'))
43 | self.layers.append(dglnn.SAGEConv(in_feats=self.hidden_dim,
44 | out_feats=self.n_classes,
45 | aggregator_type='mean'))
46 | # aggregator_type = 'pool'))
47 |
48 | def forward(self, blocks, features):
49 | h = features
50 |
51 | for l, (layer, block) in enumerate(zip(self.layers, blocks)):
52 | h = layer(block, h)
53 | if l != len(self.layers) - 1:
54 | h = self.activation(h)
55 | h = self.dropout(h)
56 |
57 | return h
58 |
59 |
60 | class GraphConvModel(thnn.Module):
61 |
62 | def __init__(self,
63 | in_feats,
64 | hidden_dim,
65 | n_layers,
66 | n_classes,
67 | norm,
68 | activation,
69 | dropout):
70 | super(GraphConvModel, self).__init__()
71 | self.in_feats = in_feats
72 | self.n_layers = n_layers
73 | self.hidden_dim = hidden_dim
74 | self.n_classes = n_classes
75 | self.norm = norm
76 | self.activation = activation
77 | self.dropout = thnn.Dropout(dropout)
78 |
79 | self.layers = thnn.ModuleList()
80 |
81 | # build multiple layers
82 | self.layers.append(dglnn.GraphConv(in_feats=self.in_feats,
83 | out_feats=self.hidden_dim,
84 | norm=self.norm,
85 | activation=self.activation,))
86 | for l in range(1, (self.n_layers - 1)):
87 | self.layers.append(dglnn.GraphConv(in_feats=self.hidden_dim,
88 | out_feats=self.hidden_dim,
89 | norm=self.norm,
90 | activation=self.activation))
91 | self.layers.append(dglnn.GraphConv(in_feats=self.hidden_dim,
92 | out_feats=self.n_classes,
93 | norm=self.norm,
94 | activation=self.activation))
95 |
96 | def forward(self, blocks, features):
97 | h = features
98 |
99 | for l, (layer, block) in enumerate(zip(self.layers, blocks)):
100 | h = layer(block, h)
101 | if l != len(self.layers) - 1:
102 | h = self.dropout(h)
103 |
104 | return h
105 |
106 |
107 | class GraphAttnModel(thnn.Module):
108 |
109 | def __init__(self,
110 | in_feats,
111 | hidden_dim,
112 | n_layers,
113 | n_classes,
114 | heads,
115 | activation,
116 | feat_drop,
117 | attn_drop
118 | ):
119 | super(GraphAttnModel, self).__init__()
120 | self.in_feats = in_feats
121 | self.hidden_dim = hidden_dim
122 | self.n_layers = n_layers
123 | self.n_classes = n_classes
124 | self.heads = heads
125 | self.feat_dropout = feat_drop
126 | self.attn_dropout = attn_drop
127 | self.activation = activation
128 |
129 | self.layers = thnn.ModuleList()
130 |
131 | # build multiple layers
132 | self.layers.append(dglnn.GATConv(in_feats=self.in_feats,
133 | out_feats=self.hidden_dim,
134 | num_heads=self.heads[0],
135 | feat_drop=self.feat_dropout,
136 | attn_drop=self.attn_dropout,
137 | activation=self.activation))
138 |
139 | for l in range(1, (self.n_layers - 1)):
140 | # due to multi-head, the in_dim = num_hidden * num_heads
141 | self.layers.append(dglnn.GATConv(in_feats=self.hidden_dim * self.heads[l - 1],
142 | out_feats=self.hidden_dim,
143 | num_heads=self.heads[l],
144 | feat_drop=self.feat_dropout,
145 | attn_drop=self.attn_dropout,
146 | activation=self.activation))
147 |
148 | self.layers.append(dglnn.GATConv(in_feats=self.hidden_dim * self.heads[-2],
149 | out_feats=self.n_classes,
150 | num_heads=self.heads[-1],
151 | feat_drop=self.feat_dropout,
152 | attn_drop=self.attn_dropout,
153 | activation=None))
154 |
155 | def forward(self, blocks, features):
156 | h = features
157 |
158 | for l in range(self.n_layers - 1):
159 | h = self.layers[l](blocks[l], h).flatten(1)
160 |
161 | logits = self.layers[-1](blocks[-1],h).mean(1)
162 |
163 | return logits
164 |
165 |
--------------------------------------------------------------------------------
/gnn/utils.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 |
3 | """
4 | Utilities to handel graph data
5 | """
6 |
7 | import os
8 | import dgl
9 | import pickle
10 | import numpy as np
11 | import torch as th
12 |
13 |
14 | def load_dgl_graph(base_path):
15 | """
16 | 读取预处理的Graph,Feature和Label文件,并构建相应的数据供训练代码使用。
17 |
18 | :param base_path:
19 | :return:
20 | """
21 | graphs, _ = dgl.load_graphs(os.path.join(base_path, 'graph.bin'))
22 | graph = graphs[0]
23 | print('################ Graph info: ###############')
24 | print(graph)
25 |
26 | with open(os.path.join(base_path, 'labels.pkl'), 'rb') as f:
27 | label_data = pickle.load(f)
28 |
29 | labels = th.from_numpy(label_data['label'])
30 | tr_label_idx = th.from_numpy(label_data['tr_label_idx']).long()
31 | val_label_idx = th.from_numpy(label_data['val_label_idx']).long()
32 | test_label_idx = th.from_numpy(label_data['test_label_idx']).long()
33 | print('################ Label info: ################')
34 | print('Total labels (including not labeled): {}'.format(labels.shape[0]))
35 | print(' Training label number: {}'.format(tr_label_idx.shape[0]))
36 | print(' Validation label number: {}'.format(val_label_idx.shape[0]))
37 | print(' Test label number: {}'.format(test_label_idx.shape[0]))
38 |
39 | # get node features
40 | features = np.load(os.path.join(base_path, 'features.npy'))
41 | node_feat = th.from_numpy(features).float()
42 | print('################ Feature info: ###############')
43 | print('Node\'s feature shape:{}'.format(node_feat.shape))
44 |
45 | return graph, labels, tr_label_idx, val_label_idx, test_label_idx, node_feat
46 |
47 |
48 | def time_diff(t_end, t_start):
49 | """
50 | 计算时间差。t_end, t_start are datetime format, so use deltatime
51 | Parameters
52 | ----------
53 | t_end
54 | t_start
55 |
56 | Returns
57 | -------
58 | """
59 | diff_sec = (t_end - t_start).seconds
60 | diff_min, rest_sec = divmod(diff_sec, 60)
61 | diff_hrs, rest_min = divmod(diff_min, 60)
62 | return (diff_hrs, rest_min, rest_sec)
63 |
--------------------------------------------------------------------------------