├── Figures └── CGT.jpg ├── datasets ├── __pycache__ │ ├── bgp.cpython-39.pyc │ ├── film.cpython-39.pyc │ ├── webkb.cpython-39.pyc │ ├── wiki.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ └── airports.cpython-39.pyc ├── __init__.py ├── airports.py ├── webkb.py ├── wiki.py ├── film.py └── bgp.py ├── exp_logs ├── cora-20231228-141613.log ├── cora-20231228-141930.log ├── cora-20231228-142159.log └── cora-20231228-142343.log ├── CITATION.cff ├── LICENSE ├── metrics.py ├── environment.yml ├── util.py ├── README.md ├── dataset.py ├── exp.py ├── script_classification.py ├── models.py └── gnnutils.py /Figures/CGT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/Figures/CGT.jpg -------------------------------------------------------------------------------- /datasets/__pycache__/bgp.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/bgp.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/film.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/film.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/webkb.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/webkb.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/wiki.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/wiki.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/airports.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NSLab-CUK/Community-aware-Graph-Transformer/HEAD/datasets/__pycache__/airports.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.airports import Airports 2 | from datasets.bgp import BGP 3 | from datasets.film import FilmNetwork 4 | from datasets.webkb import WebKB 5 | from datasets.wiki import WikipediaNetwork -------------------------------------------------------------------------------- /exp_logs/cora-20231228-141613.log: -------------------------------------------------------------------------------- 1 | INFO:root:Starting on device: cuda:0 2 | INFO:root:Config: Namespace(dataset='cora', model='Transformer', run_times=1, drop=0.5, custom_masks=True, device='cuda:0', device_2='cuda:0', lr=0.001, epochs=3, dims=64, out_size=64, k_transition=3, num_layers=4, num_heads=2, output_path='outputs/', pretrain_check=0, aug_check=1, sim_check=1, phi_check=1, alfa=0.1, beta=0.9, run_times_fine=300, index_excel=-1, file_name='outputs_excels/cora.xlsx', alpha_1=1, alpha_2=1, alpha_3=1, test_node_degree=0) 3 | -------------------------------------------------------------------------------- /exp_logs/cora-20231228-141930.log: -------------------------------------------------------------------------------- 1 | INFO:root:Starting on device: cuda:0 2 | INFO:root:Config: Namespace(dataset='cora', model='Transformer', run_times=1, drop=0.5, custom_masks=True, device='cuda:0', device_2='cuda:0', lr=0.001, epochs=3, dims=64, out_size=64, k_transition=3, num_layers=4, num_heads=2, output_path='outputs/', pretrain_check=0, aug_check=1, sim_check=1, phi_check=1, alfa=0.2, beta=0.95, run_times_fine=300, index_excel=-1, file_name='outputs_excels/cora.xlsx', alpha_1=1, alpha_2=1, alpha_3=1, test_node_degree=0) 3 | -------------------------------------------------------------------------------- /exp_logs/cora-20231228-142159.log: -------------------------------------------------------------------------------- 1 | INFO:root:Starting on device: cuda:0 2 | INFO:root:Config: Namespace(dataset='cora', model='Transformer', run_times=1, drop=0.5, custom_masks=True, device='cuda:0', device_2='cuda:0', lr=0.001, epochs=3, dims=64, out_size=64, k_transition=3, num_layers=4, num_heads=2, output_path='outputs/', pretrain_check=0, aug_check=1, sim_check=1, phi_check=1, alfa=0.1, beta=0.9, run_times_fine=200, index_excel=-1, file_name='outputs_excels/cora.xlsx', alpha_1=1, alpha_2=1, alpha_3=1, test_node_degree=0) 3 | -------------------------------------------------------------------------------- /exp_logs/cora-20231228-142343.log: -------------------------------------------------------------------------------- 1 | INFO:root:Starting on device: cuda:0 2 | INFO:root:Config: Namespace(dataset='cora', model='Transformer', run_times=1, drop=0.5, custom_masks=True, device='cuda:0', device_2='cuda:0', lr=0.001, epochs=3, dims=64, out_size=64, k_transition=3, num_layers=4, num_heads=2, output_path='outputs/', pretrain_check=0, aug_check=1, sim_check=1, phi_check=1, alfa=0.1, beta=0.9, run_times_fine=200, index_excel=-1, file_name='outputs_excels/cora.xlsx', alpha_1=1, alpha_2=1, alpha_3=1, test_node_degree=0) 3 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.0.0 2 | date-released: 2025-04 3 | message: "If you use this software, please cite it as below." 4 | authors: 5 | - family-names: "Hoang" 6 | given-names: "Van Thuy" 7 | - family-names: "Jeon" 8 | given-names: "Hyeon-Ju" 9 | - family-names: "Lee" 10 | given-names: "O-Joun" 11 | title: "Mitigating Degree Bias in Graph Representation Learning With Learnable Structural Augmentation and Structural Self-Attention" 12 | url: "https://ieeexplore.ieee.org/document/10974679" 13 | preferred-citation: 14 | type: article 15 | journal: "IEEE Transactions on Network Science and Engineering" 16 | authors: 17 | - family-names: "Hoang" 18 | given-names: "Van Thuy" 19 | - family-names: "Jeon" 20 | given-names: "Hyeon-Ju" 21 | - family-names: "Lee" 22 | given-names: "O-Joun" 23 | title: "Mitigating Degree Bias in Graph Representation Learning With Learnable Structural Augmentation and Structural Self-Attention" 24 | url: "https://ieeexplore.ieee.org/document/10974679" 25 | year: 2025 26 | publisher: "IEEE" 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 NS Lab@CUK 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/airports.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import torch 6 | from torch_geometric.data import InMemoryDataset 7 | from torch_geometric.utils import * 8 | 9 | 10 | def get_degrees(G): 11 | num_nodes = G.number_of_nodes() 12 | return np.array([G.degree[i] for i in range(num_nodes)]) 13 | 14 | 15 | class Airports(InMemoryDataset): 16 | def __init__(self, root, dataset_name, transform=None, pre_transform=None): 17 | self.dataset_name = dataset_name 18 | self.dump_location = "raw_data_src/airports_dataset_dump" 19 | super(Airports, self).__init__(root,transform, pre_transform) 20 | self.data, self.slices = torch.load(self.processed_paths[0]) 21 | 22 | @property 23 | def raw_file_names(self): 24 | return [self.dataset_name+'-airports.edgelist', 'labels-'+self.dataset_name+'-airports.txt'] 25 | 26 | @property 27 | def processed_file_names(self): 28 | return 'data.pt' 29 | 30 | def download(self): 31 | for name in self.raw_file_names: 32 | source = self.dump_location + '/' + name 33 | shutil.copy(source, self.raw_dir) 34 | 35 | def process(self): 36 | 37 | fin_labels = open(self.raw_paths[1]) 38 | labels = [] 39 | node_id_mapping = dict() 40 | node_id_labels_dict = dict() 41 | for new_id, line in enumerate(fin_labels.readlines()[1:]): # first line is header so ignore 42 | old_id, label = line.strip().split() 43 | labels.append(int(label)) 44 | node_id_mapping[old_id] = new_id 45 | node_id_labels_dict[new_id] = int(label) 46 | fin_labels.close() 47 | 48 | edges = [] 49 | fin_edges = open(self.raw_paths[0]) 50 | for line in fin_edges.readlines(): 51 | node1, node2 = line.strip().split()[:2] 52 | edges.append([node_id_mapping[node1], node_id_mapping[node2]]) 53 | fin_edges.close() 54 | 55 | networkx_graph = nx.Graph(edges) 56 | 57 | print("No. of Nodes: ",networkx_graph.number_of_nodes()) 58 | print("No. of edges: ",networkx_graph.number_of_edges()) 59 | 60 | 61 | attr = {} 62 | for node in networkx_graph.nodes(): 63 | deg = networkx_graph.degree(node) 64 | attr[node] = {"y": node_id_labels_dict[node], "x": [float(deg)]} 65 | nx.set_node_attributes(networkx_graph, attr) 66 | data = from_networkx(networkx_graph) 67 | 68 | 69 | if self.pre_filter is not None: 70 | data = self.pre_filter(data) 71 | 72 | if self.pre_transform is not None: 73 | data = self.pre_transform(data) 74 | 75 | torch.save(self.collate([data]), self.processed_paths[0]) 76 | 77 | 78 | def __repr__(self): 79 | return '{}()'.format(self.__class__.__name__) -------------------------------------------------------------------------------- /datasets/webkb.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import numpy as np 4 | import torch 5 | from torch_geometric.data import InMemoryDataset, download_url, Data 6 | from torch_geometric.utils import to_undirected 7 | from torch_sparse import coalesce 8 | 9 | 10 | class WebKB(InMemoryDataset): 11 | 12 | url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master' 13 | 14 | def __init__(self, root, name, transform=None, pre_transform=None): 15 | self.name = name.lower() 16 | assert self.name in ['cornell', 'texas', 'wisconsin'] 17 | 18 | super(WebKB, self).__init__(root, transform, pre_transform) 19 | self.data, self.slices = torch.load(self.processed_paths[0]) 20 | 21 | @property 22 | def raw_dir(self): 23 | return osp.join(self.root, self.name, 'raw') 24 | 25 | @property 26 | def processed_dir(self): 27 | return osp.join(self.root, self.name, 'processed') 28 | 29 | @property 30 | def raw_file_names(self): 31 | return ['out1_node_feature_label.txt', 'out1_graph_edges.txt'] + [ 32 | '{}_split_0.6_0.2_{}.npz'.format(self.name, i) for i in range(10) 33 | ] 34 | 35 | @property 36 | def processed_file_names(self): 37 | return 'data.pt' 38 | 39 | def download(self): 40 | for f in self.raw_file_names[:2]: 41 | download_url(f'{self.url}/new_data/{self.name}/{f}', self.raw_dir) 42 | for f in self.raw_file_names[2:]: 43 | download_url(f'{self.url}/splits/{f}', self.raw_dir) 44 | 45 | def process(self): 46 | with open(self.raw_paths[0], 'r') as f: 47 | data = f.read().split('\n')[1:-1] 48 | x = [[float(v) for v in r.split('\t')[1].split(',')] for r in data] 49 | x = torch.tensor(x, dtype=torch.float) 50 | 51 | y = [int(r.split('\t')[2]) for r in data] 52 | y = torch.tensor(y, dtype=torch.long) 53 | 54 | with open(self.raw_paths[1], 'r') as f: 55 | data = f.read().split('\n')[1:-1] 56 | data = [[int(v) for v in r.split('\t')] for r in data] 57 | edge_index = torch.tensor(data, dtype=torch.long).t().contiguous() 58 | edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0)) 59 | 60 | train_masks, val_masks, test_masks = [], [], [] 61 | for f in self.raw_paths[2:]: 62 | tmp = np.load(f) 63 | train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)] 64 | val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)] 65 | test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)] 66 | train_mask = torch.stack(train_masks, dim=1) 67 | val_mask = torch.stack(val_masks, dim=1) 68 | test_mask = torch.stack(test_masks, dim=1) 69 | 70 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, 71 | val_mask=val_mask, test_mask=test_mask) 72 | 73 | data.edge_index = to_undirected(data.edge_index) 74 | data = data if self.pre_transform is None else self.pre_transform(data) 75 | torch.save(self.collate([data]), self.processed_paths[0]) 76 | 77 | def __repr__(self): 78 | return '{}()'.format(self.name) -------------------------------------------------------------------------------- /datasets/wiki.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import numpy as np 4 | import torch 5 | from torch_geometric.data import InMemoryDataset, download_url, Data 6 | from torch_geometric.utils import to_undirected 7 | from torch_sparse import coalesce 8 | 9 | 10 | class WikipediaNetwork(InMemoryDataset): 11 | 12 | url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master' 13 | 14 | def __init__(self, root, name, transform=None, pre_transform=None): 15 | self.name = name.lower() 16 | assert self.name in ['chameleon', 'squirrel'] 17 | 18 | super(WikipediaNetwork, self).__init__(root, transform, pre_transform) 19 | self.data, self.slices = torch.load(self.processed_paths[0]) 20 | 21 | @property 22 | def raw_dir(self): 23 | return osp.join(self.root, self.name, 'raw') 24 | 25 | @property 26 | def processed_dir(self): 27 | return osp.join(self.root, self.name, 'processed') 28 | 29 | @property 30 | def raw_file_names(self): 31 | return ['out1_node_feature_label.txt', 'out1_graph_edges.txt'] + [ 32 | '{}_split_0.6_0.2_{}.npz'.format(self.name, i) for i in range(10) 33 | ] 34 | 35 | @property 36 | def processed_file_names(self): 37 | return 'data.pt' 38 | 39 | def download(self): 40 | for f in self.raw_file_names[:2]: 41 | download_url(f'{self.url}/new_data/{self.name}/{f}', self.raw_dir) 42 | for f in self.raw_file_names[2:]: 43 | download_url(f'{self.url}/splits/{f}', self.raw_dir) 44 | 45 | def process(self): 46 | with open(self.raw_paths[0], 'r') as f: 47 | data = f.read().split('\n')[1:-1] 48 | x = [[float(v) for v in r.split('\t')[1].split(',')] for r in data] 49 | x = torch.tensor(x, dtype=torch.float) 50 | 51 | y = [int(r.split('\t')[2]) for r in data] 52 | y = torch.tensor(y, dtype=torch.long) 53 | 54 | with open(self.raw_paths[1], 'r') as f: 55 | data = f.read().split('\n')[1:-1] 56 | data = [[int(v) for v in r.split('\t')] for r in data] 57 | edge_index = torch.tensor(data, dtype=torch.long).t().contiguous() 58 | edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0)) 59 | 60 | train_masks, val_masks, test_masks = [], [], [] 61 | for f in self.raw_paths[2:]: 62 | tmp = np.load(f) 63 | train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)] 64 | val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)] 65 | test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)] 66 | train_mask = torch.stack(train_masks, dim=1) 67 | val_mask = torch.stack(val_masks, dim=1) 68 | test_mask = torch.stack(test_masks, dim=1) 69 | 70 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, 71 | val_mask=val_mask, test_mask=test_mask) 72 | 73 | data.edge_index = to_undirected(data.edge_index) 74 | 75 | data = data if self.pre_transform is None else self.pre_transform(data) 76 | torch.save(self.collate([data]), self.processed_paths[0]) 77 | 78 | def __repr__(self): 79 | return '{}()'.format(self.name) 80 | -------------------------------------------------------------------------------- /datasets/film.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import numpy as np 4 | import torch 5 | from torch_geometric.data import InMemoryDataset, download_url, Data 6 | from torch_geometric.utils import to_undirected 7 | from torch_sparse import coalesce 8 | 9 | 10 | class FilmNetwork(InMemoryDataset): 11 | 12 | url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master' 13 | 14 | def __init__(self, root, name, transform=None, pre_transform=None): 15 | self.name = name.lower() 16 | assert self.name in ['film'] 17 | 18 | super(FilmNetwork, self).__init__(root, transform, pre_transform) 19 | self.data, self.slices = torch.load(self.processed_paths[0]) 20 | 21 | @property 22 | def raw_dir(self): 23 | return osp.join(self.root, self.name, 'raw') 24 | 25 | @property 26 | def processed_dir(self): 27 | return osp.join(self.root, self.name, 'processed') 28 | 29 | @property 30 | def raw_file_names(self): 31 | return ['out1_node_feature_label.txt', 'out1_graph_edges.txt'] + [ 32 | '{}_split_0.6_0.2_{}.npz'.format(self.name, i) for i in range(10) 33 | ] 34 | 35 | @property 36 | def processed_file_names(self): 37 | return 'data.pt' 38 | 39 | def download(self): 40 | for f in self.raw_file_names[:2]: 41 | download_url(f'{self.url}/new_data/{self.name}/{f}', self.raw_dir) 42 | for f in self.raw_file_names[2:]: 43 | download_url(f'{self.url}/splits/{f}', self.raw_dir) 44 | 45 | def process(self): 46 | with open(self.raw_paths[0], 'r') as f: 47 | data = f.read().split('\n')[1:-1] 48 | 49 | x = [] 50 | for r in data: 51 | f = np.zeros(932, dtype=np.float) 52 | idx = list(map(int, r.split('\t')[1].split(","))) 53 | f[idx] = 1.0 54 | x.append(f) 55 | 56 | x = torch.tensor(x, dtype=torch.float) 57 | 58 | y = [int(r.split('\t')[2]) for r in data] 59 | y = torch.tensor(y, dtype=torch.long) 60 | 61 | with open(self.raw_paths[1], 'r') as f: 62 | data = f.read().split('\n')[1:-1] 63 | data = [[int(v) for v in r.split('\t')] for r in data] 64 | edge_index = torch.tensor(data, dtype=torch.long).t().contiguous() 65 | edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0)) 66 | 67 | train_masks, val_masks, test_masks = [], [], [] 68 | for f in self.raw_paths[2:]: 69 | tmp = np.load(f) 70 | train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)] 71 | val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)] 72 | test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)] 73 | train_mask = torch.stack(train_masks, dim=1) 74 | val_mask = torch.stack(val_masks, dim=1) 75 | test_mask = torch.stack(test_masks, dim=1) 76 | 77 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, 78 | val_mask=val_mask, test_mask=test_mask) 79 | 80 | data.edge_index = to_undirected(data.edge_index) 81 | 82 | data = data if self.pre_transform is None else self.pre_transform(data) 83 | torch.save(self.collate([data]), self.processed_paths[0]) 84 | 85 | def __repr__(self): 86 | return '{}()'.format(self.name) 87 | -------------------------------------------------------------------------------- /datasets/bgp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | 4 | import networkx as nx 5 | import numpy as np 6 | import torch 7 | from torch_geometric.data import InMemoryDataset 8 | from torch_geometric.utils import * 9 | 10 | 11 | def convert_ndarray(x): 12 | y = list(range(len(x))) 13 | for k, v in x.items(): 14 | y[int(k)] = v 15 | return np.array(y) 16 | 17 | 18 | def check_rm(neighbors_set, unlabeled_nodes): 19 | for node in neighbors_set: 20 | if node not in unlabeled_nodes: 21 | return False 22 | return True 23 | 24 | 25 | def rm_useless(G, feats, class_map, unlabeled_nodes, num_layers): 26 | # find useless nodes 27 | print('start to check and remove {} unlabeled nodes'.format(len(unlabeled_nodes))) 28 | 29 | rm_nodes = unlabeled_nodes 30 | if len(rm_nodes): 31 | for node in rm_nodes: 32 | G.remove_node(node) 33 | G_new = nx.relabel.convert_node_labels_to_integers(G, ordering='sorted') 34 | feats = np.delete(feats, rm_nodes, 0) 35 | class_map = np.delete(class_map, rm_nodes, 0) 36 | print('remove {} '.format(len(rm_nodes)), 'useless unlabeled nodes') 37 | return G_new, feats, class_map 38 | 39 | class BGP(InMemoryDataset): 40 | 41 | 42 | def __init__(self, root, transform=None, pre_transform=None): 43 | 44 | self.dump_location = "raw_data_src/bgp_data_dump" 45 | super(BGP, self).__init__(root, transform, pre_transform) 46 | self.data, self.slices = torch.load(self.processed_paths[0]) 47 | 48 | @property 49 | def raw_file_names(self): 50 | return ['as-G.json', 'as-class_map.json', 'as-feats.npy','as-feats_t.npy', 'as-edge_list'] 51 | 52 | @property 53 | def processed_file_names(self): 54 | return 'data.pt' 55 | 56 | def download(self): 57 | for name in self.raw_file_names: 58 | # download_url('{}/{}'.format(self.url, name), self.raw_dir) 59 | source = self.dump_location + '/' + name 60 | shutil.copy(source, self.raw_dir) 61 | 62 | def process(self): 63 | G = nx.json_graph.node_link_graph(json.load(open(self.raw_paths[0])), False) 64 | class_map = json.load(open(self.raw_paths[1])) 65 | feats = np.load(self.raw_paths[2]) 66 | feats_t = np.load(self.raw_paths[3]) 67 | 68 | 69 | train_nodes = [n for n in G.nodes() if not G.nodes[n]['test'] and not G.nodes[n]['val']] 70 | val_nodes = [n for n in G.nodes() if not G.nodes[n]['test'] and G.nodes[n]['val']] 71 | test_nodes = [n for n in G.nodes() if G.nodes[n]['test'] and not G.nodes[n]['val']] 72 | unlabeled_nodes = [n for n in G.nodes() if G.nodes[n]['test'] and G.nodes[n]['val']] 73 | class_map = convert_ndarray(class_map) 74 | 75 | G, feats, class_map = rm_useless(G, feats, class_map, unlabeled_nodes, 1) 76 | train_nodes = [n for n in G.nodes() if not G.nodes[n]['test'] and not G.nodes[n]['val']] 77 | val_nodes = [n for n in G.nodes() if not G.nodes[n]['test'] and G.nodes[n]['val']] 78 | test_nodes = [n for n in G.nodes() if G.nodes[n]['test'] and not G.nodes[n]['val']] 79 | unlabeled_nodes = [n for n in G.nodes() if G.nodes[n]['test'] and G.nodes[n]['val']] 80 | 81 | 82 | data = from_networkx(G) 83 | data.train_mask = ~(data.test | data.val) 84 | data.val_mask = data.val 85 | data.test_mask = data.test 86 | data.test = None 87 | data.val = None 88 | data.x = torch.FloatTensor(feats) 89 | data.y = torch.LongTensor(np.argmax(class_map, axis=1)) 90 | 91 | if self.pre_filter is not None: 92 | data = self.pre_filter(data) 93 | 94 | if self.pre_transform is not None: 95 | data.edge_index = to_undirected(data.edge_index) 96 | data = self.pre_transform(data) 97 | 98 | torch.save(self.collate([data]), self.processed_paths[0]) 99 | 100 | 101 | def __repr__(self): 102 | return '{}()'.format(self.__class__.__name__) -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from sklearn.metrics import confusion_matrix 6 | from sklearn.metrics import f1_score 7 | import numpy as np 8 | 9 | 10 | def MAE(scores, targets): 11 | MAE = F.l1_loss(scores, targets) 12 | MAE = MAE.detach().item() 13 | return MAE 14 | 15 | 16 | def accuracy_TU(scores, targets): 17 | scores = scores.detach().argmax(dim=1) 18 | acc = (scores == targets).float().sum().item() 19 | return acc 20 | 21 | 22 | def accuracy_MNIST_CIFAR(scores, targets): 23 | scores = scores.detach().argmax(dim=1) 24 | acc = (scores == targets).float().sum().item() 25 | return acc 26 | 27 | 28 | def accuracy_CITATION_GRAPH(scores, targets): 29 | scores = scores.detach().argmax(dim=1) 30 | acc = (scores == targets).float().sum().item() 31 | acc = acc / len(targets) 32 | return acc 33 | 34 | 35 | def accuracy_SBM(scores, targets): 36 | S = targets.cpu().numpy() 37 | C = np.argmax(torch.nn.Softmax(dim=1)(scores).cpu().detach().numpy(), axis=1) 38 | CM = confusion_matrix(S, C).astype(np.float32) 39 | nb_classes = CM.shape[0] 40 | targets = targets.cpu().detach().numpy() 41 | nb_non_empty_classes = 0 42 | pr_classes = np.zeros(nb_classes) 43 | for r in range(nb_classes): 44 | cluster = np.where(targets == r)[0] 45 | if cluster.shape[0] != 0: 46 | pr_classes[r] = CM[r, r] / float(cluster.shape[0]) 47 | if CM[r, r] > 0: 48 | nb_non_empty_classes += 1 49 | else: 50 | pr_classes[r] = 0.0 51 | acc = 100. * np.sum(pr_classes) / float(nb_classes) 52 | return acc 53 | 54 | 55 | def binary_f1_score(scores, targets): 56 | """Computes the F1 score using scikit-learn for binary class labels. 57 | 58 | Returns the F1 score for the positive class, i.e. labelled '1'. 59 | """ 60 | y_true = targets.cpu().numpy() 61 | y_pred = scores.argmax(dim=1).cpu().numpy() 62 | return f1_score(y_true, y_pred, average='binary') 63 | 64 | 65 | def accuracy_VOC(scores, targets): 66 | scores = scores.detach().argmax(dim=1).cpu() 67 | targets = targets.cpu().detach().numpy() 68 | acc = f1_score(scores, targets, average='weighted') 69 | return acc 70 | 71 | 72 | # clustering 73 | from sklearn.metrics.cluster import contingency_matrix 74 | 75 | 76 | def _compute_counts(y_true, y_pred): # TODO(tsitsulin): add docstring pylint: disable=missing-function-docstring 77 | contingency = contingency_matrix(y_true, y_pred) 78 | same_class_true = np.max(contingency, 1) 79 | same_class_pred = np.max(contingency, 0) 80 | diff_class_true = contingency.sum(axis=1) - same_class_true 81 | diff_class_pred = contingency.sum(axis=0) - same_class_pred 82 | total = contingency.sum() 83 | 84 | true_positives = (same_class_true * (same_class_true - 1)).sum() 85 | false_positives = (diff_class_true * same_class_true * 2).sum() 86 | false_negatives = (diff_class_pred * same_class_pred * 2).sum() 87 | true_negatives = total * (total - 1) - true_positives - false_positives - false_negatives 88 | 89 | return true_positives, false_positives, false_negatives, true_negatives 90 | 91 | 92 | def modularity(adjacency, clusters): 93 | """Computes graph modularity. 94 | Args: 95 | adjacency: Input graph in terms of its sparse adjacency matrix. 96 | clusters: An (n,) int cluster vector. 97 | 98 | Returns: 99 | The value of graph modularity. 100 | https://en.wikipedia.org/wiki/Modularity_(networks) 101 | """ 102 | degrees = adjacency.sum(axis=0).A1 103 | n_edges = degrees.sum() # Note that it's actually 2*n_edges. 104 | result = 0 105 | for cluster_id in np.unique(clusters): 106 | cluster_indices = np.where(clusters == cluster_id)[0] 107 | adj_submatrix = adjacency[cluster_indices, :][:, cluster_indices] 108 | degrees_submatrix = degrees[cluster_indices] 109 | result += np.sum(adj_submatrix) - (np.sum(degrees_submatrix) ** 2) / n_edges 110 | return result / n_edges 111 | 112 | 113 | def precision(y_true, y_pred): 114 | true_positives, false_positives, _, _ = _compute_counts(y_true, y_pred) 115 | return true_positives / (true_positives + false_positives) 116 | 117 | 118 | def recall(y_true, y_pred): 119 | true_positives, _, false_negatives, _ = _compute_counts(y_true, y_pred) 120 | return true_positives / (true_positives + false_negatives) 121 | 122 | 123 | def accuracy_score(y_true, y_pred): 124 | true_positives, false_positives, false_negatives, true_negatives = _compute_counts( 125 | y_true, y_pred) 126 | return (true_positives + true_negatives) / (true_positives + false_positives + false_negatives + true_negatives) 127 | 128 | 129 | def conductance(adjacency, clusters): # TODO(tsitsulin): add docstring pylint: disable=missing-function-docstring 130 | inter = 0 131 | intra = 0 132 | cluster_idx = np.zeros(adjacency.shape[0], dtype=bool) 133 | for cluster_id in np.unique(clusters): 134 | cluster_idx[:] = 0 135 | cluster_idx[np.where(clusters == cluster_id)[0]] = 1 136 | adj_submatrix = adjacency[cluster_idx, :] 137 | inter += np.sum(adj_submatrix[:, cluster_idx]) 138 | intra += np.sum(adj_submatrix[:, ~cluster_idx]) 139 | return intra / (inter + intra) 140 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gt01 2 | channels: 3 | - pytorch 4 | - dglteam 5 | - dglteam/label/cu117 6 | - nvidia 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - appdirs=1.4.4=pyhd3eb1b0_0 12 | - blas=1.0=mkl 13 | - brotlipy=0.7.0=py39h27cfd23_1003 14 | - bzip2=1.0.8=h7b6447c_0 15 | - ca-certificates=2023.05.30=h06a4308_0 16 | - certifi=2023.5.7=py39h06a4308_0 17 | - cffi=1.15.1=py39h5eee18b_3 18 | - cryptography=39.0.1=py39h9ce1e76_1 19 | - cuda-cudart=11.7.99=0 20 | - cuda-cupti=11.7.101=0 21 | - cuda-libraries=11.7.1=0 22 | - cuda-nvrtc=11.7.99=0 23 | - cuda-nvtx=11.7.91=0 24 | - cuda-runtime=11.7.1=0 25 | - dgl=1.1.0.cu117=py39_0 26 | - dgl-cuda10.2=0.9.1post1=py39_0 27 | - ffmpeg=4.3=hf484d3e_0 28 | - freetype=2.12.1=h4a9f257_0 29 | - giflib=5.2.1=h5eee18b_3 30 | - gmp=6.2.1=h295c915_3 31 | - gmpy2=2.1.2=py39heeb90bb_0 32 | - gnutls=3.6.15=he1e5248_0 33 | - jinja2=3.1.2=py39h06a4308_0 34 | - jpeg=9e=h5eee18b_1 35 | - lame=3.100=h7b6447c_0 36 | - lcms2=2.12=h3be6417_0 37 | - ld_impl_linux-64=2.38=h1181459_1 38 | - lerc=3.0=h295c915_0 39 | - libcublas=11.10.3.66=0 40 | - libcufft=10.7.2.124=h4fbf590_0 41 | - libcufile=1.6.1.9=0 42 | - libcurand=10.3.2.106=0 43 | - libcusolver=11.4.0.1=0 44 | - libcusparse=11.7.4.91=0 45 | - libdeflate=1.17=h5eee18b_0 46 | - libffi=3.4.4=h6a678d5_0 47 | - libgcc-ng=11.2.0=h1234567_1 48 | - libgfortran-ng=11.2.0=h00389a5_1 49 | - libgfortran5=11.2.0=h1234567_1 50 | - libgomp=11.2.0=h1234567_1 51 | - libiconv=1.16=h7f8727e_2 52 | - libidn2=2.3.4=h5eee18b_0 53 | - libnpp=11.7.4.75=0 54 | - libnvjpeg=11.8.0.2=0 55 | - libpng=1.6.39=h5eee18b_0 56 | - libstdcxx-ng=11.2.0=h1234567_1 57 | - libtasn1=4.19.0=h5eee18b_0 58 | - libtiff=4.5.0=h6a678d5_2 59 | - libunistring=0.9.10=h27cfd23_0 60 | - libwebp=1.2.4=h11a3e52_1 61 | - libwebp-base=1.2.4=h5eee18b_1 62 | - lz4-c=1.9.4=h6a678d5_0 63 | - mkl-service=2.4.0=py39h5eee18b_1 64 | - mkl_fft=1.3.6=py39h417a72b_1 65 | - mkl_random=1.2.2=py39h417a72b_1 66 | - mpc=1.1.0=h10f8cd9_1 67 | - mpfr=4.0.2=hb69a4c5_1 68 | - ncurses=6.4=h6a678d5_0 69 | - nettle=3.7.3=hbbd107a_1 70 | - numpy-base=1.24.3=py39h060ed82_1 71 | - openh264=2.1.1=h4ff587b_0 72 | - openssl=3.0.8=h7f8727e_0 73 | - pooch=1.4.0=pyhd3eb1b0_0 74 | - pycparser=2.21=pyhd3eb1b0_0 75 | - pyopenssl=23.0.0=py39h06a4308_0 76 | - pysocks=1.7.1=py39h06a4308_0 77 | - python=3.9.16=h955ad1f_3 78 | - pytorch=2.0.0=py3.9_cuda11.7_cudnn8.5.0_0 79 | - pytorch-cuda=11.7=h778d358_5 80 | - pytorch-mutex=1.0=cuda 81 | - readline=8.2=h5eee18b_0 82 | - setuptools=67.8.0=py39h06a4308_0 83 | - sqlite=3.41.2=h5eee18b_0 84 | - tbb=2021.8.0=hdb19cb5_0 85 | - tk=8.6.12=h1ccaba5_0 86 | - torchtriton=2.0.0=py39 87 | - typing_extensions=4.5.0=py39h06a4308_0 88 | - wheel=0.38.4=py39h06a4308_0 89 | - xz=5.4.2=h5eee18b_0 90 | - zlib=1.2.13=h5eee18b_0 91 | - zstd=1.5.5=hc292b87_0 92 | - pip: 93 | - asgiref==3.7.2 94 | - backcall==0.2.0 95 | - chardet==3.0.4 96 | - charset-normalizer==3.1.0 97 | - cmake==3.26.4 98 | - contourpy==1.0.7 99 | - cycler==0.11.0 100 | - decorator==4.4.2 101 | - django==4.2.2 102 | - et-xmlfile==1.1.0 103 | - fastdtw==0.3.4 104 | - fastjsonschema==2.17.1 105 | - filelock==3.12.1 106 | - fonttools==4.39.4 107 | - idna==2.8 108 | - image==1.5.33 109 | - importlib-resources==5.12.0 110 | - intel-openmp==2023.1.0 111 | - ipython-genutils==0.2.0 112 | - joblib==1.2.0 113 | - kiwisolver==1.4.4 114 | - lit==16.0.5.post0 115 | - littleutils==0.2.2 116 | - markupsafe==2.1.3 117 | - matplotlib==3.7.1 118 | - mistune==0.8.4 119 | - mkl==2019.0 120 | - mpmath==1.3.0 121 | - networkx==2.5.1 122 | - numpy==1.22.4 123 | - oauthlib==3.2.2 124 | - ogb==1.3.6 125 | - openpyxl==3.1.2 126 | - outdated==0.2.2 127 | - packaging==23.1 128 | - pandas==2.0.2 129 | - pandocfilters==1.5.0 130 | - parso==0.8.3 131 | - pexpect==4.8.0 132 | - pickleshare==0.7.5 133 | - pillow==9.5.0 134 | - pip==23.1.2 135 | - prometheus-client==0.17.0 136 | - prompt-toolkit==2.0.10 137 | - protobuf==4.23.2 138 | - psutil==5.9.5 139 | - ptyprocess==0.7.0 140 | - pyasn1==0.5.0 141 | - pyg-lib==0.2.0+pt20cu117 142 | - pygments==2.15.1 143 | - pyparsing==3.0.9 144 | - pyrsistent==0.19.3 145 | - python-dateutil==2.8.2 146 | - pytz==2023.3 147 | - pyzmq==25.1.0 148 | - requests==2.31.0 149 | - retrying==1.3.4 150 | - scikit-learn==0.24.1 151 | - scipy==1.6.2 152 | - send2trash==1.8.2 153 | - six==1.15.0 154 | - sqlparse==0.4.4 155 | - sympy==1.12 156 | - testpath==0.6.0 157 | - threadpoolctl==3.1.0 158 | - torch==2.0.1+cu117 159 | - torch-cluster==1.6.1+pt20cu117 160 | - torch-geometric==2.3.1 161 | - torch-scatter==2.1.1+pt20cu117 162 | - torch-sparse==0.6.17+pt20cu117 163 | - torch-spline-conv==1.2.2+pt20cu117 164 | - torchaudio==2.0.2+cu117 165 | - torchvision==0.15.2+cu117 166 | - tornado==6.3.2 167 | - tqdm==4.60.0 168 | - traitlets==5.9.0 169 | - triton==2.0.0 170 | - typing-extensions==4.6.3 171 | - tzdata==2023.3 172 | - urllib3==1.25.11 173 | - wcwidth==0.2.6 174 | - webencodings==0.5.1 175 | - zipp==3.15.0 176 | prefix: /root/anaconda3/envs/gt01 177 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | from sklearn.preprocessing import normalize 4 | from sklearn.cluster import KMeans 5 | from itertools import chain 6 | import copy, torch, dgl 7 | 8 | 9 | def GetProbTranMat(Ak, num_node): 10 | num_node, num_node2 = Ak.shape 11 | if (num_node != num_node2): 12 | print('M must be a square matrix!') 13 | Ak_sum = np.sum(Ak, axis=0).reshape(1, -1) 14 | Ak_sum = np.repeat(Ak_sum, num_node, axis=0) 15 | probTranMat = np.log(np.divide(Ak, Ak_sum)) - np.log(1. / num_node) 16 | probTranMat[probTranMat < 0] = 0; # set zero for negative and -inf elements 17 | probTranMat[np.isnan(probTranMat)] = 0; # set zero for nan elements (the isolated nodes) 18 | return probTranMat 19 | 20 | 21 | def getM_logM(nx_g, num_nodes, kstep=3): 22 | tran_M = [] 23 | tran_logM = [] 24 | Adj = np.zeros((num_nodes, num_nodes)) 25 | for src in nx_g.nodes(): 26 | src_degree = nx_g.degree(src) 27 | for dst in nx_g.nodes(): 28 | if nx_g.has_edge(src, dst): 29 | Adj[src][dst] = round(1 / src_degree, 3) 30 | 31 | Ak = np.matrix(np.identity(num_nodes)) 32 | for i in range(kstep): 33 | Ak = np.dot(Ak, Adj) 34 | tran_M.append(Ak) 35 | probTranMat = GetProbTranMat(Ak, num_nodes) 36 | tran_logM.append(probTranMat) 37 | 38 | return tran_M, tran_logM 39 | 40 | 41 | def get_distance(deg_A, deg_B): 42 | damp = 1 / np.sqrt(deg_A * deg_B) 43 | return damp 44 | 45 | 46 | def get_B_sim_phi(nx_g, tran_M, num_nodes, n_class, X, kstep=5): 47 | print(f'processing get_B_sim_phi') 48 | count = 0 49 | B = np.zeros((num_nodes, num_nodes)) 50 | colour = np.zeros((num_nodes, num_nodes)) 51 | phi = np.zeros((num_nodes, num_nodes, 1)) 52 | sim = np.zeros((num_nodes, num_nodes, kstep)) 53 | 54 | trans_check = tran_M[kstep - 1] 55 | not_adj = tran_M[0] 56 | 57 | kmeans = KMeans(n_clusters=n_class, init='k-means++', max_iter=50, n_init=10, random_state=0) 58 | 59 | y_kmeans = kmeans.fit_predict(X) 60 | count = 0 61 | count_1 = 0 62 | for src in nx_g.nodes(): 63 | if count % 50 == 0: 64 | print(f' processing node_th {src}/{num_nodes}') 65 | for dst in nx_g.nodes(): 66 | 67 | if src == dst: 68 | continue 69 | 70 | if not_adj[src, dst] > 0: 71 | continue 72 | 73 | if colour[src, dst] == 1 or colour[src, dst] == 1: 74 | continue 75 | if trans_check[src, dst] > 0.001: 76 | 77 | src_d = nx_g.degree(src) 78 | dst_d = nx_g.degree(dst) 79 | 80 | if np.abs(src_d - dst_d) > 1: 81 | continue 82 | 83 | if y_kmeans[src] != y_kmeans[dst]: 84 | continue 85 | else: 86 | count_1 += 1 87 | d = get_distance(src_d, dst_d) 88 | # B i, j 89 | B[src, dst] = d 90 | B[dst, src] = d 91 | # phi i,j 92 | if phi[src, dst] == 0: 93 | phi[src, dst] = d 94 | phi[dst, src] = d 95 | 96 | colour[src, dst] = 1 97 | colour[dst, src] = 1 98 | B[src, src] = 0 99 | count += 1 100 | 101 | 102 | 103 | sim = compute_sim(tran_M, num_nodes, k_step=kstep) 104 | 105 | return B, sim, phi 106 | 107 | 108 | def compute_sim(tran_M, num_nodes, k_step=5): 109 | sim = np.zeros((num_nodes, num_nodes, k_step)) 110 | trans_check = tran_M[k_step - 1] 111 | 112 | for step in range(k_step): 113 | print(f'compute_sim transition step {step + 1}/{k_step}') 114 | colour = np.zeros((num_nodes, num_nodes)) 115 | trans_k = copy.deepcopy(tran_M[step]) 116 | trans_k[trans_k >= 0.001] = 1 117 | trans_k[trans_k < 0.001] = 0 118 | trans_k = np.array(trans_k) 119 | 120 | row_sums = trans_k.sum(axis=1) 121 | trans_mul = trans_k @ trans_k.T 122 | for i in range(num_nodes): 123 | 124 | for j in range(i + 1, num_nodes): 125 | if trans_check[i, j] < 0.0001: 126 | continue 127 | if colour[i, j] == 1 or colour[j, i] == 1: 128 | continue 129 | 130 | score = np.round(trans_mul[i, j] / (row_sums[i] + row_sums[j] - trans_mul[i, j]), 4) 131 | if score < 0.001: 132 | score = 0 133 | sim[i, j, step] = score 134 | sim[j, i, step] = score 135 | 136 | colour[i, j] = 1 137 | colour[j, i] = 1 138 | return sim 139 | 140 | 141 | def get_A_D(nx_g, num_nodes): 142 | num_edges = nx_g.number_of_edges() 143 | 144 | d = np.zeros((num_nodes)) 145 | 146 | Adj = np.zeros((num_nodes, num_nodes)) 147 | 148 | for src in nx_g.nodes(): 149 | src_degree = nx_g.degree(src) 150 | d[src] = src_degree 151 | for dst in nx_g.nodes(): 152 | if nx_g.has_edge(src, dst): 153 | Adj[src][dst] = 1 154 | 155 | 156 | return Adj, d, num_edges 157 | 158 | 159 | def load_dgl(nx_g, x, sim, phi): 160 | print('loading dgl...') 161 | count = 0 162 | edge_idx1 = [] 163 | edge_idx2 = [] 164 | for e in nx_g.edges: 165 | edge_idx1.append(e[0]) 166 | edge_idx2.append(e[1]) 167 | 168 | edge_idx1.append(e[1]) 169 | edge_idx2.append(e[0]) 170 | 171 | s_vals = [] 172 | phi_vals = [] 173 | for i in range(len(edge_idx1)): 174 | count += 1 175 | n1 = edge_idx1[i] 176 | n2 = edge_idx2[i] 177 | 178 | s = np.asarray(sim[n1][n2], dtype=float) 179 | s_vals.append(s) 180 | 181 | p = np.asarray(phi[n1][n2], dtype=float) 182 | phi_vals.append(p) 183 | 184 | print(f'networkx: number edges: {count}') 185 | s_vals = np.array(s_vals) 186 | phi_vals = np.array(phi_vals) 187 | 188 | s_vals[np.isnan(s_vals)] = 0 189 | s_vals = normalize(s_vals, axis=0, norm='max') 190 | phi_vals = normalize(phi_vals, axis=0, norm='max') 191 | 192 | s_vals = torch.tensor(s_vals) 193 | phi_vals = torch.tensor(phi_vals) 194 | 195 | g = dgl.graph((edge_idx1, edge_idx2)) 196 | 197 | s_vals[torch.isnan(s_vals)] = 0 198 | phi_vals[torch.isnan(phi_vals)] = 0 199 | 200 | g.ndata['x'] = x 201 | g.edata['sim'] = s_vals 202 | g.edata['phi'] = phi_vals 203 | 204 | print(f'loading dgl, done, DGL graph edges: {g.number_of_edges()}') 205 | return g 206 | 207 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Community-aware Graph Transformer 2 | 3 | Community-aware Graph Transformer (**CGT**) is a novel Graph Transformer model that utilizes community structures to address node degree biases in message-passing mechanism and developed by [NS Lab, CUK](https://nslab-cuk.github.io/) based on pure [PyTorch](https://github.com/pytorch/pytorch) backend. The paper is available on [arXiv](https://arxiv.org/abs/2312.16788). 4 | 5 |

