├── GPT_GNN
├── __init__.py
├── __pycache__
│ ├── conv.cpython-37.pyc
│ ├── data.cpython-37.pyc
│ ├── model.cpython-37.pyc
│ ├── utils.cpython-37.pyc
│ └── __init__.cpython-37.pyc
├── utils.py
├── conv.py
├── model.py
└── data.py
├── example_OAG
├── GPT_GNN
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── conv.cpython-37.pyc
│ │ ├── data.cpython-37.pyc
│ │ ├── model.cpython-37.pyc
│ │ ├── utils.cpython-37.pyc
│ │ └── __init__.cpython-37.pyc
│ ├── utils.py
│ ├── conv.py
│ ├── model.py
│ └── data.py
├── preprocess_OAG.py
├── finetune_OAG_PV.py
├── finetune_OAG_PF.py
└── pretrain_OAG.py
├── example_reddit
├── GPT_GNN
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── conv.cpython-37.pyc
│ │ ├── data.cpython-37.pyc
│ │ ├── model.cpython-37.pyc
│ │ ├── utils.cpython-37.pyc
│ │ └── __init__.cpython-37.pyc
│ ├── utils.py
│ ├── conv.py
│ ├── model.py
│ └── data.py
├── preprocess_reddit.py
├── finetune_reddit.py
├── .ipynb_checkpoints
│ └── pretrain_reddit-checkpoint.py
└── pretrain_reddit.py
├── images
├── gpt-intro.png
└── pretrain_OAG.gif
├── requirements.txt
├── LICENSE
└── README.md
/GPT_GNN/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/example_OAG/GPT_GNN/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/example_reddit/GPT_GNN/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/gpt-intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/images/gpt-intro.png
--------------------------------------------------------------------------------
/images/pretrain_OAG.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/images/pretrain_OAG.gif
--------------------------------------------------------------------------------
/GPT_GNN/__pycache__/conv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/GPT_GNN/__pycache__/conv.cpython-37.pyc
--------------------------------------------------------------------------------
/GPT_GNN/__pycache__/data.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/GPT_GNN/__pycache__/data.cpython-37.pyc
--------------------------------------------------------------------------------
/GPT_GNN/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/GPT_GNN/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/GPT_GNN/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/GPT_GNN/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/GPT_GNN/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/GPT_GNN/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/example_OAG/GPT_GNN/__pycache__/conv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/example_OAG/GPT_GNN/__pycache__/conv.cpython-37.pyc
--------------------------------------------------------------------------------
/example_OAG/GPT_GNN/__pycache__/data.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/example_OAG/GPT_GNN/__pycache__/data.cpython-37.pyc
--------------------------------------------------------------------------------
/example_OAG/GPT_GNN/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/example_OAG/GPT_GNN/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/example_OAG/GPT_GNN/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/example_OAG/GPT_GNN/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/example_reddit/GPT_GNN/__pycache__/conv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/example_reddit/GPT_GNN/__pycache__/conv.cpython-37.pyc
--------------------------------------------------------------------------------
/example_reddit/GPT_GNN/__pycache__/data.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/example_reddit/GPT_GNN/__pycache__/data.cpython-37.pyc
--------------------------------------------------------------------------------
/example_OAG/GPT_GNN/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/example_OAG/GPT_GNN/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/example_reddit/GPT_GNN/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/example_reddit/GPT_GNN/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/example_reddit/GPT_GNN/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/example_reddit/GPT_GNN/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/example_reddit/GPT_GNN/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acbull/GPT-GNN/HEAD/example_reddit/GPT_GNN/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dill==0.3.0
2 | numpy==1.22.0
3 | pandas==0.24.2
4 | torch==1.3.0
5 | torch-scatter==1.3.2
6 | torch-cluster==1.4.5
7 | torch-sparse==0.4.3
8 | torch-spline-conv==1.1.1
9 | torch-geometric==1.3.2
10 | torchvision==0.4.1
11 | tqdm==4.31.1
12 | seaborn==0.9.0
13 | matplotlib==3.0.3
14 | transformers==4.30.0
15 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2020 acbull
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
--------------------------------------------------------------------------------
/example_reddit/preprocess_reddit.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import Reddit
2 | from GPT_GNN.data import *
3 |
4 | dataset = Reddit(root='/datadrive/dataset')
5 | graph_reddit = Graph()
6 | el = defaultdict( #target_id
7 | lambda: defaultdict( #source_id(
8 | lambda: int # time
9 | ))
10 | for i, j in tqdm(dataset.data.edge_index.t()):
11 | el[i.item()][j.item()] = 1
12 |
13 | target_type = 'def'
14 | graph_reddit.edge_list['def']['def']['def'] = el
15 | n = list(el.keys())
16 | degree = np.zeros(np.max(n)+1)
17 | for i in n:
18 | degree[i] = len(el[i])
19 | x = np.concatenate((dataset.data.x.numpy(), np.log(degree).reshape(-1, 1)), axis=-1)
20 | graph_reddit.node_feature['def'] = pd.DataFrame({'emb': list(x)})
21 |
22 | idx = np.arange(len(graph_reddit.node_feature[target_type]))
23 | np.random.seed(43)
24 | np.random.shuffle(idx)
25 |
26 | graph_reddit.pre_target_nodes = idx[ : int(len(idx) * 0.7)]
27 | graph_reddit.train_target_nodes = idx[int(len(idx) * 0.7) : int(len(idx) * 0.8)]
28 | graph_reddit.valid_target_nodes = idx[int(len(idx) * 0.8) : int(len(idx) * 0.9)]
29 | graph_reddit.test_target_nodes = idx[int(len(idx) * 0.9) : ]
30 |
31 | graph_reddit.y = dataset.data.y
32 | dill.dump(graph_reddit, open('/datadrive/dataset/graph_reddit.pk', 'wb'))
33 |
--------------------------------------------------------------------------------
/GPT_GNN/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sp
3 | import torch
4 |
5 |
6 | def dcg_at_k(r, k):
7 | r = np.asfarray(r)[:k]
8 | if r.size:
9 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
10 | return 0.
11 |
12 | def ndcg_at_k(r, k):
13 | dcg_max = dcg_at_k(sorted(r, reverse=True), k)
14 | if not dcg_max:
15 | return 0.
16 | return dcg_at_k(r, k) / dcg_max
17 |
18 |
19 | def mean_reciprocal_rank(rs):
20 | rs = (np.asarray(r).nonzero()[0] for r in rs)
21 | return [1. / (r[0] + 1) if r.size else 0. for r in rs]
22 |
23 |
24 | def normalize(mx):
25 | """Row-normalize sparse matrix"""
26 | rowsum = np.array(mx.sum(1))
27 | r_inv = np.power(rowsum, -1).flatten()
28 | r_inv[np.isinf(r_inv)] = 0.
29 | r_mat_inv = sp.diags(r_inv)
30 | mx = r_mat_inv.dot(mx)
31 | return mx
32 |
33 |
34 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
35 | """Convert a scipy sparse matrix to a torch sparse tensor."""
36 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
37 | indices = torch.from_numpy(
38 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
39 | values = torch.from_numpy(sparse_mx.data)
40 | shape = torch.Size(sparse_mx.shape)
41 | return torch.sparse.FloatTensor(indices, values, shape)
42 |
43 | def randint():
44 | return np.random.randint(2**32 - 1)
45 |
46 | def feature_OAG(layer_data, graph):
47 | feature = {}
48 | times = {}
49 | indxs = {}
50 | texts = []
51 | for _type in layer_data:
52 | if len(layer_data[_type]) == 0:
53 | continue
54 | idxs = np.array(list(layer_data[_type].keys()))
55 | tims = np.array(list(layer_data[_type].values()))[:,1]
56 |
57 | if 'node_emb' in graph.node_feature[_type]:
58 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'node_emb']), dtype=np.float)
59 | else:
60 | feature[_type] = np.zeros([len(idxs), 400])
61 | feature[_type] = np.concatenate((feature[_type], list(graph.node_feature[_type].loc[idxs, 'emb']),\
62 | np.log10(np.array(list(graph.node_feature[_type].loc[idxs, 'citation'])).reshape(-1, 1) + 0.01)), axis=1)
63 |
64 | times[_type] = tims
65 | indxs[_type] = idxs
66 |
67 | if _type == 'paper':
68 | attr = np.array(list(graph.node_feature[_type].loc[idxs, 'title']), dtype=np.str)
69 | return feature, times, indxs, attr
70 |
71 | def feature_reddit(layer_data, graph):
72 | feature = {}
73 | times = {}
74 | indxs = {}
75 | texts = []
76 | for _type in layer_data:
77 | if len(layer_data[_type]) == 0:
78 | continue
79 | idxs = np.array(list(layer_data[_type].keys()))
80 | tims = np.array(list(layer_data[_type].values()))[:,1]
81 |
82 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'emb']), dtype=np.float)
83 | times[_type] = tims
84 | indxs[_type] = idxs
85 |
86 | if _type == 'def':
87 | attr = feature[_type]
88 | return feature, times, indxs, attr
--------------------------------------------------------------------------------
/example_OAG/GPT_GNN/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sp
3 | import torch
4 | from texttable import Texttable
5 | from collections import OrderedDict
6 |
7 | def args_print(args):
8 | _dict = vars(args)
9 | t = Texttable()
10 | t.add_row(["Parameter", "Value"])
11 | for k in _dict:
12 | t.add_row([k, _dict[k]])
13 | print(t.draw())
14 |
15 | def dcg_at_k(r, k):
16 | r = np.asfarray(r)[:k]
17 | if r.size:
18 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
19 | return 0.
20 |
21 | def ndcg_at_k(r, k):
22 | dcg_max = dcg_at_k(sorted(r, reverse=True), k)
23 | if not dcg_max:
24 | return 0.
25 | return dcg_at_k(r, k) / dcg_max
26 |
27 |
28 | def mean_reciprocal_rank(rs):
29 | rs = (np.asarray(r).nonzero()[0] for r in rs)
30 | return [1. / (r[0] + 1) if r.size else 0. for r in rs]
31 |
32 |
33 | def normalize(mx):
34 | """Row-normalize sparse matrix"""
35 | rowsum = np.array(mx.sum(1))
36 | r_inv = np.power(rowsum, -1).flatten()
37 | r_inv[np.isinf(r_inv)] = 0.
38 | r_mat_inv = sp.diags(r_inv)
39 | mx = r_mat_inv.dot(mx)
40 | return mx
41 |
42 |
43 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
44 | """Convert a scipy sparse matrix to a torch sparse tensor."""
45 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
46 | indices = torch.from_numpy(
47 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
48 | values = torch.from_numpy(sparse_mx.data)
49 | shape = torch.Size(sparse_mx.shape)
50 | return torch.sparse.FloatTensor(indices, values, shape)
51 |
52 | def randint():
53 | return np.random.randint(2**32 - 1)
54 |
55 | def feature_OAG(layer_data, graph):
56 | feature = {}
57 | times = {}
58 | indxs = {}
59 | texts = []
60 | for _type in layer_data:
61 | if len(layer_data[_type]) == 0:
62 | continue
63 | idxs = np.array(list(layer_data[_type].keys()))
64 | tims = np.array(list(layer_data[_type].values()))[:,1]
65 |
66 | if 'node_emb' in graph.node_feature[_type]:
67 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'node_emb']), dtype=np.float)
68 | else:
69 | feature[_type] = np.zeros([len(idxs), 400])
70 | feature[_type] = np.concatenate((feature[_type], list(graph.node_feature[_type].loc[idxs, 'emb']),\
71 | np.log10(np.array(list(graph.node_feature[_type].loc[idxs, 'citation'])).reshape(-1, 1) + 0.01)), axis=1)
72 |
73 | times[_type] = tims
74 | indxs[_type] = idxs
75 |
76 | if _type == 'paper':
77 | attr = np.array(list(graph.node_feature[_type].loc[idxs, 'title']), dtype=np.str)
78 | return feature, times, indxs, attr
79 |
80 | def feature_reddit(layer_data, graph):
81 | feature = {}
82 | times = {}
83 | indxs = {}
84 | texts = []
85 | for _type in layer_data:
86 | if len(layer_data[_type]) == 0:
87 | continue
88 | idxs = np.array(list(layer_data[_type].keys()))
89 | tims = np.array(list(layer_data[_type].values()))[:,1]
90 |
91 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'emb']), dtype=np.float)
92 | times[_type] = tims
93 | indxs[_type] = idxs
94 |
95 | if _type == 'def':
96 | attr = feature[_type]
97 | return feature, times, indxs, attr
98 |
99 | def load_gnn(_dict):
100 | out_dict = {}
101 | for key in _dict:
102 | if 'gnn' in key:
103 | out_dict[key[4:]] = _dict[key]
104 | return OrderedDict(out_dict)
--------------------------------------------------------------------------------
/example_reddit/GPT_GNN/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sp
3 | import torch
4 | from texttable import Texttable
5 |
6 | def args_print(args):
7 | _dict = vars(args)
8 | t = Texttable()
9 | t.add_row(["Parameter", "Value"])
10 | for k in _dict:
11 | t.add_row([k, _dict[k]])
12 | print(t.draw())
13 |
14 | def dcg_at_k(r, k):
15 | r = np.asfarray(r)[:k]
16 | if r.size:
17 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
18 | return 0.
19 |
20 | def ndcg_at_k(r, k):
21 | dcg_max = dcg_at_k(sorted(r, reverse=True), k)
22 | if not dcg_max:
23 | return 0.
24 | return dcg_at_k(r, k) / dcg_max
25 |
26 |
27 | def mean_reciprocal_rank(rs):
28 | rs = (np.asarray(r).nonzero()[0] for r in rs)
29 | return [1. / (r[0] + 1) if r.size else 0. for r in rs]
30 |
31 |
32 | def normalize(mx):
33 | """Row-normalize sparse matrix"""
34 | rowsum = np.array(mx.sum(1))
35 | r_inv = np.power(rowsum, -1).flatten()
36 | r_inv[np.isinf(r_inv)] = 0.
37 | r_mat_inv = sp.diags(r_inv)
38 | mx = r_mat_inv.dot(mx)
39 | return mx
40 |
41 |
42 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
43 | """Convert a scipy sparse matrix to a torch sparse tensor."""
44 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
45 | indices = torch.from_numpy(
46 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
47 | values = torch.from_numpy(sparse_mx.data)
48 | shape = torch.Size(sparse_mx.shape)
49 | return torch.sparse.FloatTensor(indices, values, shape)
50 |
51 | def randint():
52 | return np.random.randint(2**32 - 1)
53 |
54 | def feature_OAG(layer_data, graph):
55 | feature = {}
56 | times = {}
57 | indxs = {}
58 | texts = []
59 | for _type in layer_data:
60 | if len(layer_data[_type]) == 0:
61 | continue
62 | idxs = np.array(list(layer_data[_type].keys()))
63 | tims = np.array(list(layer_data[_type].values()))[:,1]
64 |
65 | if 'node_emb' in graph.node_feature[_type]:
66 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'node_emb']), dtype=np.float)
67 | else:
68 | feature[_type] = np.zeros([len(idxs), 400])
69 | feature[_type] = np.concatenate((feature[_type], list(graph.node_feature[_type].loc[idxs, 'emb']),\
70 | np.log10(np.array(list(graph.node_feature[_type].loc[idxs, 'citation'])).reshape(-1, 1) + 0.01)), axis=1)
71 |
72 | times[_type] = tims
73 | indxs[_type] = idxs
74 |
75 | if _type == 'paper':
76 | attr = np.array(list(graph.node_feature[_type].loc[idxs, 'title']), dtype=np.str)
77 | return feature, times, indxs, attr
78 |
79 | def feature_reddit(layer_data, graph):
80 | feature = {}
81 | times = {}
82 | indxs = {}
83 | texts = []
84 | for _type in layer_data:
85 | if len(layer_data[_type]) == 0:
86 | continue
87 | idxs = np.array(list(layer_data[_type].keys()))
88 | tims = np.array(list(layer_data[_type].values()))[:,1]
89 |
90 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'emb']), dtype=np.float)
91 | times[_type] = tims
92 | indxs[_type] = idxs
93 |
94 | if _type == 'def':
95 | attr = feature[_type]
96 | return feature, times, indxs, attr
97 |
98 | def load_gnn(_dict):
99 | out_dict = {}
100 | for key in _dict:
101 | if 'gnn' in key:
102 | out_dict[key[4:]] = _dict[key]
103 | return OrderedDict(out_dict)
104 |
105 | def load_gnn(_dict):
106 | out_dict = {}
107 | for key in _dict:
108 | if 'gnn' in key:
109 | out_dict[key[4:]] = _dict[key]
110 | return OrderedDict(out_dict)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GPT-GNN: Generative Pre-Training of Graph Neural Networks
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | GPT-GNN is a pre-training framework to initialize GNNs by generative pre-training. It can be applied to large-scale and heterogensous graphs.
11 |
12 | You can see our KDD 2020 paper [“**Generative Pre-Training of Graph Neural Networks**”](https://arxiv.org/abs/2006.15437) for more details.
13 |
14 |
15 | ## Overview
16 | The key package is GPT_GNN, which contains the the high-level GPT-GNN pretraining framework, base GNN models, and base graph structure and data loader.
17 |
18 | To illustrate how to apply the GPT_GNN framework for arbitrary graphs, we provide examples of pre-training on both hetergeneous (OAG) and homogeneous graphs (reddit). Both of them are of large-scale.
19 |
20 | Within each `example_*` package, there is a `pretrain_*.py` file for pre-training a GNN on the given graph, and also multiple `finetune_*.py` files for training and validating on downstream tasks.
21 |
22 | ## DataSet
23 | For **Open Academic Graph (OAG)**, we provide a heterogeneous graph containing highly-cited CS papers (8.1G) spanning from 1900-2020. You can download the preprocessed graph via [this link](https://drive.google.com/open?id=1a85skqsMBwnJ151QpurLFSa9o2ymc_rq). We split the data by their time: Pre-training ( t < 2014 ); Training ( 2014 <= t < 2017); Validation ( t = 2017 ); Testing ( 2018 <= t ). As we use the raw-text as attribute generation task for OAG, we provide a pre-trained word2vec model via [this link](https://drive.google.com/file/d/1ArdaMlPKVqdRGyiw4YSdUOV6CeFb2AmD/view?usp=sharing).
24 |
25 | If you want to directly process from raw data, you can download via [this link](https://drive.google.com/open?id=1yDdVaartOCOSsQlUZs8cJcAUhmvRiBSz). After downloading it, run `preprocess_OAG.py` to extract features and store them in our data structure.
26 |
27 | For **Reddit**, we simply download the preprocessed graph using pyG.datasets API, and then turn it into our own data structure using `preprocess_reddit.py`. We randomly split the data into different sets.
28 |
29 | ## Setup
30 |
31 | This implementation is based on pytorch_geometric. To run the code, you need the following dependencies:
32 |
33 | - [Pytorch 1.3.0](https://pytorch.org/)
34 | - [pytorch_geometric 1.3.2](https://pytorch-geometric.readthedocs.io/)
35 | - torch-cluster==1.4.5
36 | - torch-scatter==1.3.2
37 | - torch-sparse==0.4.3
38 | - [gensim](https://github.com/RaRe-Technologies/gensim)
39 | - [sklearn](https://github.com/scikit-learn/scikit-learn)
40 | - [tqdm](https://github.com/tqdm/tqdm)
41 | - [dill](https://github.com/uqfoundation/dill)
42 | - [pandas](https://github.com/pandas-dev/pandas)
43 |
44 | You can simply run ```pip install -r requirements.txt``` to install all the necessary packages.
45 |
46 | ## Usage
47 | We first introduce the arguments to control hyperparameters. There are mainly three types of arguments, for pre-training; for dataset; for model and optimization.
48 |
49 | For pre-training, we provide arguments to control different modules for attribute and edge generation tasks:
50 | ```
51 | --attr_ratio FLOAT The ratio (0~1) of attribute generation loss . Default is 0.5.
52 | --attr_type STR type of attribute decoder ['text' or 'vec'] Default is 'vec'
53 | --neg_samp_num BOOL Whether to use layer-norm on the last layer. Default is False.
54 | --queue_size INT Max size of adaptive embedding queue. Default is 256.
55 | ```
56 |
57 | For datasets, we provide arguments to control mini-batch sampling:
58 | ```
59 | --data_dir STR The address of preprocessed graph.
60 | --pretrain_model_dir STR The address for storing the pre-trained models.
61 | --sample_depth INT How many layers within a mini-batch subgraph Default is 6.
62 | --sample_width INT How many nodes to be sampled per layer per type Default is 128.
63 | ```
64 |
65 | For both pre-training and fine-tuning, we provide arguments to control model and optimizer hyperparameters. We highlight some key arguments below:
66 |
67 | ```
68 | --conv_name STR Name of GNN filter (model) Default is hgt.
69 | --scheduler STR Name of learning rate scheduler Default is cycle (for pretrain) and cosine (for fine-tuning)
70 | --n_hid INT Number of hidden dimension Default is 400.
71 | --n_layers INT Number of GNN layers Default is 3.
72 | --prev_norm BOOL Whether to use layer-norm on previous layers. Default is False.
73 | --last_norm BOOL Whether to use layer-norm on the last layer. Default is False.
74 | --max_lr FLOAT Maximum learning rate. Default is 1e-3 (for pretrain) and 5e-4 (for fine-tuning).
75 | ```
76 |
77 | The following commands pretrain a 3-layer HGT over OAG-CS:
78 | ```bash
79 | python pretrain_OAG.py --attr_type text --conv_name hgt --n_layers 3 --pretrain_model_dir /datadrive/models/gta_all_cs3
80 | ```
81 |
82 |
83 |
84 |
85 |
86 | The following commands use the pre-trained model as initialization and finetune on the paper-field classification task using 10% of training and validation data:
87 | ```bash
88 | python finetune_OAG_PF.py --use_pretrain --pretrain_model_dir /datadrive/models/gta_all_cs3 --n_layer 3 --data_percentage 0.1
89 | ```
90 |
91 |
92 | ## Pre-trained Models
93 |
94 | 1. The 3-layer HGT model pre-trained over OAG-CS under Time-Transfer Setting via [this link](https://drive.google.com/file/d/1OyIRfpNXjaD0TiRF-_Upfl5hix3is5ca/view?usp=sharing)
95 | 2. The 3-layer HGT model pre-trained over Reddit via [this link](https://drive.google.com/file/d/1Ja4PJT2bkFH0qgoWXjGBjByIFPco4h-S/view?usp=sharing)
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 | ### Citation
115 |
116 | Please consider citing the following paper when using our code for your application.
117 |
118 | ```bibtex
119 | @inproceedings{gpt_gnn,
120 | title={GPT-GNN: Generative Pre-Training of Graph Neural Networks},
121 | author={Ziniu Hu and Yuxiao Dong and Kuansan Wang and Kai-Wei Chang and Yizhou Sun},
122 | booktitle={Proceedings of the 26th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
123 | year={2020}
124 | }
125 | ```
126 |
127 |
128 | This implementation is mainly based on [pyHGT](https://github.com/acbull/pyHGT) API.
129 |
--------------------------------------------------------------------------------
/example_reddit/GPT_GNN/conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from torch_geometric.nn import GCNConv, GATConv
6 | from torch_geometric.nn.conv import MessagePassing
7 | from torch_geometric.nn.inits import glorot, uniform
8 | from torch_geometric.utils import softmax
9 | import math
10 |
11 | class HGTConv(MessagePassing):
12 | def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = True, use_RTE = True, **kwargs):
13 | super(HGTConv, self).__init__(node_dim=0, aggr='add', **kwargs)
14 |
15 | self.in_dim = in_dim
16 | self.out_dim = out_dim
17 | self.num_types = num_types
18 | self.num_relations = num_relations
19 | self.total_rel = num_types * num_relations * num_types
20 | self.n_heads = n_heads
21 | self.d_k = out_dim // n_heads
22 | self.sqrt_dk = math.sqrt(self.d_k)
23 | self.use_norm = use_norm
24 | self.att = None
25 |
26 |
27 | self.k_linears = nn.ModuleList()
28 | self.q_linears = nn.ModuleList()
29 | self.v_linears = nn.ModuleList()
30 | self.a_linears = nn.ModuleList()
31 | self.norms = nn.ModuleList()
32 |
33 | for t in range(num_types):
34 | self.k_linears.append(nn.Linear(in_dim, out_dim))
35 | self.q_linears.append(nn.Linear(in_dim, out_dim))
36 | self.v_linears.append(nn.Linear(in_dim, out_dim))
37 | self.a_linears.append(nn.Linear(out_dim, out_dim))
38 | if use_norm:
39 | self.norms.append(nn.LayerNorm(out_dim))
40 | '''
41 | TODO: make relation_pri smaller, as not all pair exist in meta relation list.
42 | '''
43 | self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads))
44 | self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
45 | self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
46 | self.skip = nn.Parameter(torch.ones(num_types))
47 | self.drop = nn.Dropout(dropout)
48 | self.emb = RelTemporalEncoding(in_dim)
49 |
50 | glorot(self.relation_att)
51 | glorot(self.relation_msg)
52 |
53 | def forward(self, node_inp, node_type, edge_index, edge_type, edge_time):
54 | return self.propagate(edge_index, node_inp=node_inp, node_type=node_type, \
55 | edge_type=edge_type, edge_time = edge_time)
56 |
57 | def message(self, edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time):
58 | '''
59 | j: source, i: target;
60 | '''
61 | data_size = edge_index_i.size(0)
62 | '''
63 | Create Attention and Message tensor beforehand.
64 | '''
65 | res_att = torch.zeros(data_size, self.n_heads).to(node_inp_i.device)
66 | res_msg = torch.zeros(data_size, self.n_heads, self.d_k).to(node_inp_i.device)
67 |
68 | for source_type in range(self.num_types):
69 | sb = (node_type_j == int(source_type))
70 | k_linear = self.k_linears[source_type]
71 | v_linear = self.v_linears[source_type]
72 | for target_type in range(self.num_types):
73 | tb = (node_type_i == int(target_type)) & sb
74 | q_linear = self.q_linears[target_type]
75 | for relation_type in range(self.num_relations):
76 | '''
77 | idx is all the edges with meta relation
78 | '''
79 | idx = (edge_type == int(relation_type)) & tb
80 | if idx.sum() == 0:
81 | continue
82 | '''
83 | Get the corresponding input node representations by idx.
84 | Add tempotal encoding to source representation (j)
85 | '''
86 | target_node_vec = node_inp_i[idx]
87 | source_node_vec = self.emb(node_inp_j[idx], edge_time[idx])
88 |
89 | '''
90 | Step 1: Heterogeneous Mutual Attention
91 | '''
92 | q_mat = q_linear(target_node_vec).view(-1, self.n_heads, self.d_k)
93 | k_mat = k_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
94 | k_mat = torch.bmm(k_mat.transpose(1,0), self.relation_att[relation_type]).transpose(1,0)
95 | res_att[idx] = (q_mat * k_mat).sum(dim=-1) * self.relation_pri[relation_type] / self.sqrt_dk
96 | '''
97 | Step 2: Heterogeneous Message Passing
98 | '''
99 | v_mat = v_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
100 | res_msg[idx] = torch.bmm(v_mat.transpose(1,0), self.relation_msg[relation_type]).transpose(1,0)
101 | '''
102 | Softmax based on target node's id (edge_index_i). Store attention value in self.att for later visualization.
103 | '''
104 | self.att = softmax(res_att, edge_index_i)
105 | res = res_msg * self.att.view(-1, self.n_heads, 1)
106 | del res_att, res_msg
107 | return res.view(-1, self.out_dim)
108 |
109 |
110 | def update(self, aggr_out, node_inp, node_type):
111 | '''
112 | Step 3: Target-specific Aggregation
113 | x = W[node_type] * gelu(Agg(x)) + x
114 | '''
115 | aggr_out = F.gelu(aggr_out)
116 | res = torch.zeros(aggr_out.size(0), self.out_dim).to(node_inp.device)
117 | for target_type in range(self.num_types):
118 | idx = (node_type == int(target_type))
119 | if idx.sum() == 0:
120 | continue
121 | trans_out = self.a_linears[target_type](aggr_out[idx])
122 | '''
123 | Add skip connection with learnable weight self.skip[t_id]
124 | '''
125 | alpha = torch.sigmoid(self.skip[target_type])
126 | if self.use_norm:
127 | res[idx] = self.norms[target_type](trans_out * alpha + node_inp[idx] * (1 - alpha))
128 | else:
129 | res[idx] = trans_out * alpha + node_inp[idx] * (1 - alpha)
130 | return self.drop(res)
131 |
132 | def __repr__(self):
133 | return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format(
134 | self.__class__.__name__, self.in_dim, self.out_dim,
135 | self.num_types, self.num_relations)
136 |
137 |
138 | class RelTemporalEncoding(nn.Module):
139 | '''
140 | Implement the Temporal Encoding (Sinusoid) function.
141 | '''
142 | def __init__(self, n_hid, max_len = 240, dropout = 0.2):
143 | super(RelTemporalEncoding, self).__init__()
144 | self.drop = nn.Dropout(dropout)
145 | position = torch.arange(0., max_len).unsqueeze(1)
146 | div_term = 1 / (10000 ** (torch.arange(0., n_hid * 2, 2.)) / n_hid / 2)
147 | self.emb = nn.Embedding(max_len, n_hid * 2)
148 | self.emb.weight.data[:, 0::2] = torch.sin(position * div_term) / math.sqrt(n_hid)
149 | self.emb.weight.data[:, 1::2] = torch.cos(position * div_term) / math.sqrt(n_hid)
150 | self.emb.requires_grad = False
151 | self.lin = nn.Linear(n_hid * 2, n_hid)
152 | def forward(self, x, t):
153 | return x + self.lin(self.drop(self.emb(t)))
154 |
155 |
156 |
157 | class GeneralConv(nn.Module):
158 | def __init__(self, conv_name, in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm = True, use_RTE = True):
159 | super(GeneralConv, self).__init__()
160 | self.conv_name = conv_name
161 | if self.conv_name == 'hgt':
162 | self.base_conv = HGTConv(in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm, use_RTE)
163 | elif self.conv_name == 'gcn':
164 | self.base_conv = GCNConv(in_hid, out_hid)
165 | elif self.conv_name == 'gat':
166 | self.base_conv = GATConv(in_hid, out_hid // n_heads, heads=n_heads)
167 | def forward(self, meta_xs, node_type, edge_index, edge_type, edge_time):
168 | if self.conv_name == 'hgt':
169 | return self.base_conv(meta_xs, node_type, edge_index, edge_type, edge_time)
170 | elif self.conv_name == 'gcn':
171 | return self.base_conv(meta_xs, edge_index)
172 | elif self.conv_name == 'gat':
173 | return self.base_conv(meta_xs, edge_index)
174 |
175 |
176 |
--------------------------------------------------------------------------------
/GPT_GNN/conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from torch_geometric.nn import GCNConv, GATConv
6 | from torch_geometric.nn.conv import MessagePassing
7 | from torch_geometric.nn.inits import glorot, uniform
8 | from torch_geometric.utils import softmax
9 | import math
10 |
11 | class HGTConv(MessagePassing):
12 | def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = True, use_RTE = True, **kwargs):
13 | super(HGTConv, self).__init__(node_dim=0, aggr='add', **kwargs)
14 |
15 | self.in_dim = in_dim
16 | self.out_dim = out_dim
17 | self.num_types = num_types
18 | self.num_relations = num_relations
19 | self.total_rel = num_types * num_relations * num_types
20 | self.n_heads = n_heads
21 | self.d_k = out_dim // n_heads
22 | self.sqrt_dk = math.sqrt(self.d_k)
23 | self.use_norm = use_norm
24 | self.use_RTE = use_RTE:
25 | self.att = None
26 |
27 |
28 | self.k_linears = nn.ModuleList()
29 | self.q_linears = nn.ModuleList()
30 | self.v_linears = nn.ModuleList()
31 | self.a_linears = nn.ModuleList()
32 | self.norms = nn.ModuleList()
33 |
34 | for t in range(num_types):
35 | self.k_linears.append(nn.Linear(in_dim, out_dim))
36 | self.q_linears.append(nn.Linear(in_dim, out_dim))
37 | self.v_linears.append(nn.Linear(in_dim, out_dim))
38 | self.a_linears.append(nn.Linear(out_dim, out_dim))
39 | if use_norm:
40 | self.norms.append(nn.LayerNorm(out_dim))
41 | '''
42 | TODO: make relation_pri smaller, as not all pair exist in meta relation list.
43 | '''
44 | self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads))
45 | self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
46 | self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
47 | self.skip = nn.Parameter(torch.ones(num_types))
48 | self.drop = nn.Dropout(dropout)
49 |
50 | if self.use_RTE:
51 | self.emb = RelTemporalEncoding(in_dim)
52 |
53 | glorot(self.relation_att)
54 | glorot(self.relation_msg)
55 |
56 | def forward(self, node_inp, node_type, edge_index, edge_type, edge_time):
57 | return self.propagate(edge_index, node_inp=node_inp, node_type=node_type, \
58 | edge_type=edge_type, edge_time = edge_time)
59 |
60 | def message(self, edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time):
61 | '''
62 | j: source, i: target;
63 | '''
64 | data_size = edge_index_i.size(0)
65 | '''
66 | Create Attention and Message tensor beforehand.
67 | '''
68 | res_att = torch.zeros(data_size, self.n_heads).to(node_inp_i.device)
69 | res_msg = torch.zeros(data_size, self.n_heads, self.d_k).to(node_inp_i.device)
70 |
71 | for source_type in range(self.num_types):
72 | sb = (node_type_j == int(source_type))
73 | k_linear = self.k_linears[source_type]
74 | v_linear = self.v_linears[source_type]
75 | for target_type in range(self.num_types):
76 | tb = (node_type_i == int(target_type)) & sb
77 | q_linear = self.q_linears[target_type]
78 | for relation_type in range(self.num_relations):
79 | '''
80 | idx is all the edges with meta relation
81 | '''
82 | idx = (edge_type == int(relation_type)) & tb
83 | if idx.sum() == 0:
84 | continue
85 | '''
86 | Get the corresponding input node representations by idx.
87 | Add tempotal encoding to source representation (j)
88 | '''
89 | target_node_vec = node_inp_i[idx]
90 | source_node_vec = node_inp_j[idx]
91 | if self.use_RTE:
92 | source_node_vec = self.emb(source_node_vec, edge_time[idx])
93 | '''
94 | Step 1: Heterogeneous Mutual Attention
95 | '''
96 | q_mat = q_linear(target_node_vec).view(-1, self.n_heads, self.d_k)
97 | k_mat = k_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
98 | k_mat = torch.bmm(k_mat.transpose(1,0), self.relation_att[relation_type]).transpose(1,0)
99 | res_att[idx] = (q_mat * k_mat).sum(dim=-1) * self.relation_pri[relation_type] / self.sqrt_dk
100 | '''
101 | Step 2: Heterogeneous Message Passing
102 | '''
103 | v_mat = v_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
104 | res_msg[idx] = torch.bmm(v_mat.transpose(1,0), self.relation_msg[relation_type]).transpose(1,0)
105 | '''
106 | Softmax based on target node's id (edge_index_i). Store attention value in self.att for later visualization.
107 | '''
108 | self.att = softmax(res_att, edge_index_i)
109 | res = res_msg * self.att.view(-1, self.n_heads, 1)
110 | del res_att, res_msg
111 | return res.view(-1, self.out_dim)
112 |
113 |
114 | def update(self, aggr_out, node_inp, node_type):
115 | '''
116 | Step 3: Target-specific Aggregation
117 | x = W[node_type] * gelu(Agg(x)) + x
118 | '''
119 | aggr_out = F.gelu(aggr_out)
120 | res = torch.zeros(aggr_out.size(0), self.out_dim).to(node_inp.device)
121 | for target_type in range(self.num_types):
122 | idx = (node_type == int(target_type))
123 | if idx.sum() == 0:
124 | continue
125 | trans_out = self.a_linears[target_type](aggr_out[idx])
126 | '''
127 | Add skip connection with learnable weight self.skip[t_id]
128 | '''
129 | alpha = torch.sigmoid(self.skip[target_type])
130 | if self.use_norm:
131 | res[idx] = self.norms[target_type](trans_out * alpha + node_inp[idx] * (1 - alpha))
132 | else:
133 | res[idx] = trans_out * alpha + node_inp[idx] * (1 - alpha)
134 | return self.drop(res)
135 |
136 | def __repr__(self):
137 | return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format(
138 | self.__class__.__name__, self.in_dim, self.out_dim,
139 | self.num_types, self.num_relations)
140 |
141 |
142 | class RelTemporalEncoding(nn.Module):
143 | '''
144 | Implement the Temporal Encoding (Sinusoid) function.
145 | '''
146 | def __init__(self, n_hid, max_len = 240, dropout = 0.2):
147 | super(RelTemporalEncoding, self).__init__()
148 | self.drop = nn.Dropout(dropout)
149 | position = torch.arange(0., max_len).unsqueeze(1)
150 | div_term = 1 / (10000 ** (torch.arange(0., n_hid * 2, 2.)) / n_hid / 2)
151 | self.emb = nn.Embedding(max_len, n_hid * 2)
152 | self.emb.weight.data[:, 0::2] = torch.sin(position * div_term) / math.sqrt(n_hid)
153 | self.emb.weight.data[:, 1::2] = torch.cos(position * div_term) / math.sqrt(n_hid)
154 | self.emb.requires_grad = False
155 | self.lin = nn.Linear(n_hid * 2, n_hid)
156 | def forward(self, x, t):
157 | return x + self.lin(self.drop(self.emb(t)))
158 |
159 |
160 |
161 | class GeneralConv(nn.Module):
162 | def __init__(self, conv_name, in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm = True, use_RTE = True):
163 | super(GeneralConv, self).__init__()
164 | self.conv_name = conv_name
165 | if self.conv_name == 'hgt':
166 | self.base_conv = HGTConv(in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm, use_RTE)
167 | elif self.conv_name == 'gcn':
168 | self.base_conv = GCNConv(in_hid, out_hid)
169 | elif self.conv_name == 'gat':
170 | self.base_conv = GATConv(in_hid, out_hid // n_heads, heads=n_heads)
171 | def forward(self, meta_xs, node_type, edge_index, edge_type, edge_time):
172 | if self.conv_name == 'hgt':
173 | return self.base_conv(meta_xs, node_type, edge_index, edge_type, edge_time)
174 | elif self.conv_name == 'gcn':
175 | return self.base_conv(meta_xs, edge_index)
176 | elif self.conv_name == 'gat':
177 | return self.base_conv(meta_xs, edge_index)
178 |
179 |
--------------------------------------------------------------------------------
/example_OAG/GPT_GNN/conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from torch_geometric.nn import GCNConv, GATConv, RGCNConv
6 | from torch_geometric.nn.conv import MessagePassing
7 | from torch_geometric.nn.inits import glorot, uniform
8 | from torch_geometric.utils import softmax
9 | import math
10 |
11 | class HGTConv(MessagePassing):
12 | def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = True, use_RTE = True, **kwargs):
13 | super(HGTConv, self).__init__(node_dim=0, aggr='add', **kwargs)
14 |
15 | self.in_dim = in_dim
16 | self.out_dim = out_dim
17 | self.num_types = num_types
18 | self.num_relations = num_relations
19 | self.total_rel = num_types * num_relations * num_types
20 | self.n_heads = n_heads
21 | self.d_k = out_dim // n_heads
22 | self.sqrt_dk = math.sqrt(self.d_k)
23 | self.use_norm = use_norm
24 | self.att = None
25 |
26 |
27 | self.k_linears = nn.ModuleList()
28 | self.q_linears = nn.ModuleList()
29 | self.v_linears = nn.ModuleList()
30 | self.a_linears = nn.ModuleList()
31 | self.norms = nn.ModuleList()
32 |
33 | for t in range(num_types):
34 | self.k_linears.append(nn.Linear(in_dim, out_dim))
35 | self.q_linears.append(nn.Linear(in_dim, out_dim))
36 | self.v_linears.append(nn.Linear(in_dim, out_dim))
37 | self.a_linears.append(nn.Linear(out_dim, out_dim))
38 | if use_norm:
39 | self.norms.append(nn.LayerNorm(out_dim))
40 | '''
41 | TODO: make relation_pri smaller, as not all pair exist in meta relation list.
42 | '''
43 | self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads))
44 | self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
45 | self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
46 | self.skip = nn.Parameter(torch.ones(num_types))
47 | self.drop = nn.Dropout(dropout)
48 | self.emb = RelTemporalEncoding(in_dim)
49 |
50 | glorot(self.relation_att)
51 | glorot(self.relation_msg)
52 |
53 | def forward(self, node_inp, node_type, edge_index, edge_type, edge_time):
54 | return self.propagate(edge_index, node_inp=node_inp, node_type=node_type, \
55 | edge_type=edge_type, edge_time = edge_time)
56 |
57 | def message(self, edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time):
58 | '''
59 | j: source, i: target;
60 | '''
61 | data_size = edge_index_i.size(0)
62 | '''
63 | Create Attention and Message tensor beforehand.
64 | '''
65 | res_att = torch.zeros(data_size, self.n_heads).to(node_inp_i.device)
66 | res_msg = torch.zeros(data_size, self.n_heads, self.d_k).to(node_inp_i.device)
67 |
68 | for source_type in range(self.num_types):
69 | sb = (node_type_j == int(source_type))
70 | k_linear = self.k_linears[source_type]
71 | v_linear = self.v_linears[source_type]
72 | for target_type in range(self.num_types):
73 | tb = (node_type_i == int(target_type)) & sb
74 | q_linear = self.q_linears[target_type]
75 | for relation_type in range(self.num_relations):
76 | '''
77 | idx is all the edges with meta relation
78 | '''
79 | idx = (edge_type == int(relation_type)) & tb
80 | if idx.sum() == 0:
81 | continue
82 | '''
83 | Get the corresponding input node representations by idx.
84 | Add tempotal encoding to source representation (j)
85 | '''
86 | target_node_vec = node_inp_i[idx]
87 | source_node_vec = self.emb(node_inp_j[idx], edge_time[idx])
88 |
89 | '''
90 | Step 1: Heterogeneous Mutual Attention
91 | '''
92 | q_mat = q_linear(target_node_vec).view(-1, self.n_heads, self.d_k)
93 | k_mat = k_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
94 | k_mat = torch.bmm(k_mat.transpose(1,0), self.relation_att[relation_type]).transpose(1,0)
95 | res_att[idx] = (q_mat * k_mat).sum(dim=-1) * self.relation_pri[relation_type] / self.sqrt_dk
96 | '''
97 | Step 2: Heterogeneous Message Passing
98 | '''
99 | v_mat = v_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
100 | res_msg[idx] = torch.bmm(v_mat.transpose(1,0), self.relation_msg[relation_type]).transpose(1,0)
101 | '''
102 | Softmax based on target node's id (edge_index_i). Store attention value in self.att for later visualization.
103 | '''
104 | self.att = softmax(res_att, edge_index_i)
105 | res = res_msg * self.att.view(-1, self.n_heads, 1)
106 | del res_att, res_msg
107 | return res.view(-1, self.out_dim)
108 |
109 |
110 | def update(self, aggr_out, node_inp, node_type):
111 | '''
112 | Step 3: Target-specific Aggregation
113 | x = W[node_type] * gelu(Agg(x)) + x
114 | '''
115 | aggr_out = F.gelu(aggr_out)
116 | res = torch.zeros(aggr_out.size(0), self.out_dim).to(node_inp.device)
117 | for target_type in range(self.num_types):
118 | idx = (node_type == int(target_type))
119 | if idx.sum() == 0:
120 | continue
121 | trans_out = self.a_linears[target_type](aggr_out[idx])
122 | '''
123 | Add skip connection with learnable weight self.skip[t_id]
124 | '''
125 | alpha = torch.sigmoid(self.skip[target_type])
126 | if self.use_norm:
127 | res[idx] = self.norms[target_type](trans_out * alpha + node_inp[idx] * (1 - alpha))
128 | else:
129 | res[idx] = trans_out * alpha + node_inp[idx] * (1 - alpha)
130 | return self.drop(res)
131 |
132 | def __repr__(self):
133 | return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format(
134 | self.__class__.__name__, self.in_dim, self.out_dim,
135 | self.num_types, self.num_relations)
136 |
137 |
138 | class RelTemporalEncoding(nn.Module):
139 | '''
140 | Implement the Temporal Encoding (Sinusoid) function.
141 | '''
142 | def __init__(self, n_hid, max_len = 240, dropout = 0.2):
143 | super(RelTemporalEncoding, self).__init__()
144 | self.drop = nn.Dropout(dropout)
145 | position = torch.arange(0., max_len).unsqueeze(1)
146 | div_term = 1 / (10000 ** (torch.arange(0., n_hid * 2, 2.)) / n_hid / 2)
147 | self.emb = nn.Embedding(max_len, n_hid * 2)
148 | self.emb.weight.data[:, 0::2] = torch.sin(position * div_term) / math.sqrt(n_hid)
149 | self.emb.weight.data[:, 1::2] = torch.cos(position * div_term) / math.sqrt(n_hid)
150 | self.emb.requires_grad = False
151 | self.lin = nn.Linear(n_hid * 2, n_hid)
152 | def forward(self, x, t):
153 | return x + self.lin(self.drop(self.emb(t)))
154 |
155 |
156 |
157 | class GeneralConv(nn.Module):
158 | def __init__(self, conv_name, in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm = True, use_RTE = True):
159 | super(GeneralConv, self).__init__()
160 | self.conv_name = conv_name
161 | if self.conv_name == 'hgt':
162 | self.base_conv = HGTConv(in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm, use_RTE)
163 | elif self.conv_name == 'gcn':
164 | self.base_conv = GCNConv(in_hid, out_hid)
165 | elif self.conv_name == 'gat':
166 | self.base_conv = GATConv(in_hid, out_hid // n_heads, heads=n_heads)
167 | elif self.conv_name == 'rgcn':
168 | self.base_conv = RGCNConv(in_hid, out_hid, num_relations)
169 | def forward(self, meta_xs, node_type, edge_index, edge_type, edge_time):
170 | if self.conv_name == 'hgt':
171 | return self.base_conv(meta_xs, node_type, edge_index, edge_type, edge_time)
172 | elif self.conv_name == 'gcn':
173 | return self.base_conv(meta_xs, edge_index)
174 | elif self.conv_name == 'gat':
175 | return self.base_conv(meta_xs, edge_index)
176 | elif self.conv_name == 'rgcn':
177 | return self.base_conv(meta_xs, edge_index, edge_type)
178 |
--------------------------------------------------------------------------------
/GPT_GNN/model.py:
--------------------------------------------------------------------------------
1 | from .conv import *
2 | import numpy as np
3 | from gensim.parsing.preprocessing import *
4 |
5 |
6 | class GPT_GNN(nn.Module):
7 | def __init__(self, gnn, rem_edge_list, attr_decoder, types, neg_samp_num, device, neg_queue_size = 0):
8 | super(GPT_GNN, self).__init__()
9 | self.types = types
10 | self.gnn = gnn
11 | self.params = nn.ModuleList()
12 | self.neg_queue_size = neg_queue_size
13 | self.link_dec_dict = {}
14 | self.neg_queue = {}
15 | for source_type in rem_edge_list:
16 | self.link_dec_dict[source_type] = {}
17 | self.neg_queue[source_type] = {}
18 | for relation_type in rem_edge_list[source_type]:
19 | print(source_type, relation_type)
20 | matcher = Matcher(gnn.n_hid, gnn.n_hid)
21 | self.neg_queue[source_type][relation_type] = torch.FloatTensor([]).to(device)
22 | self.link_dec_dict[source_type][relation_type] = matcher
23 | self.params.append(matcher)
24 | self.attr_decoder = attr_decoder
25 | self.init_emb = nn.Parameter(torch.randn(gnn.in_dim))
26 | self.ce = nn.CrossEntropyLoss(reduction = 'none')
27 | self.neg_samp_num = neg_samp_num
28 |
29 | def neg_sample(self, souce_node_list, pos_node_list):
30 | np.random.shuffle(souce_node_list)
31 | neg_nodes = []
32 | keys = {key : True for key in pos_node_list}
33 | tot = 0
34 | for node_id in souce_node_list:
35 | if node_id not in keys:
36 | neg_nodes += [node_id]
37 | tot += 1
38 | if tot == self.neg_samp_num:
39 | break
40 | return neg_nodes
41 |
42 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
43 | return self.gnn(node_feature, node_type, edge_time, edge_index, edge_type)
44 | def link_loss(self, node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue = False):
45 | losses = 0
46 | ress = []
47 | for source_type in rem_edge_list:
48 | if source_type not in self.link_dec_dict:
49 | continue
50 | for relation_type in rem_edge_list[source_type]:
51 | if relation_type not in self.link_dec_dict[source_type]:
52 | continue
53 | rem_edges = rem_edge_list[source_type][relation_type]
54 | if len(rem_edges) <= 8:
55 | continue
56 | ori_edges = ori_edge_list[source_type][relation_type]
57 | matcher = self.link_dec_dict[source_type][relation_type]
58 |
59 | target_ids, positive_source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1].reshape(-1, 1)
60 | n_nodes = len(target_ids)
61 | source_node_ids = np.unique(ori_edges[:, 1])
62 |
63 | negative_source_ids = [self.neg_sample(source_node_ids, \
64 | ori_edges[ori_edges[:, 0] == t_id][:, 1].tolist()) for t_id in target_ids]
65 | sn = min([len(neg_ids) for neg_ids in negative_source_ids])
66 |
67 | negative_source_ids = [neg_ids[:sn] for neg_ids in negative_source_ids]
68 |
69 | source_ids = torch.LongTensor(np.concatenate((positive_source_ids, negative_source_ids), axis=-1) + node_dict[source_type][0])
70 | emb = node_emb[source_ids]
71 |
72 | if use_queue and len(self.neg_queue[source_type][relation_type]) // n_nodes > 0:
73 | tmp = self.neg_queue[source_type][relation_type]
74 | stx = len(tmp) // n_nodes
75 | tmp = tmp[: stx * n_nodes].reshape(n_nodes, stx, -1)
76 | rep_size = sn + 1 + stx
77 | source_emb = torch.cat([emb, tmp], dim=1)
78 | source_emb = source_emb.reshape(n_nodes * rep_size, -1)
79 | else:
80 | rep_size = sn + 1
81 | source_emb = emb.reshape(source_ids.shape[0] * rep_size, -1)
82 |
83 | target_ids = target_ids.repeat(rep_size, 1) + node_dict[target_type][0]
84 | target_emb = node_emb[target_ids.reshape(-1)]
85 | res = matcher.forward(target_emb, source_emb)
86 | res = res.reshape(n_nodes, rep_size)
87 | ress += [res.detach()]
88 | losses += F.log_softmax(res, dim=-1)[:,0].mean()
89 | if update_queue and 'L1' not in relation_type and 'L2' not in relation_type:
90 | tmp = self.neg_queue[source_type][relation_type]
91 | self.neg_queue[source_type][relation_type] = \
92 | torch.cat([node_emb[source_node_ids].detach(), tmp], dim=0)[:int(self.neg_queue_size * n_nodes)]
93 | return -losses / len(ress), ress
94 |
95 |
96 | def text_loss(self, reps, texts, w2v_model, device):
97 | def parse_text(texts, w2v_model, device):
98 | idxs = []
99 | pad = w2v_model.wv.vocab['eos'].index
100 | for text in texts:
101 | idx = []
102 | for word in ['bos'] + preprocess_string(text) + ['eos']:
103 | if word in w2v_model.wv.vocab:
104 | idx += [w2v_model.wv.vocab[word].index]
105 | idxs += [idx]
106 | mxl = np.max([len(s) for s in idxs]) + 1
107 | inp_idxs = []
108 | out_idxs = []
109 | masks = []
110 | for i, idx in enumerate(idxs):
111 | inp_idxs += [idx + [pad for _ in range(mxl - len(idx) - 1)]]
112 | out_idxs += [idx[1:] + [pad for _ in range(mxl - len(idx))]]
113 | masks += [[1 for _ in range(len(idx))] + [0 for _ in range(mxl - len(idx) - 1)]]
114 | return torch.LongTensor(inp_idxs).transpose(0, 1).to(device), \
115 | torch.LongTensor(out_idxs).transpose(0, 1).to(device), torch.BoolTensor(masks).transpose(0, 1).to(device)
116 | inp_idxs, out_idxs, masks = parse_text(texts, w2v_model, device)
117 | pred_prob = self.attr_decoder(inp_idxs, reps.repeat(inp_idxs.shape[0], 1, 1))
118 | return self.ce(pred_prob[masks], out_idxs[masks]).mean()
119 |
120 | def feat_loss(self, reps, out):
121 | return -self.attr_decoder(reps, out).mean()
122 |
123 |
124 | class Classifier(nn.Module):
125 | def __init__(self, n_hid, n_out):
126 | super(Classifier, self).__init__()
127 | self.n_hid = n_hid
128 | self.n_out = n_out
129 | self.linear = nn.Linear(n_hid, n_out)
130 | def forward(self, x):
131 | tx = self.linear(x)
132 | return torch.log_softmax(tx.squeeze(), dim=-1)
133 | def __repr__(self):
134 | return '{}(n_hid={}, n_out={})'.format(
135 | self.__class__.__name__, self.n_hid, self.n_out)
136 |
137 |
138 | class Matcher(nn.Module):
139 | '''
140 | Matching between a pair of nodes to conduct link prediction.
141 | Use multi-head attention as matching model.
142 | '''
143 |
144 | def __init__(self, n_hid, n_out, temperature = 0.1):
145 | super(Matcher, self).__init__()
146 | self.n_hid = n_hid
147 | self.linear = nn.Linear(n_hid, n_out)
148 | self.sqrt_hd = math.sqrt(n_out)
149 | self.drop = nn.Dropout(0.2)
150 | self.cosine = nn.CosineSimilarity(dim=1)
151 | self.cache = None
152 | self.temperature = temperature
153 | def forward(self, x, ty, use_norm = True):
154 | tx = self.drop(self.linear(x))
155 | if use_norm:
156 | return self.cosine(tx, ty) / self.temperature
157 | else:
158 | return (tx * ty).sum(dim=-1) / self.sqrt_hd
159 | def __repr__(self):
160 | return '{}(n_hid={})'.format(
161 | self.__class__.__name__, self.n_hid)
162 |
163 |
164 | class GNN(nn.Module):
165 | def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.2, conv_name = 'hgt', prev_norm = False, last_norm = False, use_RTE = True):
166 | super(GNN, self).__init__()
167 | self.gcs = nn.ModuleList()
168 | self.num_types = num_types
169 | self.in_dim = in_dim
170 | self.n_hid = n_hid
171 | self.adapt_ws = nn.ModuleList()
172 | self.drop = nn.Dropout(dropout)
173 | for t in range(num_types):
174 | self.adapt_ws.append(nn.Linear(in_dim, n_hid))
175 | for l in range(n_layers - 1):
176 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = prev_norm, use_RTE = use_RTE))
177 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = last_norm, use_RTE = use_RTE))
178 |
179 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
180 | res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device)
181 | for t_id in range(self.num_types):
182 | idx = (node_type == int(t_id))
183 | if idx.sum() == 0:
184 | continue
185 | res[idx] = torch.tanh(self.adapt_ws[t_id](node_feature[idx]))
186 | meta_xs = self.drop(res)
187 | del res
188 | for gc in self.gcs:
189 | meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
190 | return meta_xs
191 |
192 |
193 | class RNNModel(nn.Module):
194 | """Container module with an encoder, a recurrent module, and a decoder."""
195 | def __init__(self, n_word, ninp, nhid, nlayers, dropout=0.2):
196 | super(RNNModel, self).__init__()
197 | self.drop = nn.Dropout(dropout)
198 | self.rnn = nn.LSTM(nhid, nhid, nlayers)
199 | self.encoder = nn.Embedding(n_word, nhid)
200 | self.decoder = nn.Linear(nhid, n_word)
201 | self.adp = nn.Linear(ninp + nhid, nhid)
202 | def forward(self, inp, hidden = None):
203 | emb = self.encoder(inp)
204 | if hidden is not None:
205 | emb = torch.cat((emb, hidden), dim=-1)
206 | emb = F.gelu(self.adp(emb))
207 | output, _ = self.rnn(emb)
208 | decoded = self.decoder(self.drop(output))
209 | return decoded
210 | def from_w2v(self, w2v):
211 | initrange = 0.1
212 | self.encoder.weight.data = w2v
213 | self.decoder.weight = self.encoder.weight
214 |
215 | self.encoder.weight.requires_grad = False
216 | self.decoder.weight.requires_grad = False
217 |
--------------------------------------------------------------------------------
/example_reddit/GPT_GNN/model.py:
--------------------------------------------------------------------------------
1 | from .conv import *
2 | import numpy as np
3 | from gensim.parsing.preprocessing import *
4 |
5 |
6 | class GPT_GNN(nn.Module):
7 | def __init__(self, gnn, rem_edge_list, attr_decoder, types, neg_samp_num, device, neg_queue_size = 0):
8 | super(GPT_GNN, self).__init__()
9 | self.types = types
10 | self.gnn = gnn
11 | self.params = nn.ModuleList()
12 | self.neg_queue_size = neg_queue_size
13 | self.link_dec_dict = {}
14 | self.neg_queue = {}
15 | for source_type in rem_edge_list:
16 | self.link_dec_dict[source_type] = {}
17 | self.neg_queue[source_type] = {}
18 | for relation_type in rem_edge_list[source_type]:
19 | print(source_type, relation_type)
20 | matcher = Matcher(gnn.n_hid, gnn.n_hid)
21 | self.neg_queue[source_type][relation_type] = torch.FloatTensor([]).to(device)
22 | self.link_dec_dict[source_type][relation_type] = matcher
23 | self.params.append(matcher)
24 | self.attr_decoder = attr_decoder
25 | self.init_emb = nn.Parameter(torch.randn(gnn.in_dim))
26 | self.ce = nn.CrossEntropyLoss(reduction = 'none')
27 | self.neg_samp_num = neg_samp_num
28 |
29 | def neg_sample(self, souce_node_list, pos_node_list):
30 | np.random.shuffle(souce_node_list)
31 | neg_nodes = []
32 | keys = {key : True for key in pos_node_list}
33 | tot = 0
34 | for node_id in souce_node_list:
35 | if node_id not in keys:
36 | neg_nodes += [node_id]
37 | tot += 1
38 | if tot == self.neg_samp_num:
39 | break
40 | return neg_nodes
41 |
42 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
43 | return self.gnn(node_feature, node_type, edge_time, edge_index, edge_type)
44 | def link_loss(self, node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue = False):
45 | losses = 0
46 | ress = []
47 | for source_type in rem_edge_list:
48 | if source_type not in self.link_dec_dict:
49 | continue
50 | for relation_type in rem_edge_list[source_type]:
51 | if relation_type not in self.link_dec_dict[source_type]:
52 | continue
53 | rem_edges = rem_edge_list[source_type][relation_type]
54 | if len(rem_edges) <= 8:
55 | continue
56 | ori_edges = ori_edge_list[source_type][relation_type]
57 | matcher = self.link_dec_dict[source_type][relation_type]
58 |
59 | target_ids, positive_source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1].reshape(-1, 1)
60 | n_nodes = len(target_ids)
61 | source_node_ids = np.unique(ori_edges[:, 1])
62 |
63 | negative_source_ids = [self.neg_sample(source_node_ids, \
64 | ori_edges[ori_edges[:, 0] == t_id][:, 1].tolist()) for t_id in target_ids]
65 | sn = min([len(neg_ids) for neg_ids in negative_source_ids])
66 |
67 | negative_source_ids = [neg_ids[:sn] for neg_ids in negative_source_ids]
68 |
69 | source_ids = torch.LongTensor(np.concatenate((positive_source_ids, negative_source_ids), axis=-1) + node_dict[source_type][0])
70 | emb = node_emb[source_ids]
71 |
72 | if use_queue and len(self.neg_queue[source_type][relation_type]) // n_nodes > 0:
73 | tmp = self.neg_queue[source_type][relation_type]
74 | stx = len(tmp) // n_nodes
75 | tmp = tmp[: stx * n_nodes].reshape(n_nodes, stx, -1)
76 | rep_size = sn + 1 + stx
77 | source_emb = torch.cat([emb, tmp], dim=1)
78 | source_emb = source_emb.reshape(n_nodes * rep_size, -1)
79 | else:
80 | rep_size = sn + 1
81 | source_emb = emb.reshape(source_ids.shape[0] * rep_size, -1)
82 |
83 | target_ids = target_ids.repeat(rep_size, 1) + node_dict[target_type][0]
84 | target_emb = node_emb[target_ids.reshape(-1)]
85 | res = matcher.forward(target_emb, source_emb)
86 | res = res.reshape(n_nodes, rep_size)
87 | ress += [res.detach()]
88 | losses += F.log_softmax(res, dim=-1)[:,0].mean()
89 | if update_queue and 'L1' not in relation_type and 'L2' not in relation_type:
90 | tmp = self.neg_queue[source_type][relation_type]
91 | self.neg_queue[source_type][relation_type] = \
92 | torch.cat([node_emb[source_node_ids].detach(), tmp], dim=0)[:int(self.neg_queue_size * n_nodes)]
93 | return -losses / len(ress), ress
94 |
95 |
96 | def text_loss(self, reps, texts, w2v_model, device):
97 | def parse_text(texts, w2v_model, device):
98 | idxs = []
99 | pad = w2v_model.wv.vocab['eos'].index
100 | for text in texts:
101 | idx = []
102 | for word in ['bos'] + preprocess_string(text) + ['eos']:
103 | if word in w2v_model.wv.vocab:
104 | idx += [w2v_model.wv.vocab[word].index]
105 | idxs += [idx]
106 | mxl = np.max([len(s) for s in idxs]) + 1
107 | inp_idxs = []
108 | out_idxs = []
109 | masks = []
110 | for i, idx in enumerate(idxs):
111 | inp_idxs += [idx + [pad for _ in range(mxl - len(idx) - 1)]]
112 | out_idxs += [idx[1:] + [pad for _ in range(mxl - len(idx))]]
113 | masks += [[1 for _ in range(len(idx))] + [0 for _ in range(mxl - len(idx) - 1)]]
114 | return torch.LongTensor(inp_idxs).transpose(0, 1).to(device), \
115 | torch.LongTensor(out_idxs).transpose(0, 1).to(device), torch.BoolTensor(masks).transpose(0, 1).to(device)
116 | inp_idxs, out_idxs, masks = parse_text(texts, w2v_model, device)
117 | pred_prob = self.attr_decoder(inp_idxs, reps.repeat(inp_idxs.shape[0], 1, 1))
118 | return self.ce(pred_prob[masks], out_idxs[masks]).mean()
119 |
120 | def feat_loss(self, reps, out):
121 | return -self.attr_decoder(reps, out).mean()
122 |
123 |
124 | class Classifier(nn.Module):
125 | def __init__(self, n_hid, n_out):
126 | super(Classifier, self).__init__()
127 | self.n_hid = n_hid
128 | self.n_out = n_out
129 | self.linear = nn.Linear(n_hid, n_out)
130 | def forward(self, x):
131 | tx = self.linear(x)
132 | return torch.log_softmax(tx.squeeze(), dim=-1)
133 | def __repr__(self):
134 | return '{}(n_hid={}, n_out={})'.format(
135 | self.__class__.__name__, self.n_hid, self.n_out)
136 |
137 |
138 | class Matcher(nn.Module):
139 | '''
140 | Matching between a pair of nodes to conduct link prediction.
141 | Use multi-head attention as matching model.
142 | '''
143 |
144 | def __init__(self, n_hid, n_out, temperature = 0.1):
145 | super(Matcher, self).__init__()
146 | self.n_hid = n_hid
147 | self.linear = nn.Linear(n_hid, n_out)
148 | self.sqrt_hd = math.sqrt(n_out)
149 | self.drop = nn.Dropout(0.2)
150 | self.cosine = nn.CosineSimilarity(dim=1)
151 | self.cache = None
152 | self.temperature = temperature
153 | def forward(self, x, ty, use_norm = True):
154 | tx = self.drop(self.linear(x))
155 | if use_norm:
156 | return self.cosine(tx, ty) / self.temperature
157 | else:
158 | return (tx * ty).sum(dim=-1) / self.sqrt_hd
159 | def __repr__(self):
160 | return '{}(n_hid={})'.format(
161 | self.__class__.__name__, self.n_hid)
162 |
163 |
164 | class GNN(nn.Module):
165 | def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.2, conv_name = 'hgt', prev_norm = False, last_norm = False, use_RTE = True):
166 | super(GNN, self).__init__()
167 | self.gcs = nn.ModuleList()
168 | self.num_types = num_types
169 | self.in_dim = in_dim
170 | self.n_hid = n_hid
171 | self.adapt_ws = nn.ModuleList()
172 | self.drop = nn.Dropout(dropout)
173 | for t in range(num_types):
174 | self.adapt_ws.append(nn.Linear(in_dim, n_hid))
175 | for l in range(n_layers - 1):
176 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = prev_norm, use_RTE = use_RTE))
177 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = last_norm, use_RTE = use_RTE))
178 |
179 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
180 | res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device)
181 | for t_id in range(self.num_types):
182 | idx = (node_type == int(t_id))
183 | if idx.sum() == 0:
184 | continue
185 | res[idx] = torch.tanh(self.adapt_ws[t_id](node_feature[idx]))
186 | meta_xs = self.drop(res)
187 | del res
188 | for gc in self.gcs:
189 | meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
190 | return meta_xs
191 |
192 |
193 | class RNNModel(nn.Module):
194 | """Container module with an encoder, a recurrent module, and a decoder."""
195 | def __init__(self, n_word, ninp, nhid, nlayers, dropout=0.2):
196 | super(RNNModel, self).__init__()
197 | self.drop = nn.Dropout(dropout)
198 | self.rnn = nn.LSTM(nhid, nhid, nlayers)
199 | self.encoder = nn.Embedding(n_word, nhid)
200 | self.decoder = nn.Linear(nhid, n_word)
201 | self.adp = nn.Linear(ninp + nhid, nhid)
202 | def forward(self, inp, hidden = None):
203 | emb = self.encoder(inp)
204 | if hidden is not None:
205 | emb = torch.cat((emb, hidden), dim=-1)
206 | emb = F.gelu(self.adp(emb))
207 | output, _ = self.rnn(emb)
208 | decoded = self.decoder(self.drop(output))
209 | return decoded
210 | def from_w2v(self, w2v):
211 | initrange = 0.1
212 | self.encoder.weight.data = w2v
213 | self.decoder.weight = self.encoder.weight
214 |
215 | self.encoder.weight.requires_grad = False
216 | self.decoder.weight.requires_grad = False
217 |
--------------------------------------------------------------------------------
/example_reddit/finetune_reddit.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from GPT_GNN.data import *
3 | from GPT_GNN.model import *
4 | from warnings import filterwarnings
5 |
6 | from sklearn.metrics import f1_score
7 | filterwarnings("ignore")
8 |
9 | import argparse
10 |
11 | parser = argparse.ArgumentParser(description='Fine-Tuning on Reddit classification task')
12 |
13 | '''
14 | Dataset arguments
15 | '''
16 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset',
17 | help='The address of preprocessed graph.')
18 | parser.add_argument('--use_pretrain', help='Whether to use pre-trained model', action='store_true')
19 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/gpt_all_cs',
20 | help='The address for pretrained model.')
21 | parser.add_argument('--model_dir', type=str, default='/datadrive/models/gpt_all_reddit',
22 | help='The address for storing the models and optimization results.')
23 | parser.add_argument('--task_name', type=str, default='reddit',
24 | help='The name of the stored models and optimization results.')
25 | parser.add_argument('--cuda', type=int, default=2,
26 | help='Avaiable GPU ID')
27 | parser.add_argument('--sample_depth', type=int, default=6,
28 | help='How many numbers to sample the graph')
29 | parser.add_argument('--sample_width', type=int, default=128,
30 | help='How many nodes to be sampled per layer per type')
31 | '''
32 | Model arguments
33 | '''
34 | parser.add_argument('--conv_name', type=str, default='hgt',
35 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'],
36 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)')
37 | parser.add_argument('--n_hid', type=int, default=400,
38 | help='Number of hidden dimension')
39 | parser.add_argument('--n_heads', type=int, default=8,
40 | help='Number of attention head')
41 | parser.add_argument('--n_layers', type=int, default=3,
42 | help='Number of GNN layers')
43 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true')
44 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true')
45 | parser.add_argument('--dropout', type=int, default=0.2,
46 | help='Dropout ratio')
47 |
48 |
49 | '''
50 | Optimization arguments
51 | '''
52 | parser.add_argument('--optimizer', type=str, default='adamw',
53 | choices=['adamw', 'adam', 'sgd', 'adagrad'],
54 | help='optimizer to use.')
55 | parser.add_argument('--scheduler', type=str, default='cosine',
56 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine'])
57 | parser.add_argument('--data_percentage', type=int, default=0.1,
58 | help='Percentage of training and validation data to use')
59 | parser.add_argument('--n_epoch', type=int, default=50,
60 | help='Number of epoch to run')
61 | parser.add_argument('--n_pool', type=int, default=8,
62 | help='Number of process to sample subgraph')
63 | parser.add_argument('--n_batch', type=int, default=16,
64 | help='Number of batch (sampled graphs) for each epoch')
65 | parser.add_argument('--batch_size', type=int, default=256,
66 | help='Number of output nodes for training')
67 | parser.add_argument('--clip', type=int, default=0.5,
68 | help='Gradient Norm Clipping')
69 |
70 | args = parser.parse_args()
71 | args_print(args)
72 |
73 | if args.cuda != -1:
74 | device = torch.device("cuda:" + str(args.cuda))
75 | else:
76 | device = torch.device("cpu")
77 |
78 | graph = dill.load(open(os.path.join(args.data_dir, 'graph_reddit.pk'), 'rb'))
79 |
80 | target_type = 'def'
81 | train_target_nodes = graph.train_target_nodes
82 | valid_target_nodes = graph.valid_target_nodes
83 | test_target_nodes = graph.test_target_nodes
84 |
85 | types = graph.get_types()
86 | criterion = nn.NLLLoss()
87 |
88 | def node_classification_sample(seed, nodes, time_range):
89 | '''
90 | sub-graph sampling and label preparation for node classification:
91 | (1) Sample batch_size number of output nodes (papers) and their time.
92 | '''
93 | np.random.seed(seed)
94 | samp_nodes = np.random.choice(nodes, args.batch_size, replace = False)
95 | feature, times, edge_list, _, texts = sample_subgraph(graph, time_range, \
96 | inp = {target_type: np.concatenate([samp_nodes, np.ones(args.batch_size)]).reshape(2, -1).transpose()}, \
97 | sampled_depth = args.sample_depth, sampled_number = args.sample_width, feature_extractor = feature_reddit)
98 |
99 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
100 | to_torch(feature, times, edge_list, graph)
101 |
102 | x_ids = np.arange(args.batch_size)
103 | return node_feature, node_type, edge_time, edge_index, edge_type, x_ids, graph.y[samp_nodes]
104 |
105 |
106 | def prepare_data(pool):
107 | '''
108 | Sampled and prepare training and validation data using multi-process parallization.
109 | '''
110 | jobs = []
111 | for batch_id in np.arange(args.n_batch):
112 | p = pool.apply_async(node_classification_sample, args=(randint(), train_target_nodes, {1: True}))
113 | jobs.append(p)
114 | p = pool.apply_async(node_classification_sample, args=(randint(), valid_target_nodes, {1: True}))
115 | jobs.append(p)
116 | return jobs
117 |
118 | stats = []
119 | res = []
120 | best_val = 0
121 | train_step = 0
122 |
123 | pool = mp.Pool(args.n_pool)
124 | st = time.time()
125 | jobs = prepare_data(pool)
126 |
127 |
128 | '''
129 | Initialize GNN (model is specified by conv_name) and Classifier
130 | '''
131 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph.node_feature[target_type]['emb'].values[0]), n_hid = args.n_hid, \
132 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \
133 | num_relations = len(graph.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm, use_RTE = False)
134 | if args.use_pretrain:
135 | gnn.load_state_dict(load_gnn(torch.load(args.pretrain_model_dir)), strict = False)
136 | print('Load Pre-trained Model from (%s)' % args.pretrain_model_dir)
137 | classifier = Classifier(args.n_hid, graph.y.max().item() + 1)
138 |
139 | model = nn.Sequential(gnn, classifier).to(device)
140 |
141 |
142 | optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-4)
143 |
144 |
145 |
146 | if args.scheduler == 'cycle':
147 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.02, anneal_strategy='linear', final_div_factor=100,\
148 | max_lr = args.max_lr, total_steps = args.n_batch * args.n_epoch + 1)
149 | elif args.scheduler == 'cosine':
150 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6)
151 |
152 |
153 | for epoch in np.arange(args.n_epoch) + 1:
154 | '''
155 | Prepare Training and Validation Data
156 | '''
157 | train_data = [job.get() for job in jobs[:-1]]
158 | valid_data = jobs[-1].get()
159 | pool.close()
160 | pool.join()
161 | '''
162 | After the data is collected, close the pool and then reopen it.
163 | '''
164 | pool = mp.Pool(args.n_pool)
165 | jobs = prepare_data(pool)
166 | et = time.time()
167 | print('Data Preparation: %.1fs' % (et - st))
168 |
169 | '''
170 | Train
171 | '''
172 | model.train()
173 | train_losses = []
174 | for node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel in train_data:
175 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
176 | edge_time.to(device), edge_index.to(device), edge_type.to(device))
177 | res = classifier.forward(node_rep[x_ids])
178 | loss = criterion(res, ylabel.to(device))
179 |
180 | optimizer.zero_grad()
181 | torch.cuda.empty_cache()
182 | loss.backward()
183 |
184 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
185 | optimizer.step()
186 |
187 | train_losses += [loss.cpu().detach().tolist()]
188 | train_step += 1
189 | scheduler.step(train_step)
190 | del res, loss
191 | '''
192 | Valid
193 | '''
194 | model.eval()
195 | with torch.no_grad():
196 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = valid_data
197 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
198 | edge_time.to(device), edge_index.to(device), edge_type.to(device))
199 | res = classifier.forward(node_rep[x_ids])
200 | loss = criterion(res, ylabel.to(device))
201 |
202 | '''
203 | Calculate Valid F1. Update the best model based on highest F1 score.
204 | '''
205 | valid_f1 = f1_score(res.argmax(dim=1).cpu().tolist(), ylabel.tolist(), average='micro')
206 |
207 | if valid_f1 > best_val:
208 | best_val = valid_f1
209 | torch.save(model, os.path.join(args.model_dir, args.task_name + '_' + args.conv_name))
210 | print('UPDATE!!!')
211 |
212 | st = time.time()
213 | print(("Epoch: %d (%.1fs) LR: %.5f Train Loss: %.2f Valid Loss: %.2f Valid F1: %.4f") % \
214 | (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), \
215 | loss.cpu().detach().tolist(), valid_f1))
216 | stats += [[np.average(train_losses), loss.cpu().detach().tolist()]]
217 | del res, loss
218 | del train_data, valid_data
219 |
220 |
221 |
222 | best_model = torch.load(os.path.join(args.model_dir, args.task_name + '_' + args.conv_name)).to(device)
223 | best_model.eval()
224 | gnn, classifier = best_model
225 | with torch.no_grad():
226 | test_res = []
227 | for _ in range(10):
228 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = \
229 | node_classification_sample(randint(), test_target_nodes, {1: True})
230 | paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
231 | edge_time.to(device), edge_index.to(device), edge_type.to(device))[x_ids]
232 | res = classifier.forward(paper_rep)
233 | test_f1 = f1_score(res.argmax(dim=1).cpu().tolist(), ylabel.tolist(), average='micro')
234 | test_res += [test_f1]
235 | print('Best Test F1: %.4f' % np.average(test_res))
236 |
--------------------------------------------------------------------------------
/example_OAG/GPT_GNN/model.py:
--------------------------------------------------------------------------------
1 | from .conv import *
2 | import numpy as np
3 | from gensim.parsing.preprocessing import *
4 |
5 |
6 | class GPT_GNN(nn.Module):
7 | def __init__(self, gnn, rem_edge_list, attr_decoder, types, neg_samp_num, device, neg_queue_size = 0):
8 | super(GPT_GNN, self).__init__()
9 | if gnn is None:
10 | return
11 | self.types = types
12 | self.gnn = gnn
13 | self.params = nn.ModuleList()
14 | self.neg_queue_size = neg_queue_size
15 | self.link_dec_dict = {}
16 | self.neg_queue = {}
17 | for source_type in rem_edge_list:
18 | self.link_dec_dict[source_type] = {}
19 | self.neg_queue[source_type] = {}
20 | for relation_type in rem_edge_list[source_type]:
21 | print(source_type, relation_type)
22 | matcher = Matcher(gnn.n_hid, gnn.n_hid)
23 | self.neg_queue[source_type][relation_type] = torch.FloatTensor([]).to(device)
24 | self.link_dec_dict[source_type][relation_type] = matcher
25 | self.params.append(matcher)
26 | self.attr_decoder = attr_decoder
27 | self.init_emb = nn.Parameter(torch.randn(gnn.in_dim))
28 | self.ce = nn.CrossEntropyLoss(reduction = 'none')
29 | self.neg_samp_num = neg_samp_num
30 |
31 | def neg_sample(self, souce_node_list, pos_node_list):
32 | np.random.shuffle(souce_node_list)
33 | neg_nodes = []
34 | keys = {key : True for key in pos_node_list}
35 | tot = 0
36 | for node_id in souce_node_list:
37 | if node_id not in keys:
38 | neg_nodes += [node_id]
39 | tot += 1
40 | if tot == self.neg_samp_num:
41 | break
42 | return neg_nodes
43 |
44 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
45 | return self.gnn(node_feature, node_type, edge_time, edge_index, edge_type)
46 | def link_loss(self, node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue = False):
47 | losses = 0
48 | ress = []
49 | for source_type in rem_edge_list:
50 | if source_type not in self.link_dec_dict:
51 | continue
52 | for relation_type in rem_edge_list[source_type]:
53 | if relation_type not in self.link_dec_dict[source_type]:
54 | continue
55 | rem_edges = rem_edge_list[source_type][relation_type]
56 | if len(rem_edges) <= 8:
57 | continue
58 | ori_edges = ori_edge_list[source_type][relation_type]
59 | matcher = self.link_dec_dict[source_type][relation_type]
60 |
61 | target_ids, positive_source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1].reshape(-1, 1)
62 | n_nodes = len(target_ids)
63 | source_node_ids = np.unique(ori_edges[:, 1])
64 |
65 | negative_source_ids = [self.neg_sample(source_node_ids, \
66 | ori_edges[ori_edges[:, 0] == t_id][:, 1].tolist()) for t_id in target_ids]
67 | sn = min([len(neg_ids) for neg_ids in negative_source_ids])
68 |
69 | negative_source_ids = [neg_ids[:sn] for neg_ids in negative_source_ids]
70 |
71 | source_ids = torch.LongTensor(np.concatenate((positive_source_ids, negative_source_ids), axis=-1) + node_dict[source_type][0])
72 | emb = node_emb[source_ids]
73 |
74 | if use_queue and len(self.neg_queue[source_type][relation_type]) // n_nodes > 0:
75 | tmp = self.neg_queue[source_type][relation_type]
76 | stx = len(tmp) // n_nodes
77 | tmp = tmp[: stx * n_nodes].reshape(n_nodes, stx, -1)
78 | rep_size = sn + 1 + stx
79 | source_emb = torch.cat([emb, tmp], dim=1)
80 | source_emb = source_emb.reshape(n_nodes * rep_size, -1)
81 | else:
82 | rep_size = sn + 1
83 | source_emb = emb.reshape(source_ids.shape[0] * rep_size, -1)
84 |
85 | target_ids = target_ids.repeat(rep_size, 1) + node_dict[target_type][0]
86 | target_emb = node_emb[target_ids.reshape(-1)]
87 | res = matcher.forward(target_emb, source_emb)
88 | res = res.reshape(n_nodes, rep_size)
89 | ress += [res.detach()]
90 | losses += F.log_softmax(res, dim=-1)[:,0].mean()
91 | if update_queue and 'L1' not in relation_type and 'L2' not in relation_type:
92 | tmp = self.neg_queue[source_type][relation_type]
93 | self.neg_queue[source_type][relation_type] = \
94 | torch.cat([node_emb[source_node_ids].detach(), tmp], dim=0)[:int(self.neg_queue_size * n_nodes)]
95 | return -losses / len(ress), ress
96 |
97 |
98 | def text_loss(self, reps, texts, w2v_model, device):
99 | def parse_text(texts, w2v_model, device):
100 | idxs = []
101 | pad = w2v_model.wv.vocab['eos'].index
102 | for text in texts:
103 | idx = []
104 | for word in ['bos'] + preprocess_string(text) + ['eos']:
105 | if word in w2v_model.wv.vocab:
106 | idx += [w2v_model.wv.vocab[word].index]
107 | idxs += [idx]
108 | mxl = np.max([len(s) for s in idxs]) + 1
109 | inp_idxs = []
110 | out_idxs = []
111 | masks = []
112 | for i, idx in enumerate(idxs):
113 | inp_idxs += [idx + [pad for _ in range(mxl - len(idx) - 1)]]
114 | out_idxs += [idx[1:] + [pad for _ in range(mxl - len(idx))]]
115 | masks += [[1 for _ in range(len(idx))] + [0 for _ in range(mxl - len(idx) - 1)]]
116 | return torch.LongTensor(inp_idxs).transpose(0, 1).to(device), \
117 | torch.LongTensor(out_idxs).transpose(0, 1).to(device), torch.BoolTensor(masks).transpose(0, 1).to(device)
118 | inp_idxs, out_idxs, masks = parse_text(texts, w2v_model, device)
119 | pred_prob = self.attr_decoder(inp_idxs, reps.repeat(inp_idxs.shape[0], 1, 1))
120 | return self.ce(pred_prob[masks], out_idxs[masks]).mean()
121 |
122 | def feat_loss(self, reps, out):
123 | return -self.attr_decoder(reps, out).mean()
124 |
125 |
126 | class Classifier(nn.Module):
127 | def __init__(self, n_hid, n_out):
128 | super(Classifier, self).__init__()
129 | self.n_hid = n_hid
130 | self.n_out = n_out
131 | self.linear = nn.Linear(n_hid, n_out)
132 | def forward(self, x):
133 | tx = self.linear(x)
134 | return torch.log_softmax(tx.squeeze(), dim=-1)
135 | def __repr__(self):
136 | return '{}(n_hid={}, n_out={})'.format(
137 | self.__class__.__name__, self.n_hid, self.n_out)
138 |
139 |
140 | class Matcher(nn.Module):
141 | '''
142 | Matching between a pair of nodes to conduct link prediction.
143 | Use multi-head attention as matching model.
144 | '''
145 |
146 | def __init__(self, n_hid, n_out, temperature = 0.1):
147 | super(Matcher, self).__init__()
148 | self.n_hid = n_hid
149 | self.linear = nn.Linear(n_hid, n_out)
150 | self.sqrt_hd = math.sqrt(n_out)
151 | self.drop = nn.Dropout(0.2)
152 | self.cosine = nn.CosineSimilarity(dim=1)
153 | self.cache = None
154 | self.temperature = temperature
155 | def forward(self, x, ty, use_norm = True):
156 | tx = self.drop(self.linear(x))
157 | if use_norm:
158 | return self.cosine(tx, ty) / self.temperature
159 | else:
160 | return (tx * ty).sum(dim=-1) / self.sqrt_hd
161 | def __repr__(self):
162 | return '{}(n_hid={})'.format(
163 | self.__class__.__name__, self.n_hid)
164 |
165 |
166 | class GNN(nn.Module):
167 | def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.2, conv_name = 'hgt', prev_norm = False, last_norm = False, use_RTE = True):
168 | super(GNN, self).__init__()
169 | self.gcs = nn.ModuleList()
170 | self.num_types = num_types
171 | self.in_dim = in_dim
172 | self.n_hid = n_hid
173 | self.adapt_ws = nn.ModuleList()
174 | self.drop = nn.Dropout(dropout)
175 | for t in range(num_types):
176 | self.adapt_ws.append(nn.Linear(in_dim, n_hid))
177 | for l in range(n_layers - 1):
178 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = prev_norm, use_RTE = use_RTE))
179 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = last_norm, use_RTE = use_RTE))
180 |
181 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
182 | res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device)
183 | for t_id in range(self.num_types):
184 | idx = (node_type == int(t_id))
185 | if idx.sum() == 0:
186 | continue
187 | res[idx] = torch.tanh(self.adapt_ws[t_id](node_feature[idx]))
188 | meta_xs = self.drop(res)
189 | del res
190 | for gc in self.gcs:
191 | meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
192 | return meta_xs
193 |
194 |
195 | class RNNModel(nn.Module):
196 | """Container module with an encoder, a recurrent module, and a decoder."""
197 | def __init__(self, n_word, ninp, nhid, nlayers, dropout=0.2):
198 | super(RNNModel, self).__init__()
199 | self.drop = nn.Dropout(dropout)
200 | self.rnn = nn.LSTM(nhid, nhid, nlayers)
201 | self.encoder = nn.Embedding(n_word, nhid)
202 | self.decoder = nn.Linear(nhid, n_word)
203 | self.adp = nn.Linear(ninp + nhid, nhid)
204 | def forward(self, inp, hidden = None):
205 | emb = self.encoder(inp)
206 | if hidden is not None:
207 | emb = torch.cat((emb, hidden), dim=-1)
208 | emb = F.gelu(self.adp(emb))
209 | output, _ = self.rnn(emb)
210 | decoded = self.decoder(self.drop(output))
211 | return decoded
212 | def from_w2v(self, w2v):
213 | initrange = 0.1
214 | self.encoder.weight.data = w2v
215 | self.decoder.weight = self.encoder.weight
216 |
217 | self.encoder.weight.requires_grad = False
218 | self.decoder.weight.requires_grad = False
219 |
--------------------------------------------------------------------------------
/GPT_GNN/data.py:
--------------------------------------------------------------------------------
1 | import json, os
2 | import math, copy, time
3 | import numpy as np
4 | from collections import defaultdict
5 | import pandas as pd
6 |
7 | import math
8 | from tqdm import tqdm
9 |
10 | import seaborn as sb
11 | import matplotlib.pyplot as plt
12 | import matplotlib.cm as cm
13 |
14 | from .utils import *
15 |
16 | import dill
17 | from functools import partial
18 | import multiprocessing as mp
19 |
20 | class Graph():
21 | def __init__(self):
22 | super(Graph, self).__init__()
23 | '''
24 | node_forward and bacward are only used when building the data.
25 | Afterwards will be transformed into node_feature by DataFrame
26 |
27 | node_forward: name -> node_id
28 | node_bacward: node_id -> feature_dict
29 | node_feature: a DataFrame containing all features
30 | '''
31 | self.node_forward = defaultdict(lambda: {})
32 | self.node_bacward = defaultdict(lambda: [])
33 | self.node_feature = defaultdict(lambda: [])
34 |
35 | '''
36 | edge_list: index the adjacancy matrix (time) by
37 |
38 | '''
39 | self.edge_list = defaultdict( #target_type
40 | lambda: defaultdict( #source_type
41 | lambda: defaultdict( #relation_type
42 | lambda: defaultdict( #target_id
43 | lambda: defaultdict( #source_id(
44 | lambda: int # time
45 | )))))
46 | self.times = {}
47 | def add_node(self, node):
48 | nfl = self.node_forward[node['type']]
49 | if node['id'] not in nfl:
50 | self.node_bacward[node['type']] += [node]
51 | ser = len(nfl)
52 | nfl[node['id']] = ser
53 | return ser
54 | return nfl[node['id']]
55 | def add_edge(self, source_node, target_node, time = None, relation_type = None, directed = True):
56 | edge = [self.add_node(source_node), self.add_node(target_node)]
57 | '''
58 | Add bi-directional edges with different relation type
59 | '''
60 | self.edge_list[target_node['type']][source_node['type']][relation_type][edge[1]][edge[0]] = time
61 | if directed:
62 | self.edge_list[source_node['type']][target_node['type']]['rev_' + relation_type][edge[0]][edge[1]] = time
63 | else:
64 | self.edge_list[source_node['type']][target_node['type']][relation_type][edge[0]][edge[1]] = time
65 | self.times[time] = True
66 |
67 | def update_node(self, node):
68 | nbl = self.node_bacward[node['type']]
69 | ser = self.add_node(node)
70 | for k in node:
71 | if k not in nbl[ser]:
72 | nbl[ser][k] = node[k]
73 |
74 | def get_meta_graph(self):
75 | types = self.get_types()
76 | metas = []
77 | for target_type in self.edge_list:
78 | for source_type in self.edge_list[target_type]:
79 | for r_type in self.edge_list[target_type][source_type]:
80 | metas += [(target_type, source_type, r_type)]
81 | return metas
82 |
83 | def get_types(self):
84 | return list(self.node_feature.keys())
85 |
86 |
87 |
88 | def sample_subgraph(graph, time_range, sampled_depth = 2, sampled_number = 8, inp = None, feature_extractor = feature_OAG):
89 | '''
90 | Sample Sub-Graph based on the connection of other nodes with currently sampled nodes
91 | We maintain budgets for each node type, indexed by .
92 | Currently sampled nodes are stored in layer_data.
93 | After nodes are sampled, we construct the sampled adjacancy matrix.
94 | '''
95 | layer_data = defaultdict( #target_type
96 | lambda: {} # {target_id: [ser, time]}
97 | )
98 | budget = defaultdict( #source_type
99 | lambda: defaultdict( #source_id
100 | lambda: [0., 0] #[sampled_score, time]
101 | ))
102 | new_layer_adj = defaultdict( #target_type
103 | lambda: defaultdict( #source_type
104 | lambda: defaultdict( #relation_type
105 | lambda: [] #[target_id, source_id]
106 | )))
107 | '''
108 | For each node being sampled, we find out all its neighborhood,
109 | adding the degree count of these nodes in the budget.
110 | Note that there exist some nodes that have many neighborhoods
111 | (such as fields, venues), for those case, we only consider
112 | '''
113 | def add_budget(te, target_id, target_time, layer_data, budget):
114 | for source_type in te:
115 | tes = te[source_type]
116 | for relation_type in tes:
117 | if relation_type == 'self' or target_id not in tes[relation_type]:
118 | continue
119 | adl = tes[relation_type][target_id]
120 | if len(adl) < sampled_number:
121 | sampled_ids = list(adl.keys())
122 | else:
123 | sampled_ids = np.random.choice(list(adl.keys()), sampled_number, replace = False)
124 | for source_id in sampled_ids:
125 | source_time = adl[source_id]
126 | if source_time == None:
127 | source_time = target_time
128 | if source_time > np.max(list(time_range.keys())) or source_id in layer_data[source_type]:
129 | continue
130 | budget[source_type][source_id][0] += 1. / len(sampled_ids)
131 | budget[source_type][source_id][1] = source_time
132 |
133 | '''
134 | First adding the sampled nodes then updating budget.
135 | '''
136 | for _type in inp:
137 | for _id, _time in inp[_type]:
138 | layer_data[_type][_id] = [len(layer_data[_type]), _time]
139 | for _type in inp:
140 | te = graph.edge_list[_type]
141 | for _id, _time in inp[_type]:
142 | add_budget(te, _id, _time, layer_data, budget)
143 | '''
144 | We recursively expand the sampled graph by sampled_depth.
145 | Each time we sample a fixed number of nodes for each budget,
146 | based on the accumulated degree.
147 | '''
148 | for layer in range(sampled_depth):
149 | sts = list(budget.keys())
150 | for source_type in sts:
151 | te = graph.edge_list[source_type]
152 | keys = np.array(list(budget[source_type].keys()))
153 | if sampled_number > len(keys):
154 | '''
155 | Directly sample all the nodes
156 | '''
157 | sampled_ids = np.arange(len(keys))
158 | else:
159 | '''
160 | Sample based on accumulated degree
161 | '''
162 | score = np.array(list(budget[source_type].values()))[:,0] ** 2
163 | score = score / np.sum(score)
164 | sampled_ids = np.random.choice(len(score), sampled_number, p = score, replace = False)
165 | sampled_keys = keys[sampled_ids]
166 | '''
167 | First adding the sampled nodes then updating budget.
168 | '''
169 | for k in sampled_keys:
170 | layer_data[source_type][k] = [len(layer_data[source_type]), budget[source_type][k][1]]
171 | for k in sampled_keys:
172 | add_budget(te, k, budget[source_type][k][1], layer_data, budget)
173 | budget[source_type].pop(k)
174 | '''
175 | Prepare feature, time and adjacency matrix for the sampled graph
176 | '''
177 | feature, times, indxs, texts = feature_extractor(layer_data, graph)
178 |
179 | edge_list = defaultdict( #target_type
180 | lambda: defaultdict( #source_type
181 | lambda: defaultdict( #relation_type
182 | lambda: [] # [target_id, source_id]
183 | )))
184 | for _type in layer_data:
185 | for _key in layer_data[_type]:
186 | _ser = layer_data[_type][_key][0]
187 | edge_list[_type][_type]['self'] += [[_ser, _ser]]
188 | '''
189 | Reconstruct sampled adjacancy matrix by checking whether each
190 | link exist in the original graph
191 | '''
192 | for target_type in graph.edge_list:
193 | te = graph.edge_list[target_type]
194 | tld = layer_data[target_type]
195 | for source_type in te:
196 | tes = te[source_type]
197 | sld = layer_data[source_type]
198 | for relation_type in tes:
199 | tesr = tes[relation_type]
200 | for target_key in tld:
201 | if target_key not in tesr:
202 | continue
203 | target_ser = tld[target_key][0]
204 | for source_key in tesr[target_key]:
205 | '''
206 | Check whether each link (target_id, source_id) exist in original adjacancy matrix
207 | '''
208 | if source_key in sld:
209 | source_ser = sld[source_key][0]
210 | edge_list[target_type][source_type][relation_type] += [[target_ser, source_ser]]
211 | return feature, times, edge_list, indxs, texts
212 |
213 | def to_torch(feature, time, edge_list, graph):
214 | '''
215 | Transform a sampled sub-graph into pytorch Tensor
216 | node_dict: {node_type: } node_number is used to trace back the nodes in original graph.
217 | edge_dict: {edge_type: edge_type_ID}
218 | '''
219 | node_dict = {}
220 | node_feature = []
221 | node_type = []
222 | node_time = []
223 | edge_index = []
224 | edge_type = []
225 | edge_time = []
226 |
227 | node_num = 0
228 | types = graph.get_types()
229 | for t in types:
230 | node_dict[t] = [node_num, len(node_dict)]
231 | node_num += len(feature[t])
232 |
233 | if 'fake_paper' in feature:
234 | node_dict['fake_paper'] = [node_num, node_dict['paper'][1]]
235 | node_num += len(feature['fake_paper'])
236 | types += ['fake_paper']
237 |
238 | for t in types:
239 | node_feature += list(feature[t])
240 | node_time += list(time[t])
241 | node_type += [node_dict[t][1] for _ in range(len(feature[t]))]
242 |
243 | edge_dict = {e[2]: i for i, e in enumerate(graph.get_meta_graph())}
244 | edge_dict['self'] = len(edge_dict)
245 |
246 | for target_type in edge_list:
247 | for source_type in edge_list[target_type]:
248 | for relation_type in edge_list[target_type][source_type]:
249 | for ii, (ti, si) in enumerate(edge_list[target_type][source_type][relation_type]):
250 | tid, sid = ti + node_dict[target_type][0], si + node_dict[source_type][0]
251 | edge_index += [[sid, tid]]
252 | edge_type += [edge_dict[relation_type]]
253 | '''
254 | Our time ranges from 1900 - 2020, largest span is 120.
255 | '''
256 | edge_time += [node_time[tid] - node_time[sid] + 120]
257 | node_feature = torch.FloatTensor(node_feature)
258 | node_type = torch.LongTensor(node_type)
259 | edge_time = torch.LongTensor(edge_time)
260 | edge_index = torch.LongTensor(edge_index).t()
261 | edge_type = torch.LongTensor(edge_type)
262 | return node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict
263 |
264 |
--------------------------------------------------------------------------------
/example_reddit/GPT_GNN/data.py:
--------------------------------------------------------------------------------
1 | import json, os
2 | import math, copy, time
3 | import numpy as np
4 | from collections import defaultdict
5 | import pandas as pd
6 |
7 | import math
8 | from tqdm import tqdm
9 |
10 | import seaborn as sb
11 | import matplotlib.pyplot as plt
12 | import matplotlib.cm as cm
13 |
14 | from .utils import *
15 |
16 | import dill
17 | from functools import partial
18 | import multiprocessing as mp
19 |
20 | class Graph():
21 | def __init__(self):
22 | super(Graph, self).__init__()
23 | '''
24 | node_forward and bacward are only used when building the data.
25 | Afterwards will be transformed into node_feature by DataFrame
26 |
27 | node_forward: name -> node_id
28 | node_bacward: node_id -> feature_dict
29 | node_feature: a DataFrame containing all features
30 | '''
31 | self.node_forward = defaultdict(lambda: {})
32 | self.node_bacward = defaultdict(lambda: [])
33 | self.node_feature = defaultdict(lambda: [])
34 |
35 | '''
36 | edge_list: index the adjacancy matrix (time) by
37 |
38 | '''
39 | self.edge_list = defaultdict( #target_type
40 | lambda: defaultdict( #source_type
41 | lambda: defaultdict( #relation_type
42 | lambda: defaultdict( #target_id
43 | lambda: defaultdict( #source_id(
44 | lambda: int # time
45 | )))))
46 | self.times = {}
47 | def add_node(self, node):
48 | nfl = self.node_forward[node['type']]
49 | if node['id'] not in nfl:
50 | self.node_bacward[node['type']] += [node]
51 | ser = len(nfl)
52 | nfl[node['id']] = ser
53 | return ser
54 | return nfl[node['id']]
55 | def add_edge(self, source_node, target_node, time = None, relation_type = None, directed = True):
56 | edge = [self.add_node(source_node), self.add_node(target_node)]
57 | '''
58 | Add bi-directional edges with different relation type
59 | '''
60 | self.edge_list[target_node['type']][source_node['type']][relation_type][edge[1]][edge[0]] = time
61 | if directed:
62 | self.edge_list[source_node['type']][target_node['type']]['rev_' + relation_type][edge[0]][edge[1]] = time
63 | else:
64 | self.edge_list[source_node['type']][target_node['type']][relation_type][edge[0]][edge[1]] = time
65 | self.times[time] = True
66 |
67 | def update_node(self, node):
68 | nbl = self.node_bacward[node['type']]
69 | ser = self.add_node(node)
70 | for k in node:
71 | if k not in nbl[ser]:
72 | nbl[ser][k] = node[k]
73 |
74 | def get_meta_graph(self):
75 | types = self.get_types()
76 | metas = []
77 | for target_type in self.edge_list:
78 | for source_type in self.edge_list[target_type]:
79 | for r_type in self.edge_list[target_type][source_type]:
80 | metas += [(target_type, source_type, r_type)]
81 | return metas
82 |
83 | def get_types(self):
84 | return list(self.node_feature.keys())
85 |
86 |
87 |
88 | def sample_subgraph(graph, time_range, sampled_depth = 2, sampled_number = 8, inp = None, feature_extractor = feature_OAG):
89 | '''
90 | Sample Sub-Graph based on the connection of other nodes with currently sampled nodes
91 | We maintain budgets for each node type, indexed by .
92 | Currently sampled nodes are stored in layer_data.
93 | After nodes are sampled, we construct the sampled adjacancy matrix.
94 | '''
95 | layer_data = defaultdict( #target_type
96 | lambda: {} # {target_id: [ser, time]}
97 | )
98 | budget = defaultdict( #source_type
99 | lambda: defaultdict( #source_id
100 | lambda: [0., 0] #[sampled_score, time]
101 | ))
102 | new_layer_adj = defaultdict( #target_type
103 | lambda: defaultdict( #source_type
104 | lambda: defaultdict( #relation_type
105 | lambda: [] #[target_id, source_id]
106 | )))
107 | '''
108 | For each node being sampled, we find out all its neighborhood,
109 | adding the degree count of these nodes in the budget.
110 | Note that there exist some nodes that have many neighborhoods
111 | (such as fields, venues), for those case, we only consider
112 | '''
113 | def add_budget(te, target_id, target_time, layer_data, budget):
114 | for source_type in te:
115 | tes = te[source_type]
116 | for relation_type in tes:
117 | if relation_type == 'self' or target_id not in tes[relation_type]:
118 | continue
119 | adl = tes[relation_type][target_id]
120 | if len(adl) < sampled_number:
121 | sampled_ids = list(adl.keys())
122 | else:
123 | sampled_ids = np.random.choice(list(adl.keys()), sampled_number, replace = False)
124 | for source_id in sampled_ids:
125 | source_time = adl[source_id]
126 | if source_time == None:
127 | source_time = target_time
128 | if source_time > np.max(list(time_range.keys())) or source_id in layer_data[source_type]:
129 | continue
130 | budget[source_type][source_id][0] += 1. / len(sampled_ids)
131 | budget[source_type][source_id][1] = source_time
132 |
133 | '''
134 | First adding the sampled nodes then updating budget.
135 | '''
136 | for _type in inp:
137 | for _id, _time in inp[_type]:
138 | layer_data[_type][_id] = [len(layer_data[_type]), _time]
139 | for _type in inp:
140 | te = graph.edge_list[_type]
141 | for _id, _time in inp[_type]:
142 | add_budget(te, _id, _time, layer_data, budget)
143 | '''
144 | We recursively expand the sampled graph by sampled_depth.
145 | Each time we sample a fixed number of nodes for each budget,
146 | based on the accumulated degree.
147 | '''
148 | for layer in range(sampled_depth):
149 | sts = list(budget.keys())
150 | for source_type in sts:
151 | te = graph.edge_list[source_type]
152 | keys = np.array(list(budget[source_type].keys()))
153 | if sampled_number > len(keys):
154 | '''
155 | Directly sample all the nodes
156 | '''
157 | sampled_ids = np.arange(len(keys))
158 | else:
159 | '''
160 | Sample based on accumulated degree
161 | '''
162 | score = np.array(list(budget[source_type].values()))[:,0] ** 2
163 | score = score / np.sum(score)
164 | sampled_ids = np.random.choice(len(score), sampled_number, p = score, replace = False)
165 | sampled_keys = keys[sampled_ids]
166 | '''
167 | First adding the sampled nodes then updating budget.
168 | '''
169 | for k in sampled_keys:
170 | layer_data[source_type][k] = [len(layer_data[source_type]), budget[source_type][k][1]]
171 | for k in sampled_keys:
172 | add_budget(te, k, budget[source_type][k][1], layer_data, budget)
173 | budget[source_type].pop(k)
174 | '''
175 | Prepare feature, time and adjacency matrix for the sampled graph
176 | '''
177 | feature, times, indxs, texts = feature_extractor(layer_data, graph)
178 |
179 | edge_list = defaultdict( #target_type
180 | lambda: defaultdict( #source_type
181 | lambda: defaultdict( #relation_type
182 | lambda: [] # [target_id, source_id]
183 | )))
184 | for _type in layer_data:
185 | for _key in layer_data[_type]:
186 | _ser = layer_data[_type][_key][0]
187 | edge_list[_type][_type]['self'] += [[_ser, _ser]]
188 | '''
189 | Reconstruct sampled adjacancy matrix by checking whether each
190 | link exist in the original graph
191 | '''
192 | for target_type in graph.edge_list:
193 | te = graph.edge_list[target_type]
194 | tld = layer_data[target_type]
195 | for source_type in te:
196 | tes = te[source_type]
197 | sld = layer_data[source_type]
198 | for relation_type in tes:
199 | tesr = tes[relation_type]
200 | for target_key in tld:
201 | if target_key not in tesr:
202 | continue
203 | target_ser = tld[target_key][0]
204 | for source_key in tesr[target_key]:
205 | '''
206 | Check whether each link (target_id, source_id) exist in original adjacancy matrix
207 | '''
208 | if source_key in sld:
209 | source_ser = sld[source_key][0]
210 | edge_list[target_type][source_type][relation_type] += [[target_ser, source_ser]]
211 | return feature, times, edge_list, indxs, texts
212 |
213 | def to_torch(feature, time, edge_list, graph):
214 | '''
215 | Transform a sampled sub-graph into pytorch Tensor
216 | node_dict: {node_type: } node_number is used to trace back the nodes in original graph.
217 | edge_dict: {edge_type: edge_type_ID}
218 | '''
219 | node_dict = {}
220 | node_feature = []
221 | node_type = []
222 | node_time = []
223 | edge_index = []
224 | edge_type = []
225 | edge_time = []
226 |
227 | node_num = 0
228 | types = graph.get_types()
229 | for t in types:
230 | node_dict[t] = [node_num, len(node_dict)]
231 | node_num += len(feature[t])
232 |
233 | if 'fake_paper' in feature:
234 | node_dict['fake_paper'] = [node_num, node_dict['paper'][1]]
235 | node_num += len(feature['fake_paper'])
236 | types += ['fake_paper']
237 |
238 | for t in types:
239 | node_feature += list(feature[t])
240 | node_time += list(time[t])
241 | node_type += [node_dict[t][1] for _ in range(len(feature[t]))]
242 |
243 | edge_dict = {e[2]: i for i, e in enumerate(graph.get_meta_graph())}
244 | edge_dict['self'] = len(edge_dict)
245 |
246 | for target_type in edge_list:
247 | for source_type in edge_list[target_type]:
248 | for relation_type in edge_list[target_type][source_type]:
249 | for ii, (ti, si) in enumerate(edge_list[target_type][source_type][relation_type]):
250 | tid, sid = ti + node_dict[target_type][0], si + node_dict[source_type][0]
251 | edge_index += [[sid, tid]]
252 | edge_type += [edge_dict[relation_type]]
253 | '''
254 | Our time ranges from 1900 - 2020, largest span is 120.
255 | '''
256 | edge_time += [node_time[tid] - node_time[sid] + 120]
257 | node_feature = torch.FloatTensor(node_feature)
258 | node_type = torch.LongTensor(node_type)
259 | edge_time = torch.LongTensor(edge_time)
260 | edge_index = torch.LongTensor(edge_index).t()
261 | edge_type = torch.LongTensor(edge_type)
262 | return node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict
263 |
264 |
--------------------------------------------------------------------------------
/example_OAG/GPT_GNN/data.py:
--------------------------------------------------------------------------------
1 | import json, os
2 | import math, copy, time
3 | import numpy as np
4 | from collections import defaultdict
5 | import pandas as pd
6 |
7 | import math
8 | from tqdm import tqdm
9 |
10 | import seaborn as sb
11 | import matplotlib.pyplot as plt
12 | import matplotlib.cm as cm
13 |
14 | from .utils import *
15 |
16 | import dill
17 | from functools import partial
18 | import multiprocessing as mp
19 |
20 | class Graph():
21 | def __init__(self):
22 | super(Graph, self).__init__()
23 | '''
24 | node_forward and bacward are only used when building the data.
25 | Afterwards will be transformed into node_feature by DataFrame
26 |
27 | node_forward: name -> node_id
28 | node_bacward: node_id -> feature_dict
29 | node_feature: a DataFrame containing all features
30 | '''
31 | self.node_forward = defaultdict(lambda: {})
32 | self.node_bacward = defaultdict(lambda: [])
33 | self.node_feature = defaultdict(lambda: [])
34 |
35 | '''
36 | edge_list: index the adjacancy matrix (time) by
37 |
38 | '''
39 | self.edge_list = defaultdict( #target_type
40 | lambda: defaultdict( #source_type
41 | lambda: defaultdict( #relation_type
42 | lambda: defaultdict( #target_id
43 | lambda: defaultdict( #source_id(
44 | lambda: int # time
45 | )))))
46 | self.times = {}
47 | def add_node(self, node):
48 | nfl = self.node_forward[node['type']]
49 | if node['id'] not in nfl:
50 | self.node_bacward[node['type']] += [node]
51 | ser = len(nfl)
52 | nfl[node['id']] = ser
53 | return ser
54 | return nfl[node['id']]
55 | def add_edge(self, source_node, target_node, time = None, relation_type = None, directed = True):
56 | edge = [self.add_node(source_node), self.add_node(target_node)]
57 | '''
58 | Add bi-directional edges with different relation type
59 | '''
60 | self.edge_list[target_node['type']][source_node['type']][relation_type][edge[1]][edge[0]] = time
61 | if directed:
62 | self.edge_list[source_node['type']][target_node['type']]['rev_' + relation_type][edge[0]][edge[1]] = time
63 | else:
64 | self.edge_list[source_node['type']][target_node['type']][relation_type][edge[0]][edge[1]] = time
65 | self.times[time] = True
66 |
67 | def update_node(self, node):
68 | nbl = self.node_bacward[node['type']]
69 | ser = self.add_node(node)
70 | for k in node:
71 | if k not in nbl[ser]:
72 | nbl[ser][k] = node[k]
73 |
74 | def get_meta_graph(self):
75 | types = self.get_types()
76 | metas = []
77 | for target_type in self.edge_list:
78 | for source_type in self.edge_list[target_type]:
79 | for r_type in self.edge_list[target_type][source_type]:
80 | metas += [(target_type, source_type, r_type)]
81 | return metas
82 |
83 | def get_types(self):
84 | return list(self.node_feature.keys())
85 |
86 |
87 |
88 | def sample_subgraph(graph, time_range, sampled_depth = 2, sampled_number = 8, inp = None, feature_extractor = feature_OAG):
89 | '''
90 | Sample Sub-Graph based on the connection of other nodes with currently sampled nodes
91 | We maintain budgets for each node type, indexed by .
92 | Currently sampled nodes are stored in layer_data.
93 | After nodes are sampled, we construct the sampled adjacancy matrix.
94 | '''
95 | layer_data = defaultdict( #target_type
96 | lambda: {} # {target_id: [ser, time]}
97 | )
98 | budget = defaultdict( #source_type
99 | lambda: defaultdict( #source_id
100 | lambda: [0., 0] #[sampled_score, time]
101 | ))
102 | new_layer_adj = defaultdict( #target_type
103 | lambda: defaultdict( #source_type
104 | lambda: defaultdict( #relation_type
105 | lambda: [] #[target_id, source_id]
106 | )))
107 | '''
108 | For each node being sampled, we find out all its neighborhood,
109 | adding the degree count of these nodes in the budget.
110 | Note that there exist some nodes that have many neighborhoods
111 | (such as fields, venues), for those case, we only consider
112 | '''
113 | def add_budget(te, target_id, target_time, layer_data, budget):
114 | for source_type in te:
115 | tes = te[source_type]
116 | for relation_type in tes:
117 | if relation_type == 'self' or target_id not in tes[relation_type]:
118 | continue
119 | adl = tes[relation_type][target_id]
120 | if len(adl) < sampled_number:
121 | sampled_ids = list(adl.keys())
122 | else:
123 | sampled_ids = np.random.choice(list(adl.keys()), sampled_number, replace = False)
124 | for source_id in sampled_ids:
125 | source_time = adl[source_id]
126 | if source_time == None:
127 | source_time = target_time
128 | if source_time > np.max(list(time_range.keys())) or source_id in layer_data[source_type]:
129 | continue
130 | budget[source_type][source_id][0] += 1. / len(sampled_ids)
131 | budget[source_type][source_id][1] = source_time
132 |
133 | '''
134 | First adding the sampled nodes then updating budget.
135 | '''
136 | for _type in inp:
137 | for _id, _time in inp[_type]:
138 | layer_data[_type][_id] = [len(layer_data[_type]), _time]
139 | for _type in inp:
140 | te = graph.edge_list[_type]
141 | for _id, _time in inp[_type]:
142 | add_budget(te, _id, _time, layer_data, budget)
143 | '''
144 | We recursively expand the sampled graph by sampled_depth.
145 | Each time we sample a fixed number of nodes for each budget,
146 | based on the accumulated degree.
147 | '''
148 | for layer in range(sampled_depth):
149 | sts = list(budget.keys())
150 | for source_type in sts:
151 | te = graph.edge_list[source_type]
152 | keys = np.array(list(budget[source_type].keys()))
153 | if sampled_number > len(keys):
154 | '''
155 | Directly sample all the nodes
156 | '''
157 | sampled_ids = np.arange(len(keys))
158 | else:
159 | '''
160 | Sample based on accumulated degree
161 | '''
162 | score = np.array(list(budget[source_type].values()))[:,0] ** 2
163 | score = score / np.sum(score)
164 | sampled_ids = np.random.choice(len(score), sampled_number, p = score, replace = False)
165 | sampled_keys = keys[sampled_ids]
166 | '''
167 | First adding the sampled nodes then updating budget.
168 | '''
169 | for k in sampled_keys:
170 | layer_data[source_type][k] = [len(layer_data[source_type]), budget[source_type][k][1]]
171 | for k in sampled_keys:
172 | add_budget(te, k, budget[source_type][k][1], layer_data, budget)
173 | budget[source_type].pop(k)
174 | '''
175 | Prepare feature, time and adjacency matrix for the sampled graph
176 | '''
177 | feature, times, indxs, texts = feature_extractor(layer_data, graph)
178 |
179 | edge_list = defaultdict( #target_type
180 | lambda: defaultdict( #source_type
181 | lambda: defaultdict( #relation_type
182 | lambda: [] # [target_id, source_id]
183 | )))
184 | for _type in layer_data:
185 | for _key in layer_data[_type]:
186 | _ser = layer_data[_type][_key][0]
187 | edge_list[_type][_type]['self'] += [[_ser, _ser]]
188 | '''
189 | Reconstruct sampled adjacancy matrix by checking whether each
190 | link exist in the original graph
191 | '''
192 | for target_type in graph.edge_list:
193 | te = graph.edge_list[target_type]
194 | tld = layer_data[target_type]
195 | for source_type in te:
196 | tes = te[source_type]
197 | sld = layer_data[source_type]
198 | for relation_type in tes:
199 | tesr = tes[relation_type]
200 | for target_key in tld:
201 | if target_key not in tesr:
202 | continue
203 | target_ser = tld[target_key][0]
204 | for source_key in tesr[target_key]:
205 | '''
206 | Check whether each link (target_id, source_id) exist in original adjacancy matrix
207 | '''
208 | if source_key in sld:
209 | source_ser = sld[source_key][0]
210 | edge_list[target_type][source_type][relation_type] += [[target_ser, source_ser]]
211 | return feature, times, edge_list, indxs, texts
212 |
213 | def to_torch(feature, time, edge_list, graph):
214 | '''
215 | Transform a sampled sub-graph into pytorch Tensor
216 | node_dict: {node_type: } node_number is used to trace back the nodes in original graph.
217 | edge_dict: {edge_type: edge_type_ID}
218 | '''
219 | node_dict = {}
220 | node_feature = []
221 | node_type = []
222 | node_time = []
223 | edge_index = []
224 | edge_type = []
225 | edge_time = []
226 |
227 | node_num = 0
228 | types = graph.get_types()
229 | for t in types:
230 | node_dict[t] = [node_num, len(node_dict)]
231 | node_num += len(feature[t])
232 |
233 | if 'fake_paper' in feature:
234 | node_dict['fake_paper'] = [node_num, node_dict['paper'][1]]
235 | node_num += len(feature['fake_paper'])
236 | types += ['fake_paper']
237 |
238 | for t in types:
239 | node_feature += list(feature[t])
240 | node_time += list(time[t])
241 | node_type += [node_dict[t][1] for _ in range(len(feature[t]))]
242 |
243 | edge_dict = {e[2]: i for i, e in enumerate(graph.get_meta_graph())}
244 | edge_dict['self'] = len(edge_dict)
245 |
246 | for target_type in edge_list:
247 | for source_type in edge_list[target_type]:
248 | for relation_type in edge_list[target_type][source_type]:
249 | for ii, (ti, si) in enumerate(edge_list[target_type][source_type][relation_type]):
250 | tid, sid = ti + node_dict[target_type][0], si + node_dict[source_type][0]
251 | edge_index += [[sid, tid]]
252 | edge_type += [edge_dict[relation_type]]
253 | '''
254 | Our time ranges from 1900 - 2020, largest span is 120.
255 | '''
256 | edge_time += [node_time[tid] - node_time[sid] + 120]
257 | node_feature = torch.FloatTensor(node_feature)
258 | node_type = torch.LongTensor(node_type)
259 | edge_time = torch.LongTensor(edge_time)
260 | edge_index = torch.LongTensor(edge_index).t()
261 | edge_type = torch.LongTensor(edge_type)
262 | return node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict
263 |
264 |
265 | class RenameUnpickler(dill.Unpickler):
266 | def find_class(self, module, name):
267 | renamed_module = module
268 | if module == "pyHGT.data" or module == 'data':
269 | renamed_module = "GPT_GNN.data"
270 | return super(RenameUnpickler, self).find_class(renamed_module, name)
271 |
272 |
273 | def renamed_load(file_obj):
274 | return RenameUnpickler(file_obj).load()
275 |
--------------------------------------------------------------------------------
/example_OAG/preprocess_OAG.py:
--------------------------------------------------------------------------------
1 | from transformers import *
2 |
3 | from data import *
4 | import gensim
5 | from gensim.models import Word2Vec
6 | from tqdm import tqdm
7 | # from tqdm import tqdm_notebook as tqdm # Comment this line if using jupyter notebook
8 |
9 |
10 | import argparse
11 |
12 | parser = argparse.ArgumentParser(description='Preprocess OAG (CS/Med/All) Data')
13 |
14 | '''
15 | Dataset arguments
16 | '''
17 | parser.add_argument('--input_dir', type=str, default='./data/oag_raw',
18 | help='The address to store the original data directory.')
19 | parser.add_argument('--output_dir', type=str, default='./data/oag_output',
20 | help='The address to output the preprocessed graph.')
21 | parser.add_argument('--cuda', type=int, default=0,
22 | help='Avaiable GPU ID')
23 | parser.add_argument('--domain', type=str, default='_CS',
24 | help='CS, Medical or All: _CS or _Med or (empty)')
25 | parser.add_argument('--citation_bar', type=int, default=1,
26 | help='Only consider papers with citation larger than (2020 - year) * citation_bar')
27 |
28 | args = parser.parse_args()
29 |
30 |
31 | test_time_bar = 2016
32 |
33 | cite_dict = defaultdict(lambda: 0)
34 | with open(args.input_dir + '/PR%s_20190919.tsv' % args.domain) as fin:
35 | fin.readline()
36 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PR%s_20190919.tsv' % args.domain))):
37 | l = l[:-1].split('\t')
38 | cite_dict[l[1]] += 1
39 |
40 |
41 | pfl = defaultdict(lambda: {})
42 | with open(args.input_dir + '/Papers%s_20190919.tsv' % args.domain) as fin:
43 | fin.readline()
44 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/Papers%s_20190919.tsv' % args.domain))):
45 | l = l[:-1].split('\t')
46 | bound = min(2020 - int(l[1]), 20) * args.citation_bar
47 | if cite_dict[l[0]] < bound or l[0] == '' or l[1] == '' or l[2] == '' or l[3] == '' and l[4] == '' or int(l[1]) < 1900:
48 | continue
49 | pi = {'id': l[0], 'title': l[2], 'type': 'paper', 'time': int(l[1])}
50 | pfl[l[0]] = pi
51 |
52 |
53 | if args.cuda != -1:
54 | device = torch.device("cuda:" + str(args.cuda))
55 | else:
56 | device = torch.device("cpu")
57 |
58 | tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
59 | model = XLNetModel.from_pretrained('xlnet-base-cased',
60 | output_hidden_states=True,
61 | output_attentions=True).to(device)
62 |
63 |
64 | with open(args.input_dir + '/PAb%s_20190919.tsv' % args.domain) as fin:
65 | fin.readline()
66 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PAb%s_20190919.tsv' % args.domain, 'r'))):
67 | try:
68 | l = l.split('\t')
69 | if l[0] in pfl:
70 | input_ids = torch.tensor([tokenizer.encode(pfl[l[0]]['title'])]).to(device)[:, :64]
71 | if len(input_ids[0]) < 4:
72 | continue
73 | all_hidden_states, all_attentions = model(input_ids)[-2:]
74 | rep = (all_hidden_states[-2][0] * all_attentions[-2][0].mean(dim=0).mean(dim=0).view(-1, 1)).sum(dim=0)
75 | pfl[l[0]]['emb'] = rep.tolist()
76 | except Exception as e:
77 | print(e)
78 |
79 |
80 |
81 | vfi_ids = {}
82 | with open(args.input_dir + '/vfi_vector.tsv') as fin:
83 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/vfi_vector.tsv'))):
84 | l = l[:-1].split('\t')
85 | vfi_ids[l[0]] = True
86 |
87 |
88 | graph = Graph()
89 | rem = []
90 | with open(args.input_dir + '/Papers%s_20190919.tsv' % args.domain) as fin:
91 | fin.readline()
92 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/Papers%s_20190919.tsv' % args.domain, 'r'))):
93 | l = l[:-1].split('\t')
94 | if l[0] not in pfl or l[4] != 'en' or 'emb' not in pfl[l[0]] or l[3] not in vfi_ids:
95 | continue
96 | rem += [l[0]]
97 | vi = {'id': l[3], 'type': 'venue', 'attr': l[-2]}
98 | graph.add_edge(pfl[l[0]], vi, time = int(l[1]), relation_type = 'PV_' + l[-2])
99 | pfl = {i: pfl[i] for i in rem}
100 | print(len(pfl))
101 |
102 |
103 | with open(args.input_dir + '/PR%s_20190919.tsv' % args.domain) as fin:
104 | fin.readline()
105 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PR%s_20190919.tsv' % args.domain))):
106 | l = l[:-1].split('\t')
107 | if l[0] in pfl and l[1] in pfl:
108 | p1 = pfl[l[0]]
109 | p2 = pfl[l[1]]
110 | if p1['time'] >= p2['time']:
111 | graph.add_edge(p1, p2, time = p1['time'], relation_type = 'PP_cite')
112 |
113 |
114 |
115 | ffl = {}
116 | with open(args.input_dir + '/PF%s_20190919.tsv' % args.domain) as fin:
117 | fin.readline()
118 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PF%s_20190919.tsv' % args.domain))):
119 | l = l[:-1].split('\t')
120 | if l[0] in pfl and l[1] in vfi_ids:
121 | ffl[l[1]] = True
122 |
123 |
124 |
125 |
126 | with open(args.input_dir + '/FHierarchy_20190919.tsv') as fin:
127 | fin.readline()
128 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/FHierarchy_20190919.tsv'))):
129 | l = l[:-1].split('\t')
130 | if l[0] in ffl and l[1] in ffl:
131 | fi = {'id': l[0], 'type': 'field', 'attr': l[2]}
132 | fj = {'id': l[1], 'type': 'field', 'attr': l[3]}
133 | graph.add_edge(fi, fj, relation_type = 'FF_in')
134 | ffl[l[0]] = fi
135 | ffl[l[1]] = fj
136 |
137 |
138 |
139 |
140 | with open(args.input_dir + '/PF%s_20190919.tsv' % args.domain) as fin:
141 | fin.readline()
142 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PF%s_20190919.tsv' % args.domain))):
143 | l = l[:-1].split('\t')
144 | if l[0] in pfl and l[1] in ffl and type(ffl[l[1]]) == dict:
145 | pi = pfl[l[0]]
146 | fi = ffl[l[1]]
147 | graph.add_edge(pi, fi, time = pi['time'], relation_type = 'PF_in_' + fi['attr'])
148 |
149 |
150 |
151 |
152 | coa = defaultdict(lambda: {})
153 | with open(args.input_dir + '/PAuAf%s_20190919.tsv' % args.domain) as fin:
154 | fin.readline()
155 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PAuAf%s_20190919.tsv' % args.domain))):
156 | l = l[:-1].split('\t')
157 | if l[0] in pfl and l[2] in vfi_ids:
158 | pi = pfl[l[0]]
159 | ai = {'id': l[1], 'type': 'author'}
160 | fi = {'id': l[2], 'type': 'affiliation'}
161 | coa[l[0]][int(l[-1])] = ai
162 | graph.add_edge(ai, fi, relation_type = 'in')
163 |
164 | for pid in tqdm(coa):
165 | pi = pfl[pid]
166 | max_seq = max(coa[pid].keys())
167 | for seq_i in coa[pid]:
168 | ai = coa[pid][seq_i]
169 | if seq_i == 1:
170 | graph.add_edge(ai, pi, time = pi['time'], relation_type = 'AP_write_first')
171 | elif seq_i == max_seq:
172 | graph.add_edge(ai, pi, time = pi['time'], relation_type = 'AP_write_last')
173 | else:
174 | graph.add_edge(ai, pi, time = pi['time'], relation_type = 'AP_write_other')
175 |
176 |
177 |
178 |
179 | with open(args.input_dir + '/vfi_vector.tsv') as fin:
180 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/vfi_vector.tsv'))):
181 | l = l[:-1].split('\t')
182 | ser = l[0]
183 | for idx in ['venue', 'field', 'affiliation']:
184 | if ser in graph.node_forward[idx]:
185 | graph.node_bacward[idx][graph.node_forward[idx][ser]]['node_emb'] = np.array(l[1].split(' '))
186 |
187 |
188 |
189 |
190 | with open(args.input_dir + '/SeqName%s_20190919.tsv' % args.domain) as fin:
191 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/SeqName%s_20190919.tsv' % args.domain))):
192 | l = l[:-1].split('\t')
193 | key = l[2]
194 | if key in ['conference', 'journal', 'repository', 'patent']:
195 | key = 'venue'
196 | if key == 'fos':
197 | key = 'field'
198 | if l[0] in graph.node_forward[key]:
199 | s = graph.node_bacward[key][graph.node_forward[key][l[0]]]
200 | s['name'] = l[1]
201 |
202 | '''
203 | Calculate the total citation information as node attributes.
204 | '''
205 |
206 | for idx, pi in enumerate(graph.node_bacward['paper']):
207 | pi['citation'] = len(graph.edge_list['paper']['paper']['PP_cite'][idx])
208 | for idx, ai in enumerate(graph.node_bacward['author']):
209 | citation = 0
210 | for rel in graph.edge_list['author']['paper'].keys():
211 | for pid in graph.edge_list['author']['paper'][rel][idx]:
212 | citation += graph.node_bacward['paper'][pid]['citation']
213 | ai['citation'] = citation
214 | for idx, fi in enumerate(graph.node_bacward['affiliation']):
215 | citation = 0
216 | for aid in graph.edge_list['affiliation']['author']['in'][idx]:
217 | citation += graph.node_bacward['author'][aid]['citation']
218 | fi['citation'] = citation
219 | for idx, vi in enumerate(graph.node_bacward['venue']):
220 | citation = 0
221 | for rel in graph.edge_list['venue']['paper'].keys():
222 | for pid in graph.edge_list['venue']['paper'][rel][idx]:
223 | citation += graph.node_bacward['paper'][pid]['citation']
224 | vi['citation'] = citation
225 | for idx, fi in enumerate(graph.node_bacward['field']):
226 | citation = 0
227 | for rel in graph.edge_list['field']['paper'].keys():
228 | for pid in graph.edge_list['field']['paper'][rel][idx]:
229 | citation += graph.node_bacward['paper'][pid]['citation']
230 | fi['citation'] = citation
231 |
232 |
233 |
234 |
235 | '''
236 | Since only paper have w2v embedding, we simply propagate its
237 | feature to other nodes by averaging neighborhoods.
238 | Then, we construct the Dataframe for each node type.
239 | '''
240 | d = pd.DataFrame(graph.node_bacward['paper'])
241 | graph.node_feature = {'paper': d}
242 | cv = np.array(list(d['emb']))
243 | for _type in graph.node_bacward:
244 | if _type not in ['paper', 'affiliation']:
245 | d = pd.DataFrame(graph.node_bacward[_type])
246 | i = []
247 | for _rel in graph.edge_list[_type]['paper']:
248 | for t in graph.edge_list[_type]['paper'][_rel]:
249 | for s in graph.edge_list[_type]['paper'][_rel][t]:
250 | if graph.edge_list[_type]['paper'][_rel][t][s] <= test_time_bar:
251 | i += [[t, s]]
252 | if len(i) == 0:
253 | continue
254 | i = np.array(i).T
255 | v = np.ones(i.shape[1])
256 | m = normalize(sp.coo_matrix((v, i), \
257 | shape=(len(graph.node_bacward[_type]), len(graph.node_bacward['paper']))))
258 | out = m.dot(cv)
259 | d['emb'] = list(out)
260 | graph.node_feature[_type] = d
261 | '''
262 | Affiliation is not directly linked with Paper, so we average the author embedding.
263 | '''
264 | cv = np.array(list(graph.node_feature['author']['emb']))
265 | d = pd.DataFrame(graph.node_bacward['affiliation'])
266 | i = []
267 | for _rel in graph.edge_list['affiliation']['author']:
268 | for j in graph.edge_list['affiliation']['author'][_rel]:
269 | for t in graph.edge_list['affiliation']['author'][_rel][j]:
270 | i += [[j, t]]
271 | i = np.array(i).T
272 | v = np.ones(i.shape[1])
273 | m = normalize(sp.coo_matrix((v, i), \
274 | shape=(len(graph.node_bacward['affiliation']), len(graph.node_bacward['author']))))
275 | out = m.dot(cv)
276 | d['emb'] = list(out)
277 | graph.node_feature['affiliation'] = d
278 |
279 |
280 | edg = {}
281 | for k1 in graph.edge_list:
282 | if k1 not in edg:
283 | edg[k1] = {}
284 | for k2 in graph.edge_list[k1]:
285 | if k2 not in edg[k1]:
286 | edg[k1][k2] = {}
287 | for k3 in graph.edge_list[k1][k2]:
288 | if k3 not in edg[k1][k2]:
289 | edg[k1][k2][k3] = {}
290 | for e1 in graph.edge_list[k1][k2][k3]:
291 | if len(graph.edge_list[k1][k2][k3][e1]) == 0:
292 | continue
293 | edg[k1][k2][k3][e1] = {}
294 | for e2 in graph.edge_list[k1][k2][k3][e1]:
295 | edg[k1][k2][k3][e1][e2] = graph.edge_list[k1][k2][k3][e1][e2]
296 | print(k1, k2, k3, len(edg[k1][k2][k3]))
297 | graph.edge_list = edg
298 |
299 |
300 | del graph.node_bacward
301 | dill.dump(graph, open(args.output_dir + '/graph%s.pk' % args.domain, 'wb'))
302 |
303 |
304 |
305 |
--------------------------------------------------------------------------------
/example_OAG/finetune_OAG_PV.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from GPT_GNN.data import *
3 | from GPT_GNN.model import *
4 | from warnings import filterwarnings
5 | filterwarnings("ignore")
6 |
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser(description='Fine-Tuning on Paper-Venue (Journal) classification task')
10 |
11 | '''
12 | Dataset arguments
13 | '''
14 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset',
15 | help='The address of preprocessed graph.')
16 | parser.add_argument('--use_pretrain', help='Whether to use pre-trained model', action='store_true')
17 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/gpt_all_cs',
18 | help='The address for pretrained model.')
19 | parser.add_argument('--model_dir', type=str, default='/datadrive/models',
20 | help='The address for storing the models and optimization results.')
21 | parser.add_argument('--task_name', type=str, default='PV',
22 | help='The name of the stored models and optimization results.')
23 | parser.add_argument('--cuda', type=int, default=2,
24 | help='Avaiable GPU ID')
25 | parser.add_argument('--domain', type=str, default='_CS',
26 | help='CS, Medicion or All: _CS or _Med or (empty)')
27 | parser.add_argument('--sample_depth', type=int, default=6,
28 | help='How many numbers to sample the graph')
29 | parser.add_argument('--sample_width', type=int, default=128,
30 | help='How many nodes to be sampled per layer per type')
31 |
32 | '''
33 | Model arguments
34 | '''
35 | parser.add_argument('--conv_name', type=str, default='hgt',
36 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'],
37 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)')
38 | parser.add_argument('--n_hid', type=int, default=400,
39 | help='Number of hidden dimension')
40 | parser.add_argument('--n_heads', type=int, default=8,
41 | help='Number of attention head')
42 | parser.add_argument('--n_layers', type=int, default=3,
43 | help='Number of GNN layers')
44 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true')
45 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true')
46 | parser.add_argument('--dropout', type=int, default=0.2,
47 | help='Dropout ratio')
48 |
49 | '''
50 | Optimization arguments
51 | '''
52 | parser.add_argument('--optimizer', type=str, default='adamw',
53 | choices=['adamw', 'adam', 'sgd', 'adagrad'],
54 | help='optimizer to use.')
55 | parser.add_argument('--scheduler', type=str, default='cycle',
56 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine'])
57 | parser.add_argument('--data_percentage', type=int, default=0.1,
58 | help='Percentage of training and validation data to use')
59 | parser.add_argument('--n_epoch', type=int, default=50,
60 | help='Number of epoch to run')
61 | parser.add_argument('--n_pool', type=int, default=8,
62 | help='Number of process to sample subgraph')
63 | parser.add_argument('--n_batch', type=int, default=16,
64 | help='Number of batch (sampled graphs) for each epoch')
65 | parser.add_argument('--batch_size', type=int, default=256,
66 | help='Number of output nodes for training')
67 | parser.add_argument('--clip', type=int, default=0.5,
68 | help='Gradient Norm Clipping')
69 |
70 | args = parser.parse_args()
71 | args_print(args)
72 |
73 | if args.cuda != -1:
74 | device = torch.device("cuda:" + str(args.cuda))
75 | else:
76 | device = torch.device("cpu")
77 |
78 | print('Start Loading Graph Data...')
79 | graph = renamed_load(open(os.path.join(args.data_dir, 'graph%s.pk' % args.domain), 'rb'))
80 | print('Finish Loading Graph Data!')
81 |
82 | target_type = 'paper'
83 |
84 | types = graph.get_types()
85 | '''
86 | cand_list stores all the Journal, which is the classification domain.
87 | '''
88 | cand_list = list(graph.edge_list['venue']['paper']['PV_Journal'].keys())
89 | '''
90 | Use CrossEntropy (log-softmax + NLL) here, since each paper can be associated with one venue.
91 | '''
92 | criterion = nn.NLLLoss()
93 |
94 | def node_classification_sample(seed, pairs, time_range):
95 | '''
96 | sub-graph sampling and label preparation for node classification:
97 | (1) Sample batch_size number of output nodes (papers) and their time.
98 | '''
99 | np.random.seed(seed)
100 | target_ids = np.random.choice(list(pairs.keys()), args.batch_size, replace = False)
101 | target_info = []
102 | for target_id in target_ids:
103 | _, _time = pairs[target_id]
104 | target_info += [[target_id, _time]]
105 |
106 | '''
107 | (2) Based on the seed nodes, sample a subgraph with 'sampled_depth' and 'sampled_number'
108 | '''
109 | feature, times, edge_list, _, _ = sample_subgraph(graph, time_range, \
110 | inp = {'paper': np.array(target_info)}, \
111 | sampled_depth = args.sample_depth, sampled_number = args.sample_width)
112 |
113 |
114 | '''
115 | (3) Mask out the edge between the output target nodes (paper) with output source nodes (Journal)
116 | '''
117 | masked_edge_list = []
118 | for i in edge_list['paper']['venue']['rev_PV_Journal']:
119 | if i[0] >= args.batch_size:
120 | masked_edge_list += [i]
121 | edge_list['paper']['venue']['rev_PV_Journal'] = masked_edge_list
122 |
123 | masked_edge_list = []
124 | for i in edge_list['venue']['paper']['PV_Journal']:
125 | if i[1] >= args.batch_size:
126 | masked_edge_list += [i]
127 | edge_list['venue']['paper']['PV_Journal'] = masked_edge_list
128 |
129 | '''
130 | (4) Transform the subgraph into torch Tensor (edge_index is in format of pytorch_geometric)
131 | '''
132 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
133 | to_torch(feature, times, edge_list, graph)
134 | '''
135 | (5) Prepare the labels for each output target node (paper), and their index in sampled graph.
136 | (node_dict[type][0] stores the start index of a specific type of nodes)
137 | '''
138 | ylabel = torch.zeros(args.batch_size, dtype = torch.long)
139 | for x_id, target_id in enumerate(target_ids):
140 | ylabel[x_id] = cand_list.index(pairs[target_id][0])
141 | x_ids = np.arange(args.batch_size) + node_dict['paper'][0]
142 | return node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel
143 |
144 | def prepare_data(pool):
145 | '''
146 | Sampled and prepare training and validation data using multi-process parallization.
147 | '''
148 | jobs = []
149 | for batch_id in np.arange(args.n_batch):
150 | p = pool.apply_async(node_classification_sample, args=(randint(), \
151 | sel_train_pairs, train_range))
152 | jobs.append(p)
153 | p = pool.apply_async(node_classification_sample, args=(randint(), \
154 | sel_valid_pairs, valid_range))
155 | jobs.append(p)
156 | return jobs
157 |
158 |
159 | train_pairs = {}
160 | valid_pairs = {}
161 | test_pairs = {}
162 | '''
163 | Prepare all the souce nodes (Journal) associated with each target node (paper) as dict
164 | '''
165 | for target_id in graph.edge_list['paper']['venue']['rev_PV_Journal']:
166 | for source_id in graph.edge_list['paper']['venue']['rev_PV_Journal'][target_id]:
167 | _time = graph.edge_list['paper']['venue']['rev_PV_Journal'][target_id][source_id]
168 | if _time in train_range:
169 | if target_id not in train_pairs:
170 | train_pairs[target_id] = [source_id, _time]
171 | elif _time in valid_range:
172 | if target_id not in valid_pairs:
173 | valid_pairs[target_id] = [source_id, _time]
174 | else:
175 | if target_id not in test_pairs:
176 | test_pairs[target_id] = [source_id, _time]
177 |
178 |
179 | np.random.seed(43)
180 | '''
181 | Only train and valid with a certain percentage of data, if necessary.
182 | '''
183 | sel_train_pairs = {p : train_pairs[p] for p in np.random.choice(list(train_pairs.keys()), int(len(train_pairs) * args.data_percentage), replace = False)}
184 | sel_valid_pairs = {p : valid_pairs[p] for p in np.random.choice(list(valid_pairs.keys()), int(len(valid_pairs) * args.data_percentage), replace = False)}
185 |
186 |
187 |
188 |
189 | '''
190 | Initialize GNN (model is specified by conv_name) and Classifier
191 | '''
192 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph.node_feature[target_type]['emb'].values[0]) + 401, n_hid = args.n_hid, \
193 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \
194 | num_relations = len(graph.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm)
195 | if args.use_pretrain:
196 | gnn.load_state_dict(load_gnn(torch.load(args.pretrain_model_dir)), strict = False)
197 | print('Load Pre-trained Model from (%s)' % args.pretrain_model_dir)
198 | classifier = Classifier(args.n_hid, len(cand_list)).to(device)
199 |
200 | model = nn.Sequential(gnn, classifier)
201 |
202 |
203 | optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-4)
204 |
205 | stats = []
206 | res = []
207 | best_val = 0
208 | train_step = 0
209 |
210 | pool = mp.Pool(args.n_pool)
211 | st = time.time()
212 | jobs = prepare_data(pool)
213 |
214 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6)
215 |
216 | for epoch in np.arange(args.n_epoch) + 1:
217 | '''
218 | Prepare Training and Validation Data
219 | '''
220 | train_data = [job.get() for job in jobs[:-1]]
221 | valid_data = jobs[-1].get()
222 | pool.close()
223 | pool.join()
224 | '''
225 | After the data is collected, close the pool and then reopen it.
226 | '''
227 | pool = mp.Pool(args.n_pool)
228 | jobs = prepare_data(pool)
229 | et = time.time()
230 | print('Data Preparation: %.1fs' % (et - st))
231 |
232 | '''
233 | Train (2014 <= time <= 2016)
234 | '''
235 | model.train()
236 | train_losses = []
237 | torch.cuda.empty_cache()
238 | for node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel in train_data:
239 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
240 | edge_time.to(device), edge_index.to(device), edge_type.to(device))
241 | res = classifier.forward(node_rep[x_ids])
242 | loss = criterion(res, ylabel.to(device))
243 |
244 | optimizer.zero_grad()
245 | torch.cuda.empty_cache()
246 | loss.backward()
247 |
248 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
249 | optimizer.step()
250 |
251 | train_losses += [loss.cpu().detach().tolist()]
252 | train_step += 1
253 | scheduler.step(train_step)
254 | del res, loss
255 | '''
256 | Valid (2017 <= time <= 2017)
257 | '''
258 | model.eval()
259 | with torch.no_grad():
260 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = valid_data
261 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
262 | edge_time.to(device), edge_index.to(device), edge_type.to(device))
263 | res = classifier.forward(node_rep[x_ids])
264 | loss = criterion(res, ylabel.to(device))
265 |
266 | '''
267 | Calculate Valid NDCG. Update the best model based on highest NDCG score.
268 | '''
269 | valid_res = []
270 | for ai, bi in zip(ylabel, res.argsort(descending = True)):
271 | valid_res += [(bi == ai).int().tolist()]
272 | valid_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in valid_res])
273 |
274 | if valid_ndcg > best_val:
275 | best_val = valid_ndcg
276 | torch.save(model, os.path.join(args.model_dir, args.task_name + '_' + args.conv_name))
277 | print('UPDATE!!!')
278 |
279 | st = time.time()
280 | print(("Epoch: %d (%.1fs) LR: %.5f Train Loss: %.2f Valid Loss: %.2f Valid NDCG: %.4f") % \
281 | (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), \
282 | loss.cpu().detach().tolist(), valid_ndcg))
283 | stats += [[np.average(train_losses), loss.cpu().detach().tolist()]]
284 | del res, loss
285 | del train_data, valid_data
286 |
287 |
288 | '''
289 | Evaluate the trained model via test set (time >= 2018)
290 | '''
291 |
292 | best_model = torch.load(os.path.join(args.model_dir, args.task_name + '_' + args.conv_name))
293 | best_model.eval()
294 | gnn, classifier = best_model
295 | with torch.no_grad():
296 | test_res = []
297 | for _ in range(10):
298 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = \
299 | node_classification_sample(randint(), test_pairs, test_range)
300 | paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
301 | edge_time.to(device), edge_index.to(device), edge_type.to(device))[x_ids]
302 | res = classifier.forward(paper_rep)
303 | for ai, bi in zip(ylabel, res.argsort(descending = True)):
304 | test_res += [(bi == ai).int().tolist()]
305 | test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
306 | print('Best Test NDCG: %.4f' % np.average(test_ndcg))
307 | test_mrr = mean_reciprocal_rank(test_res)
308 | print('Best Test MRR: %.4f' % np.average(test_mrr))
309 |
--------------------------------------------------------------------------------
/example_OAG/finetune_OAG_PF.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from GPT_GNN.data import *
3 | from GPT_GNN.model import *
4 | from warnings import filterwarnings
5 | filterwarnings("ignore")
6 |
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser(description='Fine-Tuning on OAG Paper-Field (L2) classification task')
10 |
11 | '''
12 | Dataset arguments
13 | '''
14 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset',
15 | help='The address of preprocessed graph.')
16 | parser.add_argument('--use_pretrain', help='Whether to use pre-trained model', action='store_true')
17 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/gpt_all_cs',
18 | help='The address for pretrained model.')
19 | parser.add_argument('--model_dir', type=str, default='/datadrive/models',
20 | help='The address for storing the models and optimization results.')
21 | parser.add_argument('--task_name', type=str, default='PF',
22 | help='The name of the stored models and optimization results.')
23 | parser.add_argument('--cuda', type=int, default=2,
24 | help='Avaiable GPU ID')
25 | parser.add_argument('--domain', type=str, default='_CS',
26 | help='CS, Medicion or All: _CS or _Med or (empty)')
27 | parser.add_argument('--sample_depth', type=int, default=6,
28 | help='How many numbers to sample the graph')
29 | parser.add_argument('--sample_width', type=int, default=128,
30 | help='How many nodes to be sampled per layer per type')
31 |
32 | '''
33 | Model arguments
34 | '''
35 | parser.add_argument('--conv_name', type=str, default='hgt',
36 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'],
37 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)')
38 | parser.add_argument('--n_hid', type=int, default=400,
39 | help='Number of hidden dimension')
40 | parser.add_argument('--n_heads', type=int, default=8,
41 | help='Number of attention head')
42 | parser.add_argument('--n_layers', type=int, default=3,
43 | help='Number of GNN layers')
44 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true')
45 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true')
46 | parser.add_argument('--dropout', type=int, default=0.2,
47 | help='Dropout ratio')
48 |
49 |
50 | '''
51 | Optimization arguments
52 | '''
53 | parser.add_argument('--optimizer', type=str, default='adamw',
54 | choices=['adamw', 'adam', 'sgd', 'adagrad'],
55 | help='optimizer to use.')
56 | parser.add_argument('--scheduler', type=str, default='cycle',
57 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine'])
58 | parser.add_argument('--data_percentage', type=int, default=0.1,
59 | help='Percentage of training and validation data to use')
60 | parser.add_argument('--n_epoch', type=int, default=50,
61 | help='Number of epoch to run')
62 | parser.add_argument('--n_pool', type=int, default=8,
63 | help='Number of process to sample subgraph')
64 | parser.add_argument('--n_batch', type=int, default=16,
65 | help='Number of batch (sampled graphs) for each epoch')
66 | parser.add_argument('--batch_size', type=int, default=256,
67 | help='Number of output nodes for training')
68 | parser.add_argument('--clip', type=int, default=0.5,
69 | help='Gradient Norm Clipping')
70 |
71 | args = parser.parse_args()
72 | args_print(args)
73 |
74 | if args.cuda != -1:
75 | device = torch.device("cuda:" + str(args.cuda))
76 | else:
77 | device = torch.device("cpu")
78 |
79 | print('Start Loading Graph Data...')
80 | graph = renamed_load(open(os.path.join(args.data_dir, 'graph%s.pk' % args.domain), 'rb'))
81 | print('Finish Loading Graph Data!')
82 |
83 | target_type = 'paper'
84 |
85 | types = graph.get_types()
86 | '''
87 | cand_list stores all the L2 fields, which is the classification domain.
88 | '''
89 | cand_list = list(graph.edge_list['field']['paper']['PF_in_L2'].keys())
90 | '''
91 | Use KL Divergence here, since each paper can be associated with multiple fields.
92 | Thus this task is a multi-label classification.
93 | '''
94 | criterion = nn.KLDivLoss(reduction='batchmean')
95 | def node_classification_sample(seed, pairs, time_range):
96 | '''
97 | sub-graph sampling and label preparation for node classification:
98 | (1) Sample batch_size number of output nodes (papers), get their time.
99 | '''
100 | np.random.seed(seed)
101 | target_ids = np.random.choice(list(pairs.keys()), args.batch_size, replace = False)
102 | target_info = []
103 | for target_id in target_ids:
104 | _, _time = pairs[target_id]
105 | target_info += [[target_id, _time]]
106 | '''
107 | (2) Based on the seed nodes, sample a subgraph with 'sampled_depth' and 'sampled_number'
108 | '''
109 | feature, times, edge_list, _, _ = sample_subgraph(graph, time_range, \
110 | inp = {'paper': np.array(target_info)}, \
111 | sampled_depth = args.sample_depth, sampled_number = args.sample_width)
112 |
113 | '''
114 | (3) Mask out the edge between the output target nodes (paper) with output source nodes (L2 field)
115 | '''
116 | masked_edge_list = []
117 | for i in edge_list['paper']['field']['rev_PF_in_L2']:
118 | if i[0] >= args.batch_size:
119 | masked_edge_list += [i]
120 | edge_list['paper']['field']['rev_PF_in_L2'] = masked_edge_list
121 |
122 | masked_edge_list = []
123 | for i in edge_list['field']['paper']['PF_in_L2']:
124 | if i[1] >= args.batch_size:
125 | masked_edge_list += [i]
126 | edge_list['field']['paper']['PF_in_L2'] = masked_edge_list
127 | '''
128 | (4) Transform the subgraph into torch Tensor (edge_index is in format of pytorch_geometric)
129 | '''
130 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
131 | to_torch(feature, times, edge_list, graph)
132 | '''
133 | (5) Prepare the labels for each output target node (paper), and their index in sampled graph.
134 | (node_dict[type][0] stores the start index of a specific type of nodes)
135 | '''
136 | ylabel = np.zeros([args.batch_size, len(cand_list)])
137 | for x_id, target_id in enumerate(target_ids):
138 | if target_id not in pairs:
139 | print('error 1' + str(target_id))
140 | for source_id in pairs[target_id][0]:
141 | if source_id not in cand_list:
142 | print('error 2' + str(target_id))
143 | ylabel[x_id][cand_list.index(source_id)] = 1
144 |
145 | ylabel /= ylabel.sum(axis=1).reshape(-1, 1)
146 | x_ids = np.arange(args.batch_size) + node_dict['paper'][0]
147 | return node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel
148 |
149 | def prepare_data(pool):
150 | '''
151 | Sampled and prepare training and validation data using multi-process parallization.
152 | '''
153 | jobs = []
154 | for batch_id in np.arange(args.n_batch):
155 | p = pool.apply_async(node_classification_sample, args=(randint(), \
156 | sel_train_pairs, train_range))
157 | jobs.append(p)
158 | p = pool.apply_async(node_classification_sample, args=(randint(), \
159 | sel_valid_pairs, valid_range))
160 | jobs.append(p)
161 | return jobs
162 |
163 | pre_range = {t: True for t in graph.times if t != None and t < 2014}
164 | train_range = {t: True for t in graph.times if t != None and t >= 2014 and t <= 2016}
165 | valid_range = {t: True for t in graph.times if t != None and t > 2016 and t <= 2017}
166 | test_range = {t: True for t in graph.times if t != None and t > 2017}
167 |
168 |
169 | train_pairs = {}
170 | valid_pairs = {}
171 | test_pairs = {}
172 | '''
173 | Prepare all the souce nodes (L2 field) associated with each target node (paper) as dict
174 | '''
175 | for target_id in graph.edge_list['paper']['field']['rev_PF_in_L2']:
176 | for source_id in graph.edge_list['paper']['field']['rev_PF_in_L2'][target_id]:
177 | _time = graph.edge_list['paper']['field']['rev_PF_in_L2'][target_id][source_id]
178 | if _time in train_range:
179 | if target_id not in train_pairs:
180 | train_pairs[target_id] = [[], _time]
181 | train_pairs[target_id][0] += [source_id]
182 | elif _time in valid_range:
183 | if target_id not in valid_pairs:
184 | valid_pairs[target_id] = [[], _time]
185 | valid_pairs[target_id][0] += [source_id]
186 | else:
187 | if target_id not in test_pairs:
188 | test_pairs[target_id] = [[], _time]
189 | test_pairs[target_id][0] += [source_id]
190 |
191 |
192 | np.random.seed(43)
193 | '''
194 | Only train and valid with a certain percentage of data, if necessary.
195 | '''
196 | sel_train_pairs = {p : train_pairs[p] for p in np.random.choice(list(train_pairs.keys()), int(len(train_pairs) * args.data_percentage), replace = False)}
197 | sel_valid_pairs = {p : valid_pairs[p] for p in np.random.choice(list(valid_pairs.keys()), int(len(valid_pairs) * args.data_percentage), replace = False)}
198 |
199 |
200 |
201 | '''
202 | Initialize GNN (model is specified by conv_name) and Classifier
203 | '''
204 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph.node_feature[target_type]['emb'].values[0]) + 401, n_hid = args.n_hid, \
205 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \
206 | num_relations = len(graph.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm)
207 | if args.use_pretrain:
208 | gnn.load_state_dict(load_gnn(torch.load(args.pretrain_model_dir)), strict = False)
209 | print('Load Pre-trained Model from (%s)' % args.pretrain_model_dir)
210 | classifier = Classifier(args.n_hid, len(cand_list))
211 |
212 | model = nn.Sequential(gnn, classifier).to(device)
213 |
214 |
215 | optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-4)
216 |
217 | stats = []
218 | res = []
219 | best_val = 0
220 | train_step = 0
221 |
222 | pool = mp.Pool(args.n_pool)
223 | st = time.time()
224 | jobs = prepare_data(pool)
225 |
226 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6)
227 |
228 | for epoch in np.arange(args.n_epoch) + 1:
229 | '''
230 | Prepare Training and Validation Data
231 | '''
232 | train_data = [job.get() for job in jobs[:-1]]
233 | valid_data = jobs[-1].get()
234 | pool.close()
235 | pool.join()
236 | '''
237 | After the data is collected, close the pool and then reopen it.
238 | '''
239 | pool = mp.Pool(args.n_pool)
240 | jobs = prepare_data(pool)
241 | et = time.time()
242 | print('Data Preparation: %.1fs' % (et - st))
243 |
244 | '''
245 | Train (2014 <= time <= 2016)
246 | '''
247 | model.train()
248 | train_losses = []
249 | for node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel in train_data:
250 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
251 | edge_time.to(device), edge_index.to(device), edge_type.to(device))
252 | res = classifier.forward(node_rep[x_ids])
253 | loss = criterion(res, torch.FloatTensor(ylabel).to(device))
254 |
255 | optimizer.zero_grad()
256 | loss.backward()
257 |
258 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
259 | optimizer.step()
260 |
261 | train_losses += [loss.cpu().detach().tolist()]
262 | train_step += 1
263 | scheduler.step(train_step)
264 | del res, loss
265 | '''
266 | Valid (2017 <= time <= 2017)
267 | '''
268 | model.eval()
269 | with torch.no_grad():
270 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = valid_data
271 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
272 | edge_time.to(device), edge_index.to(device), edge_type.to(device))
273 | res = classifier.forward(node_rep[x_ids])
274 | loss = criterion(res, torch.FloatTensor(ylabel).to(device))
275 |
276 | '''
277 | Calculate Valid NDCG. Update the best model based on highest NDCG score.
278 | '''
279 | valid_res = []
280 | for ai, bi in zip(ylabel, res.argsort(descending = True)):
281 | valid_res += [ai[bi.cpu().numpy()]]
282 | valid_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in valid_res])
283 | if valid_ndcg > best_val:
284 | best_val = valid_ndcg
285 | torch.save(model, os.path.join(args.model_dir, args.task_name + '_' + args.conv_name))
286 | print('UPDATE!!!')
287 |
288 | st = time.time()
289 | print(("Epoch: %d (%.1fs) LR: %.5f Train Loss: %.2f Valid Loss: %.2f Valid NDCG: %.4f") % \
290 | (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), \
291 | loss.cpu().detach().tolist(), valid_ndcg))
292 | stats += [[np.average(train_losses), loss.cpu().detach().tolist()]]
293 | del res, loss
294 | del train_data, valid_data
295 |
296 |
297 | '''
298 | Evaluate the trained model via test set (time >= 2018)
299 | '''
300 |
301 |
302 | best_model = torch.load(os.path.join(args.model_dir, args.task_name + '_' + args.conv_name))
303 | best_model.eval()
304 | gnn, classifier = best_model
305 | with torch.no_grad():
306 | test_res = []
307 | for _ in range(10):
308 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = \
309 | node_classification_sample(randint(), test_pairs, test_range)
310 | paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
311 | edge_time.to(device), edge_index.to(device), edge_type.to(device))[x_ids]
312 | res = classifier.forward(paper_rep)
313 | for ai, bi in zip(ylabel, res.argsort(descending = True)):
314 | test_res += [ai[bi.cpu().numpy()]]
315 | test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
316 | print('Best Test NDCG: %.4f' % np.average(test_ndcg))
317 | test_mrr = mean_reciprocal_rank(test_res)
318 | print('Best Test MRR: %.4f' % np.average(test_mrr))
319 |
--------------------------------------------------------------------------------
/example_reddit/.ipynb_checkpoints/pretrain_reddit-checkpoint.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from GPT_GNN.data import *
3 | from GPT_GNN.model import *
4 | from warnings import filterwarnings
5 | filterwarnings("ignore")
6 |
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser(description='Pre-training HGT on a given graph (heterogeneous / homogeneous)')
10 |
11 | '''
12 | GPT-GNN arguments
13 | '''
14 | parser.add_argument('--attr_ratio', type=float, default=0.5,
15 | help='Ratio of attr-loss against link-loss, range: [0-1]')
16 | parser.add_argument('--attr_type', type=str, default='vec',
17 | choices=['text', 'vec'],
18 | help='The type of attribute decoder')
19 | parser.add_argument('--neg_samp_num', type=int, default=255,
20 | help='Maximum number of negative sample for each target node.')
21 | parser.add_argument('--queue_size', type=int, default=256,
22 | help='Max size of adaptive embedding queue.')
23 | parser.add_argument('--w2v_dir', type=str, default='/datadrive/dataset/w2v_all',
24 | help='The address of preprocessed graph.')
25 |
26 | '''
27 | Dataset arguments
28 | '''
29 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset/graph_reddit.pk',
30 | help='The address of preprocessed graph.')
31 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/gpt_all_reddit',
32 | help='The address for storing the pre-trained models.')
33 | parser.add_argument('--cuda', type=int, default=1,
34 | help='Avaiable GPU ID')
35 | parser.add_argument('--sample_depth', type=int, default=6,
36 | help='How many layers within a mini-batch subgraph')
37 | parser.add_argument('--sample_width', type=int, default=128,
38 | help='How many nodes to be sampled per layer per type')
39 |
40 | '''
41 | Model arguments
42 | '''
43 | parser.add_argument('--conv_name', type=str, default='hgt',
44 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'],
45 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)')
46 | parser.add_argument('--n_hid', type=int, default=400,
47 | help='Number of hidden dimension')
48 | parser.add_argument('--n_heads', type=int, default=8,
49 | help='Number of attention head')
50 | parser.add_argument('--n_layers', type=int, default=3,
51 | help='Number of GNN layers')
52 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true')
53 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true')
54 | parser.add_argument('--dropout', type=int, default=0.2,
55 | help='Dropout ratio')
56 |
57 | '''
58 | Optimization arguments
59 | '''
60 | parser.add_argument('--max_lr', type=float, default=1e-3,
61 | help='Maximum learning rate.')
62 | parser.add_argument('--scheduler', type=str, default='cycle',
63 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine'])
64 | parser.add_argument('--n_epoch', type=int, default=20,
65 | help='Number of epoch to run')
66 | parser.add_argument('--n_pool', type=int, default=8,
67 | help='Number of process to sample subgraph')
68 | parser.add_argument('--n_batch', type=int, default=32,
69 | help='Number of batch (sampled graphs) for each epoch')
70 | parser.add_argument('--batch_size', type=int, default=256,
71 | help='Number of output nodes for training')
72 | parser.add_argument('--clip', type=float, default=0.5,
73 | help='Gradient Norm Clipping')
74 |
75 | args = parser.parse_args()
76 | args_print(args)
77 |
78 |
79 | if args.cuda != -1:
80 | device = torch.device("cuda:" + str(args.cuda))
81 | else:
82 | device = torch.device("cpu")
83 |
84 |
85 | print('Start Loading Graph Data...')
86 | graph_reddit = dill.load(open(args.data_dir, 'rb'))
87 | print('Finish Loading Graph Data!')
88 |
89 | target_type = 'def'
90 | rel_stop_list = ['self']
91 |
92 | pre_target_nodes = graph_reddit.pre_target_nodes
93 | train_target_nodes = graph_reddit.train_target_nodes
94 |
95 | pre_target_nodes = np.concatenate([pre_target_nodes, np.ones(len(pre_target_nodes))]).reshape(2, -1).transpose()
96 | train_target_nodes = np.concatenate([train_target_nodes, np.ones(len(train_target_nodes))]).reshape(2, -1).transpose()
97 |
98 |
99 | def GPT_sample(seed, target_nodes, time_range, batch_size, feature_extractor):
100 | np.random.seed(seed)
101 | samp_target_nodes = target_nodes[np.random.choice(len(target_nodes), batch_size)]
102 | threshold = 0.5
103 | feature, times, edge_list, _, attr = sample_subgraph(graph, time_range, \
104 | inp = {target_type: samp_target_nodes}, feature_extractor = feature_extractor, \
105 | sampled_depth = args.sample_depth, sampled_number = args.sample_width)
106 | rem_edge_list = defaultdict( #source_type
107 | lambda: defaultdict( #relation_type
108 | lambda: [] # [target_id, source_id]
109 | ))
110 |
111 | ori_list = {}
112 | for source_type in edge_list[target_type]:
113 | ori_list[source_type] = {}
114 | for relation_type in edge_list[target_type][source_type]:
115 | ori_list[source_type][relation_type] = np.array(edge_list[target_type][source_type][relation_type])
116 | el = []
117 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]:
118 | if target_ser < source_ser:
119 | if relation_type not in rel_stop_list and target_ser < batch_size and \
120 | np.random.random() > threshold:
121 | rem_edge_list[source_type][relation_type] += [[target_ser, source_ser]]
122 | continue
123 | el += [[target_ser, source_ser]]
124 | el += [[source_ser, target_ser]]
125 | el = np.array(el)
126 | edge_list[target_type][source_type][relation_type] = el
127 |
128 | if relation_type == 'self':
129 | continue
130 |
131 | '''
132 | Adding feature nodes:
133 | '''
134 | n_target_nodes = len(feature[target_type])
135 | feature[target_type] = np.concatenate((feature[target_type], np.zeros([batch_size, feature[target_type].shape[1]])))
136 | times[target_type] = np.concatenate((times[target_type], times[target_type][:batch_size]))
137 |
138 | for source_type in edge_list[target_type]:
139 | for relation_type in edge_list[target_type][source_type]:
140 | el = []
141 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]:
142 | if target_ser < batch_size:
143 | if relation_type == 'self':
144 | el += [[target_ser + n_target_nodes, target_ser + n_target_nodes]]
145 | else:
146 | el += [[target_ser + n_target_nodes, source_ser]]
147 | if len(el) > 0:
148 | edge_list[target_type][source_type][relation_type] = \
149 | np.concatenate((edge_list[target_type][source_type][relation_type], el))
150 |
151 |
152 | rem_edge_lists = {}
153 | for source_type in rem_edge_list:
154 | rem_edge_lists[source_type] = {}
155 | for relation_type in rem_edge_list[source_type]:
156 | rem_edge_lists[source_type][relation_type] = np.array(rem_edge_list[source_type][relation_type])
157 | del rem_edge_list
158 |
159 | return to_torch(feature, times, edge_list, graph), rem_edge_lists, ori_list, \
160 | attr[:batch_size], (n_target_nodes, n_target_nodes + batch_size)
161 |
162 |
163 |
164 |
165 | def prepare_data(pool):
166 | jobs = []
167 | for _ in np.arange(args.n_batch - 1):
168 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), pre_target_nodes, {1: True}, args.batch_size, feature_reddit)))
169 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), train_target_nodes, {1: True}, args.batch_size, feature_reddit)))
170 | return jobs
171 |
172 |
173 | pool = mp.Pool(args.n_pool)
174 | st = time.time()
175 | jobs = prepare_data(pool)
176 | repeat_num = int(len(pre_target_nodes) / args.batch_size // args.n_batch)
177 |
178 |
179 | data, rem_edge_list, ori_edge_list, _, _ = GPT_sample(randint(), pre_target_nodes, {1: True}, args.batch_size, feature_reddit)
180 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data
181 | types = graph.get_types()
182 |
183 |
184 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph.node_feature[target_type]['emb'].values[0]), n_hid = args.n_hid, \
185 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \
186 | num_relations = len(graph.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm, use_RTE = False)
187 |
188 | if args.attr_type == 'text':
189 | from gensim.models import Word2Vec
190 | w2v_model = Word2Vec.load(args.w2v_dir)
191 | n_tokens = len(w2v_model.wv.vocab)
192 | attr_decoder = RNNModel(n_word = n_tokens, ninp = gnn.n_hid, \
193 | nhid = w2v_model.vector_size, nlayers = 2)
194 | attr_decoder.from_w2v(torch.FloatTensor(w2v_model.wv.vectors))
195 | else:
196 | attr_decoder = Matcher(gnn.n_hid, gnn.in_dim)
197 |
198 | gpt_gnn = GPT_GNN(gnn = gnn, rem_edge_list = rem_edge_list, attr_decoder = attr_decoder, \
199 | types = types, neg_samp_num = args.neg_samp_num, device = device)
200 | gpt_gnn.init_emb.data = node_feature[node_type == node_dict[target_type][1]].mean(dim=0).detach()
201 | gpt_gnn = gpt_gnn.to(device)
202 |
203 |
204 | best_val = 100000
205 | train_step = 0
206 | stats = []
207 | optimizer = torch.optim.AdamW(gpt_gnn.parameters(), weight_decay = 1e-2, eps=1e-06, lr = args.max_lr)
208 |
209 | if args.scheduler == 'cycle':
210 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.02, anneal_strategy='linear', final_div_factor=100,\
211 | max_lr = args.max_lr, total_steps = repeat_num * args.n_batch * args.n_epoch + 1)
212 | elif args.scheduler == 'cosine':
213 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, repeat_num * args.n_batch, eta_min=1e-6)
214 |
215 | print('Start Pretraining...')
216 | for epoch in np.arange(args.n_epoch) + 1:
217 | gpt_gnn.neg_queue_size = args.queue_size * epoch // args.n_epoch
218 | for batch in np.arange(repeat_num) + 1:
219 | train_data = [job.get() for job in jobs[:-1]]
220 | valid_data = jobs[-1].get()
221 | pool.close()
222 | pool.join()
223 | pool = mp.Pool(args.n_pool)
224 | jobs = prepare_data(pool)
225 | et = time.time()
226 | print('Data Preparation: %.1fs' % (et - st))
227 |
228 | train_link_losses = []
229 | train_attr_losses = []
230 | gpt_gnn.train()
231 | for data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) in train_data:
232 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data
233 | node_feature = node_feature.detach()
234 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb
235 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \
236 | edge_index.to(device), edge_type.to(device))
237 |
238 | loss_link, _ = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue=True)
239 | if args.attr_type == 'text':
240 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device)
241 | else:
242 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device))
243 |
244 |
245 | loss = loss_link * (1 - args.attr_ratio) + loss_attr * args.attr_ratio
246 |
247 |
248 | optimizer.zero_grad()
249 | loss.backward()
250 | torch.nn.utils.clip_grad_norm_(gpt_gnn.parameters(), args.clip)
251 | optimizer.step()
252 |
253 | train_link_losses += [loss_link.item()]
254 | train_attr_losses += [loss_attr.item()]
255 | scheduler.step()
256 | '''
257 | Valid
258 | '''
259 | gpt_gnn.eval()
260 | with torch.no_grad():
261 | data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) = valid_data
262 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data
263 | node_feature = node_feature.detach()
264 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb
265 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \
266 | edge_index.to(device), edge_type.to(device))
267 | loss_link, ress = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = False, update_queue=True)
268 | loss_link = loss_link.item()
269 | if args.attr_type == 'text':
270 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device)
271 | else:
272 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device))
273 |
274 | ndcgs = []
275 | for i in ress:
276 | ai = np.zeros(len(i[0]))
277 | ai[0] = 1
278 | ndcgs += [ndcg_at_k(ai[j.cpu().numpy()], len(j)) for j in i.argsort(descending = True)]
279 |
280 | valid_loss = loss_link * (1 - args.attr_ratio) + loss_attr * args.attr_ratio
281 | st = time.time()
282 | print(("Epoch: %d, (%d / %d) %.1fs LR: %.5f Train Loss: (%.3f, %.3f) Valid Loss: (%.3f, %.3f) NDCG: %.3f Norm: %.3f queue: %d") % \
283 | (epoch, batch, repeat_num, (st-et), optimizer.param_groups[0]['lr'], np.average(train_link_losses), np.average(train_attr_losses), \
284 | loss_link, loss_attr, np.average(ndcgs), node_emb.norm(dim=1).mean(), gpt_gnn.neg_queue_size))
285 |
286 | if valid_loss < best_val:
287 | best_val = valid_loss
288 | print('UPDATE!!!')
289 | torch.save(gpt_gnn.state_dict(), args.pretrain_model_dir)
290 | stats += [[np.average(train_link_losses), loss_link, loss_attr, valid_loss]]
291 |
--------------------------------------------------------------------------------
/example_reddit/pretrain_reddit.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from GPT_GNN.data import *
3 | from GPT_GNN.model import *
4 | from warnings import filterwarnings
5 | filterwarnings("ignore")
6 |
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser(description='Pre-training HGT on a given graph (heterogeneous / homogeneous)')
10 |
11 | '''
12 | GPT-GNN arguments
13 | '''
14 | parser.add_argument('--attr_ratio', type=float, default=0.5,
15 | help='Ratio of attr-loss against link-loss, range: [0-1]')
16 | parser.add_argument('--attr_type', type=str, default='vec',
17 | choices=['text', 'vec'],
18 | help='The type of attribute decoder')
19 | parser.add_argument('--neg_samp_num', type=int, default=255,
20 | help='Maximum number of negative sample for each target node.')
21 | parser.add_argument('--queue_size', type=int, default=256,
22 | help='Max size of adaptive embedding queue.')
23 | parser.add_argument('--w2v_dir', type=str, default='/datadrive/dataset/w2v_all',
24 | help='The address of preprocessed graph.')
25 |
26 | '''
27 | Dataset arguments
28 | '''
29 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset/graph_reddit.pk',
30 | help='The address of preprocessed graph.')
31 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/gpt_all_reddit',
32 | help='The address for storing the pre-trained models.')
33 | parser.add_argument('--cuda', type=int, default=1,
34 | help='Avaiable GPU ID')
35 | parser.add_argument('--sample_depth', type=int, default=6,
36 | help='How many layers within a mini-batch subgraph')
37 | parser.add_argument('--sample_width', type=int, default=128,
38 | help='How many nodes to be sampled per layer per type')
39 |
40 | '''
41 | Model arguments
42 | '''
43 | parser.add_argument('--conv_name', type=str, default='hgt',
44 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'],
45 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)')
46 | parser.add_argument('--n_hid', type=int, default=400,
47 | help='Number of hidden dimension')
48 | parser.add_argument('--n_heads', type=int, default=8,
49 | help='Number of attention head')
50 | parser.add_argument('--n_layers', type=int, default=3,
51 | help='Number of GNN layers')
52 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true')
53 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true')
54 | parser.add_argument('--dropout', type=int, default=0.2,
55 | help='Dropout ratio')
56 |
57 | '''
58 | Optimization arguments
59 | '''
60 | parser.add_argument('--max_lr', type=float, default=1e-3,
61 | help='Maximum learning rate.')
62 | parser.add_argument('--scheduler', type=str, default='cycle',
63 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine'])
64 | parser.add_argument('--n_epoch', type=int, default=20,
65 | help='Number of epoch to run')
66 | parser.add_argument('--n_pool', type=int, default=8,
67 | help='Number of process to sample subgraph')
68 | parser.add_argument('--n_batch', type=int, default=32,
69 | help='Number of batch (sampled graphs) for each epoch')
70 | parser.add_argument('--batch_size', type=int, default=256,
71 | help='Number of output nodes for training')
72 | parser.add_argument('--clip', type=float, default=0.5,
73 | help='Gradient Norm Clipping')
74 |
75 | args = parser.parse_args()
76 | args_print(args)
77 |
78 |
79 | if args.cuda != -1:
80 | device = torch.device("cuda:" + str(args.cuda))
81 | else:
82 | device = torch.device("cpu")
83 |
84 |
85 | print('Start Loading Graph Data...')
86 | graph_reddit: Graph = dill.load(open(args.data_dir, 'rb'))
87 | print('Finish Loading Graph Data!')
88 |
89 | target_type = 'def'
90 | rel_stop_list = ['self']
91 |
92 | pre_target_nodes = graph_reddit.pre_target_nodes
93 | train_target_nodes = graph_reddit.train_target_nodes
94 |
95 | pre_target_nodes = np.concatenate([pre_target_nodes, np.ones(len(pre_target_nodes))]).reshape(2, -1).transpose()
96 | train_target_nodes = np.concatenate([train_target_nodes, np.ones(len(train_target_nodes))]).reshape(2, -1).transpose()
97 |
98 |
99 | def GPT_sample(seed, target_nodes, time_range, batch_size, feature_extractor):
100 | np.random.seed(seed)
101 | samp_target_nodes = target_nodes[np.random.choice(len(target_nodes), batch_size)]
102 | threshold = 0.5
103 | feature, times, edge_list, _, attr = sample_subgraph(graph_reddit, time_range, \
104 | inp = {target_type: samp_target_nodes}, feature_extractor = feature_extractor, \
105 | sampled_depth = args.sample_depth, sampled_number = args.sample_width)
106 | rem_edge_list = defaultdict( #source_type
107 | lambda: defaultdict( #relation_type
108 | lambda: [] # [target_id, source_id]
109 | ))
110 |
111 | ori_list = {}
112 | for source_type in edge_list[target_type]:
113 | ori_list[source_type] = {}
114 | for relation_type in edge_list[target_type][source_type]:
115 | ori_list[source_type][relation_type] = np.array(edge_list[target_type][source_type][relation_type])
116 | el = []
117 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]:
118 | if target_ser < source_ser:
119 | if relation_type not in rel_stop_list and target_ser < batch_size and \
120 | np.random.random() > threshold:
121 | rem_edge_list[source_type][relation_type] += [[target_ser, source_ser]]
122 | continue
123 | el += [[target_ser, source_ser]]
124 | el += [[source_ser, target_ser]]
125 | el = np.array(el)
126 | edge_list[target_type][source_type][relation_type] = el
127 |
128 | if relation_type == 'self':
129 | continue
130 |
131 | '''
132 | Adding feature nodes:
133 | '''
134 | n_target_nodes = len(feature[target_type])
135 | feature[target_type] = np.concatenate((feature[target_type], np.zeros([batch_size, feature[target_type].shape[1]])))
136 | times[target_type] = np.concatenate((times[target_type], times[target_type][:batch_size]))
137 |
138 | for source_type in edge_list[target_type]:
139 | for relation_type in edge_list[target_type][source_type]:
140 | el = []
141 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]:
142 | if target_ser < batch_size:
143 | if relation_type == 'self':
144 | el += [[target_ser + n_target_nodes, target_ser + n_target_nodes]]
145 | else:
146 | el += [[target_ser + n_target_nodes, source_ser]]
147 | if len(el) > 0:
148 | edge_list[target_type][source_type][relation_type] = \
149 | np.concatenate((edge_list[target_type][source_type][relation_type], el))
150 |
151 |
152 | rem_edge_lists = {}
153 | for source_type in rem_edge_list:
154 | rem_edge_lists[source_type] = {}
155 | for relation_type in rem_edge_list[source_type]:
156 | rem_edge_lists[source_type][relation_type] = np.array(rem_edge_list[source_type][relation_type])
157 | del rem_edge_list
158 |
159 | return to_torch(feature, times, edge_list, graph_reddit), rem_edge_lists, ori_list, \
160 | attr[:batch_size], (n_target_nodes, n_target_nodes + batch_size)
161 |
162 |
163 |
164 |
165 | def prepare_data(pool):
166 | jobs = []
167 | for _ in np.arange(args.n_batch - 1):
168 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), pre_target_nodes, {1: True}, args.batch_size, feature_reddit)))
169 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), train_target_nodes, {1: True}, args.batch_size, feature_reddit)))
170 | return jobs
171 |
172 |
173 | pool = mp.Pool(args.n_pool)
174 | st = time.time()
175 | jobs = prepare_data(pool)
176 | repeat_num = int(len(pre_target_nodes) / args.batch_size // args.n_batch)
177 |
178 |
179 | data, rem_edge_list, ori_edge_list, _, _ = GPT_sample(randint(), pre_target_nodes, {1: True}, args.batch_size, feature_reddit)
180 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data
181 | types = graph_reddit.get_types()
182 |
183 |
184 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph_reddit.node_feature[target_type]['emb'].values[0]), n_hid = args.n_hid, \
185 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \
186 | num_relations = len(graph_reddit.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm, use_RTE = False)
187 |
188 | if args.attr_type == 'text':
189 | from gensim.models import Word2Vec
190 | w2v_model = Word2Vec.load(args.w2v_dir)
191 | n_tokens = len(w2v_model.wv.vocab)
192 | attr_decoder = RNNModel(n_word = n_tokens, ninp = gnn.n_hid, \
193 | nhid = w2v_model.vector_size, nlayers = 2)
194 | attr_decoder.from_w2v(torch.FloatTensor(w2v_model.wv.vectors))
195 | else:
196 | attr_decoder = Matcher(gnn.n_hid, gnn.in_dim)
197 |
198 | gpt_gnn = GPT_GNN(gnn = gnn, rem_edge_list = rem_edge_list, attr_decoder = attr_decoder, \
199 | types = types, neg_samp_num = args.neg_samp_num, device = device)
200 | gpt_gnn.init_emb.data = node_feature[node_type == node_dict[target_type][1]].mean(dim=0).detach()
201 | gpt_gnn = gpt_gnn.to(device)
202 |
203 |
204 | best_val = 100000
205 | train_step = 0
206 | stats = []
207 | optimizer = torch.optim.AdamW(gpt_gnn.parameters(), weight_decay = 1e-2, eps=1e-06, lr = args.max_lr)
208 |
209 | if args.scheduler == 'cycle':
210 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.02, anneal_strategy='linear', final_div_factor=100,\
211 | max_lr = args.max_lr, total_steps = repeat_num * args.n_batch * args.n_epoch + 1)
212 | elif args.scheduler == 'cosine':
213 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, repeat_num * args.n_batch, eta_min=1e-6)
214 |
215 | print('Start Pretraining...')
216 | for epoch in np.arange(args.n_epoch) + 1:
217 | gpt_gnn.neg_queue_size = args.queue_size * epoch // args.n_epoch
218 | for batch in np.arange(repeat_num) + 1:
219 | train_data = [job.get() for job in jobs[:-1]]
220 | valid_data = jobs[-1].get()
221 | pool.close()
222 | pool.join()
223 | pool = mp.Pool(args.n_pool)
224 | jobs = prepare_data(pool)
225 | et = time.time()
226 | print('Data Preparation: %.1fs' % (et - st))
227 |
228 | train_link_losses = []
229 | train_attr_losses = []
230 | gpt_gnn.train()
231 | for data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) in train_data:
232 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data
233 | node_feature = node_feature.detach()
234 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb
235 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \
236 | edge_index.to(device), edge_type.to(device))
237 |
238 | loss_link, _ = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue=True)
239 | if args.attr_type == 'text':
240 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device)
241 | else:
242 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device))
243 |
244 |
245 | loss = loss_link * (1 - args.attr_ratio) + loss_attr * args.attr_ratio
246 |
247 |
248 | optimizer.zero_grad()
249 | loss.backward()
250 | torch.nn.utils.clip_grad_norm_(gpt_gnn.parameters(), args.clip)
251 | optimizer.step()
252 |
253 | train_link_losses += [loss_link.item()]
254 | train_attr_losses += [loss_attr.item()]
255 | scheduler.step()
256 | '''
257 | Valid
258 | '''
259 | gpt_gnn.eval()
260 | with torch.no_grad():
261 | data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) = valid_data
262 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data
263 | node_feature = node_feature.detach()
264 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb
265 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \
266 | edge_index.to(device), edge_type.to(device))
267 | loss_link, ress = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = False, update_queue=True)
268 | loss_link = loss_link.item()
269 | if args.attr_type == 'text':
270 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device)
271 | else:
272 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device))
273 |
274 | ndcgs = []
275 | for i in ress:
276 | ai = np.zeros(len(i[0]))
277 | ai[0] = 1
278 | ndcgs += [ndcg_at_k(ai[j.cpu().numpy()], len(j)) for j in i.argsort(descending = True)]
279 |
280 | valid_loss = loss_link * (1 - args.attr_ratio) + loss_attr * args.attr_ratio
281 | st = time.time()
282 | print(("Epoch: %d, (%d / %d) %.1fs LR: %.5f Train Loss: (%.3f, %.3f) Valid Loss: (%.3f, %.3f) NDCG: %.3f Norm: %.3f queue: %d") % \
283 | (epoch, batch, repeat_num, (st-et), optimizer.param_groups[0]['lr'], np.average(train_link_losses), np.average(train_attr_losses), \
284 | loss_link, loss_attr, np.average(ndcgs), node_emb.norm(dim=1).mean(), gpt_gnn.neg_queue_size))
285 |
286 | if valid_loss < best_val:
287 | best_val = valid_loss
288 | print('UPDATE!!!')
289 | torch.save(gpt_gnn.state_dict(), args.pretrain_model_dir)
290 | stats += [[np.average(train_link_losses), loss_link, loss_attr, valid_loss]]
291 |
--------------------------------------------------------------------------------
/example_OAG/pretrain_OAG.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from GPT_GNN.data import *
3 | from GPT_GNN.model import *
4 | from warnings import filterwarnings
5 | filterwarnings("ignore")
6 |
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser(description='Pre-training HGT on a given graph (heterogeneous / homogeneous)')
10 |
11 | '''
12 | GPT-GNN arguments
13 | '''
14 | parser.add_argument('--attr_ratio', type=float, default=0.5,
15 | help='Ratio of attr-loss against link-loss, range: [0-1]')
16 | parser.add_argument('--attr_type', type=str, default='text',
17 | choices=['text', 'vec'],
18 | help='The type of attribute decoder')
19 | parser.add_argument('--neg_samp_num', type=int, default=255,
20 | help='Maximum number of negative sample for each target node.')
21 | parser.add_argument('--queue_size', type=int, default=256,
22 | help='Max size of adaptive embedding queue.')
23 | parser.add_argument('--w2v_dir', type=str, default='/datadrive/dataset/w2v_all',
24 | help='The address of preprocessed graph.')
25 |
26 | '''
27 | Dataset arguments
28 | '''
29 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset/graph_CS.pk',
30 | help='The address of preprocessed graph.')
31 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/test',
32 | help='The address for storing the models and optimization results.')
33 | parser.add_argument('--cuda', type=int, default=2,
34 | help='Avaiable GPU ID')
35 | parser.add_argument('--sample_depth', type=int, default=6,
36 | help='How many layers within a mini-batch subgraph')
37 | parser.add_argument('--sample_width', type=int, default=128,
38 | help='How many nodes to be sampled per layer per type')
39 |
40 | '''
41 | Model arguments
42 | '''
43 | parser.add_argument('--conv_name', type=str, default='hgt',
44 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'],
45 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)')
46 | parser.add_argument('--n_hid', type=int, default=400,
47 | help='Number of hidden dimension')
48 | parser.add_argument('--n_heads', type=int, default=8,
49 | help='Number of attention head')
50 | parser.add_argument('--n_layers', type=int, default=3,
51 | help='Number of GNN layers')
52 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true')
53 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true')
54 | parser.add_argument('--dropout', type=int, default=0.2,
55 | help='Dropout ratio')
56 |
57 | '''
58 | Optimization arguments
59 | '''
60 | parser.add_argument('--max_lr', type=float, default=1e-3,
61 | help='Maximum learning rate.')
62 | parser.add_argument('--scheduler', type=str, default='cycle',
63 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine'])
64 | parser.add_argument('--n_epoch', type=int, default=20,
65 | help='Number of epoch to run')
66 | parser.add_argument('--n_pool', type=int, default=8,
67 | help='Number of process to sample subgraph')
68 | parser.add_argument('--n_batch', type=int, default=32,
69 | help='Number of batch (sampled graphs) for each epoch')
70 | parser.add_argument('--batch_size', type=int, default=256,
71 | help='Number of output nodes for training')
72 | parser.add_argument('--clip', type=float, default=0.5,
73 | help='Gradient Norm Clipping')
74 |
75 |
76 | args = parser.parse_args()
77 | args_print(args)
78 |
79 | if args.cuda != -1:
80 | device = torch.device("cuda:" + str(args.cuda))
81 | else:
82 | device = torch.device("cpu")
83 |
84 | print('Start Loading Graph Data...')
85 | graph = renamed_load(open(args.data_dir, 'rb'))
86 | print('Finish Loading Graph Data!')
87 |
88 | pre_range = {t: True for t in graph.times if t != None and t < 2014}
89 | train_range = {t: True for t in graph.times if t != None and t >= 2014 and t <= 2016}
90 | valid_range = {t: True for t in graph.times if t != None and t > 2016 and t <= 2017}
91 | test_range = {t: True for t in graph.times if t != None and t > 2017}
92 |
93 | pre_target_nodes = []
94 | train_target_nodes = []
95 | target_type = 'paper'
96 | rel_stop_list = ['self', 'rev_PF_in_L0', 'rev_PF_in_L5', 'rev_PV_Repository', 'rev_PV_Patent']
97 |
98 |
99 | for p_id, _time in graph.node_feature[target_type]['time'].iteritems():
100 | if _time in pre_range:
101 | pre_target_nodes += [[p_id, _time]]
102 | elif _time in train_range:
103 | train_target_nodes += [[p_id, _time]]
104 | pre_target_nodes = np.array(pre_target_nodes)
105 | train_target_nodes = np.array(train_target_nodes)
106 |
107 |
108 | def GPT_sample(seed, target_nodes, time_range, batch_size, feature_extractor):
109 | np.random.seed(seed)
110 | samp_target_nodes = target_nodes[np.random.choice(len(target_nodes), batch_size)]
111 | threshold = 0.5
112 | feature, times, edge_list, _, attr = sample_subgraph(graph, time_range, \
113 | inp = {target_type: samp_target_nodes}, feature_extractor = feature_extractor, \
114 | sampled_depth = args.sample_depth, sampled_number = args.sample_width)
115 | rem_edge_list = defaultdict( #source_type
116 | lambda: defaultdict( #relation_type
117 | lambda: [] # [target_id, source_id]
118 | ))
119 |
120 | ori_list = {}
121 | for source_type in edge_list[target_type]:
122 | ori_list[source_type] = {}
123 | for relation_type in edge_list[target_type][source_type]:
124 | ori_list[source_type][relation_type] = np.array(edge_list[target_type][source_type][relation_type])
125 | el = []
126 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]:
127 | if relation_type not in rel_stop_list and target_ser < batch_size and np.random.random() > threshold:
128 | rem_edge_list[source_type][relation_type] += [[target_ser, source_ser]]
129 | continue
130 | el += [[target_ser, source_ser]]
131 | el = np.array(el)
132 | edge_list[target_type][source_type][relation_type] = el
133 |
134 | if relation_type == 'self':
135 | continue
136 | else:
137 | if 'rev_' in relation_type:
138 | rev_relation = relation_type[4:]
139 | else:
140 | rev_relation = 'rev_' + relation_type
141 | edge_list[source_type]['paper'][rev_relation] = list(np.stack((el[:,1], el[:,0])).T)
142 |
143 | '''
144 | Adding feature nodes:
145 | '''
146 | n_target_nodes = len(feature[target_type])
147 | feature[target_type] = np.concatenate((feature[target_type], np.zeros([batch_size, feature[target_type].shape[1]])))
148 | times[target_type] = np.concatenate((times[target_type], times[target_type][:batch_size]))
149 |
150 | for source_type in edge_list[target_type]:
151 | for relation_type in edge_list[target_type][source_type]:
152 | el = []
153 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]:
154 | if target_ser < batch_size:
155 | if relation_type == 'self':
156 | el += [[target_ser + n_target_nodes, target_ser + n_target_nodes]]
157 | else:
158 | el += [[target_ser + n_target_nodes, source_ser]]
159 | if len(el) > 0:
160 | edge_list[target_type][source_type][relation_type] = \
161 | np.concatenate((edge_list[target_type][source_type][relation_type], el))
162 |
163 |
164 | rem_edge_lists = {}
165 | for source_type in rem_edge_list:
166 | rem_edge_lists[source_type] = {}
167 | for relation_type in rem_edge_list[source_type]:
168 | rem_edge_lists[source_type][relation_type] = np.array(rem_edge_list[source_type][relation_type])
169 | del rem_edge_list
170 |
171 | return to_torch(feature, times, edge_list, graph), rem_edge_lists, ori_list, \
172 | attr[:batch_size], (n_target_nodes, n_target_nodes + batch_size)
173 |
174 |
175 |
176 | def prepare_data(pool):
177 | jobs = []
178 | for _ in np.arange(args.n_batch - 1):
179 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), pre_target_nodes, pre_range, args.batch_size, feature_OAG)))
180 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), train_target_nodes, train_range, args.batch_size, feature_OAG)))
181 | return jobs
182 |
183 |
184 | pool = mp.Pool(args.n_pool)
185 | st = time.time()
186 | jobs = prepare_data(pool)
187 | repeat_num = int(len(pre_target_nodes) / args.batch_size // args.n_batch)
188 |
189 |
190 | data, rem_edge_list, ori_edge_list, _, _ = GPT_sample(randint(), pre_target_nodes, pre_range, args.batch_size, feature_OAG)
191 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data
192 | types = graph.get_types()
193 |
194 |
195 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph.node_feature[target_type]['emb'].values[0]) + 401, n_hid = args.n_hid, \
196 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \
197 | num_relations = len(graph.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm)
198 |
199 |
200 | if args.attr_type == 'text':
201 | from gensim.models import Word2Vec
202 | w2v_model = Word2Vec.load(args.w2v_dir)
203 | n_tokens = len(w2v_model.wv.vocab)
204 | attr_decoder = RNNModel(n_word = n_tokens, ninp = gnn.n_hid, \
205 | nhid = w2v_model.vector_size, nlayers = 2)
206 | attr_decoder.from_w2v(torch.FloatTensor(w2v_model.wv.vectors))
207 | else:
208 | attr_decoder = Matcher(gnn.n_hid, gnn.in_dim)
209 |
210 | gpt_gnn = GPT_GNN(gnn = gnn, rem_edge_list = rem_edge_list, attr_decoder = attr_decoder, \
211 | neg_queue_size = 0, types = types, neg_samp_num = args.neg_samp_num, device = device)
212 | gpt_gnn.init_emb.data = node_feature[node_type == node_dict[target_type][1]].mean(dim=0).detach()
213 | gpt_gnn = gpt_gnn.to(device)
214 |
215 |
216 |
217 | best_val = 100000
218 | train_step = 0
219 | stats = []
220 | optimizer = torch.optim.AdamW(gpt_gnn.parameters(), weight_decay = 1e-2, eps=1e-06, lr = args.max_lr)
221 |
222 | if args.scheduler == 'cycle':
223 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.02, anneal_strategy='linear', final_div_factor=100,\
224 | max_lr = args.max_lr, total_steps = repeat_num * args.n_batch * args.n_epoch + 1)
225 | elif args.scheduler == 'cosine':
226 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, repeat_num * args.n_batch, eta_min=1e-6)
227 |
228 | print('Start Pretraining...')
229 | for epoch in np.arange(args.n_epoch) + 1:
230 | gpt_gnn.neg_queue_size = args.queue_size * epoch // args.n_epoch
231 | for batch in np.arange(repeat_num) + 1:
232 | train_data = [job.get() for job in jobs[:-1]]
233 | valid_data = jobs[-1].get()
234 | pool.close()
235 | pool.join()
236 | pool = mp.Pool(args.n_pool)
237 | jobs = prepare_data(pool)
238 | et = time.time()
239 | print('Data Preparation: %.1fs' % (et - st))
240 |
241 | train_link_losses = []
242 | train_attr_losses = []
243 | gpt_gnn.train()
244 | for data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) in train_data:
245 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data
246 | node_feature = node_feature.detach()
247 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb
248 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \
249 | edge_index.to(device), edge_type.to(device))
250 |
251 | loss_link, _ = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue=True)
252 | if args.attr_type == 'text':
253 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device)
254 | else:
255 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device))
256 |
257 |
258 | loss = loss_link + loss_attr * args.attr_ratio
259 |
260 |
261 | optimizer.zero_grad()
262 | loss.backward()
263 | torch.nn.utils.clip_grad_norm_(gpt_gnn.parameters(), args.clip)
264 | optimizer.step()
265 |
266 | train_link_losses += [loss_link.item()]
267 | train_attr_losses += [loss_attr.item()]
268 | scheduler.step()
269 | '''
270 | Valid
271 | '''
272 | gpt_gnn.eval()
273 | with torch.no_grad():
274 | data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) = valid_data
275 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data
276 | node_feature = node_feature.detach()
277 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb
278 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \
279 | edge_index.to(device), edge_type.to(device))
280 | loss_link, ress = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = False, update_queue=True)
281 | loss_link = loss_link.item()
282 | if args.attr_type == 'text':
283 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device)
284 | else:
285 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device))
286 |
287 | ndcgs = []
288 | for i in ress:
289 | ai = np.zeros(len(i[0]))
290 | ai[0] = 1
291 | ndcgs += [ndcg_at_k(ai[j.cpu().numpy()], len(j)) for j in i.argsort(descending = True)]
292 |
293 | valid_loss = loss_link + loss_attr * args.attr_ratio
294 | st = time.time()
295 | print(("Epoch: %d, (%d / %d) %.1fs LR: %.5f Train Loss: (%.3f, %.3f) Valid Loss: (%.3f, %.3f) NDCG: %.3f Norm: %.3f queue: %d") % \
296 | (epoch, batch, repeat_num, (st-et), optimizer.param_groups[0]['lr'], np.average(train_link_losses), np.average(train_attr_losses), \
297 | loss_link, loss_attr, np.average(ndcgs), node_emb.norm(dim=1).mean(), gpt_gnn.neg_queue_size))
298 |
299 | if valid_loss < best_val:
300 | best_val = valid_loss
301 | print('UPDATE!!!')
302 | torch.save(gpt_gnn.state_dict(), args.pretrain_model_dir)
303 | stats += [[np.average(train_link_losses), loss_link, loss_attr, valid_loss]]
304 |
--------------------------------------------------------------------------------