├── Figures
└── CGT.jpg
├── datasets
├── __pycache__
│ ├── bgp.cpython-39.pyc
│ ├── film.cpython-39.pyc
│ ├── webkb.cpython-39.pyc
│ ├── wiki.cpython-39.pyc
│ ├── __init__.cpython-39.pyc
│ └── airports.cpython-39.pyc
├── __init__.py
├── airports.py
├── webkb.py
├── wiki.py
├── film.py
└── bgp.py
├── exp_logs
├── cora-20231228-141613.log
├── cora-20231228-141930.log
├── cora-20231228-142159.log
└── cora-20231228-142343.log
├── CITATION.cff
├── LICENSE
├── metrics.py
├── environment.yml
├── util.py
├── README.md
├── dataset.py
├── exp.py
├── script_classification.py
├── models.py
└── gnnutils.py
/Figures/CGT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/Figures/CGT.jpg
--------------------------------------------------------------------------------
/datasets/__pycache__/bgp.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/bgp.cpython-39.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/film.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/film.cpython-39.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/webkb.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/webkb.cpython-39.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/wiki.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/wiki.cpython-39.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/airports.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/airports.cpython-39.pyc
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from datasets.airports import Airports
2 | from datasets.bgp import BGP
3 | from datasets.film import FilmNetwork
4 | from datasets.webkb import WebKB
5 | from datasets.wiki import WikipediaNetwork
--------------------------------------------------------------------------------
/exp_logs/cora-20231228-141613.log:
--------------------------------------------------------------------------------
1 | INFO:root:Starting on device: cuda:0
2 | INFO:root:Config: Namespace(dataset='cora', model='Transformer', run_times=1, drop=0.5, custom_masks=True, device='cuda:0', device_2='cuda:0', lr=0.001, epochs=3, dims=64, out_size=64, k_transition=3, num_layers=4, num_heads=2, output_path='outputs/', pretrain_check=0, aug_check=1, sim_check=1, phi_check=1, alfa=0.1, beta=0.9, run_times_fine=300, index_excel=-1, file_name='outputs_excels/cora.xlsx', alpha_1=1, alpha_2=1, alpha_3=1, test_node_degree=0)
3 |
--------------------------------------------------------------------------------
/exp_logs/cora-20231228-141930.log:
--------------------------------------------------------------------------------
1 | INFO:root:Starting on device: cuda:0
2 | INFO:root:Config: Namespace(dataset='cora', model='Transformer', run_times=1, drop=0.5, custom_masks=True, device='cuda:0', device_2='cuda:0', lr=0.001, epochs=3, dims=64, out_size=64, k_transition=3, num_layers=4, num_heads=2, output_path='outputs/', pretrain_check=0, aug_check=1, sim_check=1, phi_check=1, alfa=0.2, beta=0.95, run_times_fine=300, index_excel=-1, file_name='outputs_excels/cora.xlsx', alpha_1=1, alpha_2=1, alpha_3=1, test_node_degree=0)
3 |
--------------------------------------------------------------------------------
/exp_logs/cora-20231228-142159.log:
--------------------------------------------------------------------------------
1 | INFO:root:Starting on device: cuda:0
2 | INFO:root:Config: Namespace(dataset='cora', model='Transformer', run_times=1, drop=0.5, custom_masks=True, device='cuda:0', device_2='cuda:0', lr=0.001, epochs=3, dims=64, out_size=64, k_transition=3, num_layers=4, num_heads=2, output_path='outputs/', pretrain_check=0, aug_check=1, sim_check=1, phi_check=1, alfa=0.1, beta=0.9, run_times_fine=200, index_excel=-1, file_name='outputs_excels/cora.xlsx', alpha_1=1, alpha_2=1, alpha_3=1, test_node_degree=0)
3 |
--------------------------------------------------------------------------------
/exp_logs/cora-20231228-142343.log:
--------------------------------------------------------------------------------
1 | INFO:root:Starting on device: cuda:0
2 | INFO:root:Config: Namespace(dataset='cora', model='Transformer', run_times=1, drop=0.5, custom_masks=True, device='cuda:0', device_2='cuda:0', lr=0.001, epochs=3, dims=64, out_size=64, k_transition=3, num_layers=4, num_heads=2, output_path='outputs/', pretrain_check=0, aug_check=1, sim_check=1, phi_check=1, alfa=0.1, beta=0.9, run_times_fine=200, index_excel=-1, file_name='outputs_excels/cora.xlsx', alpha_1=1, alpha_2=1, alpha_3=1, test_node_degree=0)
3 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.0.0
2 | date-released: 2025-04
3 | message: "If you use this software, please cite it as below."
4 | authors:
5 | - family-names: "Hoang"
6 | given-names: "Van Thuy"
7 | - family-names: "Jeon"
8 | given-names: "Hyeon-Ju"
9 | - family-names: "Lee"
10 | given-names: "O-Joun"
11 | title: "Mitigating Degree Bias in Graph Representation Learning With Learnable Structural Augmentation and Structural Self-Attention"
12 | url: "https://ieeexplore.ieee.org/document/10974679"
13 | preferred-citation:
14 | type: article
15 | journal: "IEEE Transactions on Network Science and Engineering"
16 | authors:
17 | - family-names: "Hoang"
18 | given-names: "Van Thuy"
19 | - family-names: "Jeon"
20 | given-names: "Hyeon-Ju"
21 | - family-names: "Lee"
22 | given-names: "O-Joun"
23 | title: "Mitigating Degree Bias in Graph Representation Learning With Learnable Structural Augmentation and Structural Self-Attention"
24 | url: "https://ieeexplore.ieee.org/document/10974679"
25 | year: 2025
26 | publisher: "IEEE"
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 NS Lab@CUK
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/datasets/airports.py:
--------------------------------------------------------------------------------
1 | import shutil
2 |
3 | import networkx as nx
4 | import numpy as np
5 | import torch
6 | from torch_geometric.data import InMemoryDataset
7 | from torch_geometric.utils import *
8 |
9 |
10 | def get_degrees(G):
11 | num_nodes = G.number_of_nodes()
12 | return np.array([G.degree[i] for i in range(num_nodes)])
13 |
14 |
15 | class Airports(InMemoryDataset):
16 | def __init__(self, root, dataset_name, transform=None, pre_transform=None):
17 | self.dataset_name = dataset_name
18 | self.dump_location = "raw_data_src/airports_dataset_dump"
19 | super(Airports, self).__init__(root,transform, pre_transform)
20 | self.data, self.slices = torch.load(self.processed_paths[0])
21 |
22 | @property
23 | def raw_file_names(self):
24 | return [self.dataset_name+'-airports.edgelist', 'labels-'+self.dataset_name+'-airports.txt']
25 |
26 | @property
27 | def processed_file_names(self):
28 | return 'data.pt'
29 |
30 | def download(self):
31 | for name in self.raw_file_names:
32 | source = self.dump_location + '/' + name
33 | shutil.copy(source, self.raw_dir)
34 |
35 | def process(self):
36 |
37 | fin_labels = open(self.raw_paths[1])
38 | labels = []
39 | node_id_mapping = dict()
40 | node_id_labels_dict = dict()
41 | for new_id, line in enumerate(fin_labels.readlines()[1:]): # first line is header so ignore
42 | old_id, label = line.strip().split()
43 | labels.append(int(label))
44 | node_id_mapping[old_id] = new_id
45 | node_id_labels_dict[new_id] = int(label)
46 | fin_labels.close()
47 |
48 | edges = []
49 | fin_edges = open(self.raw_paths[0])
50 | for line in fin_edges.readlines():
51 | node1, node2 = line.strip().split()[:2]
52 | edges.append([node_id_mapping[node1], node_id_mapping[node2]])
53 | fin_edges.close()
54 |
55 | networkx_graph = nx.Graph(edges)
56 |
57 | print("No. of Nodes: ",networkx_graph.number_of_nodes())
58 | print("No. of edges: ",networkx_graph.number_of_edges())
59 |
60 |
61 | attr = {}
62 | for node in networkx_graph.nodes():
63 | deg = networkx_graph.degree(node)
64 | attr[node] = {"y": node_id_labels_dict[node], "x": [float(deg)]}
65 | nx.set_node_attributes(networkx_graph, attr)
66 | data = from_networkx(networkx_graph)
67 |
68 |
69 | if self.pre_filter is not None:
70 | data = self.pre_filter(data)
71 |
72 | if self.pre_transform is not None:
73 | data = self.pre_transform(data)
74 |
75 | torch.save(self.collate([data]), self.processed_paths[0])
76 |
77 |
78 | def __repr__(self):
79 | return '{}()'.format(self.__class__.__name__)
--------------------------------------------------------------------------------
/datasets/webkb.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import numpy as np
4 | import torch
5 | from torch_geometric.data import InMemoryDataset, download_url, Data
6 | from torch_geometric.utils import to_undirected
7 | from torch_sparse import coalesce
8 |
9 |
10 | class WebKB(InMemoryDataset):
11 |
12 | url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master'
13 |
14 | def __init__(self, root, name, transform=None, pre_transform=None):
15 | self.name = name.lower()
16 | assert self.name in ['cornell', 'texas', 'wisconsin']
17 |
18 | super(WebKB, self).__init__(root, transform, pre_transform)
19 | self.data, self.slices = torch.load(self.processed_paths[0])
20 |
21 | @property
22 | def raw_dir(self):
23 | return osp.join(self.root, self.name, 'raw')
24 |
25 | @property
26 | def processed_dir(self):
27 | return osp.join(self.root, self.name, 'processed')
28 |
29 | @property
30 | def raw_file_names(self):
31 | return ['out1_node_feature_label.txt', 'out1_graph_edges.txt'] + [
32 | '{}_split_0.6_0.2_{}.npz'.format(self.name, i) for i in range(10)
33 | ]
34 |
35 | @property
36 | def processed_file_names(self):
37 | return 'data.pt'
38 |
39 | def download(self):
40 | for f in self.raw_file_names[:2]:
41 | download_url(f'{self.url}/new_data/{self.name}/{f}', self.raw_dir)
42 | for f in self.raw_file_names[2:]:
43 | download_url(f'{self.url}/splits/{f}', self.raw_dir)
44 |
45 | def process(self):
46 | with open(self.raw_paths[0], 'r') as f:
47 | data = f.read().split('\n')[1:-1]
48 | x = [[float(v) for v in r.split('\t')[1].split(',')] for r in data]
49 | x = torch.tensor(x, dtype=torch.float)
50 |
51 | y = [int(r.split('\t')[2]) for r in data]
52 | y = torch.tensor(y, dtype=torch.long)
53 |
54 | with open(self.raw_paths[1], 'r') as f:
55 | data = f.read().split('\n')[1:-1]
56 | data = [[int(v) for v in r.split('\t')] for r in data]
57 | edge_index = torch.tensor(data, dtype=torch.long).t().contiguous()
58 | edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0))
59 |
60 | train_masks, val_masks, test_masks = [], [], []
61 | for f in self.raw_paths[2:]:
62 | tmp = np.load(f)
63 | train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)]
64 | val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)]
65 | test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)]
66 | train_mask = torch.stack(train_masks, dim=1)
67 | val_mask = torch.stack(val_masks, dim=1)
68 | test_mask = torch.stack(test_masks, dim=1)
69 |
70 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
71 | val_mask=val_mask, test_mask=test_mask)
72 |
73 | data.edge_index = to_undirected(data.edge_index)
74 | data = data if self.pre_transform is None else self.pre_transform(data)
75 | torch.save(self.collate([data]), self.processed_paths[0])
76 |
77 | def __repr__(self):
78 | return '{}()'.format(self.name)
--------------------------------------------------------------------------------
/datasets/wiki.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import numpy as np
4 | import torch
5 | from torch_geometric.data import InMemoryDataset, download_url, Data
6 | from torch_geometric.utils import to_undirected
7 | from torch_sparse import coalesce
8 |
9 |
10 | class WikipediaNetwork(InMemoryDataset):
11 |
12 | url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master'
13 |
14 | def __init__(self, root, name, transform=None, pre_transform=None):
15 | self.name = name.lower()
16 | assert self.name in ['chameleon', 'squirrel']
17 |
18 | super(WikipediaNetwork, self).__init__(root, transform, pre_transform)
19 | self.data, self.slices = torch.load(self.processed_paths[0])
20 |
21 | @property
22 | def raw_dir(self):
23 | return osp.join(self.root, self.name, 'raw')
24 |
25 | @property
26 | def processed_dir(self):
27 | return osp.join(self.root, self.name, 'processed')
28 |
29 | @property
30 | def raw_file_names(self):
31 | return ['out1_node_feature_label.txt', 'out1_graph_edges.txt'] + [
32 | '{}_split_0.6_0.2_{}.npz'.format(self.name, i) for i in range(10)
33 | ]
34 |
35 | @property
36 | def processed_file_names(self):
37 | return 'data.pt'
38 |
39 | def download(self):
40 | for f in self.raw_file_names[:2]:
41 | download_url(f'{self.url}/new_data/{self.name}/{f}', self.raw_dir)
42 | for f in self.raw_file_names[2:]:
43 | download_url(f'{self.url}/splits/{f}', self.raw_dir)
44 |
45 | def process(self):
46 | with open(self.raw_paths[0], 'r') as f:
47 | data = f.read().split('\n')[1:-1]
48 | x = [[float(v) for v in r.split('\t')[1].split(',')] for r in data]
49 | x = torch.tensor(x, dtype=torch.float)
50 |
51 | y = [int(r.split('\t')[2]) for r in data]
52 | y = torch.tensor(y, dtype=torch.long)
53 |
54 | with open(self.raw_paths[1], 'r') as f:
55 | data = f.read().split('\n')[1:-1]
56 | data = [[int(v) for v in r.split('\t')] for r in data]
57 | edge_index = torch.tensor(data, dtype=torch.long).t().contiguous()
58 | edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0))
59 |
60 | train_masks, val_masks, test_masks = [], [], []
61 | for f in self.raw_paths[2:]:
62 | tmp = np.load(f)
63 | train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)]
64 | val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)]
65 | test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)]
66 | train_mask = torch.stack(train_masks, dim=1)
67 | val_mask = torch.stack(val_masks, dim=1)
68 | test_mask = torch.stack(test_masks, dim=1)
69 |
70 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
71 | val_mask=val_mask, test_mask=test_mask)
72 |
73 | data.edge_index = to_undirected(data.edge_index)
74 |
75 | data = data if self.pre_transform is None else self.pre_transform(data)
76 | torch.save(self.collate([data]), self.processed_paths[0])
77 |
78 | def __repr__(self):
79 | return '{}()'.format(self.name)
80 |
--------------------------------------------------------------------------------
/datasets/film.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import numpy as np
4 | import torch
5 | from torch_geometric.data import InMemoryDataset, download_url, Data
6 | from torch_geometric.utils import to_undirected
7 | from torch_sparse import coalesce
8 |
9 |
10 | class FilmNetwork(InMemoryDataset):
11 |
12 | url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master'
13 |
14 | def __init__(self, root, name, transform=None, pre_transform=None):
15 | self.name = name.lower()
16 | assert self.name in ['film']
17 |
18 | super(FilmNetwork, self).__init__(root, transform, pre_transform)
19 | self.data, self.slices = torch.load(self.processed_paths[0])
20 |
21 | @property
22 | def raw_dir(self):
23 | return osp.join(self.root, self.name, 'raw')
24 |
25 | @property
26 | def processed_dir(self):
27 | return osp.join(self.root, self.name, 'processed')
28 |
29 | @property
30 | def raw_file_names(self):
31 | return ['out1_node_feature_label.txt', 'out1_graph_edges.txt'] + [
32 | '{}_split_0.6_0.2_{}.npz'.format(self.name, i) for i in range(10)
33 | ]
34 |
35 | @property
36 | def processed_file_names(self):
37 | return 'data.pt'
38 |
39 | def download(self):
40 | for f in self.raw_file_names[:2]:
41 | download_url(f'{self.url}/new_data/{self.name}/{f}', self.raw_dir)
42 | for f in self.raw_file_names[2:]:
43 | download_url(f'{self.url}/splits/{f}', self.raw_dir)
44 |
45 | def process(self):
46 | with open(self.raw_paths[0], 'r') as f:
47 | data = f.read().split('\n')[1:-1]
48 |
49 | x = []
50 | for r in data:
51 | f = np.zeros(932, dtype=np.float)
52 | idx = list(map(int, r.split('\t')[1].split(",")))
53 | f[idx] = 1.0
54 | x.append(f)
55 |
56 | x = torch.tensor(x, dtype=torch.float)
57 |
58 | y = [int(r.split('\t')[2]) for r in data]
59 | y = torch.tensor(y, dtype=torch.long)
60 |
61 | with open(self.raw_paths[1], 'r') as f:
62 | data = f.read().split('\n')[1:-1]
63 | data = [[int(v) for v in r.split('\t')] for r in data]
64 | edge_index = torch.tensor(data, dtype=torch.long).t().contiguous()
65 | edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0))
66 |
67 | train_masks, val_masks, test_masks = [], [], []
68 | for f in self.raw_paths[2:]:
69 | tmp = np.load(f)
70 | train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)]
71 | val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)]
72 | test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)]
73 | train_mask = torch.stack(train_masks, dim=1)
74 | val_mask = torch.stack(val_masks, dim=1)
75 | test_mask = torch.stack(test_masks, dim=1)
76 |
77 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
78 | val_mask=val_mask, test_mask=test_mask)
79 |
80 | data.edge_index = to_undirected(data.edge_index)
81 |
82 | data = data if self.pre_transform is None else self.pre_transform(data)
83 | torch.save(self.collate([data]), self.processed_paths[0])
84 |
85 | def __repr__(self):
86 | return '{}()'.format(self.name)
87 |
--------------------------------------------------------------------------------
/datasets/bgp.py:
--------------------------------------------------------------------------------
1 | import json
2 | import shutil
3 |
4 | import networkx as nx
5 | import numpy as np
6 | import torch
7 | from torch_geometric.data import InMemoryDataset
8 | from torch_geometric.utils import *
9 |
10 |
11 | def convert_ndarray(x):
12 | y = list(range(len(x)))
13 | for k, v in x.items():
14 | y[int(k)] = v
15 | return np.array(y)
16 |
17 |
18 | def check_rm(neighbors_set, unlabeled_nodes):
19 | for node in neighbors_set:
20 | if node not in unlabeled_nodes:
21 | return False
22 | return True
23 |
24 |
25 | def rm_useless(G, feats, class_map, unlabeled_nodes, num_layers):
26 | # find useless nodes
27 | print('start to check and remove {} unlabeled nodes'.format(len(unlabeled_nodes)))
28 |
29 | rm_nodes = unlabeled_nodes
30 | if len(rm_nodes):
31 | for node in rm_nodes:
32 | G.remove_node(node)
33 | G_new = nx.relabel.convert_node_labels_to_integers(G, ordering='sorted')
34 | feats = np.delete(feats, rm_nodes, 0)
35 | class_map = np.delete(class_map, rm_nodes, 0)
36 | print('remove {} '.format(len(rm_nodes)), 'useless unlabeled nodes')
37 | return G_new, feats, class_map
38 |
39 | class BGP(InMemoryDataset):
40 |
41 |
42 | def __init__(self, root, transform=None, pre_transform=None):
43 |
44 | self.dump_location = "raw_data_src/bgp_data_dump"
45 | super(BGP, self).__init__(root, transform, pre_transform)
46 | self.data, self.slices = torch.load(self.processed_paths[0])
47 |
48 | @property
49 | def raw_file_names(self):
50 | return ['as-G.json', 'as-class_map.json', 'as-feats.npy','as-feats_t.npy', 'as-edge_list']
51 |
52 | @property
53 | def processed_file_names(self):
54 | return 'data.pt'
55 |
56 | def download(self):
57 | for name in self.raw_file_names:
58 | # download_url('{}/{}'.format(self.url, name), self.raw_dir)
59 | source = self.dump_location + '/' + name
60 | shutil.copy(source, self.raw_dir)
61 |
62 | def process(self):
63 | G = nx.json_graph.node_link_graph(json.load(open(self.raw_paths[0])), False)
64 | class_map = json.load(open(self.raw_paths[1]))
65 | feats = np.load(self.raw_paths[2])
66 | feats_t = np.load(self.raw_paths[3])
67 |
68 |
69 | train_nodes = [n for n in G.nodes() if not G.nodes[n]['test'] and not G.nodes[n]['val']]
70 | val_nodes = [n for n in G.nodes() if not G.nodes[n]['test'] and G.nodes[n]['val']]
71 | test_nodes = [n for n in G.nodes() if G.nodes[n]['test'] and not G.nodes[n]['val']]
72 | unlabeled_nodes = [n for n in G.nodes() if G.nodes[n]['test'] and G.nodes[n]['val']]
73 | class_map = convert_ndarray(class_map)
74 |
75 | G, feats, class_map = rm_useless(G, feats, class_map, unlabeled_nodes, 1)
76 | train_nodes = [n for n in G.nodes() if not G.nodes[n]['test'] and not G.nodes[n]['val']]
77 | val_nodes = [n for n in G.nodes() if not G.nodes[n]['test'] and G.nodes[n]['val']]
78 | test_nodes = [n for n in G.nodes() if G.nodes[n]['test'] and not G.nodes[n]['val']]
79 | unlabeled_nodes = [n for n in G.nodes() if G.nodes[n]['test'] and G.nodes[n]['val']]
80 |
81 |
82 | data = from_networkx(G)
83 | data.train_mask = ~(data.test | data.val)
84 | data.val_mask = data.val
85 | data.test_mask = data.test
86 | data.test = None
87 | data.val = None
88 | data.x = torch.FloatTensor(feats)
89 | data.y = torch.LongTensor(np.argmax(class_map, axis=1))
90 |
91 | if self.pre_filter is not None:
92 | data = self.pre_filter(data)
93 |
94 | if self.pre_transform is not None:
95 | data.edge_index = to_undirected(data.edge_index)
96 | data = self.pre_transform(data)
97 |
98 | torch.save(self.collate([data]), self.processed_paths[0])
99 |
100 |
101 | def __repr__(self):
102 | return '{}()'.format(self.__class__.__name__)
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from sklearn.metrics import confusion_matrix
6 | from sklearn.metrics import f1_score
7 | import numpy as np
8 |
9 |
10 | def MAE(scores, targets):
11 | MAE = F.l1_loss(scores, targets)
12 | MAE = MAE.detach().item()
13 | return MAE
14 |
15 |
16 | def accuracy_TU(scores, targets):
17 | scores = scores.detach().argmax(dim=1)
18 | acc = (scores == targets).float().sum().item()
19 | return acc
20 |
21 |
22 | def accuracy_MNIST_CIFAR(scores, targets):
23 | scores = scores.detach().argmax(dim=1)
24 | acc = (scores == targets).float().sum().item()
25 | return acc
26 |
27 |
28 | def accuracy_CITATION_GRAPH(scores, targets):
29 | scores = scores.detach().argmax(dim=1)
30 | acc = (scores == targets).float().sum().item()
31 | acc = acc / len(targets)
32 | return acc
33 |
34 |
35 | def accuracy_SBM(scores, targets):
36 | S = targets.cpu().numpy()
37 | C = np.argmax(torch.nn.Softmax(dim=1)(scores).cpu().detach().numpy(), axis=1)
38 | CM = confusion_matrix(S, C).astype(np.float32)
39 | nb_classes = CM.shape[0]
40 | targets = targets.cpu().detach().numpy()
41 | nb_non_empty_classes = 0
42 | pr_classes = np.zeros(nb_classes)
43 | for r in range(nb_classes):
44 | cluster = np.where(targets == r)[0]
45 | if cluster.shape[0] != 0:
46 | pr_classes[r] = CM[r, r] / float(cluster.shape[0])
47 | if CM[r, r] > 0:
48 | nb_non_empty_classes += 1
49 | else:
50 | pr_classes[r] = 0.0
51 | acc = 100. * np.sum(pr_classes) / float(nb_classes)
52 | return acc
53 |
54 |
55 | def binary_f1_score(scores, targets):
56 | """Computes the F1 score using scikit-learn for binary class labels.
57 |
58 | Returns the F1 score for the positive class, i.e. labelled '1'.
59 | """
60 | y_true = targets.cpu().numpy()
61 | y_pred = scores.argmax(dim=1).cpu().numpy()
62 | return f1_score(y_true, y_pred, average='binary')
63 |
64 |
65 | def accuracy_VOC(scores, targets):
66 | scores = scores.detach().argmax(dim=1).cpu()
67 | targets = targets.cpu().detach().numpy()
68 | acc = f1_score(scores, targets, average='weighted')
69 | return acc
70 |
71 |
72 | # clustering
73 | from sklearn.metrics.cluster import contingency_matrix
74 |
75 |
76 | def _compute_counts(y_true, y_pred): # TODO(tsitsulin): add docstring pylint: disable=missing-function-docstring
77 | contingency = contingency_matrix(y_true, y_pred)
78 | same_class_true = np.max(contingency, 1)
79 | same_class_pred = np.max(contingency, 0)
80 | diff_class_true = contingency.sum(axis=1) - same_class_true
81 | diff_class_pred = contingency.sum(axis=0) - same_class_pred
82 | total = contingency.sum()
83 |
84 | true_positives = (same_class_true * (same_class_true - 1)).sum()
85 | false_positives = (diff_class_true * same_class_true * 2).sum()
86 | false_negatives = (diff_class_pred * same_class_pred * 2).sum()
87 | true_negatives = total * (total - 1) - true_positives - false_positives - false_negatives
88 |
89 | return true_positives, false_positives, false_negatives, true_negatives
90 |
91 |
92 | def modularity(adjacency, clusters):
93 | """Computes graph modularity.
94 | Args:
95 | adjacency: Input graph in terms of its sparse adjacency matrix.
96 | clusters: An (n,) int cluster vector.
97 |
98 | Returns:
99 | The value of graph modularity.
100 | https://en.wikipedia.org/wiki/Modularity_(networks)
101 | """
102 | degrees = adjacency.sum(axis=0).A1
103 | n_edges = degrees.sum() # Note that it's actually 2*n_edges.
104 | result = 0
105 | for cluster_id in np.unique(clusters):
106 | cluster_indices = np.where(clusters == cluster_id)[0]
107 | adj_submatrix = adjacency[cluster_indices, :][:, cluster_indices]
108 | degrees_submatrix = degrees[cluster_indices]
109 | result += np.sum(adj_submatrix) - (np.sum(degrees_submatrix) ** 2) / n_edges
110 | return result / n_edges
111 |
112 |
113 | def precision(y_true, y_pred):
114 | true_positives, false_positives, _, _ = _compute_counts(y_true, y_pred)
115 | return true_positives / (true_positives + false_positives)
116 |
117 |
118 | def recall(y_true, y_pred):
119 | true_positives, _, false_negatives, _ = _compute_counts(y_true, y_pred)
120 | return true_positives / (true_positives + false_negatives)
121 |
122 |
123 | def accuracy_score(y_true, y_pred):
124 | true_positives, false_positives, false_negatives, true_negatives = _compute_counts(
125 | y_true, y_pred)
126 | return (true_positives + true_negatives) / (true_positives + false_positives + false_negatives + true_negatives)
127 |
128 |
129 | def conductance(adjacency, clusters): # TODO(tsitsulin): add docstring pylint: disable=missing-function-docstring
130 | inter = 0
131 | intra = 0
132 | cluster_idx = np.zeros(adjacency.shape[0], dtype=bool)
133 | for cluster_id in np.unique(clusters):
134 | cluster_idx[:] = 0
135 | cluster_idx[np.where(clusters == cluster_id)[0]] = 1
136 | adj_submatrix = adjacency[cluster_idx, :]
137 | inter += np.sum(adj_submatrix[:, cluster_idx])
138 | intra += np.sum(adj_submatrix[:, ~cluster_idx])
139 | return intra / (inter + intra)
140 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gt01
2 | channels:
3 | - pytorch
4 | - dglteam
5 | - dglteam/label/cu117
6 | - nvidia
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=main
10 | - _openmp_mutex=5.1=1_gnu
11 | - appdirs=1.4.4=pyhd3eb1b0_0
12 | - blas=1.0=mkl
13 | - brotlipy=0.7.0=py39h27cfd23_1003
14 | - bzip2=1.0.8=h7b6447c_0
15 | - ca-certificates=2023.05.30=h06a4308_0
16 | - certifi=2023.5.7=py39h06a4308_0
17 | - cffi=1.15.1=py39h5eee18b_3
18 | - cryptography=39.0.1=py39h9ce1e76_1
19 | - cuda-cudart=11.7.99=0
20 | - cuda-cupti=11.7.101=0
21 | - cuda-libraries=11.7.1=0
22 | - cuda-nvrtc=11.7.99=0
23 | - cuda-nvtx=11.7.91=0
24 | - cuda-runtime=11.7.1=0
25 | - dgl=1.1.0.cu117=py39_0
26 | - dgl-cuda10.2=0.9.1post1=py39_0
27 | - ffmpeg=4.3=hf484d3e_0
28 | - freetype=2.12.1=h4a9f257_0
29 | - giflib=5.2.1=h5eee18b_3
30 | - gmp=6.2.1=h295c915_3
31 | - gmpy2=2.1.2=py39heeb90bb_0
32 | - gnutls=3.6.15=he1e5248_0
33 | - jinja2=3.1.2=py39h06a4308_0
34 | - jpeg=9e=h5eee18b_1
35 | - lame=3.100=h7b6447c_0
36 | - lcms2=2.12=h3be6417_0
37 | - ld_impl_linux-64=2.38=h1181459_1
38 | - lerc=3.0=h295c915_0
39 | - libcublas=11.10.3.66=0
40 | - libcufft=10.7.2.124=h4fbf590_0
41 | - libcufile=1.6.1.9=0
42 | - libcurand=10.3.2.106=0
43 | - libcusolver=11.4.0.1=0
44 | - libcusparse=11.7.4.91=0
45 | - libdeflate=1.17=h5eee18b_0
46 | - libffi=3.4.4=h6a678d5_0
47 | - libgcc-ng=11.2.0=h1234567_1
48 | - libgfortran-ng=11.2.0=h00389a5_1
49 | - libgfortran5=11.2.0=h1234567_1
50 | - libgomp=11.2.0=h1234567_1
51 | - libiconv=1.16=h7f8727e_2
52 | - libidn2=2.3.4=h5eee18b_0
53 | - libnpp=11.7.4.75=0
54 | - libnvjpeg=11.8.0.2=0
55 | - libpng=1.6.39=h5eee18b_0
56 | - libstdcxx-ng=11.2.0=h1234567_1
57 | - libtasn1=4.19.0=h5eee18b_0
58 | - libtiff=4.5.0=h6a678d5_2
59 | - libunistring=0.9.10=h27cfd23_0
60 | - libwebp=1.2.4=h11a3e52_1
61 | - libwebp-base=1.2.4=h5eee18b_1
62 | - lz4-c=1.9.4=h6a678d5_0
63 | - mkl-service=2.4.0=py39h5eee18b_1
64 | - mkl_fft=1.3.6=py39h417a72b_1
65 | - mkl_random=1.2.2=py39h417a72b_1
66 | - mpc=1.1.0=h10f8cd9_1
67 | - mpfr=4.0.2=hb69a4c5_1
68 | - ncurses=6.4=h6a678d5_0
69 | - nettle=3.7.3=hbbd107a_1
70 | - numpy-base=1.24.3=py39h060ed82_1
71 | - openh264=2.1.1=h4ff587b_0
72 | - openssl=3.0.8=h7f8727e_0
73 | - pooch=1.4.0=pyhd3eb1b0_0
74 | - pycparser=2.21=pyhd3eb1b0_0
75 | - pyopenssl=23.0.0=py39h06a4308_0
76 | - pysocks=1.7.1=py39h06a4308_0
77 | - python=3.9.16=h955ad1f_3
78 | - pytorch=2.0.0=py3.9_cuda11.7_cudnn8.5.0_0
79 | - pytorch-cuda=11.7=h778d358_5
80 | - pytorch-mutex=1.0=cuda
81 | - readline=8.2=h5eee18b_0
82 | - setuptools=67.8.0=py39h06a4308_0
83 | - sqlite=3.41.2=h5eee18b_0
84 | - tbb=2021.8.0=hdb19cb5_0
85 | - tk=8.6.12=h1ccaba5_0
86 | - torchtriton=2.0.0=py39
87 | - typing_extensions=4.5.0=py39h06a4308_0
88 | - wheel=0.38.4=py39h06a4308_0
89 | - xz=5.4.2=h5eee18b_0
90 | - zlib=1.2.13=h5eee18b_0
91 | - zstd=1.5.5=hc292b87_0
92 | - pip:
93 | - asgiref==3.7.2
94 | - backcall==0.2.0
95 | - chardet==3.0.4
96 | - charset-normalizer==3.1.0
97 | - cmake==3.26.4
98 | - contourpy==1.0.7
99 | - cycler==0.11.0
100 | - decorator==4.4.2
101 | - django==4.2.2
102 | - et-xmlfile==1.1.0
103 | - fastdtw==0.3.4
104 | - fastjsonschema==2.17.1
105 | - filelock==3.12.1
106 | - fonttools==4.39.4
107 | - idna==2.8
108 | - image==1.5.33
109 | - importlib-resources==5.12.0
110 | - intel-openmp==2023.1.0
111 | - ipython-genutils==0.2.0
112 | - joblib==1.2.0
113 | - kiwisolver==1.4.4
114 | - lit==16.0.5.post0
115 | - littleutils==0.2.2
116 | - markupsafe==2.1.3
117 | - matplotlib==3.7.1
118 | - mistune==0.8.4
119 | - mkl==2019.0
120 | - mpmath==1.3.0
121 | - networkx==2.5.1
122 | - numpy==1.22.4
123 | - oauthlib==3.2.2
124 | - ogb==1.3.6
125 | - openpyxl==3.1.2
126 | - outdated==0.2.2
127 | - packaging==23.1
128 | - pandas==2.0.2
129 | - pandocfilters==1.5.0
130 | - parso==0.8.3
131 | - pexpect==4.8.0
132 | - pickleshare==0.7.5
133 | - pillow==9.5.0
134 | - pip==23.1.2
135 | - prometheus-client==0.17.0
136 | - prompt-toolkit==2.0.10
137 | - protobuf==4.23.2
138 | - psutil==5.9.5
139 | - ptyprocess==0.7.0
140 | - pyasn1==0.5.0
141 | - pyg-lib==0.2.0+pt20cu117
142 | - pygments==2.15.1
143 | - pyparsing==3.0.9
144 | - pyrsistent==0.19.3
145 | - python-dateutil==2.8.2
146 | - pytz==2023.3
147 | - pyzmq==25.1.0
148 | - requests==2.31.0
149 | - retrying==1.3.4
150 | - scikit-learn==0.24.1
151 | - scipy==1.6.2
152 | - send2trash==1.8.2
153 | - six==1.15.0
154 | - sqlparse==0.4.4
155 | - sympy==1.12
156 | - testpath==0.6.0
157 | - threadpoolctl==3.1.0
158 | - torch==2.0.1+cu117
159 | - torch-cluster==1.6.1+pt20cu117
160 | - torch-geometric==2.3.1
161 | - torch-scatter==2.1.1+pt20cu117
162 | - torch-sparse==0.6.17+pt20cu117
163 | - torch-spline-conv==1.2.2+pt20cu117
164 | - torchaudio==2.0.2+cu117
165 | - torchvision==0.15.2+cu117
166 | - tornado==6.3.2
167 | - tqdm==4.60.0
168 | - traitlets==5.9.0
169 | - triton==2.0.0
170 | - typing-extensions==4.6.3
171 | - tzdata==2023.3
172 | - urllib3==1.25.11
173 | - wcwidth==0.2.6
174 | - webencodings==0.5.1
175 | - zipp==3.15.0
176 | prefix: /root/anaconda3/envs/gt01
177 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import numpy as np
3 | from sklearn.preprocessing import normalize
4 | from sklearn.cluster import KMeans
5 | from itertools import chain
6 | import copy, torch, dgl
7 |
8 |
9 | def GetProbTranMat(Ak, num_node):
10 | num_node, num_node2 = Ak.shape
11 | if (num_node != num_node2):
12 | print('M must be a square matrix!')
13 | Ak_sum = np.sum(Ak, axis=0).reshape(1, -1)
14 | Ak_sum = np.repeat(Ak_sum, num_node, axis=0)
15 | probTranMat = np.log(np.divide(Ak, Ak_sum)) - np.log(1. / num_node)
16 | probTranMat[probTranMat < 0] = 0; # set zero for negative and -inf elements
17 | probTranMat[np.isnan(probTranMat)] = 0; # set zero for nan elements (the isolated nodes)
18 | return probTranMat
19 |
20 |
21 | def getM_logM(nx_g, num_nodes, kstep=3):
22 | tran_M = []
23 | tran_logM = []
24 | Adj = np.zeros((num_nodes, num_nodes))
25 | for src in nx_g.nodes():
26 | src_degree = nx_g.degree(src)
27 | for dst in nx_g.nodes():
28 | if nx_g.has_edge(src, dst):
29 | Adj[src][dst] = round(1 / src_degree, 3)
30 |
31 | Ak = np.matrix(np.identity(num_nodes))
32 | for i in range(kstep):
33 | Ak = np.dot(Ak, Adj)
34 | tran_M.append(Ak)
35 | probTranMat = GetProbTranMat(Ak, num_nodes)
36 | tran_logM.append(probTranMat)
37 |
38 | return tran_M, tran_logM
39 |
40 |
41 | def get_distance(deg_A, deg_B):
42 | damp = 1 / np.sqrt(deg_A * deg_B)
43 | return damp
44 |
45 |
46 | def get_B_sim_phi(nx_g, tran_M, num_nodes, n_class, X, kstep=5):
47 | print(f'processing get_B_sim_phi')
48 | count = 0
49 | B = np.zeros((num_nodes, num_nodes))
50 | colour = np.zeros((num_nodes, num_nodes))
51 | phi = np.zeros((num_nodes, num_nodes, 1))
52 | sim = np.zeros((num_nodes, num_nodes, kstep))
53 |
54 | trans_check = tran_M[kstep - 1]
55 | not_adj = tran_M[0]
56 |
57 | kmeans = KMeans(n_clusters=n_class, init='k-means++', max_iter=50, n_init=10, random_state=0)
58 |
59 | y_kmeans = kmeans.fit_predict(X)
60 | count = 0
61 | count_1 = 0
62 | for src in nx_g.nodes():
63 | if count % 50 == 0:
64 | print(f' processing node_th {src}/{num_nodes}')
65 | for dst in nx_g.nodes():
66 |
67 | if src == dst:
68 | continue
69 |
70 | if not_adj[src, dst] > 0:
71 | continue
72 |
73 | if colour[src, dst] == 1 or colour[src, dst] == 1:
74 | continue
75 | if trans_check[src, dst] > 0.001:
76 |
77 | src_d = nx_g.degree(src)
78 | dst_d = nx_g.degree(dst)
79 |
80 | if np.abs(src_d - dst_d) > 1:
81 | continue
82 |
83 | if y_kmeans[src] != y_kmeans[dst]:
84 | continue
85 | else:
86 | count_1 += 1
87 | d = get_distance(src_d, dst_d)
88 | # B i, j
89 | B[src, dst] = d
90 | B[dst, src] = d
91 | # phi i,j
92 | if phi[src, dst] == 0:
93 | phi[src, dst] = d
94 | phi[dst, src] = d
95 |
96 | colour[src, dst] = 1
97 | colour[dst, src] = 1
98 | B[src, src] = 0
99 | count += 1
100 |
101 |
102 |
103 | sim = compute_sim(tran_M, num_nodes, k_step=kstep)
104 |
105 | return B, sim, phi
106 |
107 |
108 | def compute_sim(tran_M, num_nodes, k_step=5):
109 | sim = np.zeros((num_nodes, num_nodes, k_step))
110 | trans_check = tran_M[k_step - 1]
111 |
112 | for step in range(k_step):
113 | print(f'compute_sim transition step {step + 1}/{k_step}')
114 | colour = np.zeros((num_nodes, num_nodes))
115 | trans_k = copy.deepcopy(tran_M[step])
116 | trans_k[trans_k >= 0.001] = 1
117 | trans_k[trans_k < 0.001] = 0
118 | trans_k = np.array(trans_k)
119 |
120 | row_sums = trans_k.sum(axis=1)
121 | trans_mul = trans_k @ trans_k.T
122 | for i in range(num_nodes):
123 |
124 | for j in range(i + 1, num_nodes):
125 | if trans_check[i, j] < 0.0001:
126 | continue
127 | if colour[i, j] == 1 or colour[j, i] == 1:
128 | continue
129 |
130 | score = np.round(trans_mul[i, j] / (row_sums[i] + row_sums[j] - trans_mul[i, j]), 4)
131 | if score < 0.001:
132 | score = 0
133 | sim[i, j, step] = score
134 | sim[j, i, step] = score
135 |
136 | colour[i, j] = 1
137 | colour[j, i] = 1
138 | return sim
139 |
140 |
141 | def get_A_D(nx_g, num_nodes):
142 | num_edges = nx_g.number_of_edges()
143 |
144 | d = np.zeros((num_nodes))
145 |
146 | Adj = np.zeros((num_nodes, num_nodes))
147 |
148 | for src in nx_g.nodes():
149 | src_degree = nx_g.degree(src)
150 | d[src] = src_degree
151 | for dst in nx_g.nodes():
152 | if nx_g.has_edge(src, dst):
153 | Adj[src][dst] = 1
154 |
155 |
156 | return Adj, d, num_edges
157 |
158 |
159 | def load_dgl(nx_g, x, sim, phi):
160 | print('loading dgl...')
161 | count = 0
162 | edge_idx1 = []
163 | edge_idx2 = []
164 | for e in nx_g.edges:
165 | edge_idx1.append(e[0])
166 | edge_idx2.append(e[1])
167 |
168 | edge_idx1.append(e[1])
169 | edge_idx2.append(e[0])
170 |
171 | s_vals = []
172 | phi_vals = []
173 | for i in range(len(edge_idx1)):
174 | count += 1
175 | n1 = edge_idx1[i]
176 | n2 = edge_idx2[i]
177 |
178 | s = np.asarray(sim[n1][n2], dtype=float)
179 | s_vals.append(s)
180 |
181 | p = np.asarray(phi[n1][n2], dtype=float)
182 | phi_vals.append(p)
183 |
184 | print(f'networkx: number edges: {count}')
185 | s_vals = np.array(s_vals)
186 | phi_vals = np.array(phi_vals)
187 |
188 | s_vals[np.isnan(s_vals)] = 0
189 | s_vals = normalize(s_vals, axis=0, norm='max')
190 | phi_vals = normalize(phi_vals, axis=0, norm='max')
191 |
192 | s_vals = torch.tensor(s_vals)
193 | phi_vals = torch.tensor(phi_vals)
194 |
195 | g = dgl.graph((edge_idx1, edge_idx2))
196 |
197 | s_vals[torch.isnan(s_vals)] = 0
198 | phi_vals[torch.isnan(phi_vals)] = 0
199 |
200 | g.ndata['x'] = x
201 | g.edata['sim'] = s_vals
202 | g.edata['phi'] = phi_vals
203 |
204 | print(f'loading dgl, done, DGL graph edges: {g.number_of_edges()}')
205 | return g
206 |
207 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Community-aware Graph Transformer
2 |
3 | Community-aware Graph Transformer (**CGT**) is a novel Graph Transformer model that utilizes community structures to address node degree biases in message-passing mechanism and developed by [NS Lab, CUK](https://nslab-cuk.github.io/) based on pure [PyTorch](https://github.com/pytorch/pytorch) backend. The paper is available on [arXiv](https://arxiv.org/abs/2312.16788).
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ## 1. Overview
25 |
26 | Recent augmentation-based methods showed that message-passing (MP) neural networks often perform poorly on low-degree nodes, leading to degree biases due to a lack of messages reaching low-degree nodes. Despite their success, most methods use heuristic or uniform random augmentations, which are non-differentiable and may not always generate valuable edges for learning representations. In this paper, we propose Community-aware Graph Transformers, namely **CGT**, to learn degree-unbiased representations based on learnable augmentations and graph transformers by extracting within community structures. We first design a learnable graph augmentation to generate more within-community edges connecting low-degree nodes through edge perturbation. Second, we propose an improved self-attention to learn underlying proximity and the roles of nodes within the community. Third, we propose a self-supervised learning task that could learn the representations to preserve the global graph structure and regularize the graph augmentations. Extensive experiments on various benchmark datasets showed CGT outperforms state-of-the-art baselines and significantly improves the node degree biases.
27 |
28 |
29 |
30 |
31 |
32 |
33 | The overall architecture of Community-aware Graph Transformer.
34 |
35 |
36 |
37 | ## 2. Reproducibility
38 |
39 | ### Datasets and Tasks
40 |
41 | We used six publicly available datasets, which are grouped into three different domains, including citation network (Cora, Citeseer, and Pubmed datasets), Co-purchase network networks (Amazon Computers and Photo datasets), and reference network (WikiCS). The datasets are automatically downloaded from Pytorch Geometric.
42 |
43 | ### Requirements and Environment Setup
44 |
45 | The source code was developed in Python 3.8.8. CGT is built using Torch-geometric 2.3.1 and DGL 1.1.0. Please refer to the official websites for installation and setup.
46 | All the requirements are included in the ```environment.yml``` file.
47 |
48 | ```
49 | # Conda installation
50 |
51 | # Install python environment
52 |
53 | conda env create -f environment.yml
54 | ```
55 | ### Hyperparameters
56 |
57 | The following Options can be passed to exp.py:
58 |
59 | ```--dataset:``` The name of dataset inputs. For example: ```--dataset cora```
60 |
61 | ```--lr:``` Learning rate for training the model. For example: ```--lr 0.001```
62 |
63 | ```--epochs:``` Number of epochs for pre-training the model. For example: ```--epochs 500```
64 |
65 | ```--run_times_fine:``` Number of epochs for fine-tuning the model. For example: ```--run_times_fine 500```
66 |
67 | ```--layers:``` Number of layers for model training. For example: ```--layers 4```
68 |
69 | ```--drop:``` Dropout rate. For example: ```--drop 0.5```
70 |
71 | ```--dims:``` The dimmension of hidden vectors. For example: ```--dims 64```.
72 |
73 | ```--k_transition:``` The number of transition step. For example: ```--k_transition 3```.
74 |
75 | ```--alpha:``` Hyperparameters for degree-related score. For example: ```--alpha 0.1```.
76 |
77 | ```--beta:``` Hyperparameters for adjacency matrix score. For example: ```--beta 0.95```.
78 |
79 | ```--alpha_1:``` Hyperparameters for transition construction loss. For example: ```--alpha_1 0.5```.
80 |
81 | ```--alpha_2:``` Hyperparameters for feature construction loss. For example: ```--alpha_2 0.5```.
82 |
83 | ```--alpha_3:``` Hyperparameters for augmentation loss. For example: ```--alpha_3 0.5```.
84 |
85 |
86 | ### How to run
87 |
88 | The source code contains both pre-training and fine-tuning processes.
89 | The following commands will run the pre-training process and fine-tune the **CGT** on Cora dataset for both node classification and clustering tasks.
90 |
91 | ```
92 |
93 | python exp.py --dataset cora
94 |
95 | ```
96 |
97 | ## 3. Reference
98 |
99 | :page_with_curl: Paper [on IEEE TNSE](https://ieeexplore.ieee.org/xpl/RecentIssue.jsp?punumber=6488902):
100 | * [](https://doi.org/10.1109/TNSE.2025.3563697)
101 |
102 | :page_with_curl: Paper [on arXiv](https://arxiv.org/):
103 | * [](https://arxiv.org/abs/2504.15075)
104 | [](https://arxiv.org/abs/2312.16788)
105 |
106 | :chart_with_upwards_trend: Experimental results [on Papers With Code](https://paperswithcode.com/):
107 | * [](https://paperswithcode.com/paper/mitigating-degree-bias-in-graph)
108 | [](https://paperswithcode.com/paper/mitigating-degree-biases-in-message-passing)
109 |
110 | :pencil: Blog post [on Network Science Lab](https://nslab-cuk.github.io/2023/08/17/UGT/):
111 | * [](https://nslab-cuk.github.io/2023/12/27/CGT/)
112 |
113 |
114 | ## 4. Citing CGT
115 |
116 | Please cite our [paper](https://ieeexplore.ieee.org/document/10974679) if you find *CGT* useful in your work:
117 | ```
118 | @Article{hoang2025mitigating_TNSE,
119 | author = {Van Thuy Hoang and Hyeon-Ju Jeon and O-Joun Lee},
120 | journal = {IEEE Transactions on Network Science and Engineering},
121 | title = {Mitigating Degree Bias in Graph Representation Learning with Learnable Structural Augmentation and Structural Self-Attention},
122 | year = {2025},
123 | issn = {2327-4697},
124 | volume = {12},
125 | number = {5},
126 | pages = {3656--3670},
127 | doi = {10.1109/TNSE.2025.3563697},
128 | }
129 |
130 | @misc{hoang2025mitigating,
131 | title={Mitigating Degree Bias in Graph Representation Learning with Learnable Structural Augmentation and Structural Self-Attention},
132 | author={Van Thuy Hoang and Hyeon-Ju Jeon and O-Joun Lee},
133 | year={2025},
134 | eprint={2504.15075},
135 | archivePrefix={arXiv},
136 | primaryClass={cs.AI}
137 | }
138 |
139 | @misc{hoang2023mitigating,
140 | title={Mitigating Degree Biases in Message Passing Mechanism by Utilizing Community Structures},
141 | author={Van Thuy Hoang and O-Joun Lee},
142 | year={2023},
143 | eprint={2312.16788},
144 | archivePrefix={arXiv},
145 | primaryClass={cs.LG}
146 | }
147 | ```
148 |
149 | Please take a look at our structure-preserving graph transformer model, [**UGT**](https://github.com/NSLab-CUK/Unified-Graph-Transformer), which has as high expessive power as 3-WL, together.
150 |
151 | ## 5. Contributors
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 | ***
162 |
163 |
164 |
165 | ***
166 |
167 |
168 |
169 |
170 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | #%%
2 | from typing import Optional, Callable
3 |
4 | import os.path as osp
5 |
6 | import torch
7 | import numpy as np
8 |
9 | from torch_geometric.utils import to_undirected
10 | from torch_geometric.data import InMemoryDataset, download_url, Data
11 | from torch_geometric.datasets import Planetoid, Amazon, WikipediaNetwork, Coauthor, WikiCS, SNAPDataset
12 |
13 |
14 |
15 | def mask_init(self, num_train_per_class=20, num_val_per_class=30, seed=12345):
16 | num_nodes = self.data.y.size(0)
17 | self.train_mask = torch.zeros([num_nodes], dtype=torch.bool)
18 | self.val_mask = torch.zeros([num_nodes], dtype=torch.bool)
19 | self.test_mask = torch.ones([num_nodes], dtype=torch.bool)
20 | np.random.seed(seed)
21 | for c in range(self.num_classes):
22 | samples_idx = (self.data.y == c).nonzero().squeeze()
23 | perm = list(range(samples_idx.size(0)))
24 | np.random.shuffle(perm)
25 | perm = torch.as_tensor(perm).long()
26 | self.train_mask[samples_idx[perm][:num_train_per_class]] = True
27 | self.val_mask[samples_idx[perm][num_train_per_class:num_train_per_class + num_val_per_class]] = True
28 | self.test_mask[self.train_mask] = False
29 | self.test_mask[self.val_mask] = False
30 |
31 |
32 | def mask_getitem(self, datum):
33 | datum.__setitem__("train_mask", self.train_mask)
34 | datum.__setitem__("val_mask", self.val_mask)
35 | datum.__setitem__("test_mask", self.test_mask)
36 | return datum
37 |
38 |
39 | class DigitizeY(object):
40 |
41 | def __init__(self, bins, transform_y=None):
42 | self.bins = np.asarray(bins)
43 | self.transform_y = transform_y
44 |
45 | def __call__(self, data):
46 | y = self.transform_y(data.y).numpy()
47 | digitized_y = np.digitize(y, self.bins)
48 | data.y = torch.from_numpy(digitized_y)
49 | return data
50 |
51 | def __repr__(self):
52 | return '{}(bins={})'.format(self.__class__.__name__, self.bins.tolist())
53 |
54 |
55 |
56 | class WikipediaNetwork_crocodile(InMemoryDataset):
57 | r"""The Wikipedia networks introduced in the
58 | `"Multi-scale Attributed Node Embedding"
59 | `_ paper.
60 | Nodes represent web pages and edges represent hyperlinks between them.
61 | Node features represent several informative nouns in the Wikipedia pages.
62 | The task is to predict the average daily traffic of the web page.
63 | Args:
64 | root (string): Root directory where the dataset should be saved.
65 | name (string): The name of the dataset (:obj:`"chameleon"`,
66 | :obj:`"crocodile"`, :obj:`"squirrel"`).
67 | geom_gcn_preprocess (bool): If set to :obj:`True`, will load the
68 | pre-processing data as introduced in the `"Geom-GCN: Geometric
69 | Graph Convolutional Networks" _`,
70 | in which the average monthly traffic of the web page is converted
71 | into five categories to predict.
72 | If set to :obj:`True`, the dataset :obj:`"crocodile"` is not
73 | available.
74 | transform (callable, optional): A function/transform that takes in an
75 | :obj:`torch_geometric.data.Data` object and returns a transformed
76 | version. The data object will be transformed before every access.
77 | (default: :obj:`None`)
78 | pre_transform (callable, optional): A function/transform that takes in
79 | an :obj:`torch_geometric.data.Data` object and returns a
80 | transformed version. The data object will be transformed before
81 | being saved to disk. (default: :obj:`None`)
82 | """
83 |
84 | raw_url = 'https://graphmining.ai/datasets/ptg/wiki'
85 | processed_url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/'
86 | 'geom-gcn/master')
87 |
88 | def __init__(self, root: str, name: str, geom_gcn_preprocess: bool = True,
89 | transform: Optional[Callable] = None,
90 | pre_transform: Optional[Callable] = None):
91 | self.name = name.lower()
92 | self.geom_gcn_preprocess = geom_gcn_preprocess
93 | assert self.name in ['chameleon', 'crocodile', 'squirrel']
94 | if geom_gcn_preprocess and self.name == 'crocodile':
95 | raise AttributeError("The dataset 'crocodile' is not available in "
96 | "case 'geom_gcn_preprocess=True'")
97 | super().__init__(root, transform, pre_transform)
98 | self.data, self.slices = torch.load(self.processed_paths[0])
99 |
100 | @property
101 | def raw_dir(self) -> str:
102 | if self.geom_gcn_preprocess:
103 | return osp.join(self.root, self.name, 'geom_gcn', 'raw')
104 | else:
105 | return osp.join(self.root, self.name, 'raw')
106 |
107 | @property
108 | def processed_dir(self) -> str:
109 | if self.geom_gcn_preprocess:
110 | return osp.join(self.root, self.name, 'geom_gcn', 'processed')
111 | else:
112 | return osp.join(self.root, self.name, 'processed')
113 |
114 | @property
115 | def raw_file_names(self) -> str:
116 | if self.geom_gcn_preprocess:
117 | return (['out1_node_feature_label.txt', 'out1_graph_edges.txt'] +
118 | [f'{self.name}_split_0.6_0.2_{i}.npz' for i in range(10)])
119 | else:
120 | return f'{self.name}.npz'
121 |
122 | @property
123 | def processed_file_names(self) -> str:
124 | return 'data.pt'
125 |
126 | def download(self):
127 | if self.geom_gcn_preprocess:
128 | for filename in self.raw_file_names[:2]:
129 | url = f'{self.processed_url}/new_data/{self.name}/{filename}'
130 | download_url(url, self.raw_dir)
131 | for filename in self.raw_file_names[2:]:
132 | url = f'{self.processed_url}/splits/{filename}'
133 | download_url(url, self.raw_dir)
134 | else:
135 | download_url(f'{self.raw_url}/{self.name}.npz', self.raw_dir)
136 |
137 | def process(self):
138 | if self.geom_gcn_preprocess:
139 | with open(self.raw_paths[0], 'r') as f:
140 | data = f.read().split('\n')[1:-1]
141 | x = [[float(v) for v in r.split('\t')[1].split(',')] for r in data]
142 | x = torch.tensor(x, dtype=torch.float)
143 | y = [int(r.split('\t')[2]) for r in data]
144 | y = torch.tensor(y, dtype=torch.long)
145 |
146 | with open(self.raw_paths[1], 'r') as f:
147 | data = f.read().split('\n')[1:-1]
148 | data = [[int(v) for v in r.split('\t')] for r in data]
149 | edge_index = torch.tensor(data, dtype=torch.long).t().contiguous()
150 | # edge_index = to_undirected(edge_index, num_nodes=x.size(0))
151 | print('test')
152 | train_masks, val_masks, test_masks = [], [], []
153 | for filepath in self.raw_paths[2:]:
154 | f = np.load(filepath)
155 | train_masks += [torch.from_numpy(f['train_mask'])]
156 | val_masks += [torch.from_numpy(f['val_mask'])]
157 | test_masks += [torch.from_numpy(f['test_mask'])]
158 | train_mask = torch.stack(train_masks, dim=1).to(torch.bool)
159 | val_mask = torch.stack(val_masks, dim=1).to(torch.bool)
160 | test_mask = torch.stack(test_masks, dim=1).to(torch.bool)
161 |
162 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
163 | val_mask=val_mask, test_mask=test_mask)
164 |
165 | else:
166 | data = np.load(self.raw_paths[0], 'r', allow_pickle=True)
167 | x = torch.from_numpy(data['features']).to(torch.float)
168 | edge_index = torch.from_numpy(data['edges']).to(torch.long)
169 | edge_index = edge_index.t().contiguous()
170 | # edge_index = to_undirected(edge_index, num_nodes=x.size(0))
171 | y = torch.from_numpy(data['label']).to(torch.float)
172 | train_mask = torch.from_numpy(data['train_mask']).to(torch.bool)
173 | test_mask = torch.from_numpy(data['test_mask']).to(torch.bool)
174 | val_mask = torch.from_numpy(data['val_mask']).to(torch.bool)
175 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
176 | val_mask=val_mask, test_mask=test_mask)
177 | if self.pre_transform is not None:
178 | data = self.pre_transform(data)
179 |
180 | torch.save(self.collate([data]), self.processed_paths[0])
181 |
182 | #%%
183 | from ogb.nodeproppred import NodePropPredDataset
184 |
185 | def even_quantile_labels(vals, nclasses, verbose=True):
186 | """ partitions vals into nclasses by a quantile based split,
187 | where the first class is less than the 1/nclasses quantile,
188 | second class is less than the 2/nclasses quantile, and so on
189 |
190 | vals is np array
191 | returns an np array of int class labels
192 | """
193 | label = -1 * np.ones(vals.shape[0], dtype=np.int)
194 | interval_lst = []
195 | lower = -np.inf
196 | for k in range(nclasses - 1):
197 | upper = np.nanquantile(vals, (k + 1) / nclasses)
198 | interval_lst.append((lower, upper))
199 | inds = (vals >= lower) * (vals < upper)
200 | label[inds] = k
201 | lower = upper
202 | label[vals >= lower] = nclasses - 1
203 | interval_lst.append((lower, np.inf))
204 | if verbose:
205 | print('Class Label Intervals:')
206 | for class_idx, interval in enumerate(interval_lst):
207 | print(f'Class {class_idx}: [{interval[0]}, {interval[1]})]')
208 | return label
209 |
210 | def load_arxiv_year_dataset(root):
211 | ogb_dataset = NodePropPredDataset(name='ogbn-arxiv',root=root)
212 | graph = ogb_dataset.graph
213 | graph['edge_index'] = torch.as_tensor(graph['edge_index'])
214 | graph['node_feat'] = torch.as_tensor(graph['node_feat'])
215 |
216 | label = even_quantile_labels(graph['node_year'].flatten(), 5, verbose=False)
217 | label = torch.as_tensor(label).reshape(-1, 1)
218 | import os
219 | split_idx_lst = load_fixed_splits("arxiv-year",os.path.join(root,"splits"))
220 |
221 | train_mask = torch.stack([split["train"] for split in split_idx_lst],dim=1)
222 | val_mask = torch.stack([split["valid"] for split in split_idx_lst],dim=1)
223 | test_mask = torch.stack([split["test"] for split in split_idx_lst],dim=1)
224 | data = Data(x=graph["node_feat"],y=torch.squeeze(label.long()),edge_index=graph["edge_index"],\
225 | train_mask=train_mask,val_mask=val_mask,test_mask=test_mask)
226 | return data
227 |
228 |
229 |
230 | def load_fixed_splits(dataset,split_dir):
231 | """ loads saved fixed splits for dataset
232 | """
233 | name = dataset
234 | import os
235 | splits_lst = np.load(os.path.join(split_dir,"{}-splits.npy".format(name)), allow_pickle=True)
236 | for i in range(len(splits_lst)):
237 | for key in splits_lst[i]:
238 | if not torch.is_tensor(splits_lst[i][key]):
239 | splits_lst[i][key] = torch.as_tensor(splits_lst[i][key])
240 | return splits_lst
241 | # %%
--------------------------------------------------------------------------------
/exp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import logging
4 | import math
5 | import time
6 | from pathlib import Path
7 | from torch_geometric.utils import to_networkx
8 | import networkx as nx
9 | import numpy as np
10 | import scipy.sparse as sp
11 | import torch
12 | from tqdm import tqdm
13 | from sklearn.preprocessing import normalize
14 | import os
15 | import random
16 | import dgl
17 | import warnings
18 | warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
19 | warnings.filterwarnings('ignore', category=RuntimeWarning, message='scipy._lib.messagestream.MessageStream')
20 | from gnnutils import make_masks, train, test, add_original_graph, load_webkb, load_planetoid, load_wiki, load_bgp, \
21 | load_film, load_airports, load_amazon, load_coauthor, load_WikiCS, load_crocodile, load_Cora_ML
22 |
23 | from util import get_B_sim_phi, getM_logM, load_dgl, get_A_D
24 |
25 | from models import Transformer
26 |
27 | MODEl_DICT = {"Transformer": Transformer}
28 | from script_classification import run_node_classification, run_node_clustering, update_evaluation_value
29 |
30 |
31 | def filter_rels(data, r):
32 | data = copy.deepcopy(data)
33 | mask = data.edge_color <= r
34 | data.edge_index = data.edge_index[:, mask]
35 | data.edge_weight = data.edge_weight[mask]
36 | data.edge_color = data.edge_color[mask]
37 | return data
38 |
39 |
40 | def run(i, data, num_features, num_classes, g, M, adj_org, adj, logM, B, sim, phi, degree, n_edges, aug_check,
41 | sim_check, phi_check, pretrain_check,test_node_degree):
42 | if args.model in MODEl_DICT:
43 | model = MODEl_DICT[args.model](num_features,
44 | args.out_size,
45 | num_classes,
46 | hidden_dim=args.dims,
47 | num_layers=args.num_layers,
48 | num_heads=args.num_heads,
49 | k_transition=args.k_transition,
50 | aug_check=aug_check,
51 | sim_check=sim_check,
52 | phi_check=phi_check,
53 | alfa=args.alfa, beta=args.beta,
54 | )
55 | dataset_1 = ['cora', 'citeseer', 'Photo','WikiCS']
56 |
57 | if args.dataset in dataset_1:
58 | model.to(device)
59 | else:
60 | if torch.cuda.device_count() > 1:
61 |
62 | id_2 = int ((str(args.device_2)).split(":")[1])
63 | print(f'Running multi-GPUs {id_2}')
64 | model = torch.nn.DataParallel(model, device_ids=[id_2])
65 |
66 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-5)
67 |
68 | if args.custom_masks:
69 | # create random mask
70 | data = make_masks(data, val_test_ratio = 0.0)
71 | train_mask = data.train_mask
72 | val_mask = data.val_mask
73 | test_mask = data.test_mask
74 | best_val_acc = 0
75 |
76 | pat = 20
77 | best_model = model
78 | graph_name = args.dataset
79 | best_loss = 100000000
80 |
81 | best_epoch = 0
82 |
83 | num_epochs = args.epochs + 1
84 |
85 | if args.pretrain_check == 0:
86 | print('pretrain_check: False, fine tunning')
87 | epoch = 1
88 | torch.save(model,
89 | '{}{}_{}_{}_{}_{}_{}_{}.pt'.format(args.output_path, args.dataset, args.lr, args.dims,
90 | args.num_layers, args.k_transition, args.alfa, args.beta))
91 | else: # pre-training
92 | print(f'pre-training')
93 | for epoch in range(1, num_epochs):
94 |
95 | train_loss = train(model, data, train_mask, optimizer, device,device_2, g=g, adj_org=adj_org,
96 | trans_logM=logM, sim=sim, phi=phi, B=B, k_transition=args.k_transition, current_epoch=epoch,
97 | alpha_1=args.alpha_1, alpha_2=args.alpha_2, alpha_3=args.alpha_3)
98 |
99 | if best_loss >= train_loss:
100 | best_model = model
101 | best_epoch = epoch
102 | best_loss = train_loss
103 |
104 | if epoch - best_epoch > 200:
105 | break
106 | if epoch % 1 == 0:
107 | print('Epoch: {:02d}, Best Epoch {:02d}, Train Loss: {:0.4f}'.format(epoch, best_epoch, train_loss))
108 |
109 | print(' saving model and embeddings')
110 | torch.save(best_model,
111 | '{}{}_{}_{}_{}_{}_{}_{}.pt'.format(args.output_path, args.dataset, args.lr, args.dims,
112 | args.num_layers, args.k_transition, args.alfa, args.beta))
113 | time.sleep(1)
114 | run_node_classification(args, args.index_excel, args.dataset, args.output_path, args.file_name, data, num_features,
115 | args.out_size,
116 | num_classes, g, adj_org, M, logM, sim, phi, B, degree, args.k_transition, device,device_2,
117 | args.run_times_fine, current_epoch=epoch, aug_check=aug_check,sim_check=sim_check,phi_check=phi_check,test_node_degree = test_node_degree)
118 |
119 | time.sleep(1)
120 |
121 | run_node_clustering(args, args.index_excel, args.dataset, args.output_path, args.file_name, data, num_features,
122 | args.out_size,
123 | num_classes, g, M, logM, sim, phi, B, args.k_transition, device,device_2, args.run_times_fine, adj,
124 | degree, n_edges, current_epoch=epoch, aug_check=aug_check,sim_check=sim_check,phi_check=phi_check)
125 |
126 | time.sleep(1)
127 |
128 |
129 | import collections
130 | from collections import defaultdict
131 |
132 |
133 | def main():
134 | timestr = time.strftime("%Y%m%d-%H%M%S")
135 | log_file = args.dataset + "-" + timestr + ".log"
136 | Path("./exp_logs").mkdir(parents=True, exist_ok=True)
137 | logging.basicConfig(filename="exp_logs/" + log_file, filemode="w", level=logging.INFO)
138 | logging.info("Starting on device: %s", device)
139 | logging.info("Config: %s ", args)
140 | isbgp = False
141 | if args.dataset in ['cora', 'citeseer', 'pubmed']:
142 | og_data, _ = load_planetoid(args.dataset)
143 |
144 | elif args.dataset in ['Computers', 'Photo']:
145 | print('computers and photo processing...')
146 | assert args.custom_masks == True
147 | og_data, _ = load_amazon(args.dataset)
148 | elif args.dataset in ['WikiCS']:
149 | print('WikiCS processing...')
150 | assert args.custom_masks == True
151 | og_data, _ = load_WikiCS(args.dataset)
152 | else:
153 | raise NotImplementedError
154 |
155 | data = og_data
156 |
157 | num_classes = len(data.y.unique())
158 |
159 | num_features = data.x.shape[1]
160 | nx_g = to_networkx(data, to_undirected=True)
161 |
162 | print(f'Loading dataset: {args.dataset}, data zize: {data.x.size()}')
163 |
164 |
165 | num_of_nodes = nx_g.number_of_nodes()
166 | print(f'number of nodes: {num_of_nodes}')
167 |
168 | adj, degree, n_edges = get_A_D(nx_g, num_of_nodes)
169 |
170 | ### load information
171 |
172 | path = "pts/" + args.dataset + "_kstep_" + str(args.k_transition) + ".pt"
173 | if not os.path.exists(path):
174 | M, logM, B, sim, phi = load_bias(nx_g, num_of_nodes, num_classes, data.x)
175 |
176 | M = torch.from_numpy(np.array(M)).float()
177 | logM = torch.from_numpy(np.array(logM)).float()
178 | B = torch.from_numpy(B).float()
179 | sim = torch.from_numpy(np.array(sim)).float()
180 | phi = torch.from_numpy(phi).float()
181 |
182 | print('saving M, logM,B, sim, phi')
183 | save_info(M, logM, B, sim, phi)
184 |
185 | else:
186 | print('file exist, loading M, logM, B, sim, phi')
187 | M, logM, B, sim, phi = load_info(path)
188 |
189 | sim[torch.isnan(sim)] = 0
190 |
191 | g = load_dgl(nx_g, data.x, sim, phi)
192 |
193 | adj_org = torch.from_numpy(adj)
194 |
195 | runs_acc = []
196 |
197 | for i in tqdm(range(args.run_times)):
198 | acc = run(i, data, num_features, num_classes, g, M, adj_org, adj, logM, B, sim, phi, degree, n_edges,
199 | args.aug_check, args.sim_check, args.phi_check, args.pretrain_check, args.test_node_degree)
200 | runs_acc.append(acc)
201 |
202 |
203 | def save_info(M, logM, B, sim, phi):
204 | path = "pts/"
205 | torch.save({"M": M, "logM": logM, "B": B, "sim": sim, "phi": phi},
206 | path + args.dataset + "_kstep_" + str(args.k_transition) + '.pt')
207 | print('save_info done.')
208 |
209 |
210 | def load_info(path):
211 | dic = torch.load(path)
212 | M = dic['M']
213 | logM = dic['logM']
214 | B = dic['B']
215 | sim = dic['sim']
216 |
217 | phi = dic['phi']
218 | print('load_info done.')
219 | return M, logM, B, sim, phi
220 |
221 |
222 | def load_bias(nx_g, num_nodes, n_class, X):
223 | M, logM = getM_logM(nx_g, num_nodes, kstep=args.k_transition)
224 | B, sim, phi = get_B_sim_phi(nx_g, M, num_nodes, n_class, X, kstep=args.k_transition)
225 |
226 |
227 | return M, logM, B, sim, phi
228 |
229 | if __name__ == '__main__':
230 | parser = argparse.ArgumentParser(description="Experiments")
231 | #
232 | parser.add_argument("--dataset", default="citeseer", help="Dataset")
233 | parser.add_argument("--model", default="Transformer", help="GNN Model")
234 |
235 | parser.add_argument("--run_times", type=int, default=1)
236 |
237 | parser.add_argument("--drop", type=float, default=0.5, help="dropout")
238 | parser.add_argument("--custom_masks", default=True, action='store_true', help="custom train/val/test masks")
239 |
240 | # adding args
241 | parser.add_argument("--device", default="cuda:0", help="GPU ids")
242 | parser.add_argument("--device_2", default="cuda:0", help="GPU ids")
243 |
244 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
245 | parser.add_argument("--epochs", type=int, default=3)
246 |
247 | parser.add_argument("--dims", type=int, default=64, help="hidden dims")
248 | parser.add_argument("--out_size", type=int, default=64, help="outsize dims")
249 |
250 | parser.add_argument("--k_transition", type=int, default=3)
251 | parser.add_argument("--num_layers", type=int, default=4)
252 | parser.add_argument("--num_heads", type=int, default=2)
253 | parser.add_argument("--output_path", default="outputs/", help="outputs model")
254 |
255 | parser.add_argument("--pretrain_check", type=int, default=0)
256 | parser.add_argument("--aug_check", type=int, default=1)
257 | parser.add_argument("--sim_check", type=int, default=1)
258 | parser.add_argument("--phi_check", type=int, default=1)
259 |
260 | parser.add_argument("--alfa", type=float, default=0.1)
261 | parser.add_argument("--beta", type=float, default=0.9)
262 |
263 | parser.add_argument("--run_times_fine", type=int, default=200)
264 | parser.add_argument("--index_excel", type=int, default="-1", help="index_excel")
265 | parser.add_argument("--file_name", default="outputs_excels/cora.xlsx", help="file_name dataset")
266 |
267 | parser.add_argument("--alpha_1", type=float, default=1)
268 | parser.add_argument("--alpha_2", type=float, default=1)
269 | parser.add_argument("--alpha_3", type=float, default=1)
270 | parser.add_argument("--test_node_degree", type=int, default= 0)
271 |
272 |
273 | args = parser.parse_args()
274 | print(args)
275 | device = torch.device(args.device)
276 | device_2 = torch.device(args.device_2)
277 | print(f"Using device:{device}, device_2: {device_2}")
278 | main()
--------------------------------------------------------------------------------
/script_classification.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import logging
4 | import math
5 | import time
6 | from pathlib import Path
7 | import numpy as np
8 | import pandas as pd
9 | import matplotlib.pyplot as plt
10 | # import seaborn as sns
11 | import scipy.sparse as sp
12 | import os
13 | import os.path
14 | from torch import Tensor
15 | import torch_geometric
16 | from torch_geometric.utils import to_networkx
17 | from torch_geometric.datasets import Planetoid
18 | import networkx as nx
19 | from torch_geometric.data import Data
20 | from torch.utils.data import Dataset
21 | import dgl.sparse as dglsp
22 | import numpy as np
23 | import torch
24 | from tqdm import tqdm
25 |
26 | import dgl
27 | from dgl import LaplacianPE
28 | # from dgl.nn import LaplacianPosEnc
29 | import networkx as nx
30 | import numpy as np
31 | from sklearn.preprocessing import normalize
32 | from collections import defaultdict
33 | import itertools
34 |
35 | from collections import deque
36 |
37 | from gnnutils import make_masks, test, add_original_graph, load_webkb, load_planetoid, \
38 | load_wiki, load_bgp, load_film, load_airports, train_finetuning_class, train_finetuning_cluster, test_cluster
39 |
40 | from collections import defaultdict
41 | from collections import deque
42 | from torch.utils.data import DataLoader, ConcatDataset
43 | import random as rnd
44 | import warnings
45 |
46 | warnings.filterwarnings("ignore", message="scipy._lib.messagestream.MessageStream size changed")
47 | warnings.filterwarnings("ignore", message="scipy._lib.messagestream.MessageStream size changed")
48 |
49 | from models import Transformer_class, Transformer_cluster
50 |
51 | MODEl_DICT = {"Transformer_class": Transformer_class}
52 |
53 | db_name = 0
54 |
55 |
56 | def update_evaluation_value(file_path, colume, row, value):
57 | try:
58 | df = pd.read_excel(file_path)
59 |
60 | df[colume][row] = value
61 |
62 | df.to_excel(file_path, sheet_name='data', index=False)
63 |
64 | return
65 | except:
66 | print("Error when saving results! Save again!")
67 | time.sleep(3)
68 |
69 |
70 | def run_node_classification(args, index_excel, ds_name, output_path, file_name, data_all, num_features, out_size,
71 | num_classes, g, adj_org, M, trans_logM, sim, phi, B, degree, k_transition, device,device_2,
72 | num_epochs, current_epoch, aug_check,sim_check,phi_check,test_node_degree):
73 | print("running run_node_classification")
74 |
75 | index_excel = index_excel
76 |
77 | if True:
78 | dataset = args.dataset
79 | lr = args.lr
80 | dims =args.dims
81 | out_size = args.out_size
82 | num_layers = args.num_layers
83 | k_transition = args.k_transition
84 | alfa =args.alfa
85 | beta = args.beta
86 |
87 | print(f"Node class process - {index_excel}")
88 |
89 | cp_filename = output_path + f'{dataset}_{lr}_{dims}_{num_layers}_{k_transition}_{alfa}_{beta}.pt'
90 |
91 | if not os.path.exists(cp_filename):
92 | print(f"run_node_classification: no file: {cp_filename}")
93 | print(f"run_node_classification: no file: {cp_filename}")
94 | return None
95 |
96 | runs_acc = []
97 | for i in tqdm(range(1)):
98 | print(f'run_node_classification, run time: {i}')
99 | acc = run_epoch_node_classification(i, data_all, num_features, out_size, num_classes,
100 | g, adj_org, M, trans_logM, sim, phi, B, degree, k_transition,
101 | cp_filename, dims,
102 | num_layers, lr, device, device_2,num_epochs, current_epoch, aug_check,sim_check,phi_check,test_node_degree)
103 | runs_acc.append(acc)
104 |
105 | runs_acc = np.array(runs_acc) * 100
106 |
107 |
108 | final_msg = "Node Classification: Mean %0.4f, Std %0.4f" % (runs_acc.mean(), runs_acc.std())
109 | print(final_msg)
110 |
111 |
112 | def run_epoch_node_classification(i, data, num_features, out_size, num_classes,
113 | g, adj_org, M, trans_logM, sim, phi, B, degree, k_transition, cp_filename, dims,
114 | num_layers, lr, device,device_2, num_epochs, current_epoch,aug_check,sim_check,phi_check,test_node_degree):
115 | graph_name = ""
116 | best_val_acc = 0
117 | best_model = None
118 | pat = 20
119 | best_epoch = 0
120 |
121 | # fine tuning & testing
122 | print('fine tuning ...')
123 |
124 | model = Transformer_class(num_features, out_size, num_classes, hidden_dim=dims,
125 | num_layers=num_layers, num_heads=4, graph_name=graph_name,
126 | cp_filename=cp_filename, aug_check = aug_check,sim_check = sim_check,phi_check = sim_check) #.to(device)
127 |
128 | dataset_1 = ['cora', 'citeseer', 'Photo','WikiCS']
129 | for ds in dataset_1:
130 | if ds in cp_filename:
131 | model.to(device)
132 | else:
133 | if torch.cuda.device_count() > 1:
134 |
135 | id_2 = int(str(device_2).split(":")[1])
136 | model = torch.nn.DataParallel(model, device_ids=[id_2])
137 |
138 | best_model = model
139 |
140 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
141 |
142 | print("creating random mask")
143 | data = make_masks(data, val_test_ratio=0.2)
144 | train_mask = data.train_mask
145 | val_mask = data.val_mask
146 | test_mask = data.test_mask
147 |
148 | # save dataload
149 | g.ndata["train_mask"] = train_mask
150 | g.ndata["val_mask"] = val_mask
151 | g.ndata["test_mask"] = test_mask
152 |
153 | test_check = 0
154 | for epoch in range(1, num_epochs):
155 | current_epoch = epoch
156 | train_loss, train_acc = train_finetuning_class(model, data, train_mask, optimizer, device,device_2, g=g, adj_org=adj_org,
157 | M=M, trans_logM=trans_logM, sim=sim, phi=phi, B=B,
158 | k_transition=k_transition,
159 | pre_train=0, current_epoch=current_epoch)
160 |
161 | if epoch % 1 == 0:
162 | valid_acc, valid_f1 = test(model, data, val_mask, device,device_2, g=g, adj_org=adj_org, trans_logM=trans_logM,
163 | sim=sim, phi=phi, B=B, degree=degree,
164 | k_transition=k_transition, current_epoch=current_epoch, test_node_degree=0)
165 |
166 | if valid_acc > best_val_acc:
167 | best_val_acc = valid_acc
168 | best_model = model
169 | best_epoch = epoch
170 | pat = (pat + 1) if (pat < 5) else pat
171 | else:
172 | pat -= 1
173 |
174 | if epoch % 1 == 0:
175 | print(
176 | 'Epoch: {:02d}, best_epoch: {:02d}, Train Loss: {:0.4f}, Train Acc: {:0.4f}, Val Acc: {:0.4f} '.format(
177 | epoch, best_epoch, train_loss, train_acc, valid_acc))
178 |
179 | if epoch - best_epoch > 100:
180 | print("1 validation patience reached ... finish training")
181 | break
182 |
183 | # Testing
184 | test_check = 1
185 | test_acc, test_f1 = test(best_model, data, test_mask, device,device_2, g=g, adj_org=adj_org, trans_logM=trans_logM, sim=sim,
186 | phi=phi, B=B, degree=degree,
187 | k_transition=k_transition, current_epoch=current_epoch, test_node_degree=test_node_degree)
188 | print('Best Val Epoch: {:03d}, Best Val Acc: {:0.4f}, Test Acc: {:0.4f}, F1_test: {:0.4f}'.format(
189 | best_epoch, best_val_acc, test_acc, test_f1))
190 | return test_acc, test_f1
191 | def run_node_clustering(args, index_excel, ds_name, output_path, file_name, data_all, num_features, out_size,
192 | num_classes, g, M, logM, sim, phi, B, k_transition, device,device_2, num_epochs, adj, d, n_edges,
193 | current_epoch, aug_check,sim_check,phi_check):
194 | index_excel = index_excel
195 |
196 | if True:
197 | dataset = args.dataset
198 | lr = args.lr
199 | dims =args.dims
200 | out_size = args.out_size
201 | num_layers = args.num_layers
202 | k_transition = args.k_transition
203 | alfa =args.alfa
204 | beta = args.beta
205 |
206 | print(f"Node clustering process - {index_excel}")
207 |
208 | cp_filename = output_path + f'{dataset}_{lr}_{dims}_{num_layers}_{k_transition}_{alfa}_{beta}.pt'
209 |
210 | if os.path.isfile(cp_filename) == False:
211 | print(f"run_node_clustering: no file {cp_filename}")
212 | return None
213 |
214 | runs_acc = []
215 | for i in tqdm(range(1)):
216 | print(f'run time: {i}')
217 | _, _, _, _, q, c = run_epoch_node_clustering(i, data_all, num_features, out_size,
218 | num_classes,
219 | g, M, logM, sim, phi, B, k_transition,
220 | cp_filename, dims, num_layers, lr, device,device_2,
221 | num_epochs, adj, d, n_edges, current_epoch, aug_check,sim_check,phi_check)
222 | #runs_acc.append(acc)
223 | time.sleep(1)
224 | print('Node clustering results: \n Q: {:0.4f}, , C: {:0.4f}'.format(q,c))
225 |
226 | def run_epoch_node_clustering(i, data, num_features, out_size, num_classes, g, M, logM, sim, phi, B, k_transition,
227 | cp_filename, dims, num_layers, lr, device,device_2, num_epochs, adj, d, n_edges, current_epoch,aug_check,sim_check,phi_check):
228 | graph_name = ""
229 | best_val_acc = 0
230 | best_model = None
231 | pat = 20
232 | best_epoch = 0
233 |
234 | # fine tuning & testing
235 | print('fine tuning run_epoch_node_clustering...')
236 |
237 | model = Transformer_cluster(num_features, out_size, num_classes, hidden_dim=dims,
238 | num_layers=num_layers, num_heads=4, graph_name=graph_name,
239 | cp_filename=cp_filename, aug_check=aug_check,sim_check=sim_check,phi_check=phi_check) #.to(device)
240 |
241 | dataset_1 = ['cora', 'citeseer', 'Photo','WikiCS']
242 | for ds in dataset_1:
243 | if ds in cp_filename:
244 | model.to(device)
245 | else:
246 | if torch.cuda.device_count() > 1:
247 | id_2 = int(str(device_2).split(":")[1])
248 | model = torch.nn.DataParallel(model, device_ids=[id_2])
249 |
250 | best_model = model
251 |
252 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
253 |
254 | print("creating random mask")
255 | data = make_masks(data, val_test_ratio=0.0)
256 | train_mask = data.train_mask
257 | val_mask = data.val_mask
258 | test_mask = data.test_mask
259 |
260 | # save dataload.
261 | g.ndata["train_mask"] = train_mask
262 | g.ndata["val_mask"] = val_mask
263 | g.ndata["test_mask"] = test_mask
264 | adj = torch.FloatTensor(adj).to(device)
265 | for epoch in range(1, num_epochs):
266 | current_epoch = epoch
267 | train_loss = train_finetuning_cluster(model, data, train_mask, optimizer, device, device_2,g=g,
268 | M=M, trans_logM=logM, sim=sim, phi=phi, B=B, k_transition=k_transition,
269 | pre_train=0, adj=adj, d=d, n_edges=n_edges, current_epoch=current_epoch)
270 |
271 | if epoch % 1 == 0:
272 | print('Epoch: {:02d}, Train Loss: {:0.4f}'.format(epoch, train_loss))
273 |
274 | # Testing
275 |
276 | acc, precision, recall, nmi, q, c = test_cluster(best_model, data, train_mask, optimizer, device,device_2, g=g,
277 | M=M, trans_logM=logM, sim=sim, phi=phi, B=B,
278 | k_transition=k_transition, pre_train=0, adj=adj, d=d,
279 | n_edges=n_edges, current_epoch=current_epoch)
280 |
281 |
282 | return acc, precision, recall, nmi, q, c
283 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn import Sequential, Linear, ReLU
4 | from torch_geometric.nn import GCNConv, GINConv, SAGEConv
5 |
6 | import torch.nn as nn
7 | import dgl
8 | import dgl.nn as dglnn
9 | import dgl.sparse as dglsp
10 | import torch
11 | import dgl.function as fn
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import os
15 | import os.path
16 | import torch.optim as optim
17 | import numpy as np
18 | from dgl.data import AsGraphPredDataset
19 | from dgl.dataloading import GraphDataLoader
20 | from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator
21 | from ogb.graphproppred.mol_encoder import AtomEncoder
22 | from tqdm import tqdm
23 | import dgl.function as fn
24 | import pyro
25 |
26 |
27 | class Transformer(nn.Module):
28 | def __init__(self, in_dim, out_dim, n_classes, hidden_dim, num_layers, num_heads, k_transition, aug_check,
29 | sim_check,
30 | phi_check, alfa, beta):
31 | super().__init__()
32 | self.h = None
33 | self.embedding_h = nn.Linear(in_dim, hidden_dim, bias=False)
34 | self.in_dim = in_dim
35 | self.hidden_dim = hidden_dim
36 | self.k_transition = k_transition
37 |
38 | self.aug_check = aug_check
39 | self.sim_check = sim_check
40 | self.phi_check = phi_check
41 | self.afla = alfa
42 | self.beta = beta
43 |
44 |
45 | self.gcn = GCN(in_dim, hidden_dim, self.afla, self.beta)
46 |
47 | self.layers = nn.ModuleList(
48 | [GraphTransformerLayer(hidden_dim, hidden_dim, num_heads, sim_check, phi_check) for _ in range(num_layers)])
49 |
50 | self.layers.append(GraphTransformerLayer(hidden_dim, out_dim, num_heads, sim_check, phi_check))
51 |
52 | self.MLP_layer_x = Reconstruct_X(out_dim, in_dim)
53 |
54 | self.embedding_phi = nn.Linear(1, hidden_dim)
55 | self.embedding_sim = nn.Linear(k_transition, hidden_dim)
56 |
57 | def extract_features(self, g, adj_org, X, current_epoch, edge_index, sim, phi, B, k_transition, device, device_2):
58 |
59 | adj_sampled = None
60 |
61 | if self.aug_check == 1:
62 | edge_index_sampled, x_gcn, adj_sampled, check_nan = self.gcn(g, adj_org, X, B, edge_index, current_epoch,device_2)
63 | g = dgl_renew(g, X, edge_index_sampled, sim, phi, k_transition, device,device_2)
64 |
65 | h = self.embedding_h(X)
66 |
67 | phi = g.edata['phi']
68 | sim = g.edata['sim']
69 | phi = self.embedding_phi(phi.float())
70 | sim = self.embedding_sim(sim.float())
71 |
72 |
73 | for layer in self.layers:
74 | h = layer(h, g, phi, sim, current_epoch)
75 |
76 | return h, adj_sampled
77 |
78 | def forward(self, g, adj_org, sim, phi, B, k_transition, current_epoch, device, device_2):
79 |
80 | X = g.ndata['x'].to(device_2)
81 |
82 | edge_index = torch.stack([g.edges()[0], g.edges()[1]]).to(device_2)
83 |
84 | h, adj_sampled = self.extract_features(g, adj_org, X, current_epoch, edge_index, sim, phi, B, k_transition, device, device_2)
85 |
86 | x_hat = self.MLP_layer_x(h)
87 |
88 | self.h = h
89 |
90 | return h, x_hat, adj_sampled
91 |
92 |
93 | def dgl_renew(g, x0, edge_index_sampled, sim, phi, k_transition, device,device_2):
94 | g = dgl.graph((edge_index_sampled[0], edge_index_sampled[1]))
95 |
96 |
97 | g.edata['sim'] = sim[edge_index_sampled[0], edge_index_sampled[1]]
98 | g.edata['phi'] = phi[edge_index_sampled[0], edge_index_sampled[1]]
99 |
100 | return g
101 |
102 | class GraphTransformerLayer(nn.Module):
103 | """Graph Transformer Layer"""
104 |
105 | def __init__(self, in_dim, out_dim, num_heads, sim_check, phi_check):
106 | super().__init__()
107 |
108 | self.sim_check = sim_check
109 | self.phi_check = phi_check
110 |
111 | self.in_channels = in_dim
112 | self.out_channels = out_dim
113 | self.num_heads = num_heads
114 |
115 | self.attention = MultiHeadAttentionLayer(in_dim, out_dim // num_heads, num_heads, sim_check, phi_check)
116 |
117 | self.O = nn.Linear(out_dim, out_dim)
118 |
119 | self.batchnorm1 = nn.BatchNorm1d(out_dim)
120 | self.batchnorm2 = nn.BatchNorm1d(out_dim)
121 | self.layer_norm1 = nn.LayerNorm(out_dim)
122 | self.layer_norm2 = nn.LayerNorm(out_dim)
123 |
124 | self.FFN_layer1 = nn.Linear(out_dim, out_dim * 2)
125 | self.FFN_layer2 = nn.Linear(out_dim * 2, out_dim)
126 |
127 | def forward(self, h, g, phi, sim, current_epoch):
128 | h_in1 = h # for first residual connection
129 |
130 | attn_out = self.attention(h, g, phi, sim, current_epoch)
131 |
132 | h = attn_out.view(-1, self.out_channels)
133 |
134 | h = self.O(h)
135 |
136 | h = h_in1 + h # residual connection
137 |
138 | h = self.layer_norm1(h)
139 |
140 | h_in2 = h # for second residual connection
141 |
142 | # FFN
143 | h = self.FFN_layer1(h)
144 | h = F.relu(h)
145 |
146 | h = F.dropout(h, 0.5, training=self.training)
147 | h = self.FFN_layer2(h)
148 | h = h_in2 + h # residual connection
149 | h = self.layer_norm2(h)
150 |
151 | return h
152 |
153 |
154 | class MultiHeadAttentionLayer(nn.Module):
155 | # in_dim, out_dim, num_heads
156 | def __init__(self, in_dim, out_dim, num_heads, sim_check, phi_check):
157 | super().__init__()
158 |
159 | self.sim_check = sim_check
160 | self.phi_check = phi_check
161 |
162 | self.out_dim = out_dim
163 | self.num_heads = num_heads
164 | self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
165 | self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True)
166 | self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True)
167 |
168 | self.hidden_size = in_dim # 80
169 | self.num_heads = num_heads # 8
170 | self.head_dim = out_dim // num_heads # 10
171 |
172 | self.scaling = self.head_dim ** -0.5
173 |
174 | self.q_proj = nn.Linear(in_dim, in_dim)
175 | self.k_proj = nn.Linear(in_dim, in_dim)
176 | self.v_proj = nn.Linear(in_dim, in_dim)
177 |
178 | self.proj_phi = nn.Linear(in_dim, out_dim * num_heads, bias=True)
179 |
180 | self.sim = nn.Linear(in_dim, out_dim * num_heads, bias=True)
181 |
182 | def propagate_attention(self, g):
183 | # Compute attention score
184 | if self.sim_check == 1:
185 | g.apply_edges(src_dot_dst_sim('K_h', 'Q_h', 'sim_h', 'score'))
186 | else:
187 | g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score'))
188 |
189 | g.apply_edges(scaling('score', np.sqrt(self.out_dim)))
190 |
191 | if self.phi_check == 1:
192 | g.apply_edges(imp_add_attn('score', 'proj_phi'))
193 |
194 | # softmax
195 | g.apply_edges(exp('score'))
196 |
197 | eids = g.edges()
198 | g.send_and_recv(eids, dgl.function.u_mul_e('V_h', 'score', 'V_h'), fn.sum('V_h', 'wV')) # src_mul_edge
199 | g.send_and_recv(eids, dgl.function.copy_e('score', 'score'), fn.sum('score', 'z')) # copy_edge
200 |
201 |
202 | def forward(self, h, g, phi, sim, current_epoch):
203 | Q_h = self.Q(h)
204 | K_h = self.K(h)
205 | V_h = self.V(h)
206 |
207 | sim_h = self.sim(sim)
208 |
209 | proj_phi = self.proj_phi(phi)
210 | # proj_sim = self.proj_sim(sim)
211 |
212 | g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim)
213 | g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim)
214 | g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim)
215 | g.edata['sim_h'] = sim_h.view(-1, self.num_heads, self.out_dim)
216 |
217 | g.edata['proj_phi'] = proj_phi.view(-1, self.num_heads, self.out_dim)
218 |
219 | self.propagate_attention(g)
220 |
221 | h_out = g.ndata['wV'] / (g.ndata['z'] + torch.full_like(g.ndata['z'], 1e-6)) # adding eps to all values here
222 |
223 | return h_out
224 |
225 |
226 | class GCN(torch.nn.Module):
227 | # g,adj_org, X, B, edge_index, current_epoch, self.afla, self.beta)
228 | def __init__(self, num_features, hidden_dim=64, alfa=0.1, beta=0.95):
229 | super(GCN, self).__init__()
230 | self.conv1 = GCNConv(num_features, hidden_dim * 2)
231 | self.conv2 = GCNConv(hidden_dim * 2, hidden_dim)
232 | self.anfa = alfa
233 | self.beta = beta
234 |
235 | self.MLPA = torch.nn.Sequential(
236 | torch.nn.Linear(hidden_dim, hidden_dim),
237 | torch.nn.ReLU(),
238 | torch.nn.Linear(hidden_dim, hidden_dim))
239 |
240 | def forward(self, g, adj_org, x, B, edge_index, current_epoch,device_2):
241 |
242 | edge_probs = self.anfa * B + self.beta * adj_org
243 |
244 | edge_probs[edge_probs > 1] = 1
245 |
246 | edge_probs = edge_probs.cuda()
247 | check_nan = False
248 | while True:
249 | # try:
250 | adj_sampled = pyro.distributions.RelaxedBernoulliStraightThrough(temperature=1, probs=edge_probs).rsample()
251 |
252 | adj_sampled = adj_sampled.triu(1)
253 | adj_sampled = adj_sampled + adj_sampled.T
254 |
255 | edge_index_sampled = adj_sampled.to_sparse()._indices()
256 |
257 | g_new = dgl.graph((edge_index_sampled[0], edge_index_sampled[1]))
258 |
259 | check_nan = True
260 |
261 | if g.num_nodes() == g_new.num_nodes():
262 | if current_epoch % 50 == 0:
263 | print(f'adj_sampled size: {edge_index_sampled.size()}')
264 | break
265 | else:
266 | print("----------------------------------------")
267 | edge_index_sampled = edge_index_sampled.to(device_2)
268 | adj_sampled = adj_sampled.to(device_2)
269 |
270 | return edge_index_sampled, x, adj_sampled, check_nan
271 |
272 |
273 | class Reconstruct_X(torch.nn.Module):
274 | def __init__(self, inp, outp, dims=128):
275 | super().__init__()
276 |
277 |
278 | self.mlp = torch.nn.Sequential(
279 | torch.nn.Linear(inp, dims * 2),
280 | torch.nn.SELU(),
281 | torch.nn.Linear(dims * 2, outp))
282 |
283 | def forward(self, x):
284 | x = self.mlp(x)
285 | return x
286 |
287 |
288 | class MLPA(torch.nn.Module):
289 |
290 | def __init__(self, in_feats, dim_h, dim_z):
291 | super(MLPA, self).__init__()
292 |
293 | self.gcn_mean = torch.nn.Sequential(
294 | torch.nn.Linear(in_feats, dim_h),
295 | torch.nn.ReLU(),
296 | torch.nn.Linear(dim_h, dim_z)
297 | )
298 |
299 | def forward(self, hidden):
300 | Z = self.gcn_mean(hidden)
301 |
302 | adj_logits = Z @ Z.T
303 | return adj_logits
304 |
305 |
306 | class MLP(torch.nn.Module):
307 |
308 | def __init__(self, num_features, num_classes, dims=16):
309 | super(MLP, self).__init__()
310 | self.mlp = torch.nn.Sequential(
311 | torch.nn.Linear(num_features, dims), torch.nn.ReLU(),
312 | torch.nn.Linear(dims, num_classes))
313 |
314 | def forward(self, x):
315 | x = self.mlp(x)
316 | return x
317 |
318 |
319 | class Transformer_class(nn.Module):
320 | def __init__(self, in_dim, out_dim, n_classes, hidden_dim, num_layers, num_heads, graph_name, cp_filename, aug_check ,sim_check ,phi_check ):
321 | super().__init__()
322 |
323 | print(f'Loading Transformer_class {cp_filename}')
324 | self.model = torch.load(cp_filename)
325 | if isinstance(self.model, torch.nn.DataParallel):
326 | self.model = self.model.module
327 |
328 | self.model.aug_check = aug_check
329 | self.model.sim_check = sim_check
330 | self.model.phi_check = phi_check
331 |
332 | for p in self.model.parameters():
333 | p.requires_grad = True
334 |
335 |
336 | self.MLP = MLP(out_dim, n_classes)
337 |
338 |
339 | def forward(self, g, adj_org, sim, phi, B, k_transition, current_epoch, device, device_2):
340 |
341 | X = g.ndata['x'].to(device_2)
342 | edge_index = torch.stack([g.edges()[0], g.edges()[1]])
343 |
344 | h, _ = self.model.extract_features(g, adj_org, X, current_epoch, edge_index, sim, phi, B, k_transition, device, device_2)
345 |
346 | h = self.MLP(h)
347 | h = F.softmax(h, dim=1)
348 |
349 | return h
350 |
351 |
352 | class MLPReadout(nn.Module):
353 |
354 | def __init__(self, input_dim, output_dim, L=2): # L = nb_hidden_layers
355 | super().__init__()
356 | list_FC_layers = [nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True) for l in range(L)]
357 | list_FC_layers.append(nn.Linear(input_dim // 2 ** L, output_dim, bias=True))
358 |
359 | self.FC_layers = nn.ModuleList(list_FC_layers)
360 | self.L = L
361 |
362 | def forward(self, x):
363 | y = x
364 | for l in range(self.L):
365 | y = self.FC_layers[l](y)
366 | y = F.relu(y)
367 | y = self.FC_layers[self.L](y)
368 |
369 | return y
370 |
371 |
372 | class Transformer_cluster(nn.Module):
373 | def __init__(self, in_dim, out_dim, n_classes, hidden_dim, num_layers, num_heads, graph_name, cp_filename, aug_check,sim_check,phi_check):
374 | super().__init__()
375 |
376 | print(f'Loading Transformer_class {cp_filename}')
377 | self.model = torch.load(cp_filename)
378 | if isinstance(self.model, torch.nn.DataParallel):
379 | self.model = self.model.module
380 | self.model.aug_check = aug_check
381 | self.model.sim_check = sim_check
382 | self.model.phi_check = phi_check
383 |
384 | for p in self.model.parameters():
385 | p.requires_grad = True
386 |
387 | self.MLP = MLPReadout(out_dim, n_classes)
388 |
389 |
390 |
391 | def forward(self, g, adj_org, sim, phi, B, k_transition, current_epoch, device, device_2):
392 | X = g.ndata['x'].to(device_2)
393 | edge_index = torch.stack([g.edges()[0], g.edges()[1]])
394 |
395 |
396 | h, _ = self.model.extract_features(g, adj_org, X, current_epoch, edge_index, sim, phi, B, k_transition, device, device_2)
397 |
398 | h = self.MLP(h)
399 | h = F.softmax(h, dim=1)
400 |
401 | return h
402 |
403 |
404 | """
405 | Util functions
406 | """
407 |
408 |
409 | def exp(field):
410 | def func(edges):
411 | # clamp for softmax numerical stability
412 | return {field: torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-8, 8))}
413 |
414 | return func
415 |
416 |
417 | def src_dot_dst(src_field, dst_field, out_field):
418 | def func(edges):
419 | return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}
420 |
421 | return func
422 |
423 |
424 | def src_dot_dst_sim(src_field, dst_field, edge_field, out_field):
425 | def func(edges):
426 | return {out_field: ((edges.src[src_field] + edges.data[edge_field]) * (
427 | edges.dst[dst_field] + edges.data[edge_field])).sum(-1, keepdim=True)}
428 |
429 | return func
430 |
431 |
432 | # Improving implicit attention scores with explicit edge features, if available
433 | def scaling(field, scale_constant):
434 | def func(edges):
435 | return {field: (((edges.data[field])) / scale_constant)}
436 |
437 | return func
438 |
439 |
440 | def imp_exp_attn(implicit_attn, explicit_edge):
441 | """
442 | implicit_attn: the output of K Q
443 | explicit_edge: the explicit edge features
444 | """
445 |
446 | def func(edges):
447 | return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])}
448 |
449 | return func
450 |
451 |
452 | def imp_add_attn(implicit_attn, explicit_edge):
453 | """
454 | implicit_attn: the output of K Q
455 | explicit_edge: the explicit edge features
456 | """
457 |
458 | def func(edges):
459 | return {implicit_attn: (edges.data[implicit_attn] + edges.data[explicit_edge])}
460 |
461 | return func
462 |
463 |
464 | # To copy edge features to be passed to FFN_e
465 | def out_edge_features(edge_feat):
466 | def func(edges):
467 | return {'e_out': edges.data[edge_feat]}
468 |
469 | return func
--------------------------------------------------------------------------------
/gnnutils.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import networkx as nx
4 | import numpy as np
5 | import scipy.sparse as sparse
6 | import torch
7 | import torch.nn.functional as F
8 | from networkx.utils import dict_to_numpy_array
9 | from sklearn.metrics import f1_score
10 | from sklearn.model_selection import train_test_split
11 | from torch_geometric.datasets import Planetoid, Airports
12 | from torch_geometric.utils import add_remaining_self_loops
13 | from tqdm import tqdm
14 | import torch_geometric as tg
15 | # from build_multigraph import build_pyg_struc_multigraph
16 | from datasets import WebKB, FilmNetwork, BGP
17 | from torch_geometric.datasets import Planetoid, Amazon, WikipediaNetwork, Coauthor, WikiCS, SNAPDataset, CitationFull
18 |
19 | from dataset import WikipediaNetwork_crocodile
20 | import warnings
21 | warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
22 |
23 | def DigitizeY(pyg_data):
24 | print("Checking ....")
25 | transform_y = np.log10
26 | bins = [0.2, 0.4, 0.6, 0.8, 1]
27 | y = transform_y(pyg_data.y).numpy()
28 |
29 | digitized_y = np.digitize(y, bins)
30 |
31 | # print(digitized_y)
32 | # raise SystemExit()
33 | # file1 = open("test.txt","a")
34 | # with open('test.txt', 'w') as f:
35 | # f.write()
36 |
37 | pyg_data.y = torch.from_numpy(digitized_y)
38 |
39 | # print(pyg_data.y )
40 |
41 | return pyg_data
42 |
43 |
44 | def load_crocodile(dataset):
45 | assert dataset in ["crocodile"]
46 | print("checking1: ")
47 |
48 | og = WikipediaNetwork_crocodile(root="original_datasets/wiki", name=dataset, geom_gcn_preprocess=False)[0]
49 |
50 | print(og.y)
51 | return og
52 |
53 |
54 | def load_Cora_ML(dataset):
55 | assert dataset in ["Cora_ML"]
56 | og = CitationFull(root="original_datasets/Cora_ML", name=dataset)[0]
57 |
58 | return og, None
59 |
60 |
61 | def load_WikiCS(dataset):
62 | assert dataset in ["WikiCS"]
63 | og = WikiCS(root="original_datasets/WikiCS")[0]
64 |
65 | return og, None
66 |
67 |
68 | def load_coauthor(dataset):
69 | assert dataset in ["CS", "Physics"]
70 | og = Coauthor(root="original_datasets/coauthor", name=dataset)[0]
71 |
72 | return og, None
73 |
74 |
75 | def load_amazon(dataset):
76 | assert dataset in ["Computers", "Photo"]
77 | og = Amazon(root="original_datasets/amazon", name=dataset)[0]
78 | # print("DONE og load_amazon")
79 | # st = Amazon(root="datasets_py_geom_format_10/amazon", name=dataset,pre_transform=build_pyg_struc_multigraph)[0]
80 |
81 | # print("DONE load_planetoid")
82 | # print(f'Number of edges og : {og.num_edges}')
83 | # print(f'Number of edges st : {st.num_edges}')
84 | return og, None
85 |
86 |
87 | def load_airports(dataset):
88 | assert dataset in ["brazil", "europe", "usa"]
89 | og = Airports("original_datasets/airports_dataset/" + dataset, name=dataset)[0]
90 | # st = Airports(root="datasets_py_geom_format/airports_dataset/"+dataset, dataset_name=dataset,pre_transform=build_pyg_struc_multigraph)[0]
91 | return og, None
92 |
93 |
94 | def load_bgp(dataset):
95 | assert dataset in ["bgp"]
96 | og = BGP(root="original_datasets/bgp_dataset")[0]
97 | # st = BGP(root="datasets_py_geom_format/bgp_dataset", pre_transform=build_pyg_struc_multigraph)[0]
98 | return og, None
99 |
100 |
101 | def load_film(dataset):
102 | assert dataset in ["film"]
103 | og = FilmNetwork(root="original_datasets/film", name=dataset)[0]
104 | # st = FilmNetwork(root="datasets_py_geom_format/film", name=dataset,
105 | # pre_transform=build_pyg_struc_multigraph)[0]
106 | return og, None
107 |
108 |
109 | def load_wiki(dataset):
110 | assert dataset in ["chameleon", "squirrel", "crocodile"]
111 | og = WikipediaNetwork(root="original_datasets/wiki", geom_gcn_preprocess=True, name=dataset)[0]
112 | # st = WikipediaNetwork(root="datasets_py_geom_format/wiki", name=dataset,
113 | # pre_transform=build_pyg_struc_multigraph)[0]
114 |
115 | return og, None
116 |
117 |
118 | def load_webkb(dataset):
119 | assert dataset in ["cornell", "texas", "wisconsin"]
120 | og = WebKB(root="original_datasets/webkb", name=dataset)[0]
121 | # st = WebKB(root="datasets_py_geom_format/webkb", name=dataset,pre_transform=build_pyg_struc_multigraph)[0]
122 |
123 | return og, None
124 |
125 |
126 | def load_planetoid(dataset):
127 | assert dataset in ["cora", "citeseer", "pubmed"]
128 | og = Planetoid(root="original_datasets/planetoid", name=dataset, split="public")[0]
129 | # st = Planetoid(root="datasets_py_geom_format/planetoid", name=dataset, split="public", pre_transform=build_pyg_struc_multigraph)[0]
130 |
131 | return og, None
132 |
133 |
134 | # def structure_edge_weight_threshold(data, threshold):
135 | # data = copy.deepcopy(data)
136 | # mask = data.edge_weight > threshold
137 | # data.edge_weight = data.edge_weight[mask]
138 | # data.edge_index = data.edge_index[:, mask]
139 | # data.edge_color = data.edge_color[mask]
140 | # return data
141 |
142 |
143 | def add_original_graph(og_data, st_data, weight=1.0):
144 | st_data = copy.deepcopy(st_data)
145 | e_i = torch.cat((og_data.edge_index, st_data.edge_index), dim=1)
146 | st_data.edge_color = st_data.edge_color + 1
147 | e_c = torch.cat((torch.zeros(og_data.edge_index.shape[1], dtype=torch.long), st_data.edge_color), dim=0)
148 | e_w = torch.cat((torch.ones(og_data.edge_index.shape[1], dtype=torch.float) * weight, st_data.edge_weight), dim=0)
149 | st_data.edge_index = e_i
150 | st_data.edge_color = e_c
151 | st_data.edge_weight = e_w
152 | return st_data
153 |
154 |
155 | def NormalizeTensor(data):
156 | return (data - torch.min(data)) / (torch.max(data) - torch.min(data))
157 |
158 |
159 | def cosinSim(x_hat):
160 | x_norm = torch.norm(x_hat, p=2, dim=1)
161 | nume = torch.mm(x_hat, x_hat.t())
162 | deno = torch.ger(x_norm, x_norm)
163 | cosine_similarity = nume / deno
164 | return cosine_similarity
165 |
166 |
167 | def train_finetuning_class(model, train_data, mask, optimizer, device,device_2, g, adj_org, M, trans_logM, sim, phi, B,
168 | k_transition, pre_train, current_epoch):
169 | # pre_train = 0
170 | model.train()
171 | optimizer.zero_grad()
172 | mask = mask
173 | true_label = train_data.y
174 |
175 | out = model(g.to(device), adj_org.cuda(), sim.cuda(), phi.cuda(), B.cuda(), k_transition,
176 | current_epoch, device,device_2)
177 |
178 | # loss = F.nll_loss(out[mask], true_label[mask])
179 | criterion = torch.nn.CrossEntropyLoss()
180 | # print(out[mask])
181 | # print( true_label[mask])
182 | # raise SystemExit()
183 | loss = criterion(out[mask].cpu(), true_label[mask])
184 | # print(out)
185 | # print(true_label)
186 | # raise SystemExit()
187 | pred = out.max(1)[1][mask].cpu()
188 | acc = pred.eq(train_data.y[mask]).sum().item() / len(train_data.y[mask])
189 | loss.backward()
190 | optimizer.step()
191 | # print('training loss')
192 | return loss.item(), acc
193 |
194 |
195 | def train(model, train_data, mask, optimizer, device, device_2, g, adj_org, trans_logM, sim, phi, B, k_transition, current_epoch,
196 | alpha_1, alpha_2, alpha_3):
197 | model.train()
198 | optimizer.zero_grad()
199 | optimizer.zero_grad()
200 | mask = mask.to(device)
201 | if torch.cuda.device_count() > 1:
202 | h, x_hat, adj_sampled = model(g.to(device), adj_org.cuda(), sim.cuda(), phi.cuda(), B.cuda(),
203 | k_transition, current_epoch, device, device_2)
204 |
205 | loss_M = 0
206 | h = cosinSim(h).cpu()
207 | for i in range(k_transition):
208 | loss_M += torch.sum(((h - (torch.FloatTensor(trans_logM[i]))) ** 2))
209 |
210 | row_num, col_num = (torch.FloatTensor(trans_logM[i])).size()
211 | loss_M = loss_M / (k_transition * row_num * col_num)
212 |
213 | row_num, col_num = train_data.x.size()
214 |
215 | loss_X = F.mse_loss(x_hat.cpu(), train_data.x.cpu())
216 |
217 | # BCE:
218 | if adj_sampled == None:
219 | adj_loss = 0
220 | else:
221 | adj_loss = F.binary_cross_entropy_with_logits(adj_sampled.cpu(), adj_org.cpu())
222 |
223 | loss_all = alpha_1 * loss_M + alpha_2 * loss_X + alpha_3 * adj_loss
224 | loss_all = loss_all.to(device_2)
225 | loss_all.backward()
226 | optimizer.step()
227 |
228 | return loss_M.item()
229 |
230 |
231 | import collections
232 | from collections import defaultdict
233 |
234 |
235 |
236 |
237 | import matplotlib.pylab as plt
238 |
239 |
240 |
241 |
242 | @torch.no_grad()
243 | def test(model, test_data, mask, device,device_2, g, adj_org, trans_logM, sim, phi, B, degree, k_transition, current_epoch,
244 | test_node_degree):
245 |
246 | model.eval()
247 | mask = mask.to(device)
248 | # n x 4 size
249 | logits = model(g.to(device), adj_org.to(device), sim.to(device), phi.to(device), B.to(device), k_transition,
250 | current_epoch, device,device_2)
251 |
252 |
253 | pred = torch.argmax(logits, dim=1)
254 | pred_before = pred
255 | pred = pred.to(device)[mask]
256 |
257 | acc = pred.eq(test_data.y.to(device)[mask]).sum().item() / len(test_data.y.to(device)[mask])
258 |
259 | pred = pred.cpu().numpy()
260 | y_true = test_data.y.to(device)[mask].cpu().numpy()
261 | f1 = f1_score(y_true, pred, average='micro')
262 |
263 | return acc, f1
264 |
265 |
266 | def filter_relations(data, num_relations, rel_last):
267 | if rel_last:
268 | l = data.edge_color.unique(sorted=True).tolist()
269 | mask_l = l[-num_relations:]
270 | mask = data.edge_color == mask_l[0]
271 | for c in mask_l[1:]:
272 | mask = mask | (data.edge_color == c)
273 | else:
274 | mask = data.edge_color < (num_relations + 1)
275 |
276 | data.edge_index = data.edge_index[:, mask]
277 | data.edge_weight = data.edge_weight[mask]
278 | data.edge_color = data.edge_color[mask]
279 | return data
280 |
281 |
282 | def make_masks(data, val_test_ratio=0.2, stratify=False):
283 | data = copy.deepcopy(data)
284 | n_nodes = data.x.shape[0]
285 | all_nodes_idx = np.arange(n_nodes)
286 | all_y = data.y.numpy()
287 | if stratify:
288 | train, test_idx, y_train, _ = train_test_split(
289 | all_nodes_idx, all_y, test_size=0.2, stratify=all_y)
290 |
291 | train_idx, val_idx, _, _ = train_test_split(
292 | train, y_train, test_size=0.25, stratify=y_train)
293 |
294 | else:
295 | val_test_num = 2 * int(val_test_ratio * data.x.shape[0])
296 | val_test_idx = np.random.choice(n_nodes, (val_test_num,), replace=False)
297 | val_idx = val_test_idx[:int(val_test_num / 2)]
298 | test_idx = val_test_idx[int(val_test_num / 2):]
299 |
300 | val_mask = np.zeros(n_nodes)
301 | val_mask[val_idx] = 1
302 | test_mask = np.zeros(n_nodes)
303 | test_mask[test_idx] = 1
304 | val_mask = torch.tensor(val_mask, dtype=torch.bool)
305 | test_mask = torch.tensor(test_mask, dtype=torch.bool)
306 | val_test_mask = val_mask | test_mask
307 | train_mask = ~val_test_mask
308 | data.train_mask = train_mask
309 | data.val_mask = val_mask
310 | data.test_mask = test_mask
311 | return data
312 |
313 |
314 | def create_self_loops(data):
315 | orig_relations = len(data.edge_color.unique())
316 | data.edge_index, data.edge_weight = add_remaining_self_loops(data.edge_index, data.edge_weight,
317 | fill_value=1.0)
318 | row, col = data.edge_index[0], data.edge_index[1]
319 | mask = row == col
320 | tmp = torch.full(mask.nonzero().shape, orig_relations + 1, dtype=torch.long).squeeze()
321 | data.edge_color = torch.cat([data.edge_color, tmp], dim=0)
322 | return data
323 |
324 |
325 | # create adjacency matrix and degree sequence
326 | def createA(E, n, m, undir=True):
327 | if undir:
328 | G = nx.Graph()
329 | G.add_nodes_from(range(n))
330 | G.add_edges_from(list(E))
331 | A = nx.to_scipy_sparse_matrix(G)
332 | else:
333 | A = sparse.coo_matrix((np.ones(m), (E[:, 0], E[:, 1])),
334 | shape=(n, n)).tocsc()
335 | degree = np.array(A.sum(1)).flatten()
336 |
337 | return A, degree
338 |
339 |
340 | def calculateRWRrange(A, degree, i, prs, n, trans=True, maxIter=1000):
341 | pr = prs[-1]
342 | D = sparse.diags(1. / degree, 0, format='csc')
343 | W = D * A
344 | diff = 1
345 | it = 1
346 |
347 | F = np.zeros(n)
348 | Fall = np.zeros((n, len(prs)))
349 | F[i] = 1
350 | Fall[i, :] = 1
351 | Fold = F.copy()
352 | T = F.copy()
353 |
354 | if trans:
355 | W = W.T
356 |
357 | oneminuspr = 1 - pr
358 |
359 | while diff > 1e-9:
360 | F = pr * W.dot(F)
361 | F[i] += oneminuspr
362 | Fall += np.outer((F - Fold), (prs / pr) ** it)
363 | T += (F - Fold) / ((it + 1) * (pr ** it))
364 |
365 | diff = np.sum((F - Fold) ** 2)
366 | it += 1
367 | if it > maxIter:
368 | print(i, "max iterations exceeded")
369 | diff = 0
370 | Fold = F.copy()
371 |
372 | return Fall, T, it
373 |
374 |
375 | def localAssortF(G, M, pr=np.arange(0., 1., 0.1), undir=True, missingValue=-1, edge_attribute=None):
376 | n = len(M)
377 | ncomp = (M != missingValue).sum()
378 | # m = len(E)
379 | m = G.number_of_edges()
380 |
381 | if edge_attribute is None:
382 | A = nx.to_scipy_sparse_matrix(G, weight=None)
383 | else:
384 | A = nx.to_scipy_sparse_matrix(G, weight=edge_attribute)
385 |
386 | degree = np.array(A.sum(1)).flatten()
387 |
388 | D = sparse.diags(1. / degree, 0, format='csc')
389 | W = D.dot(A)
390 | c = len(np.unique(M))
391 | if ncomp < n:
392 | c -= 1
393 |
394 |
395 | Z = np.zeros(n)
396 | Z[M == missingValue] = 1.
397 | Z = W.dot(Z) / degree
398 |
399 | values = np.ones(ncomp)
400 | yi = (M != missingValue).nonzero()[0]
401 | yj = M[M != missingValue]
402 | Y = sparse.coo_matrix((values, (yi, yj)), shape=(n, c)).tocsc()
403 |
404 | assortM = np.empty((n, len(pr)))
405 | assortT = np.empty(n)
406 |
407 | eij_glob = np.array(Y.T.dot(A.dot(Y)).todense())
408 | eij_glob /= np.sum(eij_glob)
409 | ab_glob = np.sum(eij_glob.sum(1) * eij_glob.sum(0))
410 |
411 | WY = W.dot(Y).tocsc()
412 |
413 | print("start iteration")
414 |
415 | for i in tqdm(range(n)):
416 | pis, ti, it = calculateRWRrange(A, degree, i, pr, n)
417 |
418 | YPI = sparse.coo_matrix((ti[M != missingValue], (M[M != missingValue],
419 | np.arange(n)[M != missingValue])),
420 | shape=(c, n)).tocsr()
421 | e_gh = YPI.dot(WY).toarray()
422 | Z[i] = np.sum(e_gh)
423 | e_gh /= np.sum(e_gh)
424 | trace_e = np.trace(e_gh)
425 | assortT[i] = trace_e
426 |
427 | # assortM -= ab_glob
428 | # assortM /= (1. - ab_glob + 1e-200)
429 |
430 | assortT -= ab_glob
431 | assortT /= (1. - ab_glob + 1e-200)
432 |
433 | return assortM, assortT, Z
434 |
435 |
436 | def mixing_dict(xy, normalized=True):
437 | d = {}
438 | psum = 0.0
439 | for x, y, w in xy:
440 | if x not in d:
441 | d[x] = {}
442 | if y not in d:
443 | d[y] = {}
444 | v = d[x].get(y, 0)
445 | d[x][y] = v + w
446 | psum += w
447 |
448 | if normalized:
449 | for k, jdict in d.items():
450 | for j in jdict:
451 | jdict[j] /= psum
452 | return d
453 |
454 |
455 | def node_attribute_xy(G, attribute, edge_attribute=None, nodes=None):
456 | if nodes is None:
457 | nodes = set(G)
458 | else:
459 | nodes = set(nodes)
460 | Gnodes = G.nodes
461 | for u, nbrsdict in G.adjacency():
462 | if u not in nodes:
463 | continue
464 | uattr = Gnodes[u].get(attribute, None)
465 | if G.is_multigraph():
466 | raise NotImplementedError
467 | else:
468 | for v, eattr in nbrsdict.items():
469 | vattr = Gnodes[v].get(attribute, None)
470 | if edge_attribute is None:
471 | yield (uattr, vattr, 1)
472 | else:
473 | edge_data = G.get_edge_data(u, v)
474 | yield (uattr, vattr, edge_data[edge_attribute])
475 |
476 |
477 | def global_assortativity(networkx_graph, labels, weights=None):
478 | attr_dict = {}
479 | for i in networkx_graph.nodes():
480 | attr_dict[i] = labels[i]
481 |
482 | nx.set_node_attributes(networkx_graph, attr_dict, "label")
483 | if weights is None:
484 | xy_iter = node_attribute_xy(networkx_graph, "label", edge_attribute=None)
485 | d = mixing_dict(xy_iter)
486 | M = dict_to_numpy_array(d, mapping=None)
487 | s = (M @ M).sum()
488 | t = M.trace()
489 | r = (t - s) / (1 - s)
490 | else:
491 | edge_attr = {}
492 | for i, e in enumerate(networkx_graph.edges()):
493 | edge_attr[e] = weights[i]
494 | nx.set_edge_attributes(networkx_graph, edge_attr, "weight")
495 |
496 | xy_iter = node_attribute_xy(networkx_graph, "label", edge_attribute="weight")
497 | d = mixing_dict(xy_iter)
498 | M = dict_to_numpy_array(d, mapping=None)
499 | s = (M @ M).sum()
500 | t = M.trace()
501 | r = (t - s) / (1 - s)
502 |
503 | return r, M
504 |
505 |
506 | def local_assortativity(networkx_graph, labels, weights=None):
507 | if weights is None:
508 | assort_m, assort_t, z = localAssortF(networkx_graph, np.array(labels))
509 | else:
510 | edge_attr = {}
511 | for i, e in enumerate(networkx_graph.edges()):
512 | edge_attr[e] = weights[i]
513 | nx.set_edge_attributes(networkx_graph, edge_attr, "weight")
514 | assort_m, assort_t, z = localAssortF(networkx_graph, np.array(labels), edge_attribute="weight")
515 |
516 | return assort_m, assort_t, z
517 |
518 |
519 | # (model, data, train_mask, optimizer, device, g=g, M = M, trans_logM = logM, pre_train=0, adj = adj, n_edges= n_edges)
520 | def train_finetuning_cluster(model, train_data, mask, optimizer, device,device_2, g, M, trans_logM, sim, phi, B, k_transition,
521 | pre_train, adj, d, n_edges, current_epoch):
522 | model.train()
523 | optimizer.zero_grad()
524 | mask = mask.to(device)
525 | k = torch.numel(train_data.y.unique())
526 |
527 | C = model(g.to(device), adj.cuda(), sim.cuda(), phi.cuda(), B.cuda(), k_transition, current_epoch, device,device_2)
528 |
529 | n = g.number_of_nodes()
530 | # adj = torch.FloatTensor(adj)
531 | C = C.cpu()
532 | d = torch.FloatTensor(d).unsqueeze(1)
533 | C_t = C.t()
534 | # Computes the size [k, k] pooled graph as S^T*A*S in two multiplications.
535 | graph_pooled = torch.mm(C_t, adj.cpu())
536 | graph_pooled = torch.mm(graph_pooled, C)
537 |
538 | # Left part is [k, 1] tensor.
539 | normalizer_left = torch.mm(C_t, d)
540 |
541 | # Right part is [1, k] tensor.
542 | normalizer_right = torch.mm(d.t(), C)
543 |
544 | normalizer = torch.mm(normalizer_left, normalizer_right) / 2 / n_edges
545 |
546 | spectral_loss = - torch.trace(graph_pooled - normalizer) / 2 / n_edges
547 |
548 | cluster_sizes = torch.sum(C, axis=0) # Size [k]
549 | collapse_loss = (k / (2 * n)) * (torch.norm(cluster_sizes))
550 |
551 | loss = spectral_loss + collapse_loss # + wrt
552 |
553 | loss.backward()
554 | optimizer.step()
555 | return loss.item()
556 |
557 |
558 | from sklearn.metrics import f1_score
559 | from torch_geometric.utils import to_networkx
560 | from sklearn.metrics.cluster import normalized_mutual_info_score
561 | from metrics import accuracy_score, precision, modularity, conductance, recall
562 |
563 |
564 | # (best_model, data, train_mask, device, g=g, M=M)
565 | @torch.no_grad()
566 | def test_cluster(model, test_data, mask, optimizer, device,device_2, g,
567 | M, trans_logM, sim, phi, B, k_transition, pre_train, adj, d, n_edges, current_epoch):
568 | # print('testing code')
569 | model.eval()
570 | mask = mask.to(device)
571 | # n x 4 size
572 |
573 | logits = model(g.to(device), adj, sim.to(device), phi.to(device), B.to(device), k_transition, current_epoch, device,device_2)
574 |
575 | pred = torch.argmax(logits, dim=1)
576 | pred = pred.to(device)[mask]
577 |
578 |
579 | pred = pred.cpu().numpy()
580 | y_true = test_data.y.to(device)[mask].cpu().numpy()
581 | f1 = f1_score(y_true, pred, average='micro')
582 |
583 | G_st = to_networkx(test_data, to_undirected=True)
584 | Adj = nx.adjacency_matrix(G_st)
585 | Adj = Adj.todense()
586 |
587 | acc = accuracy_score(y_true, pred)
588 | p = precision(y_true, pred)
589 | r = recall(y_true, pred)
590 | nmi = normalized_mutual_info_score(y_true, pred)
591 | q = modularity(Adj, pred)
592 | c = conductance(Adj, pred)
593 |
594 | return acc, p, r, nmi, q, c
595 |
--------------------------------------------------------------------------------