6 | 7 | Python 8 | 9 | 10 | pytorch 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |

20 | 21 |
22 | 23 | 24 | ## 1. Overview 25 | 26 | Recent augmentation-based methods showed that message-passing (MP) neural networks often perform poorly on low-degree nodes, leading to degree biases due to a lack of messages reaching low-degree nodes. Despite their success, most methods use heuristic or uniform random augmentations, which are non-differentiable and may not always generate valuable edges for learning representations. In this paper, we propose Community-aware Graph Transformers, namely **CGT**, to learn degree-unbiased representations based on learnable augmentations and graph transformers by extracting within community structures. We first design a learnable graph augmentation to generate more within-community edges connecting low-degree nodes through edge perturbation. Second, we propose an improved self-attention to learn underlying proximity and the roles of nodes within the community. Third, we propose a self-supervised learning task that could learn the representations to preserve the global graph structure and regularize the graph augmentations. Extensive experiments on various benchmark datasets showed CGT outperforms state-of-the-art baselines and significantly improves the node degree biases. 27 | 28 |
29 | 30 |

31 | Graph Transformer Architecture 32 |
33 | The overall architecture of Community-aware Graph Transformer. 34 |

35 | 36 | 37 | ## 2. Reproducibility 38 | 39 | ### Datasets and Tasks 40 | 41 | We used six publicly available datasets, which are grouped into three different domains, including citation network (Cora, Citeseer, and Pubmed datasets), Co-purchase network networks (Amazon Computers and Photo datasets), and reference network (WikiCS). The datasets are automatically downloaded from Pytorch Geometric. 42 | 43 | ### Requirements and Environment Setup 44 | 45 | The source code was developed in Python 3.8.8. CGT is built using Torch-geometric 2.3.1 and DGL 1.1.0. Please refer to the official websites for installation and setup. 46 | All the requirements are included in the ```environment.yml``` file. 47 | 48 | ``` 49 | # Conda installation 50 | 51 | # Install python environment 52 | 53 | conda env create -f environment.yml 54 | ``` 55 | ### Hyperparameters 56 | 57 | The following Options can be passed to exp.py: 58 | 59 | ```--dataset:``` The name of dataset inputs. For example: ```--dataset cora``` 60 | 61 | ```--lr:``` Learning rate for training the model. For example: ```--lr 0.001``` 62 | 63 | ```--epochs:``` Number of epochs for pre-training the model. For example: ```--epochs 500``` 64 | 65 | ```--run_times_fine:``` Number of epochs for fine-tuning the model. For example: ```--run_times_fine 500``` 66 | 67 | ```--layers:``` Number of layers for model training. For example: ```--layers 4``` 68 | 69 | ```--drop:``` Dropout rate. For example: ```--drop 0.5``` 70 | 71 | ```--dims:``` The dimmension of hidden vectors. For example: ```--dims 64```. 72 | 73 | ```--k_transition:``` The number of transition step. For example: ```--k_transition 3```. 74 | 75 | ```--alpha:``` Hyperparameters for degree-related score. For example: ```--alpha 0.1```. 76 | 77 | ```--beta:``` Hyperparameters for adjacency matrix score. For example: ```--beta 0.95```. 78 | 79 | ```--alpha_1:``` Hyperparameters for transition construction loss. For example: ```--alpha_1 0.5```. 80 | 81 | ```--alpha_2:``` Hyperparameters for feature construction loss. For example: ```--alpha_2 0.5```. 82 | 83 | ```--alpha_3:``` Hyperparameters for augmentation loss. For example: ```--alpha_3 0.5```. 84 | 85 | 86 | ### How to run 87 | 88 | The source code contains both pre-training and fine-tuning processes. 89 | The following commands will run the pre-training process and fine-tune the **CGT** on Cora dataset for both node classification and clustering tasks. 90 | 91 | ``` 92 | 93 | python exp.py --dataset cora 94 | 95 | ``` 96 | 97 | ## 3. Reference 98 | 99 | :page_with_curl: Paper [on IEEE TNSE](https://ieeexplore.ieee.org/xpl/RecentIssue.jsp?punumber=6488902): 100 | * [![DOI](http://img.shields.io/:DOI-10.1109/TNSE.2025.3563697-FAB70C?style=flat-square&logo=doi)](https://doi.org/10.1109/TNSE.2025.3563697) 101 | 102 | :page_with_curl: Paper [on arXiv](https://arxiv.org/): 103 | * [![arXiv](https://img.shields.io/badge/arXiv-2504.15075-b31b1b?style=flat-square&logo=arxiv&logoColor=red)](https://arxiv.org/abs/2504.15075) 104 | [![arXiv](https://img.shields.io/badge/arXiv--Previous-2312.16788-b31b1b?style=flat-square&logo=arxiv&logoColor=red)](https://arxiv.org/abs/2312.16788) 105 | 106 | :chart_with_upwards_trend: Experimental results [on Papers With Code](https://paperswithcode.com/): 107 | * [![PwC](https://custom-icon-badges.demolab.com/badge/Papers%20With%20Code-CGT-21CBCE?style=flat-square&logo=paperswithcode)](https://paperswithcode.com/paper/mitigating-degree-bias-in-graph) 108 | [![PwC](https://custom-icon-badges.demolab.com/badge/Papers%20With%20Code--Previous-CGT-21CBCE?style=flat-square&logo=paperswithcode)](https://paperswithcode.com/paper/mitigating-degree-biases-in-message-passing) 109 | 110 | :pencil: Blog post [on Network Science Lab](https://nslab-cuk.github.io/2023/08/17/UGT/): 111 | * [![Web](https://img.shields.io/badge/NS@CUK-Post-0C2E86?style=flat-square&logo=jekyll&logoColor=FFFFFF)](https://nslab-cuk.github.io/2023/12/27/CGT/) 112 | 113 | 114 | ## 4. Citing CGT 115 | 116 | Please cite our [paper](https://ieeexplore.ieee.org/document/10974679) if you find *CGT* useful in your work: 117 | ``` 118 | @Article{hoang2025mitigating_TNSE, 119 | author = {Van Thuy Hoang and Hyeon-Ju Jeon and O-Joun Lee}, 120 | journal = {IEEE Transactions on Network Science and Engineering}, 121 | title = {Mitigating Degree Bias in Graph Representation Learning with Learnable Structural Augmentation and Structural Self-Attention}, 122 | year = {2025}, 123 | issn = {2327-4697}, 124 | volume = {12}, 125 | number = {5}, 126 | pages = {3656--3670}, 127 | doi = {10.1109/TNSE.2025.3563697}, 128 | } 129 | 130 | @misc{hoang2025mitigating, 131 | title={Mitigating Degree Bias in Graph Representation Learning with Learnable Structural Augmentation and Structural Self-Attention}, 132 | author={Van Thuy Hoang and Hyeon-Ju Jeon and O-Joun Lee}, 133 | year={2025}, 134 | eprint={2504.15075}, 135 | archivePrefix={arXiv}, 136 | primaryClass={cs.AI} 137 | } 138 | 139 | @misc{hoang2023mitigating, 140 | title={Mitigating Degree Biases in Message Passing Mechanism by Utilizing Community Structures}, 141 | author={Van Thuy Hoang and O-Joun Lee}, 142 | year={2023}, 143 | eprint={2312.16788}, 144 | archivePrefix={arXiv}, 145 | primaryClass={cs.LG} 146 | } 147 | ``` 148 | 149 | Please take a look at our structure-preserving graph transformer model, [**UGT**](https://github.com/NSLab-CUK/Unified-Graph-Transformer), which has as high expessive power as 3-WL, together. 150 | 151 | ## 5. Contributors 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 |
160 | 161 | *** 162 | 163 | 164 | 165 | *** 166 | 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from typing import Optional, Callable 3 | 4 | import os.path as osp 5 | 6 | import torch 7 | import numpy as np 8 | 9 | from torch_geometric.utils import to_undirected 10 | from torch_geometric.data import InMemoryDataset, download_url, Data 11 | from torch_geometric.datasets import Planetoid, Amazon, WikipediaNetwork, Coauthor, WikiCS, SNAPDataset 12 | 13 | 14 | 15 | def mask_init(self, num_train_per_class=20, num_val_per_class=30, seed=12345): 16 | num_nodes = self.data.y.size(0) 17 | self.train_mask = torch.zeros([num_nodes], dtype=torch.bool) 18 | self.val_mask = torch.zeros([num_nodes], dtype=torch.bool) 19 | self.test_mask = torch.ones([num_nodes], dtype=torch.bool) 20 | np.random.seed(seed) 21 | for c in range(self.num_classes): 22 | samples_idx = (self.data.y == c).nonzero().squeeze() 23 | perm = list(range(samples_idx.size(0))) 24 | np.random.shuffle(perm) 25 | perm = torch.as_tensor(perm).long() 26 | self.train_mask[samples_idx[perm][:num_train_per_class]] = True 27 | self.val_mask[samples_idx[perm][num_train_per_class:num_train_per_class + num_val_per_class]] = True 28 | self.test_mask[self.train_mask] = False 29 | self.test_mask[self.val_mask] = False 30 | 31 | 32 | def mask_getitem(self, datum): 33 | datum.__setitem__("train_mask", self.train_mask) 34 | datum.__setitem__("val_mask", self.val_mask) 35 | datum.__setitem__("test_mask", self.test_mask) 36 | return datum 37 | 38 | 39 | class DigitizeY(object): 40 | 41 | def __init__(self, bins, transform_y=None): 42 | self.bins = np.asarray(bins) 43 | self.transform_y = transform_y 44 | 45 | def __call__(self, data): 46 | y = self.transform_y(data.y).numpy() 47 | digitized_y = np.digitize(y, self.bins) 48 | data.y = torch.from_numpy(digitized_y) 49 | return data 50 | 51 | def __repr__(self): 52 | return '{}(bins={})'.format(self.__class__.__name__, self.bins.tolist()) 53 | 54 | 55 | 56 | class WikipediaNetwork_crocodile(InMemoryDataset): 57 | r"""The Wikipedia networks introduced in the 58 | `"Multi-scale Attributed Node Embedding" 59 | `_ paper. 60 | Nodes represent web pages and edges represent hyperlinks between them. 61 | Node features represent several informative nouns in the Wikipedia pages. 62 | The task is to predict the average daily traffic of the web page. 63 | Args: 64 | root (string): Root directory where the dataset should be saved. 65 | name (string): The name of the dataset (:obj:`"chameleon"`, 66 | :obj:`"crocodile"`, :obj:`"squirrel"`). 67 | geom_gcn_preprocess (bool): If set to :obj:`True`, will load the 68 | pre-processing data as introduced in the `"Geom-GCN: Geometric 69 | Graph Convolutional Networks" _`, 70 | in which the average monthly traffic of the web page is converted 71 | into five categories to predict. 72 | If set to :obj:`True`, the dataset :obj:`"crocodile"` is not 73 | available. 74 | transform (callable, optional): A function/transform that takes in an 75 | :obj:`torch_geometric.data.Data` object and returns a transformed 76 | version. The data object will be transformed before every access. 77 | (default: :obj:`None`) 78 | pre_transform (callable, optional): A function/transform that takes in 79 | an :obj:`torch_geometric.data.Data` object and returns a 80 | transformed version. The data object will be transformed before 81 | being saved to disk. (default: :obj:`None`) 82 | """ 83 | 84 | raw_url = 'https://graphmining.ai/datasets/ptg/wiki' 85 | processed_url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/' 86 | 'geom-gcn/master') 87 | 88 | def __init__(self, root: str, name: str, geom_gcn_preprocess: bool = True, 89 | transform: Optional[Callable] = None, 90 | pre_transform: Optional[Callable] = None): 91 | self.name = name.lower() 92 | self.geom_gcn_preprocess = geom_gcn_preprocess 93 | assert self.name in ['chameleon', 'crocodile', 'squirrel'] 94 | if geom_gcn_preprocess and self.name == 'crocodile': 95 | raise AttributeError("The dataset 'crocodile' is not available in " 96 | "case 'geom_gcn_preprocess=True'") 97 | super().__init__(root, transform, pre_transform) 98 | self.data, self.slices = torch.load(self.processed_paths[0]) 99 | 100 | @property 101 | def raw_dir(self) -> str: 102 | if self.geom_gcn_preprocess: 103 | return osp.join(self.root, self.name, 'geom_gcn', 'raw') 104 | else: 105 | return osp.join(self.root, self.name, 'raw') 106 | 107 | @property 108 | def processed_dir(self) -> str: 109 | if self.geom_gcn_preprocess: 110 | return osp.join(self.root, self.name, 'geom_gcn', 'processed') 111 | else: 112 | return osp.join(self.root, self.name, 'processed') 113 | 114 | @property 115 | def raw_file_names(self) -> str: 116 | if self.geom_gcn_preprocess: 117 | return (['out1_node_feature_label.txt', 'out1_graph_edges.txt'] + 118 | [f'{self.name}_split_0.6_0.2_{i}.npz' for i in range(10)]) 119 | else: 120 | return f'{self.name}.npz' 121 | 122 | @property 123 | def processed_file_names(self) -> str: 124 | return 'data.pt' 125 | 126 | def download(self): 127 | if self.geom_gcn_preprocess: 128 | for filename in self.raw_file_names[:2]: 129 | url = f'{self.processed_url}/new_data/{self.name}/{filename}' 130 | download_url(url, self.raw_dir) 131 | for filename in self.raw_file_names[2:]: 132 | url = f'{self.processed_url}/splits/{filename}' 133 | download_url(url, self.raw_dir) 134 | else: 135 | download_url(f'{self.raw_url}/{self.name}.npz', self.raw_dir) 136 | 137 | def process(self): 138 | if self.geom_gcn_preprocess: 139 | with open(self.raw_paths[0], 'r') as f: 140 | data = f.read().split('\n')[1:-1] 141 | x = [[float(v) for v in r.split('\t')[1].split(',')] for r in data] 142 | x = torch.tensor(x, dtype=torch.float) 143 | y = [int(r.split('\t')[2]) for r in data] 144 | y = torch.tensor(y, dtype=torch.long) 145 | 146 | with open(self.raw_paths[1], 'r') as f: 147 | data = f.read().split('\n')[1:-1] 148 | data = [[int(v) for v in r.split('\t')] for r in data] 149 | edge_index = torch.tensor(data, dtype=torch.long).t().contiguous() 150 | # edge_index = to_undirected(edge_index, num_nodes=x.size(0)) 151 | print('test') 152 | train_masks, val_masks, test_masks = [], [], [] 153 | for filepath in self.raw_paths[2:]: 154 | f = np.load(filepath) 155 | train_masks += [torch.from_numpy(f['train_mask'])] 156 | val_masks += [torch.from_numpy(f['val_mask'])] 157 | test_masks += [torch.from_numpy(f['test_mask'])] 158 | train_mask = torch.stack(train_masks, dim=1).to(torch.bool) 159 | val_mask = torch.stack(val_masks, dim=1).to(torch.bool) 160 | test_mask = torch.stack(test_masks, dim=1).to(torch.bool) 161 | 162 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, 163 | val_mask=val_mask, test_mask=test_mask) 164 | 165 | else: 166 | data = np.load(self.raw_paths[0], 'r', allow_pickle=True) 167 | x = torch.from_numpy(data['features']).to(torch.float) 168 | edge_index = torch.from_numpy(data['edges']).to(torch.long) 169 | edge_index = edge_index.t().contiguous() 170 | # edge_index = to_undirected(edge_index, num_nodes=x.size(0)) 171 | y = torch.from_numpy(data['label']).to(torch.float) 172 | train_mask = torch.from_numpy(data['train_mask']).to(torch.bool) 173 | test_mask = torch.from_numpy(data['test_mask']).to(torch.bool) 174 | val_mask = torch.from_numpy(data['val_mask']).to(torch.bool) 175 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, 176 | val_mask=val_mask, test_mask=test_mask) 177 | if self.pre_transform is not None: 178 | data = self.pre_transform(data) 179 | 180 | torch.save(self.collate([data]), self.processed_paths[0]) 181 | 182 | #%% 183 | from ogb.nodeproppred import NodePropPredDataset 184 | 185 | def even_quantile_labels(vals, nclasses, verbose=True): 186 | """ partitions vals into nclasses by a quantile based split, 187 | where the first class is less than the 1/nclasses quantile, 188 | second class is less than the 2/nclasses quantile, and so on 189 | 190 | vals is np array 191 | returns an np array of int class labels 192 | """ 193 | label = -1 * np.ones(vals.shape[0], dtype=np.int) 194 | interval_lst = [] 195 | lower = -np.inf 196 | for k in range(nclasses - 1): 197 | upper = np.nanquantile(vals, (k + 1) / nclasses) 198 | interval_lst.append((lower, upper)) 199 | inds = (vals >= lower) * (vals < upper) 200 | label[inds] = k 201 | lower = upper 202 | label[vals >= lower] = nclasses - 1 203 | interval_lst.append((lower, np.inf)) 204 | if verbose: 205 | print('Class Label Intervals:') 206 | for class_idx, interval in enumerate(interval_lst): 207 | print(f'Class {class_idx}: [{interval[0]}, {interval[1]})]') 208 | return label 209 | 210 | def load_arxiv_year_dataset(root): 211 | ogb_dataset = NodePropPredDataset(name='ogbn-arxiv',root=root) 212 | graph = ogb_dataset.graph 213 | graph['edge_index'] = torch.as_tensor(graph['edge_index']) 214 | graph['node_feat'] = torch.as_tensor(graph['node_feat']) 215 | 216 | label = even_quantile_labels(graph['node_year'].flatten(), 5, verbose=False) 217 | label = torch.as_tensor(label).reshape(-1, 1) 218 | import os 219 | split_idx_lst = load_fixed_splits("arxiv-year",os.path.join(root,"splits")) 220 | 221 | train_mask = torch.stack([split["train"] for split in split_idx_lst],dim=1) 222 | val_mask = torch.stack([split["valid"] for split in split_idx_lst],dim=1) 223 | test_mask = torch.stack([split["test"] for split in split_idx_lst],dim=1) 224 | data = Data(x=graph["node_feat"],y=torch.squeeze(label.long()),edge_index=graph["edge_index"],\ 225 | train_mask=train_mask,val_mask=val_mask,test_mask=test_mask) 226 | return data 227 | 228 | 229 | 230 | def load_fixed_splits(dataset,split_dir): 231 | """ loads saved fixed splits for dataset 232 | """ 233 | name = dataset 234 | import os 235 | splits_lst = np.load(os.path.join(split_dir,"{}-splits.npy".format(name)), allow_pickle=True) 236 | for i in range(len(splits_lst)): 237 | for key in splits_lst[i]: 238 | if not torch.is_tensor(splits_lst[i][key]): 239 | splits_lst[i][key] = torch.as_tensor(splits_lst[i][key]) 240 | return splits_lst 241 | # %% -------------------------------------------------------------------------------- /exp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import logging 4 | import math 5 | import time 6 | from pathlib import Path 7 | from torch_geometric.utils import to_networkx 8 | import networkx as nx 9 | import numpy as np 10 | import scipy.sparse as sp 11 | import torch 12 | from tqdm import tqdm 13 | from sklearn.preprocessing import normalize 14 | import os 15 | import random 16 | import dgl 17 | import warnings 18 | warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') 19 | warnings.filterwarnings('ignore', category=RuntimeWarning, message='scipy._lib.messagestream.MessageStream') 20 | from gnnutils import make_masks, train, test, add_original_graph, load_webkb, load_planetoid, load_wiki, load_bgp, \ 21 | load_film, load_airports, load_amazon, load_coauthor, load_WikiCS, load_crocodile, load_Cora_ML 22 | 23 | from util import get_B_sim_phi, getM_logM, load_dgl, get_A_D 24 | 25 | from models import Transformer 26 | 27 | MODEl_DICT = {"Transformer": Transformer} 28 | from script_classification import run_node_classification, run_node_clustering, update_evaluation_value 29 | 30 | 31 | def filter_rels(data, r): 32 | data = copy.deepcopy(data) 33 | mask = data.edge_color <= r 34 | data.edge_index = data.edge_index[:, mask] 35 | data.edge_weight = data.edge_weight[mask] 36 | data.edge_color = data.edge_color[mask] 37 | return data 38 | 39 | 40 | def run(i, data, num_features, num_classes, g, M, adj_org, adj, logM, B, sim, phi, degree, n_edges, aug_check, 41 | sim_check, phi_check, pretrain_check,test_node_degree): 42 | if args.model in MODEl_DICT: 43 | model = MODEl_DICT[args.model](num_features, 44 | args.out_size, 45 | num_classes, 46 | hidden_dim=args.dims, 47 | num_layers=args.num_layers, 48 | num_heads=args.num_heads, 49 | k_transition=args.k_transition, 50 | aug_check=aug_check, 51 | sim_check=sim_check, 52 | phi_check=phi_check, 53 | alfa=args.alfa, beta=args.beta, 54 | ) 55 | dataset_1 = ['cora', 'citeseer', 'Photo','WikiCS'] 56 | 57 | if args.dataset in dataset_1: 58 | model.to(device) 59 | else: 60 | if torch.cuda.device_count() > 1: 61 | 62 | id_2 = int ((str(args.device_2)).split(":")[1]) 63 | print(f'Running multi-GPUs {id_2}') 64 | model = torch.nn.DataParallel(model, device_ids=[id_2]) 65 | 66 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-5) 67 | 68 | if args.custom_masks: 69 | # create random mask 70 | data = make_masks(data, val_test_ratio = 0.0) 71 | train_mask = data.train_mask 72 | val_mask = data.val_mask 73 | test_mask = data.test_mask 74 | best_val_acc = 0 75 | 76 | pat = 20 77 | best_model = model 78 | graph_name = args.dataset 79 | best_loss = 100000000 80 | 81 | best_epoch = 0 82 | 83 | num_epochs = args.epochs + 1 84 | 85 | if args.pretrain_check == 0: 86 | print('pretrain_check: False, fine tunning') 87 | epoch = 1 88 | torch.save(model, 89 | '{}{}_{}_{}_{}_{}_{}_{}.pt'.format(args.output_path, args.dataset, args.lr, args.dims, 90 | args.num_layers, args.k_transition, args.alfa, args.beta)) 91 | else: # pre-training 92 | print(f'pre-training') 93 | for epoch in range(1, num_epochs): 94 | 95 | train_loss = train(model, data, train_mask, optimizer, device,device_2, g=g, adj_org=adj_org, 96 | trans_logM=logM, sim=sim, phi=phi, B=B, k_transition=args.k_transition, current_epoch=epoch, 97 | alpha_1=args.alpha_1, alpha_2=args.alpha_2, alpha_3=args.alpha_3) 98 | 99 | if best_loss >= train_loss: 100 | best_model = model 101 | best_epoch = epoch 102 | best_loss = train_loss 103 | 104 | if epoch - best_epoch > 200: 105 | break 106 | if epoch % 1 == 0: 107 | print('Epoch: {:02d}, Best Epoch {:02d}, Train Loss: {:0.4f}'.format(epoch, best_epoch, train_loss)) 108 | 109 | print(' saving model and embeddings') 110 | torch.save(best_model, 111 | '{}{}_{}_{}_{}_{}_{}_{}.pt'.format(args.output_path, args.dataset, args.lr, args.dims, 112 | args.num_layers, args.k_transition, args.alfa, args.beta)) 113 | time.sleep(1) 114 | run_node_classification(args, args.index_excel, args.dataset, args.output_path, args.file_name, data, num_features, 115 | args.out_size, 116 | num_classes, g, adj_org, M, logM, sim, phi, B, degree, args.k_transition, device,device_2, 117 | args.run_times_fine, current_epoch=epoch, aug_check=aug_check,sim_check=sim_check,phi_check=phi_check,test_node_degree = test_node_degree) 118 | 119 | time.sleep(1) 120 | 121 | run_node_clustering(args, args.index_excel, args.dataset, args.output_path, args.file_name, data, num_features, 122 | args.out_size, 123 | num_classes, g, M, logM, sim, phi, B, args.k_transition, device,device_2, args.run_times_fine, adj, 124 | degree, n_edges, current_epoch=epoch, aug_check=aug_check,sim_check=sim_check,phi_check=phi_check) 125 | 126 | time.sleep(1) 127 | 128 | 129 | import collections 130 | from collections import defaultdict 131 | 132 | 133 | def main(): 134 | timestr = time.strftime("%Y%m%d-%H%M%S") 135 | log_file = args.dataset + "-" + timestr + ".log" 136 | Path("./exp_logs").mkdir(parents=True, exist_ok=True) 137 | logging.basicConfig(filename="exp_logs/" + log_file, filemode="w", level=logging.INFO) 138 | logging.info("Starting on device: %s", device) 139 | logging.info("Config: %s ", args) 140 | isbgp = False 141 | if args.dataset in ['cora', 'citeseer', 'pubmed']: 142 | og_data, _ = load_planetoid(args.dataset) 143 | 144 | elif args.dataset in ['Computers', 'Photo']: 145 | print('computers and photo processing...') 146 | assert args.custom_masks == True 147 | og_data, _ = load_amazon(args.dataset) 148 | elif args.dataset in ['WikiCS']: 149 | print('WikiCS processing...') 150 | assert args.custom_masks == True 151 | og_data, _ = load_WikiCS(args.dataset) 152 | else: 153 | raise NotImplementedError 154 | 155 | data = og_data 156 | 157 | num_classes = len(data.y.unique()) 158 | 159 | num_features = data.x.shape[1] 160 | nx_g = to_networkx(data, to_undirected=True) 161 | 162 | print(f'Loading dataset: {args.dataset}, data zize: {data.x.size()}') 163 | 164 | 165 | num_of_nodes = nx_g.number_of_nodes() 166 | print(f'number of nodes: {num_of_nodes}') 167 | 168 | adj, degree, n_edges = get_A_D(nx_g, num_of_nodes) 169 | 170 | ### load information 171 | 172 | path = "pts/" + args.dataset + "_kstep_" + str(args.k_transition) + ".pt" 173 | if not os.path.exists(path): 174 | M, logM, B, sim, phi = load_bias(nx_g, num_of_nodes, num_classes, data.x) 175 | 176 | M = torch.from_numpy(np.array(M)).float() 177 | logM = torch.from_numpy(np.array(logM)).float() 178 | B = torch.from_numpy(B).float() 179 | sim = torch.from_numpy(np.array(sim)).float() 180 | phi = torch.from_numpy(phi).float() 181 | 182 | print('saving M, logM,B, sim, phi') 183 | save_info(M, logM, B, sim, phi) 184 | 185 | else: 186 | print('file exist, loading M, logM, B, sim, phi') 187 | M, logM, B, sim, phi = load_info(path) 188 | 189 | sim[torch.isnan(sim)] = 0 190 | 191 | g = load_dgl(nx_g, data.x, sim, phi) 192 | 193 | adj_org = torch.from_numpy(adj) 194 | 195 | runs_acc = [] 196 | 197 | for i in tqdm(range(args.run_times)): 198 | acc = run(i, data, num_features, num_classes, g, M, adj_org, adj, logM, B, sim, phi, degree, n_edges, 199 | args.aug_check, args.sim_check, args.phi_check, args.pretrain_check, args.test_node_degree) 200 | runs_acc.append(acc) 201 | 202 | 203 | def save_info(M, logM, B, sim, phi): 204 | path = "pts/" 205 | torch.save({"M": M, "logM": logM, "B": B, "sim": sim, "phi": phi}, 206 | path + args.dataset + "_kstep_" + str(args.k_transition) + '.pt') 207 | print('save_info done.') 208 | 209 | 210 | def load_info(path): 211 | dic = torch.load(path) 212 | M = dic['M'] 213 | logM = dic['logM'] 214 | B = dic['B'] 215 | sim = dic['sim'] 216 | 217 | phi = dic['phi'] 218 | print('load_info done.') 219 | return M, logM, B, sim, phi 220 | 221 | 222 | def load_bias(nx_g, num_nodes, n_class, X): 223 | M, logM = getM_logM(nx_g, num_nodes, kstep=args.k_transition) 224 | B, sim, phi = get_B_sim_phi(nx_g, M, num_nodes, n_class, X, kstep=args.k_transition) 225 | 226 | 227 | return M, logM, B, sim, phi 228 | 229 | if __name__ == '__main__': 230 | parser = argparse.ArgumentParser(description="Experiments") 231 | # 232 | parser.add_argument("--dataset", default="citeseer", help="Dataset") 233 | parser.add_argument("--model", default="Transformer", help="GNN Model") 234 | 235 | parser.add_argument("--run_times", type=int, default=1) 236 | 237 | parser.add_argument("--drop", type=float, default=0.5, help="dropout") 238 | parser.add_argument("--custom_masks", default=True, action='store_true', help="custom train/val/test masks") 239 | 240 | # adding args 241 | parser.add_argument("--device", default="cuda:0", help="GPU ids") 242 | parser.add_argument("--device_2", default="cuda:0", help="GPU ids") 243 | 244 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 245 | parser.add_argument("--epochs", type=int, default=3) 246 | 247 | parser.add_argument("--dims", type=int, default=64, help="hidden dims") 248 | parser.add_argument("--out_size", type=int, default=64, help="outsize dims") 249 | 250 | parser.add_argument("--k_transition", type=int, default=3) 251 | parser.add_argument("--num_layers", type=int, default=4) 252 | parser.add_argument("--num_heads", type=int, default=2) 253 | parser.add_argument("--output_path", default="outputs/", help="outputs model") 254 | 255 | parser.add_argument("--pretrain_check", type=int, default=0) 256 | parser.add_argument("--aug_check", type=int, default=1) 257 | parser.add_argument("--sim_check", type=int, default=1) 258 | parser.add_argument("--phi_check", type=int, default=1) 259 | 260 | parser.add_argument("--alfa", type=float, default=0.1) 261 | parser.add_argument("--beta", type=float, default=0.9) 262 | 263 | parser.add_argument("--run_times_fine", type=int, default=200) 264 | parser.add_argument("--index_excel", type=int, default="-1", help="index_excel") 265 | parser.add_argument("--file_name", default="outputs_excels/cora.xlsx", help="file_name dataset") 266 | 267 | parser.add_argument("--alpha_1", type=float, default=1) 268 | parser.add_argument("--alpha_2", type=float, default=1) 269 | parser.add_argument("--alpha_3", type=float, default=1) 270 | parser.add_argument("--test_node_degree", type=int, default= 0) 271 | 272 | 273 | args = parser.parse_args() 274 | print(args) 275 | device = torch.device(args.device) 276 | device_2 = torch.device(args.device_2) 277 | print(f"Using device:{device}, device_2: {device_2}") 278 | main() -------------------------------------------------------------------------------- /script_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import logging 4 | import math 5 | import time 6 | from pathlib import Path 7 | import numpy as np 8 | import pandas as pd 9 | import matplotlib.pyplot as plt 10 | # import seaborn as sns 11 | import scipy.sparse as sp 12 | import os 13 | import os.path 14 | from torch import Tensor 15 | import torch_geometric 16 | from torch_geometric.utils import to_networkx 17 | from torch_geometric.datasets import Planetoid 18 | import networkx as nx 19 | from torch_geometric.data import Data 20 | from torch.utils.data import Dataset 21 | import dgl.sparse as dglsp 22 | import numpy as np 23 | import torch 24 | from tqdm import tqdm 25 | 26 | import dgl 27 | from dgl import LaplacianPE 28 | # from dgl.nn import LaplacianPosEnc 29 | import networkx as nx 30 | import numpy as np 31 | from sklearn.preprocessing import normalize 32 | from collections import defaultdict 33 | import itertools 34 | 35 | from collections import deque 36 | 37 | from gnnutils import make_masks, test, add_original_graph, load_webkb, load_planetoid, \ 38 | load_wiki, load_bgp, load_film, load_airports, train_finetuning_class, train_finetuning_cluster, test_cluster 39 | 40 | from collections import defaultdict 41 | from collections import deque 42 | from torch.utils.data import DataLoader, ConcatDataset 43 | import random as rnd 44 | import warnings 45 | 46 | warnings.filterwarnings("ignore", message="scipy._lib.messagestream.MessageStream size changed") 47 | warnings.filterwarnings("ignore", message="scipy._lib.messagestream.MessageStream size changed") 48 | 49 | from models import Transformer_class, Transformer_cluster 50 | 51 | MODEl_DICT = {"Transformer_class": Transformer_class} 52 | 53 | db_name = 0 54 | 55 | 56 | def update_evaluation_value(file_path, colume, row, value): 57 | try: 58 | df = pd.read_excel(file_path) 59 | 60 | df[colume][row] = value 61 | 62 | df.to_excel(file_path, sheet_name='data', index=False) 63 | 64 | return 65 | except: 66 | print("Error when saving results! Save again!") 67 | time.sleep(3) 68 | 69 | 70 | def run_node_classification(args, index_excel, ds_name, output_path, file_name, data_all, num_features, out_size, 71 | num_classes, g, adj_org, M, trans_logM, sim, phi, B, degree, k_transition, device,device_2, 72 | num_epochs, current_epoch, aug_check,sim_check,phi_check,test_node_degree): 73 | print("running run_node_classification") 74 | 75 | index_excel = index_excel 76 | 77 | if True: 78 | dataset = args.dataset 79 | lr = args.lr 80 | dims =args.dims 81 | out_size = args.out_size 82 | num_layers = args.num_layers 83 | k_transition = args.k_transition 84 | alfa =args.alfa 85 | beta = args.beta 86 | 87 | print(f"Node class process - {index_excel}") 88 | 89 | cp_filename = output_path + f'{dataset}_{lr}_{dims}_{num_layers}_{k_transition}_{alfa}_{beta}.pt' 90 | 91 | if not os.path.exists(cp_filename): 92 | print(f"run_node_classification: no file: {cp_filename}") 93 | print(f"run_node_classification: no file: {cp_filename}") 94 | return None 95 | 96 | runs_acc = [] 97 | for i in tqdm(range(1)): 98 | print(f'run_node_classification, run time: {i}') 99 | acc = run_epoch_node_classification(i, data_all, num_features, out_size, num_classes, 100 | g, adj_org, M, trans_logM, sim, phi, B, degree, k_transition, 101 | cp_filename, dims, 102 | num_layers, lr, device, device_2,num_epochs, current_epoch, aug_check,sim_check,phi_check,test_node_degree) 103 | runs_acc.append(acc) 104 | 105 | runs_acc = np.array(runs_acc) * 100 106 | 107 | 108 | final_msg = "Node Classification: Mean %0.4f, Std %0.4f" % (runs_acc.mean(), runs_acc.std()) 109 | print(final_msg) 110 | 111 | 112 | def run_epoch_node_classification(i, data, num_features, out_size, num_classes, 113 | g, adj_org, M, trans_logM, sim, phi, B, degree, k_transition, cp_filename, dims, 114 | num_layers, lr, device,device_2, num_epochs, current_epoch,aug_check,sim_check,phi_check,test_node_degree): 115 | graph_name = "" 116 | best_val_acc = 0 117 | best_model = None 118 | pat = 20 119 | best_epoch = 0 120 | 121 | # fine tuning & testing 122 | print('fine tuning ...') 123 | 124 | model = Transformer_class(num_features, out_size, num_classes, hidden_dim=dims, 125 | num_layers=num_layers, num_heads=4, graph_name=graph_name, 126 | cp_filename=cp_filename, aug_check = aug_check,sim_check = sim_check,phi_check = sim_check) #.to(device) 127 | 128 | dataset_1 = ['cora', 'citeseer', 'Photo','WikiCS'] 129 | for ds in dataset_1: 130 | if ds in cp_filename: 131 | model.to(device) 132 | else: 133 | if torch.cuda.device_count() > 1: 134 | 135 | id_2 = int(str(device_2).split(":")[1]) 136 | model = torch.nn.DataParallel(model, device_ids=[id_2]) 137 | 138 | best_model = model 139 | 140 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) 141 | 142 | print("creating random mask") 143 | data = make_masks(data, val_test_ratio=0.2) 144 | train_mask = data.train_mask 145 | val_mask = data.val_mask 146 | test_mask = data.test_mask 147 | 148 | # save dataload 149 | g.ndata["train_mask"] = train_mask 150 | g.ndata["val_mask"] = val_mask 151 | g.ndata["test_mask"] = test_mask 152 | 153 | test_check = 0 154 | for epoch in range(1, num_epochs): 155 | current_epoch = epoch 156 | train_loss, train_acc = train_finetuning_class(model, data, train_mask, optimizer, device,device_2, g=g, adj_org=adj_org, 157 | M=M, trans_logM=trans_logM, sim=sim, phi=phi, B=B, 158 | k_transition=k_transition, 159 | pre_train=0, current_epoch=current_epoch) 160 | 161 | if epoch % 1 == 0: 162 | valid_acc, valid_f1 = test(model, data, val_mask, device,device_2, g=g, adj_org=adj_org, trans_logM=trans_logM, 163 | sim=sim, phi=phi, B=B, degree=degree, 164 | k_transition=k_transition, current_epoch=current_epoch, test_node_degree=0) 165 | 166 | if valid_acc > best_val_acc: 167 | best_val_acc = valid_acc 168 | best_model = model 169 | best_epoch = epoch 170 | pat = (pat + 1) if (pat < 5) else pat 171 | else: 172 | pat -= 1 173 | 174 | if epoch % 1 == 0: 175 | print( 176 | 'Epoch: {:02d}, best_epoch: {:02d}, Train Loss: {:0.4f}, Train Acc: {:0.4f}, Val Acc: {:0.4f} '.format( 177 | epoch, best_epoch, train_loss, train_acc, valid_acc)) 178 | 179 | if epoch - best_epoch > 100: 180 | print("1 validation patience reached ... finish training") 181 | break 182 | 183 | # Testing 184 | test_check = 1 185 | test_acc, test_f1 = test(best_model, data, test_mask, device,device_2, g=g, adj_org=adj_org, trans_logM=trans_logM, sim=sim, 186 | phi=phi, B=B, degree=degree, 187 | k_transition=k_transition, current_epoch=current_epoch, test_node_degree=test_node_degree) 188 | print('Best Val Epoch: {:03d}, Best Val Acc: {:0.4f}, Test Acc: {:0.4f}, F1_test: {:0.4f}'.format( 189 | best_epoch, best_val_acc, test_acc, test_f1)) 190 | return test_acc, test_f1 191 | def run_node_clustering(args, index_excel, ds_name, output_path, file_name, data_all, num_features, out_size, 192 | num_classes, g, M, logM, sim, phi, B, k_transition, device,device_2, num_epochs, adj, d, n_edges, 193 | current_epoch, aug_check,sim_check,phi_check): 194 | index_excel = index_excel 195 | 196 | if True: 197 | dataset = args.dataset 198 | lr = args.lr 199 | dims =args.dims 200 | out_size = args.out_size 201 | num_layers = args.num_layers 202 | k_transition = args.k_transition 203 | alfa =args.alfa 204 | beta = args.beta 205 | 206 | print(f"Node clustering process - {index_excel}") 207 | 208 | cp_filename = output_path + f'{dataset}_{lr}_{dims}_{num_layers}_{k_transition}_{alfa}_{beta}.pt' 209 | 210 | if os.path.isfile(cp_filename) == False: 211 | print(f"run_node_clustering: no file {cp_filename}") 212 | return None 213 | 214 | runs_acc = [] 215 | for i in tqdm(range(1)): 216 | print(f'run time: {i}') 217 | _, _, _, _, q, c = run_epoch_node_clustering(i, data_all, num_features, out_size, 218 | num_classes, 219 | g, M, logM, sim, phi, B, k_transition, 220 | cp_filename, dims, num_layers, lr, device,device_2, 221 | num_epochs, adj, d, n_edges, current_epoch, aug_check,sim_check,phi_check) 222 | #runs_acc.append(acc) 223 | time.sleep(1) 224 | print('Node clustering results: \n Q: {:0.4f}, , C: {:0.4f}'.format(q,c)) 225 | 226 | def run_epoch_node_clustering(i, data, num_features, out_size, num_classes, g, M, logM, sim, phi, B, k_transition, 227 | cp_filename, dims, num_layers, lr, device,device_2, num_epochs, adj, d, n_edges, current_epoch,aug_check,sim_check,phi_check): 228 | graph_name = "" 229 | best_val_acc = 0 230 | best_model = None 231 | pat = 20 232 | best_epoch = 0 233 | 234 | # fine tuning & testing 235 | print('fine tuning run_epoch_node_clustering...') 236 | 237 | model = Transformer_cluster(num_features, out_size, num_classes, hidden_dim=dims, 238 | num_layers=num_layers, num_heads=4, graph_name=graph_name, 239 | cp_filename=cp_filename, aug_check=aug_check,sim_check=sim_check,phi_check=phi_check) #.to(device) 240 | 241 | dataset_1 = ['cora', 'citeseer', 'Photo','WikiCS'] 242 | for ds in dataset_1: 243 | if ds in cp_filename: 244 | model.to(device) 245 | else: 246 | if torch.cuda.device_count() > 1: 247 | id_2 = int(str(device_2).split(":")[1]) 248 | model = torch.nn.DataParallel(model, device_ids=[id_2]) 249 | 250 | best_model = model 251 | 252 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) 253 | 254 | print("creating random mask") 255 | data = make_masks(data, val_test_ratio=0.0) 256 | train_mask = data.train_mask 257 | val_mask = data.val_mask 258 | test_mask = data.test_mask 259 | 260 | # save dataload. 261 | g.ndata["train_mask"] = train_mask 262 | g.ndata["val_mask"] = val_mask 263 | g.ndata["test_mask"] = test_mask 264 | adj = torch.FloatTensor(adj).to(device) 265 | for epoch in range(1, num_epochs): 266 | current_epoch = epoch 267 | train_loss = train_finetuning_cluster(model, data, train_mask, optimizer, device, device_2,g=g, 268 | M=M, trans_logM=logM, sim=sim, phi=phi, B=B, k_transition=k_transition, 269 | pre_train=0, adj=adj, d=d, n_edges=n_edges, current_epoch=current_epoch) 270 | 271 | if epoch % 1 == 0: 272 | print('Epoch: {:02d}, Train Loss: {:0.4f}'.format(epoch, train_loss)) 273 | 274 | # Testing 275 | 276 | acc, precision, recall, nmi, q, c = test_cluster(best_model, data, train_mask, optimizer, device,device_2, g=g, 277 | M=M, trans_logM=logM, sim=sim, phi=phi, B=B, 278 | k_transition=k_transition, pre_train=0, adj=adj, d=d, 279 | n_edges=n_edges, current_epoch=current_epoch) 280 | 281 | 282 | return acc, precision, recall, nmi, q, c 283 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Sequential, Linear, ReLU 4 | from torch_geometric.nn import GCNConv, GINConv, SAGEConv 5 | 6 | import torch.nn as nn 7 | import dgl 8 | import dgl.nn as dglnn 9 | import dgl.sparse as dglsp 10 | import torch 11 | import dgl.function as fn 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import os 15 | import os.path 16 | import torch.optim as optim 17 | import numpy as np 18 | from dgl.data import AsGraphPredDataset 19 | from dgl.dataloading import GraphDataLoader 20 | from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator 21 | from ogb.graphproppred.mol_encoder import AtomEncoder 22 | from tqdm import tqdm 23 | import dgl.function as fn 24 | import pyro 25 | 26 | 27 | class Transformer(nn.Module): 28 | def __init__(self, in_dim, out_dim, n_classes, hidden_dim, num_layers, num_heads, k_transition, aug_check, 29 | sim_check, 30 | phi_check, alfa, beta): 31 | super().__init__() 32 | self.h = None 33 | self.embedding_h = nn.Linear(in_dim, hidden_dim, bias=False) 34 | self.in_dim = in_dim 35 | self.hidden_dim = hidden_dim 36 | self.k_transition = k_transition 37 | 38 | self.aug_check = aug_check 39 | self.sim_check = sim_check 40 | self.phi_check = phi_check 41 | self.afla = alfa 42 | self.beta = beta 43 | 44 | 45 | self.gcn = GCN(in_dim, hidden_dim, self.afla, self.beta) 46 | 47 | self.layers = nn.ModuleList( 48 | [GraphTransformerLayer(hidden_dim, hidden_dim, num_heads, sim_check, phi_check) for _ in range(num_layers)]) 49 | 50 | self.layers.append(GraphTransformerLayer(hidden_dim, out_dim, num_heads, sim_check, phi_check)) 51 | 52 | self.MLP_layer_x = Reconstruct_X(out_dim, in_dim) 53 | 54 | self.embedding_phi = nn.Linear(1, hidden_dim) 55 | self.embedding_sim = nn.Linear(k_transition, hidden_dim) 56 | 57 | def extract_features(self, g, adj_org, X, current_epoch, edge_index, sim, phi, B, k_transition, device, device_2): 58 | 59 | adj_sampled = None 60 | 61 | if self.aug_check == 1: 62 | edge_index_sampled, x_gcn, adj_sampled, check_nan = self.gcn(g, adj_org, X, B, edge_index, current_epoch,device_2) 63 | g = dgl_renew(g, X, edge_index_sampled, sim, phi, k_transition, device,device_2) 64 | 65 | h = self.embedding_h(X) 66 | 67 | phi = g.edata['phi'] 68 | sim = g.edata['sim'] 69 | phi = self.embedding_phi(phi.float()) 70 | sim = self.embedding_sim(sim.float()) 71 | 72 | 73 | for layer in self.layers: 74 | h = layer(h, g, phi, sim, current_epoch) 75 | 76 | return h, adj_sampled 77 | 78 | def forward(self, g, adj_org, sim, phi, B, k_transition, current_epoch, device, device_2): 79 | 80 | X = g.ndata['x'].to(device_2) 81 | 82 | edge_index = torch.stack([g.edges()[0], g.edges()[1]]).to(device_2) 83 | 84 | h, adj_sampled = self.extract_features(g, adj_org, X, current_epoch, edge_index, sim, phi, B, k_transition, device, device_2) 85 | 86 | x_hat = self.MLP_layer_x(h) 87 | 88 | self.h = h 89 | 90 | return h, x_hat, adj_sampled 91 | 92 | 93 | def dgl_renew(g, x0, edge_index_sampled, sim, phi, k_transition, device,device_2): 94 | g = dgl.graph((edge_index_sampled[0], edge_index_sampled[1])) 95 | 96 | 97 | g.edata['sim'] = sim[edge_index_sampled[0], edge_index_sampled[1]] 98 | g.edata['phi'] = phi[edge_index_sampled[0], edge_index_sampled[1]] 99 | 100 | return g 101 | 102 | class GraphTransformerLayer(nn.Module): 103 | """Graph Transformer Layer""" 104 | 105 | def __init__(self, in_dim, out_dim, num_heads, sim_check, phi_check): 106 | super().__init__() 107 | 108 | self.sim_check = sim_check 109 | self.phi_check = phi_check 110 | 111 | self.in_channels = in_dim 112 | self.out_channels = out_dim 113 | self.num_heads = num_heads 114 | 115 | self.attention = MultiHeadAttentionLayer(in_dim, out_dim // num_heads, num_heads, sim_check, phi_check) 116 | 117 | self.O = nn.Linear(out_dim, out_dim) 118 | 119 | self.batchnorm1 = nn.BatchNorm1d(out_dim) 120 | self.batchnorm2 = nn.BatchNorm1d(out_dim) 121 | self.layer_norm1 = nn.LayerNorm(out_dim) 122 | self.layer_norm2 = nn.LayerNorm(out_dim) 123 | 124 | self.FFN_layer1 = nn.Linear(out_dim, out_dim * 2) 125 | self.FFN_layer2 = nn.Linear(out_dim * 2, out_dim) 126 | 127 | def forward(self, h, g, phi, sim, current_epoch): 128 | h_in1 = h # for first residual connection 129 | 130 | attn_out = self.attention(h, g, phi, sim, current_epoch) 131 | 132 | h = attn_out.view(-1, self.out_channels) 133 | 134 | h = self.O(h) 135 | 136 | h = h_in1 + h # residual connection 137 | 138 | h = self.layer_norm1(h) 139 | 140 | h_in2 = h # for second residual connection 141 | 142 | # FFN 143 | h = self.FFN_layer1(h) 144 | h = F.relu(h) 145 | 146 | h = F.dropout(h, 0.5, training=self.training) 147 | h = self.FFN_layer2(h) 148 | h = h_in2 + h # residual connection 149 | h = self.layer_norm2(h) 150 | 151 | return h 152 | 153 | 154 | class MultiHeadAttentionLayer(nn.Module): 155 | # in_dim, out_dim, num_heads 156 | def __init__(self, in_dim, out_dim, num_heads, sim_check, phi_check): 157 | super().__init__() 158 | 159 | self.sim_check = sim_check 160 | self.phi_check = phi_check 161 | 162 | self.out_dim = out_dim 163 | self.num_heads = num_heads 164 | self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) 165 | self.K = nn.Linear(in_dim, out_dim * num_heads, bias=True) 166 | self.V = nn.Linear(in_dim, out_dim * num_heads, bias=True) 167 | 168 | self.hidden_size = in_dim # 80 169 | self.num_heads = num_heads # 8 170 | self.head_dim = out_dim // num_heads # 10 171 | 172 | self.scaling = self.head_dim ** -0.5 173 | 174 | self.q_proj = nn.Linear(in_dim, in_dim) 175 | self.k_proj = nn.Linear(in_dim, in_dim) 176 | self.v_proj = nn.Linear(in_dim, in_dim) 177 | 178 | self.proj_phi = nn.Linear(in_dim, out_dim * num_heads, bias=True) 179 | 180 | self.sim = nn.Linear(in_dim, out_dim * num_heads, bias=True) 181 | 182 | def propagate_attention(self, g): 183 | # Compute attention score 184 | if self.sim_check == 1: 185 | g.apply_edges(src_dot_dst_sim('K_h', 'Q_h', 'sim_h', 'score')) 186 | else: 187 | g.apply_edges(src_dot_dst('K_h', 'Q_h', 'score')) 188 | 189 | g.apply_edges(scaling('score', np.sqrt(self.out_dim))) 190 | 191 | if self.phi_check == 1: 192 | g.apply_edges(imp_add_attn('score', 'proj_phi')) 193 | 194 | # softmax 195 | g.apply_edges(exp('score')) 196 | 197 | eids = g.edges() 198 | g.send_and_recv(eids, dgl.function.u_mul_e('V_h', 'score', 'V_h'), fn.sum('V_h', 'wV')) # src_mul_edge 199 | g.send_and_recv(eids, dgl.function.copy_e('score', 'score'), fn.sum('score', 'z')) # copy_edge 200 | 201 | 202 | def forward(self, h, g, phi, sim, current_epoch): 203 | Q_h = self.Q(h) 204 | K_h = self.K(h) 205 | V_h = self.V(h) 206 | 207 | sim_h = self.sim(sim) 208 | 209 | proj_phi = self.proj_phi(phi) 210 | # proj_sim = self.proj_sim(sim) 211 | 212 | g.ndata['Q_h'] = Q_h.view(-1, self.num_heads, self.out_dim) 213 | g.ndata['K_h'] = K_h.view(-1, self.num_heads, self.out_dim) 214 | g.ndata['V_h'] = V_h.view(-1, self.num_heads, self.out_dim) 215 | g.edata['sim_h'] = sim_h.view(-1, self.num_heads, self.out_dim) 216 | 217 | g.edata['proj_phi'] = proj_phi.view(-1, self.num_heads, self.out_dim) 218 | 219 | self.propagate_attention(g) 220 | 221 | h_out = g.ndata['wV'] / (g.ndata['z'] + torch.full_like(g.ndata['z'], 1e-6)) # adding eps to all values here 222 | 223 | return h_out 224 | 225 | 226 | class GCN(torch.nn.Module): 227 | # g,adj_org, X, B, edge_index, current_epoch, self.afla, self.beta) 228 | def __init__(self, num_features, hidden_dim=64, alfa=0.1, beta=0.95): 229 | super(GCN, self).__init__() 230 | self.conv1 = GCNConv(num_features, hidden_dim * 2) 231 | self.conv2 = GCNConv(hidden_dim * 2, hidden_dim) 232 | self.anfa = alfa 233 | self.beta = beta 234 | 235 | self.MLPA = torch.nn.Sequential( 236 | torch.nn.Linear(hidden_dim, hidden_dim), 237 | torch.nn.ReLU(), 238 | torch.nn.Linear(hidden_dim, hidden_dim)) 239 | 240 | def forward(self, g, adj_org, x, B, edge_index, current_epoch,device_2): 241 | 242 | edge_probs = self.anfa * B + self.beta * adj_org 243 | 244 | edge_probs[edge_probs > 1] = 1 245 | 246 | edge_probs = edge_probs.cuda() 247 | check_nan = False 248 | while True: 249 | # try: 250 | adj_sampled = pyro.distributions.RelaxedBernoulliStraightThrough(temperature=1, probs=edge_probs).rsample() 251 | 252 | adj_sampled = adj_sampled.triu(1) 253 | adj_sampled = adj_sampled + adj_sampled.T 254 | 255 | edge_index_sampled = adj_sampled.to_sparse()._indices() 256 | 257 | g_new = dgl.graph((edge_index_sampled[0], edge_index_sampled[1])) 258 | 259 | check_nan = True 260 | 261 | if g.num_nodes() == g_new.num_nodes(): 262 | if current_epoch % 50 == 0: 263 | print(f'adj_sampled size: {edge_index_sampled.size()}') 264 | break 265 | else: 266 | print("----------------------------------------") 267 | edge_index_sampled = edge_index_sampled.to(device_2) 268 | adj_sampled = adj_sampled.to(device_2) 269 | 270 | return edge_index_sampled, x, adj_sampled, check_nan 271 | 272 | 273 | class Reconstruct_X(torch.nn.Module): 274 | def __init__(self, inp, outp, dims=128): 275 | super().__init__() 276 | 277 | 278 | self.mlp = torch.nn.Sequential( 279 | torch.nn.Linear(inp, dims * 2), 280 | torch.nn.SELU(), 281 | torch.nn.Linear(dims * 2, outp)) 282 | 283 | def forward(self, x): 284 | x = self.mlp(x) 285 | return x 286 | 287 | 288 | class MLPA(torch.nn.Module): 289 | 290 | def __init__(self, in_feats, dim_h, dim_z): 291 | super(MLPA, self).__init__() 292 | 293 | self.gcn_mean = torch.nn.Sequential( 294 | torch.nn.Linear(in_feats, dim_h), 295 | torch.nn.ReLU(), 296 | torch.nn.Linear(dim_h, dim_z) 297 | ) 298 | 299 | def forward(self, hidden): 300 | Z = self.gcn_mean(hidden) 301 | 302 | adj_logits = Z @ Z.T 303 | return adj_logits 304 | 305 | 306 | class MLP(torch.nn.Module): 307 | 308 | def __init__(self, num_features, num_classes, dims=16): 309 | super(MLP, self).__init__() 310 | self.mlp = torch.nn.Sequential( 311 | torch.nn.Linear(num_features, dims), torch.nn.ReLU(), 312 | torch.nn.Linear(dims, num_classes)) 313 | 314 | def forward(self, x): 315 | x = self.mlp(x) 316 | return x 317 | 318 | 319 | class Transformer_class(nn.Module): 320 | def __init__(self, in_dim, out_dim, n_classes, hidden_dim, num_layers, num_heads, graph_name, cp_filename, aug_check ,sim_check ,phi_check ): 321 | super().__init__() 322 | 323 | print(f'Loading Transformer_class {cp_filename}') 324 | self.model = torch.load(cp_filename) 325 | if isinstance(self.model, torch.nn.DataParallel): 326 | self.model = self.model.module 327 | 328 | self.model.aug_check = aug_check 329 | self.model.sim_check = sim_check 330 | self.model.phi_check = phi_check 331 | 332 | for p in self.model.parameters(): 333 | p.requires_grad = True 334 | 335 | 336 | self.MLP = MLP(out_dim, n_classes) 337 | 338 | 339 | def forward(self, g, adj_org, sim, phi, B, k_transition, current_epoch, device, device_2): 340 | 341 | X = g.ndata['x'].to(device_2) 342 | edge_index = torch.stack([g.edges()[0], g.edges()[1]]) 343 | 344 | h, _ = self.model.extract_features(g, adj_org, X, current_epoch, edge_index, sim, phi, B, k_transition, device, device_2) 345 | 346 | h = self.MLP(h) 347 | h = F.softmax(h, dim=1) 348 | 349 | return h 350 | 351 | 352 | class MLPReadout(nn.Module): 353 | 354 | def __init__(self, input_dim, output_dim, L=2): # L = nb_hidden_layers 355 | super().__init__() 356 | list_FC_layers = [nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True) for l in range(L)] 357 | list_FC_layers.append(nn.Linear(input_dim // 2 ** L, output_dim, bias=True)) 358 | 359 | self.FC_layers = nn.ModuleList(list_FC_layers) 360 | self.L = L 361 | 362 | def forward(self, x): 363 | y = x 364 | for l in range(self.L): 365 | y = self.FC_layers[l](y) 366 | y = F.relu(y) 367 | y = self.FC_layers[self.L](y) 368 | 369 | return y 370 | 371 | 372 | class Transformer_cluster(nn.Module): 373 | def __init__(self, in_dim, out_dim, n_classes, hidden_dim, num_layers, num_heads, graph_name, cp_filename, aug_check,sim_check,phi_check): 374 | super().__init__() 375 | 376 | print(f'Loading Transformer_class {cp_filename}') 377 | self.model = torch.load(cp_filename) 378 | if isinstance(self.model, torch.nn.DataParallel): 379 | self.model = self.model.module 380 | self.model.aug_check = aug_check 381 | self.model.sim_check = sim_check 382 | self.model.phi_check = phi_check 383 | 384 | for p in self.model.parameters(): 385 | p.requires_grad = True 386 | 387 | self.MLP = MLPReadout(out_dim, n_classes) 388 | 389 | 390 | 391 | def forward(self, g, adj_org, sim, phi, B, k_transition, current_epoch, device, device_2): 392 | X = g.ndata['x'].to(device_2) 393 | edge_index = torch.stack([g.edges()[0], g.edges()[1]]) 394 | 395 | 396 | h, _ = self.model.extract_features(g, adj_org, X, current_epoch, edge_index, sim, phi, B, k_transition, device, device_2) 397 | 398 | h = self.MLP(h) 399 | h = F.softmax(h, dim=1) 400 | 401 | return h 402 | 403 | 404 | """ 405 | Util functions 406 | """ 407 | 408 | 409 | def exp(field): 410 | def func(edges): 411 | # clamp for softmax numerical stability 412 | return {field: torch.exp((edges.data[field].sum(-1, keepdim=True)).clamp(-8, 8))} 413 | 414 | return func 415 | 416 | 417 | def src_dot_dst(src_field, dst_field, out_field): 418 | def func(edges): 419 | return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)} 420 | 421 | return func 422 | 423 | 424 | def src_dot_dst_sim(src_field, dst_field, edge_field, out_field): 425 | def func(edges): 426 | return {out_field: ((edges.src[src_field] + edges.data[edge_field]) * ( 427 | edges.dst[dst_field] + edges.data[edge_field])).sum(-1, keepdim=True)} 428 | 429 | return func 430 | 431 | 432 | # Improving implicit attention scores with explicit edge features, if available 433 | def scaling(field, scale_constant): 434 | def func(edges): 435 | return {field: (((edges.data[field])) / scale_constant)} 436 | 437 | return func 438 | 439 | 440 | def imp_exp_attn(implicit_attn, explicit_edge): 441 | """ 442 | implicit_attn: the output of K Q 443 | explicit_edge: the explicit edge features 444 | """ 445 | 446 | def func(edges): 447 | return {implicit_attn: (edges.data[implicit_attn] * edges.data[explicit_edge])} 448 | 449 | return func 450 | 451 | 452 | def imp_add_attn(implicit_attn, explicit_edge): 453 | """ 454 | implicit_attn: the output of K Q 455 | explicit_edge: the explicit edge features 456 | """ 457 | 458 | def func(edges): 459 | return {implicit_attn: (edges.data[implicit_attn] + edges.data[explicit_edge])} 460 | 461 | return func 462 | 463 | 464 | # To copy edge features to be passed to FFN_e 465 | def out_edge_features(edge_feat): 466 | def func(edges): 467 | return {'e_out': edges.data[edge_feat]} 468 | 469 | return func -------------------------------------------------------------------------------- /gnnutils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import scipy.sparse as sparse 6 | import torch 7 | import torch.nn.functional as F 8 | from networkx.utils import dict_to_numpy_array 9 | from sklearn.metrics import f1_score 10 | from sklearn.model_selection import train_test_split 11 | from torch_geometric.datasets import Planetoid, Airports 12 | from torch_geometric.utils import add_remaining_self_loops 13 | from tqdm import tqdm 14 | import torch_geometric as tg 15 | # from build_multigraph import build_pyg_struc_multigraph 16 | from datasets import WebKB, FilmNetwork, BGP 17 | from torch_geometric.datasets import Planetoid, Amazon, WikipediaNetwork, Coauthor, WikiCS, SNAPDataset, CitationFull 18 | 19 | from dataset import WikipediaNetwork_crocodile 20 | import warnings 21 | warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') 22 | 23 | def DigitizeY(pyg_data): 24 | print("Checking ....") 25 | transform_y = np.log10 26 | bins = [0.2, 0.4, 0.6, 0.8, 1] 27 | y = transform_y(pyg_data.y).numpy() 28 | 29 | digitized_y = np.digitize(y, bins) 30 | 31 | # print(digitized_y) 32 | # raise SystemExit() 33 | # file1 = open("test.txt","a") 34 | # with open('test.txt', 'w') as f: 35 | # f.write() 36 | 37 | pyg_data.y = torch.from_numpy(digitized_y) 38 | 39 | # print(pyg_data.y ) 40 | 41 | return pyg_data 42 | 43 | 44 | def load_crocodile(dataset): 45 | assert dataset in ["crocodile"] 46 | print("checking1: ") 47 | 48 | og = WikipediaNetwork_crocodile(root="original_datasets/wiki", name=dataset, geom_gcn_preprocess=False)[0] 49 | 50 | print(og.y) 51 | return og 52 | 53 | 54 | def load_Cora_ML(dataset): 55 | assert dataset in ["Cora_ML"] 56 | og = CitationFull(root="original_datasets/Cora_ML", name=dataset)[0] 57 | 58 | return og, None 59 | 60 | 61 | def load_WikiCS(dataset): 62 | assert dataset in ["WikiCS"] 63 | og = WikiCS(root="original_datasets/WikiCS")[0] 64 | 65 | return og, None 66 | 67 | 68 | def load_coauthor(dataset): 69 | assert dataset in ["CS", "Physics"] 70 | og = Coauthor(root="original_datasets/coauthor", name=dataset)[0] 71 | 72 | return og, None 73 | 74 | 75 | def load_amazon(dataset): 76 | assert dataset in ["Computers", "Photo"] 77 | og = Amazon(root="original_datasets/amazon", name=dataset)[0] 78 | # print("DONE og load_amazon") 79 | # st = Amazon(root="datasets_py_geom_format_10/amazon", name=dataset,pre_transform=build_pyg_struc_multigraph)[0] 80 | 81 | # print("DONE load_planetoid") 82 | # print(f'Number of edges og : {og.num_edges}') 83 | # print(f'Number of edges st : {st.num_edges}') 84 | return og, None 85 | 86 | 87 | def load_airports(dataset): 88 | assert dataset in ["brazil", "europe", "usa"] 89 | og = Airports("original_datasets/airports_dataset/" + dataset, name=dataset)[0] 90 | # st = Airports(root="datasets_py_geom_format/airports_dataset/"+dataset, dataset_name=dataset,pre_transform=build_pyg_struc_multigraph)[0] 91 | return og, None 92 | 93 | 94 | def load_bgp(dataset): 95 | assert dataset in ["bgp"] 96 | og = BGP(root="original_datasets/bgp_dataset")[0] 97 | # st = BGP(root="datasets_py_geom_format/bgp_dataset", pre_transform=build_pyg_struc_multigraph)[0] 98 | return og, None 99 | 100 | 101 | def load_film(dataset): 102 | assert dataset in ["film"] 103 | og = FilmNetwork(root="original_datasets/film", name=dataset)[0] 104 | # st = FilmNetwork(root="datasets_py_geom_format/film", name=dataset, 105 | # pre_transform=build_pyg_struc_multigraph)[0] 106 | return og, None 107 | 108 | 109 | def load_wiki(dataset): 110 | assert dataset in ["chameleon", "squirrel", "crocodile"] 111 | og = WikipediaNetwork(root="original_datasets/wiki", geom_gcn_preprocess=True, name=dataset)[0] 112 | # st = WikipediaNetwork(root="datasets_py_geom_format/wiki", name=dataset, 113 | # pre_transform=build_pyg_struc_multigraph)[0] 114 | 115 | return og, None 116 | 117 | 118 | def load_webkb(dataset): 119 | assert dataset in ["cornell", "texas", "wisconsin"] 120 | og = WebKB(root="original_datasets/webkb", name=dataset)[0] 121 | # st = WebKB(root="datasets_py_geom_format/webkb", name=dataset,pre_transform=build_pyg_struc_multigraph)[0] 122 | 123 | return og, None 124 | 125 | 126 | def load_planetoid(dataset): 127 | assert dataset in ["cora", "citeseer", "pubmed"] 128 | og = Planetoid(root="original_datasets/planetoid", name=dataset, split="public")[0] 129 | # st = Planetoid(root="datasets_py_geom_format/planetoid", name=dataset, split="public", pre_transform=build_pyg_struc_multigraph)[0] 130 | 131 | return og, None 132 | 133 | 134 | # def structure_edge_weight_threshold(data, threshold): 135 | # data = copy.deepcopy(data) 136 | # mask = data.edge_weight > threshold 137 | # data.edge_weight = data.edge_weight[mask] 138 | # data.edge_index = data.edge_index[:, mask] 139 | # data.edge_color = data.edge_color[mask] 140 | # return data 141 | 142 | 143 | def add_original_graph(og_data, st_data, weight=1.0): 144 | st_data = copy.deepcopy(st_data) 145 | e_i = torch.cat((og_data.edge_index, st_data.edge_index), dim=1) 146 | st_data.edge_color = st_data.edge_color + 1 147 | e_c = torch.cat((torch.zeros(og_data.edge_index.shape[1], dtype=torch.long), st_data.edge_color), dim=0) 148 | e_w = torch.cat((torch.ones(og_data.edge_index.shape[1], dtype=torch.float) * weight, st_data.edge_weight), dim=0) 149 | st_data.edge_index = e_i 150 | st_data.edge_color = e_c 151 | st_data.edge_weight = e_w 152 | return st_data 153 | 154 | 155 | def NormalizeTensor(data): 156 | return (data - torch.min(data)) / (torch.max(data) - torch.min(data)) 157 | 158 | 159 | def cosinSim(x_hat): 160 | x_norm = torch.norm(x_hat, p=2, dim=1) 161 | nume = torch.mm(x_hat, x_hat.t()) 162 | deno = torch.ger(x_norm, x_norm) 163 | cosine_similarity = nume / deno 164 | return cosine_similarity 165 | 166 | 167 | def train_finetuning_class(model, train_data, mask, optimizer, device,device_2, g, adj_org, M, trans_logM, sim, phi, B, 168 | k_transition, pre_train, current_epoch): 169 | # pre_train = 0 170 | model.train() 171 | optimizer.zero_grad() 172 | mask = mask 173 | true_label = train_data.y 174 | 175 | out = model(g.to(device), adj_org.cuda(), sim.cuda(), phi.cuda(), B.cuda(), k_transition, 176 | current_epoch, device,device_2) 177 | 178 | # loss = F.nll_loss(out[mask], true_label[mask]) 179 | criterion = torch.nn.CrossEntropyLoss() 180 | # print(out[mask]) 181 | # print( true_label[mask]) 182 | # raise SystemExit() 183 | loss = criterion(out[mask].cpu(), true_label[mask]) 184 | # print(out) 185 | # print(true_label) 186 | # raise SystemExit() 187 | pred = out.max(1)[1][mask].cpu() 188 | acc = pred.eq(train_data.y[mask]).sum().item() / len(train_data.y[mask]) 189 | loss.backward() 190 | optimizer.step() 191 | # print('training loss') 192 | return loss.item(), acc 193 | 194 | 195 | def train(model, train_data, mask, optimizer, device, device_2, g, adj_org, trans_logM, sim, phi, B, k_transition, current_epoch, 196 | alpha_1, alpha_2, alpha_3): 197 | model.train() 198 | optimizer.zero_grad() 199 | optimizer.zero_grad() 200 | mask = mask.to(device) 201 | if torch.cuda.device_count() > 1: 202 | h, x_hat, adj_sampled = model(g.to(device), adj_org.cuda(), sim.cuda(), phi.cuda(), B.cuda(), 203 | k_transition, current_epoch, device, device_2) 204 | 205 | loss_M = 0 206 | h = cosinSim(h).cpu() 207 | for i in range(k_transition): 208 | loss_M += torch.sum(((h - (torch.FloatTensor(trans_logM[i]))) ** 2)) 209 | 210 | row_num, col_num = (torch.FloatTensor(trans_logM[i])).size() 211 | loss_M = loss_M / (k_transition * row_num * col_num) 212 | 213 | row_num, col_num = train_data.x.size() 214 | 215 | loss_X = F.mse_loss(x_hat.cpu(), train_data.x.cpu()) 216 | 217 | # BCE: 218 | if adj_sampled == None: 219 | adj_loss = 0 220 | else: 221 | adj_loss = F.binary_cross_entropy_with_logits(adj_sampled.cpu(), adj_org.cpu()) 222 | 223 | loss_all = alpha_1 * loss_M + alpha_2 * loss_X + alpha_3 * adj_loss 224 | loss_all = loss_all.to(device_2) 225 | loss_all.backward() 226 | optimizer.step() 227 | 228 | return loss_M.item() 229 | 230 | 231 | import collections 232 | from collections import defaultdict 233 | 234 | 235 | 236 | 237 | import matplotlib.pylab as plt 238 | 239 | 240 | 241 | 242 | @torch.no_grad() 243 | def test(model, test_data, mask, device,device_2, g, adj_org, trans_logM, sim, phi, B, degree, k_transition, current_epoch, 244 | test_node_degree): 245 | 246 | model.eval() 247 | mask = mask.to(device) 248 | # n x 4 size 249 | logits = model(g.to(device), adj_org.to(device), sim.to(device), phi.to(device), B.to(device), k_transition, 250 | current_epoch, device,device_2) 251 | 252 | 253 | pred = torch.argmax(logits, dim=1) 254 | pred_before = pred 255 | pred = pred.to(device)[mask] 256 | 257 | acc = pred.eq(test_data.y.to(device)[mask]).sum().item() / len(test_data.y.to(device)[mask]) 258 | 259 | pred = pred.cpu().numpy() 260 | y_true = test_data.y.to(device)[mask].cpu().numpy() 261 | f1 = f1_score(y_true, pred, average='micro') 262 | 263 | return acc, f1 264 | 265 | 266 | def filter_relations(data, num_relations, rel_last): 267 | if rel_last: 268 | l = data.edge_color.unique(sorted=True).tolist() 269 | mask_l = l[-num_relations:] 270 | mask = data.edge_color == mask_l[0] 271 | for c in mask_l[1:]: 272 | mask = mask | (data.edge_color == c) 273 | else: 274 | mask = data.edge_color < (num_relations + 1) 275 | 276 | data.edge_index = data.edge_index[:, mask] 277 | data.edge_weight = data.edge_weight[mask] 278 | data.edge_color = data.edge_color[mask] 279 | return data 280 | 281 | 282 | def make_masks(data, val_test_ratio=0.2, stratify=False): 283 | data = copy.deepcopy(data) 284 | n_nodes = data.x.shape[0] 285 | all_nodes_idx = np.arange(n_nodes) 286 | all_y = data.y.numpy() 287 | if stratify: 288 | train, test_idx, y_train, _ = train_test_split( 289 | all_nodes_idx, all_y, test_size=0.2, stratify=all_y) 290 | 291 | train_idx, val_idx, _, _ = train_test_split( 292 | train, y_train, test_size=0.25, stratify=y_train) 293 | 294 | else: 295 | val_test_num = 2 * int(val_test_ratio * data.x.shape[0]) 296 | val_test_idx = np.random.choice(n_nodes, (val_test_num,), replace=False) 297 | val_idx = val_test_idx[:int(val_test_num / 2)] 298 | test_idx = val_test_idx[int(val_test_num / 2):] 299 | 300 | val_mask = np.zeros(n_nodes) 301 | val_mask[val_idx] = 1 302 | test_mask = np.zeros(n_nodes) 303 | test_mask[test_idx] = 1 304 | val_mask = torch.tensor(val_mask, dtype=torch.bool) 305 | test_mask = torch.tensor(test_mask, dtype=torch.bool) 306 | val_test_mask = val_mask | test_mask 307 | train_mask = ~val_test_mask 308 | data.train_mask = train_mask 309 | data.val_mask = val_mask 310 | data.test_mask = test_mask 311 | return data 312 | 313 | 314 | def create_self_loops(data): 315 | orig_relations = len(data.edge_color.unique()) 316 | data.edge_index, data.edge_weight = add_remaining_self_loops(data.edge_index, data.edge_weight, 317 | fill_value=1.0) 318 | row, col = data.edge_index[0], data.edge_index[1] 319 | mask = row == col 320 | tmp = torch.full(mask.nonzero().shape, orig_relations + 1, dtype=torch.long).squeeze() 321 | data.edge_color = torch.cat([data.edge_color, tmp], dim=0) 322 | return data 323 | 324 | 325 | # create adjacency matrix and degree sequence 326 | def createA(E, n, m, undir=True): 327 | if undir: 328 | G = nx.Graph() 329 | G.add_nodes_from(range(n)) 330 | G.add_edges_from(list(E)) 331 | A = nx.to_scipy_sparse_matrix(G) 332 | else: 333 | A = sparse.coo_matrix((np.ones(m), (E[:, 0], E[:, 1])), 334 | shape=(n, n)).tocsc() 335 | degree = np.array(A.sum(1)).flatten() 336 | 337 | return A, degree 338 | 339 | 340 | def calculateRWRrange(A, degree, i, prs, n, trans=True, maxIter=1000): 341 | pr = prs[-1] 342 | D = sparse.diags(1. / degree, 0, format='csc') 343 | W = D * A 344 | diff = 1 345 | it = 1 346 | 347 | F = np.zeros(n) 348 | Fall = np.zeros((n, len(prs))) 349 | F[i] = 1 350 | Fall[i, :] = 1 351 | Fold = F.copy() 352 | T = F.copy() 353 | 354 | if trans: 355 | W = W.T 356 | 357 | oneminuspr = 1 - pr 358 | 359 | while diff > 1e-9: 360 | F = pr * W.dot(F) 361 | F[i] += oneminuspr 362 | Fall += np.outer((F - Fold), (prs / pr) ** it) 363 | T += (F - Fold) / ((it + 1) * (pr ** it)) 364 | 365 | diff = np.sum((F - Fold) ** 2) 366 | it += 1 367 | if it > maxIter: 368 | print(i, "max iterations exceeded") 369 | diff = 0 370 | Fold = F.copy() 371 | 372 | return Fall, T, it 373 | 374 | 375 | def localAssortF(G, M, pr=np.arange(0., 1., 0.1), undir=True, missingValue=-1, edge_attribute=None): 376 | n = len(M) 377 | ncomp = (M != missingValue).sum() 378 | # m = len(E) 379 | m = G.number_of_edges() 380 | 381 | if edge_attribute is None: 382 | A = nx.to_scipy_sparse_matrix(G, weight=None) 383 | else: 384 | A = nx.to_scipy_sparse_matrix(G, weight=edge_attribute) 385 | 386 | degree = np.array(A.sum(1)).flatten() 387 | 388 | D = sparse.diags(1. / degree, 0, format='csc') 389 | W = D.dot(A) 390 | c = len(np.unique(M)) 391 | if ncomp < n: 392 | c -= 1 393 | 394 | 395 | Z = np.zeros(n) 396 | Z[M == missingValue] = 1. 397 | Z = W.dot(Z) / degree 398 | 399 | values = np.ones(ncomp) 400 | yi = (M != missingValue).nonzero()[0] 401 | yj = M[M != missingValue] 402 | Y = sparse.coo_matrix((values, (yi, yj)), shape=(n, c)).tocsc() 403 | 404 | assortM = np.empty((n, len(pr))) 405 | assortT = np.empty(n) 406 | 407 | eij_glob = np.array(Y.T.dot(A.dot(Y)).todense()) 408 | eij_glob /= np.sum(eij_glob) 409 | ab_glob = np.sum(eij_glob.sum(1) * eij_glob.sum(0)) 410 | 411 | WY = W.dot(Y).tocsc() 412 | 413 | print("start iteration") 414 | 415 | for i in tqdm(range(n)): 416 | pis, ti, it = calculateRWRrange(A, degree, i, pr, n) 417 | 418 | YPI = sparse.coo_matrix((ti[M != missingValue], (M[M != missingValue], 419 | np.arange(n)[M != missingValue])), 420 | shape=(c, n)).tocsr() 421 | e_gh = YPI.dot(WY).toarray() 422 | Z[i] = np.sum(e_gh) 423 | e_gh /= np.sum(e_gh) 424 | trace_e = np.trace(e_gh) 425 | assortT[i] = trace_e 426 | 427 | # assortM -= ab_glob 428 | # assortM /= (1. - ab_glob + 1e-200) 429 | 430 | assortT -= ab_glob 431 | assortT /= (1. - ab_glob + 1e-200) 432 | 433 | return assortM, assortT, Z 434 | 435 | 436 | def mixing_dict(xy, normalized=True): 437 | d = {} 438 | psum = 0.0 439 | for x, y, w in xy: 440 | if x not in d: 441 | d[x] = {} 442 | if y not in d: 443 | d[y] = {} 444 | v = d[x].get(y, 0) 445 | d[x][y] = v + w 446 | psum += w 447 | 448 | if normalized: 449 | for k, jdict in d.items(): 450 | for j in jdict: 451 | jdict[j] /= psum 452 | return d 453 | 454 | 455 | def node_attribute_xy(G, attribute, edge_attribute=None, nodes=None): 456 | if nodes is None: 457 | nodes = set(G) 458 | else: 459 | nodes = set(nodes) 460 | Gnodes = G.nodes 461 | for u, nbrsdict in G.adjacency(): 462 | if u not in nodes: 463 | continue 464 | uattr = Gnodes[u].get(attribute, None) 465 | if G.is_multigraph(): 466 | raise NotImplementedError 467 | else: 468 | for v, eattr in nbrsdict.items(): 469 | vattr = Gnodes[v].get(attribute, None) 470 | if edge_attribute is None: 471 | yield (uattr, vattr, 1) 472 | else: 473 | edge_data = G.get_edge_data(u, v) 474 | yield (uattr, vattr, edge_data[edge_attribute]) 475 | 476 | 477 | def global_assortativity(networkx_graph, labels, weights=None): 478 | attr_dict = {} 479 | for i in networkx_graph.nodes(): 480 | attr_dict[i] = labels[i] 481 | 482 | nx.set_node_attributes(networkx_graph, attr_dict, "label") 483 | if weights is None: 484 | xy_iter = node_attribute_xy(networkx_graph, "label", edge_attribute=None) 485 | d = mixing_dict(xy_iter) 486 | M = dict_to_numpy_array(d, mapping=None) 487 | s = (M @ M).sum() 488 | t = M.trace() 489 | r = (t - s) / (1 - s) 490 | else: 491 | edge_attr = {} 492 | for i, e in enumerate(networkx_graph.edges()): 493 | edge_attr[e] = weights[i] 494 | nx.set_edge_attributes(networkx_graph, edge_attr, "weight") 495 | 496 | xy_iter = node_attribute_xy(networkx_graph, "label", edge_attribute="weight") 497 | d = mixing_dict(xy_iter) 498 | M = dict_to_numpy_array(d, mapping=None) 499 | s = (M @ M).sum() 500 | t = M.trace() 501 | r = (t - s) / (1 - s) 502 | 503 | return r, M 504 | 505 | 506 | def local_assortativity(networkx_graph, labels, weights=None): 507 | if weights is None: 508 | assort_m, assort_t, z = localAssortF(networkx_graph, np.array(labels)) 509 | else: 510 | edge_attr = {} 511 | for i, e in enumerate(networkx_graph.edges()): 512 | edge_attr[e] = weights[i] 513 | nx.set_edge_attributes(networkx_graph, edge_attr, "weight") 514 | assort_m, assort_t, z = localAssortF(networkx_graph, np.array(labels), edge_attribute="weight") 515 | 516 | return assort_m, assort_t, z 517 | 518 | 519 | # (model, data, train_mask, optimizer, device, g=g, M = M, trans_logM = logM, pre_train=0, adj = adj, n_edges= n_edges) 520 | def train_finetuning_cluster(model, train_data, mask, optimizer, device,device_2, g, M, trans_logM, sim, phi, B, k_transition, 521 | pre_train, adj, d, n_edges, current_epoch): 522 | model.train() 523 | optimizer.zero_grad() 524 | mask = mask.to(device) 525 | k = torch.numel(train_data.y.unique()) 526 | 527 | C = model(g.to(device), adj.cuda(), sim.cuda(), phi.cuda(), B.cuda(), k_transition, current_epoch, device,device_2) 528 | 529 | n = g.number_of_nodes() 530 | # adj = torch.FloatTensor(adj) 531 | C = C.cpu() 532 | d = torch.FloatTensor(d).unsqueeze(1) 533 | C_t = C.t() 534 | # Computes the size [k, k] pooled graph as S^T*A*S in two multiplications. 535 | graph_pooled = torch.mm(C_t, adj.cpu()) 536 | graph_pooled = torch.mm(graph_pooled, C) 537 | 538 | # Left part is [k, 1] tensor. 539 | normalizer_left = torch.mm(C_t, d) 540 | 541 | # Right part is [1, k] tensor. 542 | normalizer_right = torch.mm(d.t(), C) 543 | 544 | normalizer = torch.mm(normalizer_left, normalizer_right) / 2 / n_edges 545 | 546 | spectral_loss = - torch.trace(graph_pooled - normalizer) / 2 / n_edges 547 | 548 | cluster_sizes = torch.sum(C, axis=0) # Size [k] 549 | collapse_loss = (k / (2 * n)) * (torch.norm(cluster_sizes)) 550 | 551 | loss = spectral_loss + collapse_loss # + wrt 552 | 553 | loss.backward() 554 | optimizer.step() 555 | return loss.item() 556 | 557 | 558 | from sklearn.metrics import f1_score 559 | from torch_geometric.utils import to_networkx 560 | from sklearn.metrics.cluster import normalized_mutual_info_score 561 | from metrics import accuracy_score, precision, modularity, conductance, recall 562 | 563 | 564 | # (best_model, data, train_mask, device, g=g, M=M) 565 | @torch.no_grad() 566 | def test_cluster(model, test_data, mask, optimizer, device,device_2, g, 567 | M, trans_logM, sim, phi, B, k_transition, pre_train, adj, d, n_edges, current_epoch): 568 | # print('testing code') 569 | model.eval() 570 | mask = mask.to(device) 571 | # n x 4 size 572 | 573 | logits = model(g.to(device), adj, sim.to(device), phi.to(device), B.to(device), k_transition, current_epoch, device,device_2) 574 | 575 | pred = torch.argmax(logits, dim=1) 576 | pred = pred.to(device)[mask] 577 | 578 | 579 | pred = pred.cpu().numpy() 580 | y_true = test_data.y.to(device)[mask].cpu().numpy() 581 | f1 = f1_score(y_true, pred, average='micro') 582 | 583 | G_st = to_networkx(test_data, to_undirected=True) 584 | Adj = nx.adjacency_matrix(G_st) 585 | Adj = Adj.todense() 586 | 587 | acc = accuracy_score(y_true, pred) 588 | p = precision(y_true, pred) 589 | r = recall(y_true, pred) 590 | nmi = normalized_mutual_info_score(y_true, pred) 591 | q = modularity(Adj, pred) 592 | c = conductance(Adj, pred) 593 | 594 | return acc, p, r, nmi, q, c 595 | --------------------------------------------------------------------------------