├── .gitignore ├── LICENSE ├── PlanarSATPairsDataset.py ├── README.md ├── batch.py ├── dataloader.py ├── dataset_pyg.py ├── distance.py ├── kernel ├── README.md ├── __init__.py ├── datasets.py ├── diff_pool.py ├── gat.py ├── gcn.py ├── gin.py ├── global_attention.py ├── graclus.py ├── graph_sage.py ├── set2set.py ├── sort_pool.py ├── statistics.py ├── top_k.py ├── train_eval.py └── tu_dataset.py ├── modules ├── gine_operations.py ├── ppgn_layers.py └── ppgn_modules.py ├── ogb_mol_gnn.py ├── qm9.py ├── qm9_models.py ├── run_all_targets_qm9.sh ├── run_exp.py ├── run_ogb_mol.py ├── run_qm9.py ├── run_simulation.py ├── run_tu.py ├── software └── k-gnn-master │ ├── .gitignore │ ├── README.md │ ├── cpu │ ├── adjacency.h │ ├── assignment.h │ ├── connect.h │ ├── graph.cpp │ ├── isomorphism.h │ ├── iterate.h │ └── utils.h │ ├── examples │ ├── 1-2-3-imdb.py │ ├── 1-2-3-mutag.py │ ├── 1-2-3-proteins.py │ ├── 1-2-3-qm9.py │ ├── 1-2-3-qm9_all_targets.py │ ├── 1-2-qm9.py │ ├── 1-2-qm9_all_targets.py │ ├── 1-3-qm9.py │ ├── 1-3-qm9_all_targets.py │ ├── 1-NCI1.py │ ├── 1-imdb.py │ ├── 1-mutag.py │ ├── 1-proteins.py │ ├── 1-qm9.py │ ├── 1-reddit.py │ └── nci_perm.pt │ ├── k_gnn │ ├── __init__.py │ ├── complete.py │ ├── dataloader.py │ ├── graph_conv.py │ ├── pool.py │ └── transform.py │ ├── setup.cfg │ └── setup.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | !data/EXP/ 3 | !data/CEXP/ 4 | backup/ 5 | results/ 6 | software/ 7 | !software/k-gnn/ 8 | *results.txt 9 | *.swp 10 | *.pyc 11 | screenlog.* 12 | tmp_vis.png 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Muhan Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PlanarSATPairsDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import InMemoryDataset 3 | import pickle 4 | import os 5 | from torch_geometric.utils import to_networkx 6 | NAME = "GRAPHSAT" 7 | 8 | 9 | class PlanarSATPairsDataset(InMemoryDataset): 10 | def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): 11 | super(PlanarSATPairsDataset, self).__init__(root, transform, pre_transform, pre_filter) 12 | self.data, self.slices = torch.load(self.processed_paths[0]) 13 | 14 | @property 15 | def raw_file_names(self): 16 | return [NAME+".pkl"] 17 | 18 | @property 19 | def processed_file_names(self): 20 | return 'data.pt' 21 | 22 | def download(self): 23 | pass 24 | 25 | def process(self): 26 | # Read data into huge `Data` list. 27 | data_list = pickle.load(open(os.path.join(self.root, "raw/"+NAME+".pkl"), "rb")) 28 | 29 | if self.pre_filter is not None: 30 | data_list = [data for data in data_list if self.pre_filter(data)] 31 | 32 | if self.pre_transform is not None: 33 | data_list = [self.pre_transform(data) for data in data_list] 34 | 35 | data, slices = self.collate(data_list) 36 | torch.save((data, slices), self.processed_paths[0]) 37 | 38 | 39 | if __name__ == "__main__": 40 | test_path = "Data/EXP/" 41 | dataset = PlanarSATPairsDataset(test_path) 42 | print(dataset[0]) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Nested Graph Neural Networks 2 | ============================ 3 | 4 | About 5 | ----- 6 | Nested Graph Neural Network (NGNN) is a general framework to improve a base GNN's expressive power and performance. It consists of a base GNN (usually a weak message-passing GNN) and an outer GNN. In NGNN, we extract a rooted subgraph around each node, and let the base GNN to learn a subgraph representation from the rooted subgraph, which is used as the root node's representation. Then, the outer GNN further learns a graph representation from these root node representations returned from the base GNN (in this paper, we simply let the outer GNN be a global pooling layer without graph convolution). NGNN is proved to be more powerful than 1-WL, being able to discriminate almost all r-regular graphs where 1-WL always fails. In contrast to other high-order GNNs, NGNN only incurs a constant time higher time complexity than its base GNN (given the rooted subgraph size is bounded). NGNN often shows immediate performance gains in real-world datasets when applying it to a weak base GNN. 7 | 8 | For more details, please refer to our paper: 9 | > M. Zhang and P. Li, Nested Graph Neural Networks, Advances in Neural Information Processing Systems (NeurIPS-21), 2021. [\[PDF\]](https://arxiv.org/pdf/2110.13197.pdf) 10 | 11 | Requirements 12 | ------------ 13 | Stable: Python 3.8 + PyTorch 1.8.1 + PyTorch\_Geometric 1.7.0 + OGB 1.3.1 14 | 15 | Latest: Python 3.8 + PyTorch 1.9.0 + PyTorch\_Geometric 1.7.2 + OGB 1.3.1 16 | 17 | Install [PyTorch](https://pytorch.org/) 18 | 19 | Install [PyTorch\_Geometric](https://rusty1s.github.io/pytorch_geometric/build/html/notes/installation.html) 20 | 21 | Install [OGB](https://ogb.stanford.edu/docs/home/) 22 | 23 | Install rdkit by 24 | 25 | conda install -c conda-forge rdkit 26 | 27 | To run 1-GNN, 1-2-GNN, 1-3-GNN, 1-2-3-GNN and their nested versions on QM9, install k-gnn by executing 28 | 29 | python setup.py install 30 | 31 | under "software/k-gnn-master/". 32 | 33 | Other required python libraries include: numpy, scipy, tqdm etc. 34 | 35 | Usages 36 | ------ 37 | 38 | ### TU dataset 39 | 40 | To run Nested GCN on MUTAG (with subgraph height=3 and base GCN #layers=4), type: 41 | 42 | python run_tu.py --model NestedGCN --h 3 --layers 4 --node_label spd --use_rd --data MUTAG 43 | 44 | To compare it with a base GCN model only, type: 45 | 46 | python run_tu.py --model GCN --layers 4 --data MUTAG 47 | 48 | To reproduce the GCN and Nested GCN results in Table 4 with hyperparameter searching, type: 49 | 50 | python run_tu.py --model GCN --search --data MUTAG 51 | 52 | python run_tu.py --model NestedGCN --h 0 --search --node_label spd --use_rd --data MUTAG 53 | 54 | Replace with "--data all" and "--model all" to run all models (NestedGCN, NestedGraphSAGE, NestedGIN, NestedGAT) on all datasets. 55 | 56 | 57 | ### QM9 58 | 59 | We include the commands for reproducing the QM9 experiments in "run_all_targets_qm9.sh". Uncomment the corresponding command in this file, and then run 60 | 61 | ./run_all_targets_qm9.sh 0 11 62 | 63 | to execute this command repeatedly for all 12 targets. 64 | 65 | ### OGB molecular datasets 66 | 67 | To reproduce the ogb-molhiv experiment, run 68 | 69 | python run_ogb_mol.py --h 4 --num_layer 6 --save_appendix _h4_l6_spd_rd --dataset ogbg-molhiv --node_label spd --use_rd --drop_ratio 0.65 --runs 10 70 | 71 | When finished, to get the ensemble test result, run 72 | 73 | python run_ogb_mol.py --h 4 --num_layer 6 --save_appendix _h4_l6_spd_rd --dataset ogbg-molhiv --node_label spd --use_rd --drop_ratio 0.65 --runs 10 --continue_from 100 --ensemble 74 | 75 | To reproduce the ogb-molpcba experiment, run 76 | 77 | python run_ogb_mol.py --h 3 --num_layer 4 --save_appendix _h3_l4_spd_rd --dataset ogbg-molpcba --subgraph_pooling center --node_label spd --use_rd --drop_ratio 0.35 --epochs 150 --runs 10 78 | 79 | When finished, to get the ensemble test result, run 80 | 81 | python run_ogb_mol.py --h 3 --num_layer 4 --save_appendix _h3_l4_spd_rd --dataset ogbg-molpcba --subgraph_pooling center --node_label spd --use_rd --drop_ratio 0.35 --epochs 150 --runs 10 --continue_from 150 --ensemble --ensemble_lookback 140 82 | 83 | ### Simulation on r-regular graphs 84 | 85 | To reproduce Appendix D Figure 3, run the following commands: 86 | 87 | python run_simulation.py --n 10 20 40 80 160 320 640 1280 --save_appendix _node --N 10 --h 10 88 | 89 | python run_simulation.py --n 10 20 40 80 160 320 640 1280 --save_appendix _graph --N 100 --h 10 --graph 90 | 91 | The results will be saved in "results/simulation\_node/" and "results/simulation\_graph/". 92 | 93 | ### EXP dataset 94 | 95 | To reproduce the Nested GIN result in Table 2, run the following command: 96 | 97 | python run_exp.py --dataset EXP --h 3 --learnRate 0.0001 98 | 99 | Reference 100 | --------- 101 | 102 | If you find the code useful, please cite our paper: 103 | 104 | @article{zhang2021nested, 105 | title={Nested Graph Neural Networks}, 106 | author={Zhang, Muhan and Li, Pan}, 107 | journal={arXiv preprint arXiv:2110.13197}, 108 | year={2021} 109 | } 110 | 111 | Muhan Zhang\ 112 | Peking University\ 113 | muhan@pku.edu.cn\ 114 | 10/30/2021 115 | 116 | -------------------------------------------------------------------------------- /batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric 3 | from torch_geometric.data import Data 4 | import pdb 5 | 6 | # This is a copy from torch_geometric/data/batch.py 7 | # which is modified to support batch asignment in subgraph level 8 | 9 | class Batch(Data): 10 | r"""A plain old python object modeling a batch of graphs as one big 11 | (dicconnected) graph. With :class:`torch_geometric.data.Data` being the 12 | base class, all its methods can also be used here. 13 | In addition, single graphs can be reconstructed via the assignment vector 14 | :obj:`batch`, which maps each node to its respective graph identifier. 15 | """ 16 | def __init__(self, batch=None, **kwargs): 17 | super(Batch, self).__init__(**kwargs) 18 | 19 | self.batch = batch 20 | self.__data_class__ = Data 21 | self.__slices__ = None 22 | 23 | @staticmethod 24 | def from_data_list(data_list, follow_batch=[]): 25 | r"""Constructs a batch object from a python list holding 26 | :class:`torch_geometric.data.Data` objects. 27 | The assignment vector :obj:`batch` is created on the fly. 28 | Additionally, creates assignment batch vectors for each key in 29 | :obj:`follow_batch`.""" 30 | 31 | keys = [set(data.keys) for data in data_list] 32 | keys = list(set.union(*keys)) 33 | assert 'batch' not in keys 34 | 35 | batch = Batch() 36 | batch.__data_class__ = data_list[0].__class__ 37 | batch.__slices__ = {key: [0] for key in keys} 38 | 39 | for key in keys: 40 | batch[key] = [] 41 | 42 | for key in follow_batch: 43 | batch['{}_batch'.format(key)] = [] 44 | 45 | cumsum = {key: 0 for key in keys} 46 | if 'assignment_index_2' in keys: 47 | cumsum['assignment_index_2'] = torch.LongTensor([[0], [0]]) 48 | if 'assignment_index_3' in keys: 49 | cumsum['assignment_index_3'] = torch.LongTensor([[0], [0]]) 50 | batch.batch = [] 51 | for i, data in enumerate(data_list): 52 | for key in data.keys: 53 | item = data[key] 54 | if torch.is_tensor(item) and item.dtype != torch.bool: 55 | item = item + cumsum[key] 56 | if torch.is_tensor(item): 57 | size = item.size(data.__cat_dim__(key, data[key])) 58 | else: 59 | size = 1 60 | batch.__slices__[key].append(size + batch.__slices__[key][-1]) 61 | if key == 'node_to_subgraph': 62 | cumsum[key] = cumsum[key] + data.num_subgraphs 63 | elif key == 'subgraph_to_graph': 64 | cumsum[key] = cumsum[key] + 1 65 | elif key == 'original_edge_index': 66 | cumsum[key] = cumsum[key] + data.num_subgraphs 67 | elif key == 'tree_edge_index': 68 | cumsum[key] = cumsum[key] + data.num_cliques 69 | elif key == 'atom2clique_index': 70 | cumsum[key] = cumsum[key] + torch.tensor([[data.num_atoms], [data.num_cliques]]) 71 | elif key == 'edge_index_2': 72 | cumsum[key] = cumsum[key] + data.iso_type_2.shape[0] 73 | elif key == 'edge_index_3': 74 | cumsum[key] = cumsum[key] + data.iso_type_3.shape[0] 75 | elif key == 'batch_2': 76 | cumsum[key] = cumsum[key] + 1 77 | elif key == 'batch_3': 78 | cumsum[key] = cumsum[key] + 1 79 | elif key == 'assignment2_to_subgraph': 80 | cumsum[key] = cumsum[key] + data.num_subgraphs 81 | elif key == 'assignment3_to_subgraph': 82 | cumsum[key] = cumsum[key] + data.num_subgraphs 83 | elif key == 'assignment_index_2': 84 | cumsum[key] = cumsum[key] + torch.LongTensor([[data.num_nodes], [data.iso_type_2.shape[0]]]) 85 | elif key == 'assignment_index_3': 86 | inc = data.iso_type_2.shape[0] if 'assignment_index_2' in data else data.num_nodes 87 | cumsum[key] = cumsum[key] + torch.LongTensor([[inc], [data.iso_type_3.shape[0]]]) 88 | else: 89 | cumsum[key] = cumsum[key] + data.__inc__(key, item) 90 | batch[key].append(item) 91 | 92 | if key in follow_batch: 93 | item = torch.full((size, ), i, dtype=torch.long) 94 | batch['{}_batch'.format(key)].append(item) 95 | 96 | num_nodes = data.num_nodes 97 | if num_nodes is not None: 98 | item = torch.full((num_nodes, ), i, dtype=torch.long) 99 | batch.batch.append(item) 100 | 101 | if num_nodes is None: 102 | batch.batch = None 103 | 104 | for key in batch.keys: 105 | item = batch[key][0] 106 | if torch.is_tensor(item): 107 | batch[key] = torch.cat(batch[key], 108 | dim=data_list[0].__cat_dim__(key, item)) 109 | elif isinstance(item, int) or isinstance(item, float): 110 | batch[key] = torch.tensor(batch[key]) 111 | 112 | # Copy custom data functions to batch (does not work yet): 113 | # if data_list.__class__ != Data: 114 | # org_funcs = set(Data.__dict__.keys()) 115 | # funcs = set(data_list[0].__class__.__dict__.keys()) 116 | # batch.__custom_funcs__ = funcs.difference(org_funcs) 117 | # for func in funcs.difference(org_funcs): 118 | # setattr(batch, func, getattr(data_list[0], func)) 119 | 120 | if torch_geometric.is_debug_enabled(): 121 | batch.debug() 122 | 123 | return batch.contiguous() 124 | 125 | def to_data_list(self): 126 | r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects 127 | from the batch object. 128 | The batch object must have been created via :meth:`from_data_list` in 129 | order to be able reconstruct the initial objects.""" 130 | 131 | if self.__slices__ is None: 132 | raise RuntimeError( 133 | ('Cannot reconstruct data list from batch because the batch ' 134 | 'object was not created using Batch.from_data_list()')) 135 | 136 | keys = [key for key in self.keys if key[-5:] != 'batch'] 137 | cumsum = {key: 0 for key in keys} 138 | if 'assignment_index_2' in keys: 139 | cumsum['assignment_index_2'] = torch.LongTensor([[0], [0]]) 140 | if 'assignment_index_3' in keys: 141 | cumsum['assignment_index_3'] = torch.LongTensor([[0], [0]]) 142 | data_list = [] 143 | for i in range(len(self.__slices__[keys[0]]) - 1): 144 | data = self.__data_class__() 145 | for key in keys: 146 | if torch.is_tensor(self[key]): 147 | data[key] = self[key].narrow( 148 | data.__cat_dim__(key, 149 | self[key]), self.__slices__[key][i], 150 | self.__slices__[key][i + 1] - self.__slices__[key][i]) 151 | if self[key].dtype != torch.bool: 152 | data[key] = data[key] - cumsum[key] 153 | else: 154 | data[key] = self[key][self.__slices__[key][i]:self. 155 | __slices__[key][i + 1]] 156 | if key == 'node_to_subgraph': 157 | cumsum[key] = cumsum[key] + data.num_subgraphs 158 | elif key == 'subgraph_to_graph': 159 | cumsum[key] = cumsum[key] + 1 160 | elif key == 'original_edge_index': 161 | cumsum[key] = cumsum[key] + data.num_subgraphs 162 | elif key == 'tree_edge_index': 163 | cumsum[key] = cumsum[key] + data.num_cliques 164 | elif key == 'atom2clique_index': 165 | cumsum[key] = cumsum[key] + torch.tensor([[data.num_atoms], [data.num_cliques]]) 166 | elif key == 'edge_index_2': 167 | cumsum[key] = cumsum[key] + data.iso_type_2.shape[0] 168 | elif key == 'edge_index_3': 169 | cumsum[key] = cumsum[key] + data.iso_type_3.shape[0] 170 | elif key == 'batch_2': 171 | cumsum[key] = cumsum[key] + 1 172 | elif key == 'batch_3': 173 | cumsum[key] = cumsum[key] + 1 174 | elif key == 'assignment2_to_subgraph': 175 | cumsum[key] = cumsum[key] + data.num_subgraphs 176 | elif key == 'assignment3_to_subgraph': 177 | cumsum[key] = cumsum[key] + data.num_subgraphs 178 | elif key == 'assignment_index_2': 179 | cumsum[key] = cumsum[key] + torch.LongTensor([[data.num_nodes], [data.iso_type_2.shape[0]]]) 180 | elif key == 'assignment_index_3': 181 | cumsum[key] = cumsum[key] + torch.LongTensor([[data.iso_type_2.shape[0]], [data.iso_type_3.shape[0]]]) 182 | else: 183 | cumsum[key] = cumsum[key] + data.__inc__(key, data[key]) 184 | data_list.append(data) 185 | 186 | return data_list 187 | 188 | @property 189 | def num_graphs(self): 190 | """Returns the number of graphs in the batch.""" 191 | return self.batch[-1].item() + 1 192 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch.utils.data.dataloader import default_collate 3 | 4 | from torch_geometric.data import Data 5 | from batch import Batch # replace with custom Batch to handle subgraphs 6 | import collections.abc as container_abcs 7 | int_classes = int 8 | string_classes = (str, bytes) 9 | 10 | 11 | class DataLoader(torch.utils.data.DataLoader): 12 | r"""Data loader which merges data objects from a 13 | :class:`torch_geometric.data.dataset` to a mini-batch. 14 | 15 | Args: 16 | dataset (Dataset): The dataset from which to load the data. 17 | batch_size (int, optional): How many samples per batch to load. 18 | (default: :obj:`1`) 19 | shuffle (bool, optional): If set to :obj:`True`, the data will be 20 | reshuffled at every epoch. (default: :obj:`False`) 21 | follow_batch (list or tuple, optional): Creates assignment batch 22 | vectors for each key in the list. (default: :obj:`[]`) 23 | """ 24 | def __init__(self, dataset, batch_size=1, shuffle=False, follow_batch=[], 25 | **kwargs): 26 | def collate(batch): 27 | elem = batch[0] 28 | if isinstance(elem, Data): 29 | return Batch.from_data_list(batch, follow_batch) 30 | elif isinstance(elem, float): 31 | return torch.tensor(batch, dtype=torch.float) 32 | elif isinstance(elem, int_classes): 33 | return torch.tensor(batch) 34 | elif isinstance(elem, string_classes): 35 | return batch 36 | elif isinstance(elem, container_abcs.Mapping): 37 | return {key: collate([d[key] for d in batch]) for key in elem} 38 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): 39 | return type(elem)(*(collate(s) for s in zip(*batch))) 40 | elif isinstance(elem, container_abcs.Sequence): 41 | return [collate(s) for s in zip(*batch)] 42 | 43 | raise TypeError('DataLoader found invalid type: {}'.format( 44 | type(elem))) 45 | 46 | super(DataLoader, 47 | self).__init__(dataset, batch_size, shuffle, 48 | collate_fn=lambda batch: collate(batch), **kwargs) 49 | 50 | 51 | class DataListLoader(torch.utils.data.DataLoader): 52 | r"""Data loader which merges data objects from a 53 | :class:`torch_geometric.data.dataset` to a python list. 54 | 55 | .. note:: 56 | 57 | This data loader should be used for multi-gpu support via 58 | :class:`torch_geometric.nn.DataParallel`. 59 | 60 | Args: 61 | dataset (Dataset): The dataset from which to load the data. 62 | batch_size (int, optional): How many samples per batch to load. 63 | (default: :obj:`1`) 64 | shuffle (bool, optional): If set to :obj:`True`, the data will be 65 | reshuffled at every epoch (default: :obj:`False`) 66 | """ 67 | def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): 68 | super(DataListLoader, 69 | self).__init__(dataset, batch_size, shuffle, 70 | collate_fn=lambda data_list: data_list, **kwargs) 71 | 72 | 73 | class DenseDataLoader(torch.utils.data.DataLoader): 74 | r"""Data loader which merges data objects from a 75 | :class:`torch_geometric.data.dataset` to a mini-batch. 76 | 77 | .. note:: 78 | 79 | To make use of this data loader, all graphs in the dataset needs to 80 | have the same shape for each its attributes. 81 | Therefore, this data loader should only be used when working with 82 | *dense* adjacency matrices. 83 | 84 | Args: 85 | dataset (Dataset): The dataset from which to load the data. 86 | batch_size (int, optional): How many samples per batch to load. 87 | (default: :obj:`1`) 88 | shuffle (bool, optional): If set to :obj:`True`, the data will be 89 | reshuffled at every epoch (default: :obj:`False`) 90 | """ 91 | def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): 92 | def dense_collate(data_list): 93 | batch = Batch() 94 | for key in data_list[0].keys: 95 | batch[key] = default_collate([d[key] for d in data_list]) 96 | return batch 97 | 98 | super(DenseDataLoader, 99 | self).__init__(dataset, batch_size, shuffle, 100 | collate_fn=dense_collate, **kwargs) 101 | -------------------------------------------------------------------------------- /distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pdb 3 | 4 | 5 | class Distance(object): 6 | r"""Saves the Euclidean distance of linked nodes in its edge attributes. 7 | 8 | Args: 9 | norm (bool, optional): If set to :obj:`False`, the output will not be 10 | normalized to the interval :math:`[0, 1]`. (default: :obj:`True`) 11 | max_value (float, optional): If set and :obj:`norm=True`, normalization 12 | will be performed based on this value instead of the maximum value 13 | found in the data. (default: :obj:`None`) 14 | cat (bool, optional): If set to :obj:`False`, all existing edge 15 | attributes will be replaced. (default: :obj:`True`) 16 | """ 17 | def __init__(self, norm=True, max_value=None, cat=True, relative_pos=False, 18 | squared=False): 19 | self.norm = norm 20 | self.max = max_value 21 | self.cat = cat 22 | self.relative_pos = relative_pos 23 | self.squared = squared 24 | 25 | def __call__(self, data): 26 | if type(data) == dict: 27 | return {key: self.__call__(data_) for key, data_ in data.items()} 28 | 29 | (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr 30 | 31 | if self.squared: 32 | dist = ((pos[col] - pos[row]) ** 2).sum(1).view(-1, 1) 33 | else: 34 | dist = torch.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1) 35 | 36 | if self.norm and dist.numel() > 0: 37 | dist = dist / (dist.max() if self.max is None else self.max) 38 | 39 | if pseudo is not None and self.cat: 40 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo 41 | data.edge_attr = torch.cat([pseudo, dist.type_as(pseudo)], dim=-1) 42 | else: 43 | data.edge_attr = dist 44 | 45 | if self.relative_pos: 46 | relative_pos = pos[col] - pos[row] 47 | data.edge_attr = torch.cat([data.edge_attr, relative_pos], dim=-1) 48 | 49 | if "original_edge_index" in data: 50 | (row, col), pos, pseudo = ( 51 | data.original_edge_index, data.original_pos, data.original_edge_attr 52 | ) 53 | 54 | dist = torch.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1) 55 | 56 | if self.norm and dist.numel() > 0: 57 | dist = dist / (dist.max() if self.max is None else self.max) 58 | 59 | if pseudo is not None and self.cat: 60 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo 61 | data.original_edge_attr = torch.cat([pseudo, dist.type_as(pseudo)], dim=-1) 62 | else: 63 | data.original_edge_attr = dist 64 | 65 | return data 66 | 67 | def __repr__(self): 68 | return '{}(norm={}, max_value={})'.format(self.__class__.__name__, 69 | self.norm, self.max) 70 | -------------------------------------------------------------------------------- /kernel/README.md: -------------------------------------------------------------------------------- 1 | # Graph Classification 2 | 3 | Evaluation script for various methods on [common benchmark datasets](http://graphkernels.cs.tu-dortmund.de) via 10-fold cross validation, where a training fold is randomly sampled to serve as a validation set. 4 | Hyperparameter selection is performed for the number of hidden units and the number of layers with respect to the validation set: 5 | 6 | * **[GCN](https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/gcn.py)** 7 | * **[GraphSAGE](https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/graph_sage.py)** 8 | * **[GIN](https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/gin.py)** 9 | * **[Graclus](https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/graclus.py)** 10 | * **[Top-K-Pooling](https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/top_k.py)** 11 | * **[DiffPool](https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/diff_pool.py)** 12 | * **[GlobalAttention](https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/global_attention.py)** 13 | * **[Set2Set](https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)** 14 | * **[SortPool](https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/sort_pool.py)** 15 | 16 | Run (or modify) the whole test suite via 17 | 18 | ``` 19 | $ python main.py 20 | ``` 21 | -------------------------------------------------------------------------------- /kernel/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import get_dataset 2 | from .train_eval import cross_validation_with_val_set 3 | 4 | __all__ = [ 5 | 'get_dataset', 6 | 'cross_validation_with_val_set', 7 | ] 8 | -------------------------------------------------------------------------------- /kernel/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys, os 3 | from shutil import rmtree 4 | import torch 5 | #from torch_geometric.datasets import TUDataset 6 | from torch_geometric.utils import degree 7 | import torch_geometric.transforms as T 8 | sys.path.append('%s/../' % os.path.dirname(os.path.realpath(__file__))) 9 | sys.path.append('%s/' % os.path.dirname(os.path.realpath(__file__))) 10 | from utils import create_subgraphs, return_prob 11 | from tu_dataset import TUDataset 12 | import pdb 13 | 14 | 15 | class NormalizedDegree(object): 16 | def __init__(self, mean, std): 17 | self.mean = mean 18 | self.std = std 19 | 20 | def __call__(self, data): 21 | deg = degree(data.edge_index[0], dtype=torch.float) 22 | deg = (deg - self.mean) / self.std 23 | data.x = deg.view(-1, 1) 24 | return data 25 | 26 | 27 | def get_dataset(name, sparse=True, h=None, node_label='hop', use_rd=False, 28 | use_rp=None, reprocess=False, clean=False, max_nodes_per_hop=None): 29 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data') 30 | pre_transform = None 31 | if h is not None: 32 | path += '/ngnn_h' + str(h) 33 | path += '_' + node_label 34 | if use_rd: 35 | path += '_rd' 36 | if max_nodes_per_hop is not None: 37 | path += '_mnph{}'.format(max_nodes_per_hop) 38 | 39 | pre_transform = lambda x: create_subgraphs(x, h, 1.0, max_nodes_per_hop, node_label, use_rd) 40 | 41 | if use_rp is not None: # use RW return probability as additional features 42 | path += f'_rp{use_rp}' 43 | if pre_transform is None: 44 | pre_transform = return_prob(use_rp) 45 | else: 46 | pre_transform = T.Compose([return_prob(use_rp), pre_transform]) 47 | 48 | if reprocess and os.path.isdir(path): 49 | rmtree(path) 50 | 51 | print(path) 52 | dataset = TUDataset(path, name, pre_transform=pre_transform, cleaned=clean) 53 | dataset.data.edge_attr = None 54 | 55 | if dataset.data.x is None: 56 | max_degree = 0 57 | degs = [] 58 | for data in dataset: 59 | degs += [degree(data.edge_index[0], dtype=torch.long)] 60 | max_degree = max(max_degree, degs[-1].max().item()) 61 | 62 | if max_degree < 1000: 63 | dataset.transform = T.OneHotDegree(max_degree) 64 | else: 65 | deg = torch.cat(degs, dim=0).to(torch.float) 66 | mean, std = deg.mean().item(), deg.std().item() 67 | dataset.transform = NormalizedDegree(mean, std) 68 | 69 | if not sparse: 70 | num_nodes = max_num_nodes = 0 71 | for data in dataset: 72 | num_nodes += data.num_nodes 73 | max_num_nodes = max(data.num_nodes, max_num_nodes) 74 | if name == 'REDDIT-BINARY': 75 | num_nodes = min(int(num_nodes / len(dataset) * 1.5), max_num_nodes) 76 | else: 77 | num_nodes = min(int(num_nodes / len(dataset) * 5), max_num_nodes) 78 | 79 | indices = [] 80 | for i, data in enumerate(dataset): 81 | if data.num_nodes <= num_nodes: 82 | indices.append(i) 83 | dataset = dataset[torch.tensor(indices)] 84 | 85 | if dataset.transform is None: 86 | dataset.transform = T.ToDense(num_nodes) 87 | else: 88 | dataset.transform = T.Compose( 89 | [dataset.transform, T.ToDense(num_nodes)]) 90 | 91 | return dataset 92 | -------------------------------------------------------------------------------- /kernel/diff_pool.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn import Linear 6 | from torch_geometric.nn import GCNConv, global_mean_pool 7 | from torch_geometric.nn import DenseSAGEConv, dense_diff_pool 8 | import torch_geometric.transforms as T 9 | from torch_geometric.data import Data 10 | from torch_geometric.utils import to_dense_batch, to_dense_adj 11 | import pdb 12 | 13 | 14 | class NestedDiffPool(torch.nn.Module): 15 | def __init__(self, dataset, num_layers, hidden, use_z=False, use_rd=False): 16 | super(NestedDiffPool, self).__init__() 17 | self.use_rd = use_rd 18 | self.use_z = use_z 19 | if self.use_rd: 20 | self.rd_projection = torch.nn.Linear(1, 8) 21 | if self.use_z: 22 | self.z_embedding = torch.nn.Embedding(1000, 8) 23 | input_dim = dataset.num_features 24 | if self.use_z or self.use_rd: 25 | input_dim += 8 26 | 27 | self.conv1 = GCNConv(input_dim, hidden) 28 | self.convs = torch.nn.ModuleList() 29 | for i in range(num_layers - 1): 30 | self.convs.append(GCNConv(hidden, hidden)) 31 | 32 | num_nodes = ceil(0.25 * dataset[0].num_subgraphs) 33 | self.embed_block1 = Block(hidden, hidden, hidden) 34 | self.pool_block1 = Block(hidden, hidden, num_nodes) 35 | self.embed_blocks = torch.nn.ModuleList() 36 | self.pool_blocks = torch.nn.ModuleList() 37 | #for i in range((num_layers // 2) - 1): 38 | for i in range(0): 39 | num_nodes = ceil(0.25 * num_nodes) 40 | self.embed_blocks.append(Block(hidden, hidden, hidden)) 41 | self.pool_blocks.append(Block(hidden, hidden, num_nodes)) 42 | 43 | self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden) 44 | self.lin2 = Linear(hidden, dataset.num_classes) 45 | 46 | def reset_parameters(self): 47 | if self.use_rd: 48 | self.rd_projection.reset_parameters() 49 | if self.use_z: 50 | self.z_embedding.reset_parameters() 51 | self.conv1.reset_parameters() 52 | for conv in self.convs: 53 | conv.reset_parameters() 54 | self.lin1.reset_parameters() 55 | self.lin2.reset_parameters() 56 | self.embed_block1.reset_parameters() 57 | self.pool_block1.reset_parameters() 58 | for block1, block2 in zip(self.embed_blocks, self.pool_blocks): 59 | block1.reset_parameters() 60 | block2.reset_parameters() 61 | self.lin1.reset_parameters() 62 | self.lin2.reset_parameters() 63 | 64 | def forward(self, data): 65 | x, edge_index, batch = data.x, data.edge_index, data.batch 66 | 67 | # node label embedding 68 | z_emb = 0 69 | if self.use_z and 'z' in data: 70 | ### computing input node embedding 71 | z_emb = self.z_embedding(data.z) 72 | if z_emb.ndim == 3: 73 | z_emb = z_emb.sum(dim=1) 74 | 75 | if self.use_rd and 'rd' in data: 76 | rd_proj = self.rd_projection(data.rd) 77 | z_emb += rd_proj 78 | 79 | if self.use_rd or self.use_z: 80 | x = torch.cat([z_emb, x], -1) 81 | 82 | x = F.relu(self.conv1(x, edge_index)) 83 | xs = [x] 84 | for conv in self.convs: 85 | x = F.relu(conv(x, edge_index)) 86 | xs += [x] 87 | x = global_mean_pool(xs[-1], data.node_to_subgraph) 88 | 89 | x, mask = to_dense_batch(x, data.subgraph_to_graph) 90 | adj = to_dense_adj(data.original_edge_index, data.subgraph_to_graph) 91 | 92 | s = self.pool_block1(x, adj, mask, add_loop=True) 93 | x = F.relu(self.embed_block1(x, adj, mask, add_loop=True)) 94 | xs = [x.mean(dim=1)] 95 | x, adj, _, _ = dense_diff_pool(x, adj, s, mask) 96 | 97 | for embed, pool in zip(self.embed_blocks, self.pool_blocks): 98 | s = pool(x, adj) 99 | x = F.relu(embed(x, adj)) 100 | xs.append(x.mean(dim=1)) 101 | x, adj, _, _ = dense_diff_pool(x, adj, s) 102 | 103 | x = torch.cat(xs, dim=1) 104 | x = F.relu(self.lin1(x)) 105 | x = F.dropout(x, p=0.5, training=self.training) 106 | x = self.lin2(x) 107 | 108 | return F.log_softmax(x, dim=-1) 109 | 110 | def __repr__(self): 111 | return self.__class__.__name__ 112 | 113 | 114 | class Block(torch.nn.Module): 115 | def __init__(self, in_channels, hidden_channels, out_channels): 116 | super(Block, self).__init__() 117 | 118 | self.conv1 = DenseSAGEConv(in_channels, hidden_channels) 119 | self.conv2 = DenseSAGEConv(hidden_channels, out_channels) 120 | 121 | self.lin = torch.nn.Linear(hidden_channels + out_channels, 122 | out_channels) 123 | 124 | def reset_parameters(self): 125 | self.conv1.reset_parameters() 126 | self.conv2.reset_parameters() 127 | self.lin.reset_parameters() 128 | 129 | def forward(self, x, adj, mask=None, add_loop=True): 130 | x1 = F.relu(self.conv1(x, adj, mask)) 131 | x2 = F.relu(self.conv2(x1, adj, mask)) 132 | return self.lin(torch.cat([x1, x2], dim=-1)) 133 | 134 | 135 | class DiffPool(torch.nn.Module): 136 | def __init__(self, dataset, num_layers, hidden, *kwargs): 137 | super(DiffPool, self).__init__() 138 | 139 | num_nodes = ceil(0.25 * dataset[0].num_nodes) 140 | self.embed_block1 = Block(dataset.num_features, hidden, hidden) 141 | self.pool_block1 = Block(dataset.num_features, hidden, num_nodes) 142 | 143 | self.embed_blocks = torch.nn.ModuleList() 144 | self.pool_blocks = torch.nn.ModuleList() 145 | for i in range((num_layers // 2) - 1): 146 | num_nodes = ceil(0.25 * num_nodes) 147 | self.embed_blocks.append(Block(hidden, hidden, hidden)) 148 | self.pool_blocks.append(Block(hidden, hidden, num_nodes)) 149 | 150 | self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden) 151 | self.lin2 = Linear(hidden, dataset.num_classes) 152 | 153 | def reset_parameters(self): 154 | self.embed_block1.reset_parameters() 155 | self.pool_block1.reset_parameters() 156 | for block1, block2 in zip(self.embed_blocks, self.pool_blocks): 157 | block1.reset_parameters() 158 | block2.reset_parameters() 159 | self.lin1.reset_parameters() 160 | self.lin2.reset_parameters() 161 | 162 | def forward(self, data): 163 | x, adj, mask = data.x, data.adj, data.mask 164 | 165 | s = self.pool_block1(x, adj, mask, add_loop=True) 166 | x = F.relu(self.embed_block1(x, adj, mask, add_loop=True)) 167 | xs = [x.mean(dim=1)] 168 | x, adj, _, _ = dense_diff_pool(x, adj, s, mask) 169 | 170 | for embed, pool in zip(self.embed_blocks, self.pool_blocks): 171 | s = pool(x, adj) 172 | x = F.relu(embed(x, adj)) 173 | xs.append(x.mean(dim=1)) 174 | x, adj, _, _ = dense_diff_pool(x, adj, s) 175 | 176 | x = torch.cat(xs, dim=1) 177 | x = F.relu(self.lin1(x)) 178 | x = F.dropout(x, p=0.5, training=self.training) 179 | x = self.lin2(x) 180 | return F.log_softmax(x, dim=-1) 181 | 182 | def __repr__(self): 183 | return self.__class__.__name__ 184 | -------------------------------------------------------------------------------- /kernel/gat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear 4 | from torch_geometric.nn import GATConv, global_mean_pool 5 | import pdb 6 | 7 | 8 | class NestedGAT(torch.nn.Module): 9 | def __init__(self, dataset, num_layers, hidden, use_z=False, use_rd=False): 10 | super(NestedGAT, self).__init__() 11 | self.use_rd = use_rd 12 | self.use_z = use_z 13 | if self.use_rd: 14 | self.rd_projection = torch.nn.Linear(1, 8) 15 | if self.use_z: 16 | self.z_embedding = torch.nn.Embedding(1000, 8) 17 | input_dim = dataset.num_features 18 | if self.use_z or self.use_rd: 19 | input_dim += 8 20 | 21 | self.conv1 = GATConv(input_dim, hidden) 22 | self.convs = torch.nn.ModuleList() 23 | for i in range(num_layers - 1): 24 | self.convs.append(GATConv(hidden, hidden)) 25 | self.lin1 = torch.nn.Linear(num_layers * hidden, hidden) 26 | self.lin2 = Linear(hidden, dataset.num_classes) 27 | 28 | def reset_parameters(self): 29 | if self.use_rd: 30 | self.rd_projection.reset_parameters() 31 | if self.use_z: 32 | self.z_embedding.reset_parameters() 33 | self.conv1.reset_parameters() 34 | for conv in self.convs: 35 | conv.reset_parameters() 36 | self.lin1.reset_parameters() 37 | self.lin2.reset_parameters() 38 | 39 | def forward(self, data): 40 | x, edge_index, batch = data.x, data.edge_index, data.batch 41 | 42 | # node label embedding 43 | z_emb = 0 44 | if self.use_z and 'z' in data: 45 | ### computing input node embedding 46 | z_emb = self.z_embedding(data.z) 47 | if z_emb.ndim == 3: 48 | z_emb = z_emb.sum(dim=1) 49 | 50 | if self.use_rd and 'rd' in data: 51 | rd_proj = self.rd_projection(data.rd) 52 | z_emb += rd_proj 53 | 54 | if self.use_rd or self.use_z: 55 | x = torch.cat([z_emb, x], -1) 56 | 57 | x = F.relu(self.conv1(x, edge_index)) 58 | xs = [x] 59 | for conv in self.convs: 60 | x = F.relu(conv(x, edge_index)) 61 | xs += [x] 62 | x = global_mean_pool(torch.cat(xs, dim=1), data.node_to_subgraph) 63 | x = global_mean_pool(x, data.subgraph_to_graph) 64 | x = F.relu(self.lin1(x)) 65 | x = F.dropout(x, p=0.5, training=self.training) 66 | x = self.lin2(x) 67 | return F.log_softmax(x, dim=-1) 68 | 69 | def __repr__(self): 70 | return self.__class__.__name__ 71 | 72 | 73 | class GAT(torch.nn.Module): 74 | def __init__(self, dataset, num_layers, hidden, *args, **kwargs): 75 | super(GAT, self).__init__() 76 | self.conv1 = GATConv(dataset.num_features, hidden) 77 | self.convs = torch.nn.ModuleList() 78 | for i in range(num_layers - 1): 79 | self.convs.append(GATConv(hidden, hidden)) 80 | self.lin1 = torch.nn.Linear(num_layers * hidden, hidden) 81 | self.lin2 = Linear(hidden, dataset.num_classes) 82 | 83 | def reset_parameters(self): 84 | self.conv1.reset_parameters() 85 | for conv in self.convs: 86 | conv.reset_parameters() 87 | self.lin1.reset_parameters() 88 | self.lin2.reset_parameters() 89 | 90 | def forward(self, data): 91 | x, edge_index, batch = data.x, data.edge_index, data.batch 92 | x = F.relu(self.conv1(x, edge_index)) 93 | xs = [x] 94 | for conv in self.convs: 95 | x = F.relu(conv(x, edge_index)) 96 | xs += [x] 97 | x = global_mean_pool(torch.cat(xs, dim=1), batch) 98 | x = F.relu(self.lin1(x)) 99 | x = F.dropout(x, p=0.5, training=self.training) 100 | x = self.lin2(x) 101 | return F.log_softmax(x, dim=-1) 102 | 103 | def __repr__(self): 104 | return self.__class__.__name__ 105 | -------------------------------------------------------------------------------- /kernel/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear 4 | from torch_geometric.nn import GCNConv, global_mean_pool 5 | import pdb 6 | 7 | 8 | class NestedGCN(torch.nn.Module): 9 | def __init__(self, dataset, num_layers, hidden, use_z=False, use_rd=False): 10 | super(NestedGCN, self).__init__() 11 | self.use_rd = use_rd 12 | self.use_z = use_z 13 | if self.use_rd: 14 | self.rd_projection = torch.nn.Linear(1, 8) 15 | if self.use_z: 16 | self.z_embedding = torch.nn.Embedding(1000, 8) 17 | input_dim = dataset.num_features 18 | if self.use_z or self.use_rd: 19 | input_dim += 8 20 | 21 | self.conv1 = GCNConv(input_dim, hidden) 22 | self.convs = torch.nn.ModuleList() 23 | for i in range(num_layers - 1): 24 | self.convs.append(GCNConv(hidden, hidden)) 25 | self.lin1 = torch.nn.Linear(num_layers * hidden, hidden) 26 | self.lin2 = Linear(hidden, dataset.num_classes) 27 | 28 | def reset_parameters(self): 29 | if self.use_rd: 30 | self.rd_projection.reset_parameters() 31 | if self.use_z: 32 | self.z_embedding.reset_parameters() 33 | self.conv1.reset_parameters() 34 | for conv in self.convs: 35 | conv.reset_parameters() 36 | self.lin1.reset_parameters() 37 | self.lin2.reset_parameters() 38 | 39 | def forward(self, data): 40 | x, edge_index, batch = data.x, data.edge_index, data.batch 41 | 42 | # node label embedding 43 | z_emb = 0 44 | if self.use_z and 'z' in data: 45 | ### computing input node embedding 46 | z_emb = self.z_embedding(data.z) 47 | if z_emb.ndim == 3: 48 | z_emb = z_emb.sum(dim=1) 49 | 50 | if self.use_rd and 'rd' in data: 51 | rd_proj = self.rd_projection(data.rd) 52 | z_emb += rd_proj 53 | 54 | if self.use_rd or self.use_z: 55 | x = torch.cat([z_emb, x], -1) 56 | 57 | x = F.relu(self.conv1(x, edge_index)) 58 | xs = [x] 59 | for conv in self.convs: 60 | x = F.relu(conv(x, edge_index)) 61 | xs += [x] 62 | x = global_mean_pool(torch.cat(xs, dim=1), data.node_to_subgraph) 63 | x = global_mean_pool(x, data.subgraph_to_graph) 64 | x = F.relu(self.lin1(x)) 65 | x = F.dropout(x, p=0.5, training=self.training) 66 | x = self.lin2(x) 67 | return F.log_softmax(x, dim=-1) 68 | 69 | def __repr__(self): 70 | return self.__class__.__name__ 71 | 72 | 73 | class GCN(torch.nn.Module): 74 | def __init__(self, dataset, num_layers, hidden, *args, **kwargs): 75 | super(GCN, self).__init__() 76 | self.conv1 = GCNConv(dataset.num_features, hidden) 77 | self.convs = torch.nn.ModuleList() 78 | for i in range(num_layers - 1): 79 | self.convs.append(GCNConv(hidden, hidden)) 80 | self.lin1 = torch.nn.Linear(num_layers * hidden, hidden) 81 | self.lin2 = Linear(hidden, dataset.num_classes) 82 | 83 | def reset_parameters(self): 84 | self.conv1.reset_parameters() 85 | for conv in self.convs: 86 | conv.reset_parameters() 87 | self.lin1.reset_parameters() 88 | self.lin2.reset_parameters() 89 | 90 | def forward(self, data): 91 | x, edge_index, batch = data.x, data.edge_index, data.batch 92 | x = F.relu(self.conv1(x, edge_index)) 93 | xs = [x] 94 | for conv in self.convs: 95 | x = F.relu(conv(x, edge_index)) 96 | xs += [x] 97 | x = global_mean_pool(torch.cat(xs, dim=1), batch) 98 | x = F.relu(self.lin1(x)) 99 | x = F.dropout(x, p=0.5, training=self.training) 100 | x = self.lin2(x) 101 | return F.log_softmax(x, dim=-1) 102 | 103 | def __repr__(self): 104 | return self.__class__.__name__ 105 | -------------------------------------------------------------------------------- /kernel/gin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear, Sequential, ReLU, BatchNorm1d as BN 4 | from torch_geometric.nn import GINConv, global_mean_pool, global_add_pool 5 | 6 | 7 | class NestedGIN(torch.nn.Module): 8 | def __init__(self, dataset, num_layers, hidden, use_z=False, use_rd=False): 9 | super(NestedGIN, self).__init__() 10 | self.use_rd = use_rd 11 | self.use_z = use_z 12 | if self.use_rd: 13 | self.rd_projection = torch.nn.Linear(1, 8) 14 | if self.use_z: 15 | self.z_embedding = torch.nn.Embedding(1000, 8) 16 | input_dim = dataset.num_features 17 | if self.use_z or self.use_rd: 18 | input_dim += 8 19 | 20 | self.conv1 = GINConv( 21 | Sequential( 22 | Linear(input_dim, hidden), 23 | BN(hidden), 24 | ReLU(), 25 | Linear(hidden, hidden), 26 | BN(hidden), 27 | ReLU(), 28 | ), 29 | train_eps=True) 30 | self.convs = torch.nn.ModuleList() 31 | for i in range(num_layers - 1): 32 | self.convs.append( 33 | GINConv( 34 | Sequential( 35 | Linear(hidden, hidden), 36 | BN(hidden), 37 | ReLU(), 38 | Linear(hidden, hidden), 39 | BN(hidden), 40 | ReLU(), 41 | ), 42 | train_eps=True)) 43 | self.lin1 = torch.nn.Linear(num_layers * hidden, hidden) 44 | self.lin2 = Linear(hidden, dataset.num_classes) 45 | 46 | def reset_parameters(self): 47 | if self.use_rd: 48 | self.rd_projection.reset_parameters() 49 | if self.use_z: 50 | self.z_embedding.reset_parameters() 51 | self.conv1.reset_parameters() 52 | for conv in self.convs: 53 | conv.reset_parameters() 54 | self.lin1.reset_parameters() 55 | self.lin2.reset_parameters() 56 | 57 | def forward(self, data): 58 | x, edge_index, batch = data.x, data.edge_index, data.batch 59 | 60 | # node label embedding 61 | z_emb = 0 62 | if self.use_z and 'z' in data: 63 | ### computing input node embedding 64 | z_emb = self.z_embedding(data.z) 65 | if z_emb.ndim == 3: 66 | z_emb = z_emb.sum(dim=1) 67 | 68 | if self.use_rd and 'rd' in data: 69 | rd_proj = self.rd_projection(data.rd) 70 | z_emb += rd_proj 71 | 72 | if self.use_rd or self.use_z: 73 | x = torch.cat([z_emb, x], -1) 74 | 75 | x = self.conv1(x, edge_index) 76 | xs = [x] 77 | for conv in self.convs: 78 | x = conv(x, edge_index) 79 | xs += [x] 80 | 81 | x = global_mean_pool(torch.cat(xs, dim=1), data.node_to_subgraph) 82 | #x = global_add_pool(x, data.subgraph_to_graph) 83 | x = global_mean_pool(x, data.subgraph_to_graph) 84 | x = F.relu(self.lin1(x)) 85 | x = F.dropout(x, p=0.5, training=self.training) 86 | x = self.lin2(x) 87 | 88 | return F.log_softmax(x, dim=-1) 89 | 90 | def __repr__(self): 91 | return self.__class__.__name__ 92 | 93 | 94 | class GIN0(torch.nn.Module): 95 | def __init__(self, dataset, num_layers, hidden, subconv=False): 96 | super(GIN0, self).__init__() 97 | self.subconv = subconv 98 | self.conv1 = GINConv( 99 | Sequential( 100 | Linear(dataset.num_features, hidden), 101 | BN(hidden), 102 | ReLU(), 103 | Linear(hidden, hidden), 104 | BN(hidden), 105 | ReLU(), 106 | ), 107 | train_eps=False) 108 | self.convs = torch.nn.ModuleList() 109 | for i in range(num_layers - 1): 110 | self.convs.append( 111 | GINConv( 112 | Sequential( 113 | Linear(hidden, hidden), 114 | BN(hidden), 115 | ReLU(), 116 | Linear(hidden, hidden), 117 | BN(hidden), 118 | ReLU(), 119 | ), 120 | train_eps=False)) 121 | self.lin1 = torch.nn.Linear(num_layers * hidden, hidden) 122 | self.lin2 = Linear(hidden, dataset.num_classes) 123 | 124 | def reset_parameters(self): 125 | self.conv1.reset_parameters() 126 | for conv in self.convs: 127 | conv.reset_parameters() 128 | self.lin1.reset_parameters() 129 | self.lin2.reset_parameters() 130 | 131 | def forward(self, data): 132 | x, edge_index, batch = data.x, data.edge_index, data.batch 133 | x = self.conv1(x, edge_index) 134 | xs = [x] 135 | for conv in self.convs: 136 | x = conv(x, edge_index) 137 | xs += [x] 138 | if True: 139 | if self.subconv: 140 | x = global_mean_pool(torch.cat(xs, dim=1), data.node_to_subgraph) 141 | x = global_add_pool(x, data.subgraph_to_graph) 142 | x = F.relu(self.lin1(x)) 143 | x = F.dropout(x, p=0.5, training=self.training) 144 | x = self.lin2(x) 145 | else: 146 | x = global_add_pool(torch.cat(xs, dim=1), batch) 147 | x = F.relu(self.lin1(x)) 148 | x = F.dropout(x, p=0.5, training=self.training) 149 | x = self.lin2(x) 150 | else: # GIN pooling in the paper 151 | xs = [global_add_pool(x, batch) for x in xs] 152 | xs = [F.dropout(self.lin2(x), p=0.5, training=self.training) for x in xs] 153 | x = 0 154 | for x_ in xs: 155 | x += x_ 156 | 157 | return F.log_softmax(x, dim=-1) 158 | 159 | def __repr__(self): 160 | return self.__class__.__name__ 161 | 162 | 163 | class GIN(torch.nn.Module): 164 | def __init__(self, dataset, num_layers, hidden, *args, **kwargs): 165 | super(GIN, self).__init__() 166 | self.conv1 = GINConv( 167 | Sequential( 168 | Linear(dataset.num_features, hidden), 169 | ReLU(), 170 | Linear(hidden, hidden), 171 | ReLU(), 172 | BN(hidden), 173 | ), 174 | train_eps=True) 175 | self.convs = torch.nn.ModuleList() 176 | for i in range(num_layers - 1): 177 | self.convs.append( 178 | GINConv( 179 | Sequential( 180 | Linear(hidden, hidden), 181 | ReLU(), 182 | Linear(hidden, hidden), 183 | ReLU(), 184 | BN(hidden), 185 | ), 186 | train_eps=True)) 187 | self.lin1 = torch.nn.Linear(num_layers * hidden, hidden) 188 | self.lin2 = Linear(hidden, dataset.num_classes) 189 | 190 | def reset_parameters(self): 191 | self.conv1.reset_parameters() 192 | for conv in self.convs: 193 | conv.reset_parameters() 194 | self.lin1.reset_parameters() 195 | self.lin2.reset_parameters() 196 | 197 | def forward(self, data): 198 | x, edge_index, batch = data.x, data.edge_index, data.batch 199 | x = self.conv1(x, edge_index) 200 | xs = [x] 201 | for conv in self.convs: 202 | x = conv(x, edge_index) 203 | xs += [x] 204 | x = global_mean_pool(torch.cat(xs, dim=1), batch) 205 | x = F.relu(self.lin1(x)) 206 | x = F.dropout(x, p=0.5, training=self.training) 207 | x = self.lin2(x) 208 | return F.log_softmax(x, dim=-1) 209 | 210 | def __repr__(self): 211 | return self.__class__.__name__ 212 | -------------------------------------------------------------------------------- /kernel/global_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear 4 | from torch_geometric.nn import SAGEConv, GlobalAttention 5 | 6 | 7 | class GlobalAttentionNet(torch.nn.Module): 8 | def __init__(self, dataset, num_layers, hidden): 9 | super(GlobalAttentionNet, self).__init__() 10 | self.conv1 = SAGEConv(dataset.num_features, hidden) 11 | self.convs = torch.nn.ModuleList() 12 | for i in range(num_layers - 1): 13 | self.convs.append(SAGEConv(hidden, hidden)) 14 | self.att = GlobalAttention(Linear(hidden, 1)) 15 | self.lin1 = Linear(hidden, hidden) 16 | self.lin2 = Linear(hidden, dataset.num_classes) 17 | 18 | def reset_parameters(self): 19 | self.conv1.reset_parameters() 20 | for conv in self.convs: 21 | conv.reset_parameters() 22 | self.att.reset_parameters() 23 | self.lin1.reset_parameters() 24 | self.lin2.reset_parameters() 25 | 26 | def forward(self, data): 27 | x, edge_index, batch = data.x, data.edge_index, data.batch 28 | x = F.relu(self.conv1(x, edge_index)) 29 | for conv in self.convs: 30 | x = F.relu(conv(x, edge_index)) 31 | x = self.att(x, batch) 32 | x = F.relu(self.lin1(x)) 33 | x = F.dropout(x, p=0.5, training=self.training) 34 | x = self.lin2(x) 35 | return F.log_softmax(x, dim=-1) 36 | 37 | def __repr__(self): 38 | return self.__class__.__name__ 39 | -------------------------------------------------------------------------------- /kernel/graclus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear 4 | from torch_geometric.nn import GraphConv, global_mean_pool, graclus, max_pool 5 | from torch_geometric.data import Batch 6 | 7 | 8 | class Graclus(torch.nn.Module): 9 | def __init__(self, dataset, num_layers, hidden): 10 | super(Graclus, self).__init__() 11 | self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean') 12 | self.convs = torch.nn.ModuleList() 13 | for i in range(num_layers - 1): 14 | self.convs.append(GraphConv(hidden, hidden, aggr='mean')) 15 | self.lin1 = Linear(num_layers * hidden, hidden) 16 | self.lin2 = Linear(hidden, dataset.num_classes) 17 | 18 | def reset_parameters(self): 19 | self.conv1.reset_parameters() 20 | for conv in self.convs: 21 | conv.reset_parameters() 22 | self.lin1.reset_parameters() 23 | self.lin2.reset_parameters() 24 | 25 | def forward(self, data): 26 | x, edge_index, batch = data.x, data.edge_index, data.batch 27 | x = F.relu(self.conv1(x, edge_index)) 28 | xs = [global_mean_pool(x, batch)] 29 | for i, conv in enumerate(self.convs): 30 | x = F.relu(conv(x, edge_index)) 31 | xs += [global_mean_pool(x, batch)] 32 | if i % 2 == 0: 33 | cluster = graclus(edge_index, num_nodes=x.size(0)) 34 | data = Batch(x=x, edge_index=edge_index, batch=batch) 35 | data = max_pool(cluster, data) 36 | x, edge_index, batch = data.x, data.edge_index, data.batch 37 | x = torch.cat(xs, dim=1) 38 | x = F.relu(self.lin1(x)) 39 | x = F.dropout(x, p=0.5, training=self.training) 40 | x = self.lin2(x) 41 | return F.log_softmax(x, dim=-1) 42 | 43 | def __repr__(self): 44 | return self.__class__.__name__ 45 | -------------------------------------------------------------------------------- /kernel/graph_sage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear 4 | from torch_geometric.nn import SAGEConv, global_mean_pool 5 | 6 | 7 | class NestedGraphSAGE(torch.nn.Module): 8 | def __init__(self, dataset, num_layers, hidden, use_z=False, use_rd=False): 9 | super(NestedGraphSAGE, self).__init__() 10 | self.use_rd = use_rd 11 | self.use_z = use_z 12 | if self.use_rd: 13 | self.rd_projection = torch.nn.Linear(1, 8) 14 | if self.use_z: 15 | self.z_embedding = torch.nn.Embedding(1000, 8) 16 | input_dim = dataset.num_features 17 | if self.use_z or self.use_rd: 18 | input_dim += 8 19 | 20 | self.conv1 = SAGEConv(input_dim, hidden) 21 | self.convs = torch.nn.ModuleList() 22 | for i in range(num_layers - 1): 23 | self.convs.append(SAGEConv(hidden, hidden)) 24 | self.lin1 = torch.nn.Linear(num_layers * hidden, hidden) 25 | self.lin2 = Linear(hidden, dataset.num_classes) 26 | 27 | def reset_parameters(self): 28 | if self.use_rd: 29 | self.rd_projection.reset_parameters() 30 | if self.use_z: 31 | self.z_embedding.reset_parameters() 32 | self.conv1.reset_parameters() 33 | for conv in self.convs: 34 | conv.reset_parameters() 35 | self.lin1.reset_parameters() 36 | self.lin2.reset_parameters() 37 | 38 | def forward(self, data): 39 | x, edge_index, batch = data.x, data.edge_index, data.batch 40 | 41 | # node label embedding 42 | z_emb = 0 43 | if self.use_z and 'z' in data: 44 | ### computing input node embedding 45 | z_emb = self.z_embedding(data.z) 46 | if z_emb.ndim == 3: 47 | z_emb = z_emb.sum(dim=1) 48 | 49 | if self.use_rd and 'rd' in data: 50 | rd_proj = self.rd_projection(data.rd) 51 | z_emb += rd_proj 52 | 53 | if self.use_rd or self.use_z: 54 | x = torch.cat([z_emb, x], -1) 55 | 56 | x = F.relu(self.conv1(x, edge_index)) 57 | xs = [x] 58 | for conv in self.convs: 59 | x = F.relu(conv(x, edge_index)) 60 | xs += [x] 61 | x = global_mean_pool(torch.cat(xs, dim=1), data.node_to_subgraph) 62 | x = global_mean_pool(x, data.subgraph_to_graph) 63 | x = F.relu(self.lin1(x)) 64 | x = F.dropout(x, p=0.5, training=self.training) 65 | x = self.lin2(x) 66 | return F.log_softmax(x, dim=-1) 67 | 68 | def __repr__(self): 69 | return self.__class__.__name__ 70 | 71 | 72 | class GraphSAGE(torch.nn.Module): 73 | def __init__(self, dataset, num_layers, hidden, *args, **kwargs): 74 | super(GraphSAGE, self).__init__() 75 | self.conv1 = SAGEConv(dataset.num_features, hidden) 76 | self.convs = torch.nn.ModuleList() 77 | for i in range(num_layers - 1): 78 | self.convs.append(SAGEConv(hidden, hidden)) 79 | self.lin1 = torch.nn.Linear(num_layers * hidden, hidden) 80 | self.lin2 = Linear(hidden, dataset.num_classes) 81 | 82 | def reset_parameters(self): 83 | self.conv1.reset_parameters() 84 | for conv in self.convs: 85 | conv.reset_parameters() 86 | self.lin1.reset_parameters() 87 | self.lin2.reset_parameters() 88 | 89 | def forward(self, data): 90 | x, edge_index, batch = data.x, data.edge_index, data.batch 91 | x = F.relu(self.conv1(x, edge_index)) 92 | xs = [x] 93 | for conv in self.convs: 94 | x = F.relu(conv(x, edge_index)) 95 | xs += [x] 96 | x = global_mean_pool(torch.cat(xs, dim=1), batch) 97 | x = F.relu(self.lin1(x)) 98 | x = F.dropout(x, p=0.5, training=self.training) 99 | x = self.lin2(x) 100 | return F.log_softmax(x, dim=-1) 101 | 102 | def __repr__(self): 103 | return self.__class__.__name__ 104 | 105 | 106 | class GraphSAGEWithoutJK(torch.nn.Module): 107 | def __init__(self, dataset, num_layers, hidden): 108 | super(GraphSAGEWithoutJK, self).__init__() 109 | self.conv1 = SAGEConv(dataset.num_features, hidden) 110 | self.convs = torch.nn.ModuleList() 111 | for i in range(num_layers - 1): 112 | self.convs.append(SAGEConv(hidden, hidden)) 113 | self.lin1 = Linear(hidden, hidden) 114 | self.lin2 = Linear(hidden, dataset.num_classes) 115 | 116 | def reset_parameters(self): 117 | self.conv1.reset_parameters() 118 | for conv in self.convs: 119 | conv.reset_parameters() 120 | self.lin1.reset_parameters() 121 | self.lin2.reset_parameters() 122 | 123 | def forward(self, data): 124 | x, edge_index, batch = data.x, data.edge_index, data.batch 125 | x = F.relu(self.conv1(x, edge_index)) 126 | for conv in self.convs: 127 | x = F.relu(conv(x, edge_index)) 128 | x = global_mean_pool(x, batch) 129 | x = F.relu(self.lin1(x)) 130 | x = F.dropout(x, p=0.5, training=self.training) 131 | x = self.lin2(x) 132 | return F.log_softmax(x, dim=-1) 133 | 134 | def __repr__(self): 135 | return self.__class__.__name__ 136 | -------------------------------------------------------------------------------- /kernel/set2set.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear 4 | from torch_geometric.nn import SAGEConv, Set2Set 5 | 6 | 7 | class Set2SetNet(torch.nn.Module): 8 | def __init__(self, dataset, num_layers, hidden): 9 | super(Set2SetNet, self).__init__() 10 | self.conv1 = SAGEConv(dataset.num_features, hidden) 11 | self.convs = torch.nn.ModuleList() 12 | for i in range(num_layers - 1): 13 | self.convs.append(SAGEConv(hidden, hidden)) 14 | self.set2set = Set2Set(hidden, processing_steps=4) 15 | self.lin1 = Linear(2 * hidden, hidden) 16 | self.lin2 = Linear(hidden, dataset.num_classes) 17 | 18 | def reset_parameters(self): 19 | self.conv1.reset_parameters() 20 | for conv in self.convs: 21 | conv.reset_parameters() 22 | self.set2set.reset_parameters() 23 | self.lin1.reset_parameters() 24 | self.lin2.reset_parameters() 25 | 26 | def forward(self, data): 27 | x, edge_index, batch = data.x, data.edge_index, data.batch 28 | x = F.relu(self.conv1(x, edge_index)) 29 | for conv in self.convs: 30 | x = F.relu(conv(x, edge_index)) 31 | x = self.set2set(x, batch) 32 | x = F.relu(self.lin1(x)) 33 | x = F.dropout(x, p=0.5, training=self.training) 34 | x = self.lin2(x) 35 | return F.log_softmax(x, dim=-1) 36 | 37 | def __repr__(self): 38 | return self.__class__.__name__ 39 | -------------------------------------------------------------------------------- /kernel/sort_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Linear, Conv1d 5 | from torch_geometric.nn import SAGEConv, GCNConv, global_sort_pool 6 | import pdb 7 | 8 | original = False # whether to use the original model setting in this script 9 | 10 | class SortPool(torch.nn.Module): 11 | def __init__(self, dataset, k=30): 12 | super(SortPool, self).__init__() 13 | self.conv1 = GCNConv(dataset.num_features, hidden) 14 | self.convs = torch.nn.ModuleList() 15 | for i in range(num_layers - 1): 16 | self.convs.append(GCNConv(hidden, hidden)) 17 | if k < 1: # transform percentile to number 18 | node_nums = sorted([g.num_nodes for g in dataset]) 19 | k = node_nums[int(math.ceil(k * len(node_nums)))-1] 20 | k = max(10, k) # no smaller than 10 21 | self.k = int(k) 22 | print('k used in sortpooling is:', self.k) 23 | if original: 24 | self.k = 10 25 | self.lin1 = Linear(self.k * hidden, hidden) 26 | self.lin2 = Linear(hidden, dataset.num_classes) 27 | else: 28 | self.k = 30 29 | conv1d_output_channels = 32 30 | conv1d_kernel_size = 5 31 | self.conv1d = Conv1d(hidden, conv1d_output_channels, conv1d_kernel_size) 32 | self.lin1 = Linear(conv1d_output_channels * (self.k - conv1d_kernel_size + 1), hidden) 33 | self.lin2 = Linear(hidden, dataset.num_classes) 34 | 35 | ''' 36 | conv1d_channels = [16, 32] 37 | conv1d_activation = nn.ReLU() 38 | self.total_latent_dim = sum(latent_dim) 39 | conv1d_kws = [self.total_latent_dim, 5] 40 | self.conv1d_params1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0]) 41 | self.maxpool1d = nn.MaxPool1d(2, 2) 42 | self.conv1d_params2 = Conv1d(conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1) 43 | dense_dim = int((k - 2) / 2 + 1) 44 | self.dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1] 45 | self.lin1 = Linear(self.dense_dim, 128) 46 | ''' 47 | 48 | 49 | 50 | 51 | def reset_parameters(self): 52 | self.conv1.reset_parameters() 53 | for conv in self.convs: 54 | conv.reset_parameters() 55 | self.lin1.reset_parameters() 56 | self.lin2.reset_parameters() 57 | 58 | def forward(self, data): 59 | x, edge_index, batch = data.x, data.edge_index, data.batch 60 | x = F.relu(self.conv1(x, edge_index)) 61 | for conv in self.convs: 62 | x = F.relu(conv(x, edge_index)) 63 | x = global_sort_pool(x, batch, self.k) # batch * (k*hidden) 64 | if original: 65 | x = F.relu(self.lin1(x)) 66 | else: 67 | x = x.view(len(x), self.k, -1).permute(0, 2, 1) # batch * hidden * k 68 | x = F.relu(self.conv1d(x)) # batch * output_channels * (k-kernel_size+1) 69 | x = x.view(len(x), -1) 70 | x = F.relu(self.lin1(x)) # batch * hidden 71 | x = F.dropout(x, p=0.5, training=self.training) 72 | x = self.lin2(x) 73 | return F.log_softmax(x, dim=-1) 74 | 75 | def __repr__(self): 76 | return self.__class__.__name__ 77 | -------------------------------------------------------------------------------- /kernel/statistics.py: -------------------------------------------------------------------------------- 1 | from kernel.datasets import get_dataset 2 | 3 | 4 | def print_dataset(dataset): 5 | num_nodes = num_edges = 0 6 | for data in dataset: 7 | num_nodes += data.num_nodes 8 | num_edges += data.num_edges 9 | 10 | print('Name', dataset) 11 | print('Graphs', len(dataset)) 12 | print('Nodes', num_nodes / len(dataset)) 13 | print('Edges', (num_edges // 2) / len(dataset)) 14 | print('Features', dataset.num_features) 15 | print('Classes', dataset.num_classes) 16 | print() 17 | 18 | 19 | for name in ['MUTAG', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'REDDIT-BINARY']: 20 | print_dataset(get_dataset(name)) 21 | -------------------------------------------------------------------------------- /kernel/top_k.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear 4 | from torch_geometric.nn import GraphConv, global_mean_pool, TopKPooling 5 | 6 | 7 | class TopK(torch.nn.Module): 8 | def __init__(self, dataset, num_layers, hidden): 9 | super(TopK, self).__init__() 10 | self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean') 11 | self.convs = torch.nn.ModuleList() 12 | self.pools = torch.nn.ModuleList() 13 | for i in range(num_layers - 1): 14 | self.convs.append(GraphConv(hidden, hidden, aggr='mean')) 15 | self.pools.append(TopKPooling(hidden, ratio=0.8)) 16 | self.lin1 = Linear(num_layers * hidden, hidden) 17 | self.lin2 = Linear(hidden, dataset.num_classes) 18 | 19 | def reset_parameters(self): 20 | self.conv1.reset_parameters() 21 | for conv, pool in zip(self.convs, self.pools): 22 | conv.reset_parameters() 23 | pool.reset_parameters() 24 | self.lin1.reset_parameters() 25 | self.lin2.reset_parameters() 26 | 27 | def forward(self, data): 28 | x, edge_index, batch = data.x, data.edge_index, data.batch 29 | x = F.relu(self.conv1(x, edge_index)) 30 | xs = [global_mean_pool(x, batch)] 31 | for i, (conv, pool) in enumerate(zip(self.convs, self.pools)): 32 | x = F.relu(conv(x, edge_index)) 33 | xs += [global_mean_pool(x, batch)] 34 | if i % 2 == 0: 35 | x, edge_index, _, batch, _ = pool(x, edge_index, batch=batch) 36 | x = torch.cat(xs, dim=1) 37 | x = F.relu(self.lin1(x)) 38 | x = F.dropout(x, p=0.5, training=self.training) 39 | x = self.lin2(x) 40 | return F.log_softmax(x, dim=-1) 41 | 42 | def __repr__(self): 43 | return self.__class__.__name__ 44 | -------------------------------------------------------------------------------- /kernel/tu_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | from tqdm import tqdm 5 | 6 | import torch 7 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip 8 | from torch_geometric.io import read_tu_data 9 | 10 | 11 | class TUDataset(InMemoryDataset): 12 | r"""A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY", 13 | "REDDIT-BINARY" or "PROTEINS", collected from the `TU Dortmund University 14 | `_. 15 | In addition, this dataset wrapper provides `cleaned dataset versions 16 | `_ as motivated by the 17 | `"Understanding Isomorphism Bias in Graph Data Sets" 18 | `_ paper, containing only non-isomorphic 19 | graphs. 20 | 21 | .. note:: 22 | Some datasets may not come with any node labels. 23 | You can then either make use of the argument :obj:`use_node_attr` 24 | to load additional continuous node attributes (if present) or provide 25 | synthetic node features using transforms such as 26 | like :class:`torch_geometric.transforms.Constant` or 27 | :class:`torch_geometric.transforms.OneHotDegree`. 28 | 29 | Args: 30 | root (string): Root directory where the dataset should be saved. 31 | name (string): The `name 32 | `_ of the 33 | dataset. 34 | transform (callable, optional): A function/transform that takes in an 35 | :obj:`torch_geometric.data.Data` object and returns a transformed 36 | version. The data object will be transformed before every access. 37 | (default: :obj:`None`) 38 | pre_transform (callable, optional): A function/transform that takes in 39 | an :obj:`torch_geometric.data.Data` object and returns a 40 | transformed version. The data object will be transformed before 41 | being saved to disk. (default: :obj:`None`) 42 | pre_filter (callable, optional): A function that takes in an 43 | :obj:`torch_geometric.data.Data` object and returns a boolean 44 | value, indicating whether the data object should be included in the 45 | final dataset. (default: :obj:`None`) 46 | use_node_attr (bool, optional): If :obj:`True`, the dataset will 47 | contain additional continuous node attributes (if present). 48 | (default: :obj:`False`) 49 | use_edge_attr (bool, optional): If :obj:`True`, the dataset will 50 | contain additional continuous edge attributes (if present). 51 | (default: :obj:`False`) 52 | cleaned: (bool, optional): If :obj:`True`, the dataset will 53 | contain only non-isomorphic graphs. (default: :obj:`False`) 54 | """ 55 | 56 | url = 'https://www.chrsmrrs.com/graphkerneldatasets' 57 | cleaned_url = ('https://raw.githubusercontent.com/nd7141/' 58 | 'graph_datasets/master/datasets') 59 | 60 | def __init__(self, root, name, transform=None, pre_transform=None, 61 | pre_filter=None, use_node_attr=False, use_edge_attr=False, 62 | cleaned=False): 63 | self.name = name 64 | self.cleaned = cleaned 65 | super(TUDataset, self).__init__(root, transform, pre_transform, 66 | pre_filter) 67 | self.data, self.slices = torch.load(self.processed_paths[0]) 68 | if self.data.x is not None and not use_node_attr: 69 | num_node_attributes = self.num_node_attributes 70 | self.data.x = self.data.x[:, num_node_attributes:] 71 | if self.data.edge_attr is not None and not use_edge_attr: 72 | num_edge_attributes = self.num_edge_attributes 73 | self.data.edge_attr = self.data.edge_attr[:, num_edge_attributes:] 74 | 75 | @property 76 | def raw_dir(self): 77 | name = 'raw{}'.format('_cleaned' if self.cleaned else '') 78 | return osp.join(self.root, self.name, name) 79 | 80 | @property 81 | def processed_dir(self): 82 | name = 'processed{}'.format('_cleaned' if self.cleaned else '') 83 | return osp.join(self.root, self.name, name) 84 | 85 | @property 86 | def num_node_labels(self): 87 | if self.data.x is None: 88 | return 0 89 | for i in range(self.data.x.size(1)): 90 | x = self.data.x[:, i:] 91 | if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all(): 92 | return self.data.x.size(1) - i 93 | return 0 94 | 95 | @property 96 | def num_node_attributes(self): 97 | if self.data.x is None: 98 | return 0 99 | return self.data.x.size(1) - self.num_node_labels 100 | 101 | @property 102 | def num_edge_labels(self): 103 | if self.data.edge_attr is None: 104 | return 0 105 | for i in range(self.data.edge_attr.size(1)): 106 | if self.data.edge_attr[:, i:].sum() == self.data.edge_attr.size(0): 107 | return self.data.edge_attr.size(1) - i 108 | return 0 109 | 110 | @property 111 | def num_edge_attributes(self): 112 | if self.data.edge_attr is None: 113 | return 0 114 | return self.data.edge_attr.size(1) - self.num_edge_labels 115 | 116 | @property 117 | def raw_file_names(self): 118 | names = ['A', 'graph_indicator'] 119 | return ['{}_{}.txt'.format(self.name, name) for name in names] 120 | 121 | @property 122 | def processed_file_names(self): 123 | return 'data.pt' 124 | 125 | def download(self): 126 | url = self.cleaned_url if self.cleaned else self.url 127 | folder = osp.join(self.root, self.name) 128 | path = download_url('{}/{}.zip'.format(url, self.name), folder) 129 | extract_zip(path, folder) 130 | os.unlink(path) 131 | shutil.rmtree(self.raw_dir) 132 | os.rename(osp.join(folder, self.name), self.raw_dir) 133 | 134 | def process(self): 135 | self.data, self.slices = read_tu_data(self.raw_dir, self.name) 136 | 137 | if self.pre_filter is not None: 138 | data_list = [self.get(idx) for idx in range(len(self))] 139 | data_list = [data for data in data_list if self.pre_filter(data)] 140 | self.data, self.slices = self.collate(data_list) 141 | 142 | if self.pre_transform is not None: 143 | data_list = [self.get(idx) for idx in range(len(self))] 144 | #data_list = [self.pre_transform(data) for data in data_list] 145 | new_data_list = [] 146 | for data in tqdm(data_list): 147 | new_data_list.append(self.pre_transform(data)) 148 | data_list = new_data_list 149 | self.data, self.slices = self.collate(data_list) 150 | 151 | torch.save((self.data, self.slices), self.processed_paths[0]) 152 | 153 | def __repr__(self): 154 | return '{}({})'.format(self.name, len(self)) 155 | -------------------------------------------------------------------------------- /modules/ppgn_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pdb 3 | 4 | 5 | def diag_offdiag_maxpool(input): 6 | N = input.shape[-1] 7 | 8 | max_diag = torch.max(torch.diagonal(input, dim1=-2, dim2=-1), dim=2)[0] # BxS 9 | 10 | # with torch.no_grad(): 11 | max_val = torch.max(max_diag) 12 | min_val = torch.max(-1 * input) 13 | val = torch.abs(torch.add(max_val, min_val)) 14 | 15 | min_mat = torch.mul(val, torch.eye(N, device=input.device)).view(1, 1, N, N) 16 | 17 | max_offdiag = torch.max(torch.max(input - min_mat, dim=3)[0], dim=2)[0] # BxS 18 | 19 | return torch.cat((max_diag, max_offdiag), dim=1) # output Bx2S 20 | 21 | 22 | def diag_offdiag_meanpool(input): 23 | N = input.shape[-1] 24 | 25 | mean_diag = torch.mean(torch.diagonal(input, dim1=-2, dim2=-1), dim=2) # BxS 26 | mean_offdiag = (torch.sum(input, dim=[-1, -2]) - mean_diag * N) / (N * N - N) 27 | 28 | return torch.cat((mean_diag, mean_offdiag), dim=1) # output Bx2S 29 | -------------------------------------------------------------------------------- /modules/ppgn_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RegularBlock(nn.Module): 6 | """ 7 | Imputs: N x input_depth x m x m 8 | Take the input through 2 parallel MLP routes, multiply the result, and add a skip-connection at the end. 9 | At the skip-connection, reduce the dimension back to output_depth 10 | """ 11 | def __init__(self, depth_of_mlp, in_features, out_features): 12 | super().__init__() 13 | 14 | self.mlp1 = MlpBlock(in_features, out_features, depth_of_mlp) 15 | self.mlp2 = MlpBlock(in_features, out_features, depth_of_mlp) 16 | 17 | self.skip = SkipConnection(in_features+out_features, out_features) 18 | 19 | def forward(self, inputs): 20 | mlp1 = self.mlp1(inputs) 21 | mlp2 = self.mlp2(inputs) 22 | 23 | mult = torch.matmul(mlp1, mlp2) 24 | 25 | out = self.skip(in1=inputs, in2=mult) 26 | return out 27 | 28 | 29 | class MlpBlock(nn.Module): 30 | """ 31 | Block of MLP layers with activation function after each (1x1 conv layers). 32 | """ 33 | def __init__(self, in_features, out_features, depth_of_mlp, activation_fn=nn.functional.relu): 34 | super().__init__() 35 | self.activation = activation_fn 36 | self.convs = nn.ModuleList() 37 | for i in range(depth_of_mlp): 38 | self.convs.append(nn.Conv2d(in_features, out_features, kernel_size=1, padding=0, bias=True)) 39 | _init_weights(self.convs[-1]) 40 | in_features = out_features 41 | 42 | def forward(self, inputs): 43 | out = inputs 44 | for conv_layer in self.convs: 45 | out = self.activation(conv_layer(out)) 46 | 47 | return out 48 | 49 | 50 | class SkipConnection(nn.Module): 51 | """ 52 | Connects the two given inputs with concatenation 53 | :param in1: earlier input tensor of shape N x d1 x m x m 54 | :param in2: later input tensor of shape N x d2 x m x m 55 | :param in_features: d1+d2 56 | :param out_features: output num of features 57 | :return: Tensor of shape N x output_depth x m x m 58 | """ 59 | def __init__(self, in_features, out_features): 60 | super().__init__() 61 | self.conv = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0, bias=True) 62 | _init_weights(self.conv) 63 | 64 | def forward(self, in1, in2): 65 | # in1: N x d1 x m x m 66 | # in2: N x d2 x m x m 67 | out = torch.cat((in1, in2), dim=1) 68 | out = self.conv(out) 69 | return out 70 | 71 | 72 | class FullyConnected(nn.Module): 73 | def __init__(self, in_features, out_features, activation_fn=nn.functional.relu): 74 | super().__init__() 75 | 76 | self.fc = nn.Linear(in_features, out_features) 77 | _init_weights(self.fc) 78 | 79 | self.activation = activation_fn 80 | 81 | def forward(self, input): 82 | out = self.fc(input) 83 | if self.activation is not None: 84 | out = self.activation(out) 85 | 86 | return out 87 | 88 | 89 | def _init_weights(layer): 90 | """ 91 | Init weights of the layer 92 | :param layer: 93 | :return: 94 | """ 95 | nn.init.xavier_uniform_(layer.weight) 96 | # nn.init.xavier_normal_(layer.weight) 97 | if layer.bias is not None: 98 | nn.init.zeros_(layer.bias) 99 | -------------------------------------------------------------------------------- /run_all_targets_qm9.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Run all targets on QM9. Usage: 4 | # ./run_all_targets_qm9.sh 2 10 5 | # to run qm9 from targets 2 to 10 6 | # When you finish all the runs, type 7 | # for i in `seq 0 11`; do tail -1 QM9_${i}_k1_h3_spd_rd/log.txt; done 8 | # for example to summarize the results 9 | 10 | 11 | T0=${1} 12 | T1=${2} 13 | for target in $(seq ${T0} ${T1}) 14 | do 15 | # The following 4 commands reproduce the NGNN results in Table 4. 16 | python run_qm9.py --h 3 --model Nested_k1_GNN --save_appendix _k1_h3_spd_rd --use_rd --target ${target} 17 | #python run_qm9.py --h 3 --model Nested_k12_GNN --save_appendix _k12_h3_spd_rd --use_rd --target ${target} 18 | #python run_qm9.py --h 3 --model Nested_k13_GNN --save_appendix _k13_h3_spd_rd --use_rd --target ${target} 19 | #python run_qm9.py --h 3 --model Nested_k123_GNN --save_appendix _k123_h3_spd_rd --use_rd --target ${target} 20 | 21 | # The following 4 commands reproduce the NGNN (no DE features) in Table 5 of Appendix E. 22 | #python run_qm9.py --h 3 --model Nested_k1_GNN --save_appendix _k1_h3_no --node_label no --target ${target} 23 | #python run_qm9.py --h 3 --model Nested_k12_GNN --save_appendix _k12_h3_no --node_label no --target ${target} 24 | #python run_qm9.py --h 3 --model Nested_k13_GNN --save_appendix _k13_h3_no --node_label no --target ${target} 25 | #python run_qm9.py --h 3 --model Nested_k123_GNN --save_appendix _k123_h3_no --node_label no --target ${target} 26 | done 27 | -------------------------------------------------------------------------------- /run_simulation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time, os, sys 3 | from shutil import copy 4 | import matplotlib.pyplot as plt 5 | import logging 6 | from math import ceil 7 | import numpy as np 8 | import networkx as nx 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.nn import Linear, Sequential, ReLU, BatchNorm1d as BN 12 | from torch_geometric.nn import GINConv, global_mean_pool, global_add_pool 13 | from torch_geometric.data import Data, DataLoader 14 | from torch_geometric.utils import from_networkx 15 | logging.getLogger('matplotlib.font_manager').disabled = True 16 | logging.getLogger('matplotlib.ticker').disabled = True 17 | from utils import create_subgraphs 18 | import pdb 19 | 20 | 21 | def simulate(args, device): 22 | results = {} 23 | for n in args.n: 24 | print('n = {}'.format(n)) 25 | graphs = generate_many_k_regular_graphs(args.k, n, args.N) 26 | for h in range(1, args.h+1): 27 | G = [pre_transform(g, h) for g in graphs] 28 | loader = DataLoader(G, batch_size=1) 29 | model = NestedGIN(args.layers, 32) 30 | model.to(device) 31 | output = run_simulation(model, loader, device) # output shape [G.number_of_nodes(), feat_dim] 32 | collision_rate = compute_simulation_collisions(output, ratio=True) 33 | results[(n, h)] = collision_rate 34 | torch.cuda.empty_cache() 35 | print('h = {}: {}'.format(h, collision_rate)) 36 | print('#'*30) 37 | return results 38 | 39 | 40 | def generate_many_k_regular_graphs(k, n, N, seed=0): 41 | graphs = [generate_k_regular(k, n, s) for s in range(seed, seed+N)] 42 | graphs = [from_networkx(g) for g in graphs] 43 | return graphs 44 | 45 | 46 | def generate_k_regular(k, n, seed=0): 47 | G = nx.random_regular_graph(d=k, n=n, seed=seed) 48 | return G 49 | 50 | 51 | def run_simulation(model, loader, device): 52 | model.eval() 53 | with torch.no_grad(): 54 | output = [] 55 | for data in loader: 56 | data = data.to(device) 57 | output.append(model(data)) 58 | output = torch.cat(output, 0) 59 | return output 60 | 61 | 62 | def save_simulation_result(results, res_dir, pic_format='pdf'): 63 | n_l, h_l, r_l = [], [], [] 64 | for (n, h), r in results.items(): 65 | n_l.append(n) 66 | h_l.append(h) 67 | r_l.append(r) 68 | main = plt.scatter(n_l, h_l, c=r_l, cmap="Greys", edgecolors='k', linewidths=0.2) 69 | plt.colorbar(main, ticks=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) 70 | n_min, n_max = min(n_l), max(n_l) 71 | lbound = plt.plot([n_min, n_max], 72 | [np.log(n_min)/np.log(2)/2, np.log(n_max)/np.log(2)/2], 73 | 'r--', label='0.5 log(n) / log(r-1)') 74 | ubound = plt.plot([n_min, n_max], 75 | [np.log(n_min)/np.log(2), np.log(n_max)/np.log(2)], 76 | 'b--', label='log(n) / log(r-1)') 77 | plt.xscale('log') 78 | plt.xlabel('number of nodes (n)') 79 | plt.ylabel('height of rooted subgraphs (h)') 80 | plt.legend(loc = 'upper left') 81 | plt.savefig('{}/simulation_results.{}'.format(res_dir, pic_format), dpi=300) 82 | 83 | 84 | def compute_simulation_collisions(outputs, ratio=True): 85 | epsilon = 1e-10 86 | N = outputs.size(0) 87 | with torch.no_grad(): 88 | a = outputs.unsqueeze(-1) 89 | b = outputs.t().unsqueeze(0) 90 | diff = a-b 91 | diff = (diff**2).sum(dim=1) 92 | n_collision = int(((diff < epsilon).sum().item()-N)/2) 93 | r = n_collision / (N*(N-1)/2) 94 | if ratio: 95 | return r 96 | else: 97 | return n_collision 98 | 99 | 100 | class NestedGIN(torch.nn.Module): 101 | def __init__(self, num_layers, hidden): 102 | super(NestedGIN, self).__init__() 103 | self.conv1 = GINConv( 104 | Sequential( 105 | Linear(1, hidden), 106 | ReLU(), 107 | Linear(hidden, hidden), 108 | ReLU(), 109 | ), 110 | train_eps=False) 111 | self.convs = torch.nn.ModuleList() 112 | for i in range(num_layers - 1): 113 | self.convs.append( 114 | GINConv( 115 | Sequential( 116 | Linear(hidden, hidden), 117 | ReLU(), 118 | Linear(hidden, hidden), 119 | ReLU(), 120 | ), 121 | train_eps=False)) 122 | self.lin1 = torch.nn.Linear(hidden, hidden) 123 | self.lin2 = Linear(hidden, hidden) 124 | 125 | def reset_parameters(self): 126 | self.conv1.reset_parameters() 127 | for conv in self.convs: 128 | conv.reset_parameters() 129 | self.lin1.reset_parameters() 130 | self.lin2.reset_parameters() 131 | 132 | def forward(self, data): 133 | edge_index, batch = data.edge_index, data.batch 134 | if 'x' in data: 135 | x = data.x 136 | else: 137 | x = torch.ones([data.num_nodes, 1]).to(edge_index.device) 138 | x = self.conv1(x, edge_index) 139 | for conv in self.convs: 140 | x = conv(x, edge_index) 141 | 142 | x = global_add_pool(x, data.node_to_subgraph) 143 | if args.graph: 144 | x = global_add_pool(x, data.subgraph_to_graph) 145 | 146 | return x 147 | 148 | def __repr__(self): 149 | return self.__class__.__name__ 150 | 151 | 152 | parser = argparse.ArgumentParser(description='Nested GNN Simulation Experiment') 153 | parser.add_argument('--k', type=int, default=3, 154 | help='node degree (k) or synthetic k-regular graph') 155 | parser.add_argument('--n', nargs='*', 156 | help='a list of number of nodes in each connected k-regular subgraph') 157 | parser.add_argument('--N', type=int, default=100, 158 | help='number of graphs in simultation') 159 | parser.add_argument('--h', type=int, default=6, 160 | help='largest height of rooted subgraphs to simulate') 161 | parser.add_argument('--graph', action='store_true', default=False, 162 | help='if True, compute whole-graph collision rate; otherwise node') 163 | parser.add_argument('--layers', type=int, default=1, help='# message passing layers') 164 | parser.add_argument('--save_appendix', default='', 165 | help='what to append to save-names when saving results') 166 | args = parser.parse_args() 167 | args.n = [int(n) for n in args.n] 168 | 169 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 170 | 171 | if args.save_appendix == '': 172 | args.save_appendix = '_' + time.strftime("%Y%m%d%H%M%S") 173 | args.res_dir = 'results/simulation{}'.format(args.save_appendix) 174 | print('Results will be saved in ' + args.res_dir) 175 | if not os.path.exists(args.res_dir): 176 | os.makedirs(args.res_dir) 177 | # Backup python files. 178 | copy('run_simulation.py', args.res_dir) 179 | copy('utils.py', args.res_dir) 180 | # Save command line input. 181 | cmd_input = 'python ' + ' '.join(sys.argv) + '\n' 182 | with open(os.path.join(args.res_dir, 'cmd_input.txt'), 'a') as f: 183 | f.write(cmd_input) 184 | print('Command line input: ' + cmd_input + ' is saved.') 185 | 186 | path = 'data/simulation' 187 | pre_transform = None 188 | if args.h is not None: 189 | if type(args.h) == int: 190 | path += '/ngnn_h' + str(args.h) 191 | def pre_transform(g, h): 192 | return create_subgraphs(g, h, node_label='no', use_rd=False, 193 | subgraph_pretransform=None) 194 | 195 | # Plot visualization figure 196 | results = simulate(args, device) 197 | save_simulation_result(results, args.res_dir) 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /run_tu.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os, sys 3 | import time 4 | from shutil import copy, rmtree 5 | from itertools import product 6 | import pdb 7 | import argparse 8 | import random 9 | import torch 10 | import numpy as np 11 | from kernel.datasets import get_dataset 12 | from kernel.train_eval import cross_validation_with_val_set 13 | from kernel.train_eval import cross_validation_without_val_set 14 | from kernel.gcn import * 15 | from kernel.graph_sage import * 16 | from kernel.gin import * 17 | from kernel.gat import * 18 | from kernel.graclus import Graclus 19 | from kernel.top_k import TopK 20 | from kernel.diff_pool import * 21 | from kernel.global_attention import GlobalAttentionNet 22 | from kernel.set2set import Set2SetNet 23 | from kernel.sort_pool import SortPool 24 | 25 | 26 | # used to traceback which code cause warnings, can delete 27 | import traceback 28 | import warnings 29 | import sys 30 | def warn_with_traceback(message, category, filename, lineno, file=None, line=None): 31 | 32 | log = file if hasattr(file,'write') else sys.stderr 33 | traceback.print_stack(file=log) 34 | log.write(warnings.formatwarning(message, category, filename, lineno, line)) 35 | 36 | warnings.showwarning = warn_with_traceback 37 | 38 | 39 | # General settings. 40 | parser = argparse.ArgumentParser(description='Nested GNN for TU graphs') 41 | parser.add_argument('--data', type=str, default='MUTAG') 42 | parser.add_argument('--clean', action='store_true', default=False, 43 | help='use a cleaned version of dataset by removing isomorphism') 44 | parser.add_argument('--no_val', action='store_true', default=False, 45 | help='if True, do not use validation set, but directly report best\ 46 | test performance.') 47 | 48 | # GNN settings. 49 | parser.add_argument('--model', type=str, default='NestedGCN', 50 | help='GCN, GraphSAGE, GIN, GAT, NestedGCN, NestedGraphSAGE, \ 51 | NestedGIN, and NestedGAT') 52 | parser.add_argument('--layers', type=int, default=4) 53 | parser.add_argument('--hiddens', type=int, default=32) 54 | parser.add_argument('--h', type=int, default=None, help='the height of rooted subgraph \ 55 | for NGNN models') 56 | parser.add_argument('--node_label', type=str, default='spd', 57 | help='apply distance encoding to nodes within each subgraph, use node\ 58 | labels as additional node features; support "hop", "drnl", "spd", \ 59 | "spd5", etc. Default "spd"=="spd2".') 60 | parser.add_argument('--use_rd', action='store_true', default=False, 61 | help='use resistance distance as additional node labels') 62 | parser.add_argument('--use_rp', type=int, default=None, 63 | help='use RW return probability as additional node features,\ 64 | specify num of RW steps here') 65 | parser.add_argument('--max_nodes_per_hop', type=int, default=None) 66 | 67 | # Training settings. 68 | parser.add_argument('--epochs', type=int, default=100) 69 | parser.add_argument('--batch_size', type=int, default=128) 70 | parser.add_argument('--lr', type=float, default=1E-2) 71 | parser.add_argument('--lr_decay_factor', type=float, default=0.5) 72 | parser.add_argument('--lr_decay_step_size', type=int, default=50) 73 | 74 | # Other settings. 75 | parser.add_argument('--seed', type=int, default=1) 76 | parser.add_argument('--search', action='store_true', default=False, 77 | help='search hyperparameters (layers, hiddens)') 78 | parser.add_argument('--save_appendix', default='', 79 | help='what to append to save-names when saving results') 80 | parser.add_argument('--keep_old', action='store_true', default=False, 81 | help='if True, do not overwrite old .py files in the result folder') 82 | parser.add_argument('--reprocess', action='store_true', default=False, 83 | help='if True, reprocess data') 84 | parser.add_argument('--cpu', action='store_true', default=False, help='use cpu') 85 | args = parser.parse_args() 86 | 87 | torch.manual_seed(args.seed) 88 | if torch.cuda.is_available(): 89 | torch.cuda.manual_seed(args.seed) 90 | random.seed(args.seed) 91 | np.random.seed(args.seed) 92 | 93 | file_dir = os.path.dirname(os.path.realpath('__file__')) 94 | if args.save_appendix == '': 95 | args.save_appendix = '_' + time.strftime("%Y%m%d%H%M%S") 96 | args.res_dir = os.path.join(file_dir, 'results/TU{}'.format(args.save_appendix)) 97 | print('Results will be saved in ' + args.res_dir) 98 | if not os.path.exists(args.res_dir): 99 | os.makedirs(args.res_dir) 100 | if not args.keep_old: 101 | # backup current main.py, model.py files 102 | copy('run_tu.py', args.res_dir) 103 | copy('utils.py', args.res_dir) 104 | copy('kernel/train_eval.py', args.res_dir) 105 | copy('kernel/datasets.py', args.res_dir) 106 | copy('kernel/gcn.py', args.res_dir) 107 | copy('kernel/gat.py', args.res_dir) 108 | copy('kernel/graph_sage.py', args.res_dir) 109 | copy('kernel/gin.py', args.res_dir) 110 | # save command line input 111 | cmd_input = 'python ' + ' '.join(sys.argv) + '\n' 112 | with open(os.path.join(args.res_dir, 'cmd_input.txt'), 'a') as f: 113 | f.write(cmd_input) 114 | print('Command line input: ' + cmd_input + ' is saved.') 115 | 116 | if args.data == 'all': 117 | datasets = [ 'DD', 'MUTAG', 'PROTEINS', 'PTC_MR', 'ENZYMES'] 118 | else: 119 | datasets = [args.data] 120 | 121 | if args.search: 122 | if args.h is None: 123 | layers = [2, 3, 4, 5] 124 | hiddens = [32] 125 | hs = [None] 126 | else: 127 | layers = [3, 4, 5, 6] 128 | hiddens = [32, 32, 32, 32] 129 | hs = [2, 3, 4, 5] 130 | else: 131 | layers = [args.layers] 132 | hiddens = [args.hiddens] 133 | hs = [args.h] 134 | 135 | if args.model == 'all': 136 | #nets = [GCN, GraphSAGE, GIN, GAT] 137 | nets = [NestedGCN, NestedGraphSAGE, NestedGIN, NestedGAT] 138 | else: 139 | nets = [eval(args.model)] 140 | 141 | def logger(info): 142 | f = open(os.path.join(args.res_dir, 'log.txt'), 'a') 143 | print(info, file=f) 144 | 145 | device = torch.device( 146 | 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu' 147 | ) 148 | 149 | if args.no_val: 150 | cross_val_method = cross_validation_without_val_set 151 | else: 152 | cross_val_method = cross_validation_with_val_set 153 | 154 | results = [] 155 | for dataset_name, Net in product(datasets, nets): 156 | best_result = (float('inf'), 0, 0) 157 | log = '-----\n{} - {}'.format(dataset_name, Net.__name__) 158 | print(log) 159 | logger(log) 160 | if args.h is not None: 161 | combinations = zip(layers, hiddens, hs) 162 | else: 163 | combinations = product(layers, hiddens, hs) 164 | for num_layers, hidden, h in combinations: 165 | if dataset_name == 'DD' and Net.__name__ == 'NestedGAT' and h >= 5: 166 | print('NestedGAT on DD will OOM for h >= 5. Skipped.') 167 | continue 168 | log = "Using {} layers, {} hidden units, h = {}".format(num_layers, hidden, h) 169 | print(log) 170 | logger(log) 171 | dataset = get_dataset( 172 | dataset_name, 173 | Net != DiffPool, 174 | h, 175 | args.node_label, 176 | args.use_rd, 177 | args.use_rp, 178 | args.reprocess, 179 | args.clean, 180 | args.max_nodes_per_hop, 181 | ) 182 | model = Net(dataset, num_layers, hidden, args.node_label!='no', args.use_rd) 183 | loss, acc, std = cross_val_method( 184 | dataset, 185 | model, 186 | folds=10, 187 | epochs=args.epochs, 188 | batch_size=args.batch_size, 189 | lr=args.lr, 190 | lr_decay_factor=args.lr_decay_factor, 191 | lr_decay_step_size=args.lr_decay_step_size, 192 | weight_decay=0, 193 | device=device, 194 | logger=logger) 195 | if loss < best_result[0]: 196 | best_result = (loss, acc, std) 197 | best_hyper = (num_layers, hidden, h) 198 | 199 | desc = '{:.3f} ± {:.3f}'.format( 200 | best_result[1], best_result[2] 201 | ) 202 | log = 'Best result - {}, with {} layers and {} hidden units and h = {}'.format( 203 | desc, best_hyper[0], best_hyper[1], best_hyper[2] 204 | ) 205 | print(log) 206 | logger(log) 207 | results += ['{} - {}: {}'.format(dataset_name, model.__class__.__name__, desc)] 208 | 209 | log = '-----\n{}'.format('\n'.join(results)) 210 | print(cmd_input[:-1]) 211 | print(log) 212 | logger(log) 213 | -------------------------------------------------------------------------------- /software/k-gnn-master/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | dist/ 4 | data/ 5 | .cache/ 6 | .eggs/ 7 | *.egg-info/ 8 | .coverage 9 | *.so 10 | -------------------------------------------------------------------------------- /software/k-gnn-master/README.md: -------------------------------------------------------------------------------- 1 | # Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks 2 | 3 | This is the source code for the AAAI 2019 paper **Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks** (**[Preprint](https://arxiv.org/abs/1810.02244)**). 4 | 5 | ## Installation 6 | 7 | The code is built upon the [PyTorch Geometric package](https://github.com/rusty1s/pytorch_geometric), which needs to be installed before running the examples. 8 | Please follow its [installation instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html). 9 | 10 | Finally, run 11 | 12 | ``` 13 | python setup.py install 14 | ``` 15 | 16 | in the root directory of this repository. Tested with PyTorch Geometric 1.4. 17 | 18 | ## Running examples 19 | 20 | ``` 21 | cd examples 22 | python 1-2-3-proteins.py 23 | ``` 24 | -------------------------------------------------------------------------------- /software/k-gnn-master/cpu/adjacency.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "iterate.h" 6 | 7 | using namespace at; 8 | 9 | inline int64_t is_adjacent(int64_t u, int64_t v, Tensor row, Tensor col) { 10 | ITERATE_NEIGHBORS(u, w, row.data(), col.data(), { 11 | if (v == w) 12 | return 1; 13 | }); 14 | return 0; 15 | } 16 | -------------------------------------------------------------------------------- /software/k-gnn-master/cpu/assignment.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "isomorphism.h" 6 | #include "iterate.h" 7 | #include "utils.h" 8 | 9 | using namespace at; 10 | using namespace std; 11 | 12 | typedef tuple, int64_t>, Tensor> AssignmentType; 13 | 14 | template struct Assignment; 15 | 16 | template <> struct Assignment<2> { 17 | static AssignmentType unconnected(Tensor row, Tensor col, Tensor x, 18 | int64_t num_nodes) { 19 | map, int64_t> set_to_id; 20 | vector iso_type; 21 | auto num_labels = x.size(1); 22 | x = convert(x); 23 | 24 | int64_t i = 0; 25 | ITERATE_NODES(0, u, num_nodes, { 26 | ITERATE_NODES(u + 1, v, num_nodes, { 27 | set_to_id.insert({{u, v}, i}); 28 | iso_type.push_back( 29 | Isomorphism<2, false>::type({u, v}, row, col, x, num_labels)); 30 | i++; 31 | }); 32 | }); 33 | 34 | return make_tuple(set_to_id, from_vector(iso_type)); 35 | } 36 | 37 | static AssignmentType connected(Tensor row, Tensor col, Tensor x, 38 | int64_t num_nodes) { 39 | auto row_data = row.data(), col_data = col.data(); 40 | map, int64_t> set_to_id; 41 | vector iso_type; 42 | auto num_labels = x.size(1); 43 | x = convert(x); 44 | 45 | int64_t i = 0; 46 | ITERATE_NODES(0, u, num_nodes, { 47 | ITERATE_NEIGHBORS(u, v, row_data, col_data, { 48 | if (u >= v) 49 | continue; 50 | set_to_id.insert({{u, v}, i}); 51 | iso_type.push_back( 52 | Isomorphism<2, true>::type({u, v}, row, col, x, num_labels)); 53 | i++; 54 | }); 55 | }); 56 | 57 | return make_tuple(set_to_id, from_vector(iso_type)); 58 | } 59 | }; 60 | 61 | template <> struct Assignment<3> { 62 | static AssignmentType unconnected(Tensor row, Tensor col, Tensor x, 63 | int64_t num_nodes) { 64 | map, int64_t> set_to_id; 65 | vector iso_type; 66 | auto num_labels = x.size(1); 67 | x = convert(x); 68 | 69 | int64_t i = 0; 70 | ITERATE_NODES(0, u, num_nodes, { 71 | ITERATE_NODES(u + 1, v, num_nodes, { 72 | ITERATE_NODES(v + 1, w, num_nodes, { 73 | set_to_id.insert({{u, v, w}, i}); 74 | iso_type.push_back( 75 | Isomorphism<3, false>::type({u, v, w}, row, col, x, num_labels)); 76 | i++; 77 | }); 78 | }); 79 | }); 80 | 81 | return make_tuple(set_to_id, from_vector(iso_type)); 82 | } 83 | 84 | static AssignmentType connected(Tensor row, Tensor col, Tensor x, 85 | int64_t num_nodes) { 86 | auto row_data = row.data(), col_data = col.data(); 87 | map, int64_t> set_to_id; 88 | vector iso_type; 89 | auto num_labels = x.size(1); 90 | x = convert(x); 91 | 92 | int64_t i = 0; 93 | ITERATE_NODES(0, u, num_nodes, { 94 | ITERATE_NEIGHBORS(u, v, row_data, col_data, { 95 | ITERATE_NEIGHBORS(v, w, row_data, col_data, { 96 | if (w == u) 97 | continue; 98 | vector set = {u, v, w}; 99 | sort(set.begin(), set.end()); 100 | auto iter = set_to_id.find(set); 101 | if (iter == set_to_id.end()) { 102 | set_to_id.insert({set, i}); 103 | iso_type.push_back( 104 | Isomorphism<3, true>::type(set, row, col, x, num_labels)); 105 | i++; 106 | } 107 | }); 108 | }); 109 | }); 110 | 111 | return make_tuple(set_to_id, from_vector(iso_type)); 112 | } 113 | }; 114 | -------------------------------------------------------------------------------- /software/k-gnn-master/cpu/connect.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "iterate.h" 6 | #include "utils.h" 7 | 8 | using namespace at; 9 | using namespace std; 10 | 11 | #define ADD_SET(ID, SET) \ 12 | [&] { \ 13 | sort(SET.begin(), SET.end()); \ 14 | auto item2 = set_to_id.find(SET); \ 15 | if (item2 != set_to_id.end()) { \ 16 | rows.push_back(ID); \ 17 | cols.push_back(item2->second); \ 18 | rows.push_back(item2->second); \ 19 | cols.push_back(ID); \ 20 | } \ 21 | }() 22 | 23 | template struct Connect; 24 | 25 | template <> struct Connect<2> { 26 | static Tensor local(Tensor row, Tensor col, 27 | map, int64_t> set_to_id) { 28 | auto row_data = row.data(), col_data = col.data(); 29 | vector rows, cols; 30 | 31 | for (auto item : set_to_id) { 32 | ITERATE_NEIGHBORS(item.first[0], x, row_data, col_data, { 33 | vector set1 = {item.first[0], x}; 34 | ADD_SET(item.second, set1); 35 | vector set2 = {item.first[1], x}; 36 | ADD_SET(item.second, set2); 37 | }); 38 | 39 | ITERATE_NEIGHBORS(item.first[1], x, row_data, col_data, { 40 | vector set1 = {item.first[0], x}; 41 | ADD_SET(item.second, set1); 42 | vector set2 = {item.first[1], x}; 43 | ADD_SET(item.second, set2); 44 | }); 45 | } 46 | 47 | if (rows.size() == 0) { 48 | return torch::empty(0, row.options()); 49 | } 50 | 51 | auto index = stack({from_vector(rows), from_vector(cols)}, 0); 52 | return coalesce(remove_self_loops(index), (int64_t)set_to_id.size()); 53 | } 54 | 55 | static Tensor malkin(Tensor row, Tensor col, 56 | map, int64_t> set_to_id) { 57 | auto row_data = row.data(), col_data = col.data(); 58 | vector rows, cols; 59 | 60 | for (auto item : set_to_id) { 61 | ITERATE_NEIGHBORS(item.first[0], x, row_data, col_data, { 62 | vector set = {item.first[1], x}; 63 | ADD_SET(item.second, set); 64 | }); 65 | 66 | ITERATE_NEIGHBORS(item.first[1], x, row_data, col_data, { 67 | vector set = {item.first[0], x}; 68 | ADD_SET(item.second, set); 69 | }); 70 | } 71 | 72 | if (rows.size() == 0) { 73 | return torch::empty(0, row.options()); 74 | } 75 | 76 | auto index = stack({from_vector(rows), from_vector(cols)}, 0); 77 | return coalesce(remove_self_loops(index), (int64_t)set_to_id.size()); 78 | } 79 | }; 80 | 81 | template <> struct Connect<3> { 82 | static Tensor local(Tensor row, Tensor col, 83 | map, int64_t> set_to_id) { 84 | auto row_data = row.data(), col_data = col.data(); 85 | vector rows, cols; 86 | 87 | for (auto item : set_to_id) { 88 | ITERATE_NEIGHBORS(item.first[0], x, row_data, col_data, { 89 | vector set1 = {item.first[0], item.first[1], x}; 90 | ADD_SET(item.second, set1); 91 | vector set2 = {item.first[0], item.first[2], x}; 92 | ADD_SET(item.second, set2); 93 | vector set3 = {item.first[1], item.first[2], x}; 94 | ADD_SET(item.second, set3); 95 | }); 96 | 97 | ITERATE_NEIGHBORS(item.first[1], x, row_data, col_data, { 98 | vector set1 = {item.first[0], item.first[1], x}; 99 | ADD_SET(item.second, set1); 100 | vector set2 = {item.first[0], item.first[2], x}; 101 | ADD_SET(item.second, set2); 102 | vector set3 = {item.first[1], item.first[2], x}; 103 | ADD_SET(item.second, set3); 104 | }); 105 | 106 | ITERATE_NEIGHBORS(item.first[2], x, row_data, col_data, { 107 | vector set1 = {item.first[0], item.first[1], x}; 108 | ADD_SET(item.second, set1); 109 | vector set2 = {item.first[0], item.first[2], x}; 110 | ADD_SET(item.second, set2); 111 | vector set3 = {item.first[1], item.first[2], x}; 112 | ADD_SET(item.second, set3); 113 | }); 114 | } 115 | 116 | if (rows.size() == 0) { 117 | return torch::empty(0, row.options()); 118 | } 119 | 120 | auto index = stack({from_vector(rows), from_vector(cols)}, 0); 121 | return coalesce(remove_self_loops(index), (int64_t)set_to_id.size()); 122 | } 123 | 124 | static Tensor malkin(Tensor row, Tensor col, 125 | map, int64_t> set_to_id) { 126 | auto row_data = row.data(), col_data = col.data(); 127 | vector rows, cols; 128 | 129 | for (auto item : set_to_id) { 130 | ITERATE_NEIGHBORS(item.first[0], x, row_data, col_data, { 131 | vector set = {item.first[1], item.first[2], x}; 132 | ADD_SET(item.second, set); 133 | }); 134 | 135 | ITERATE_NEIGHBORS(item.first[1], x, row_data, col_data, { 136 | vector set = {item.first[0], item.first[2], x}; 137 | ADD_SET(item.second, set); 138 | }); 139 | 140 | ITERATE_NEIGHBORS(item.first[2], x, row_data, col_data, { 141 | vector set = {item.first[0], item.first[1], x}; 142 | ADD_SET(item.second, set); 143 | }); 144 | } 145 | 146 | if (rows.size() == 0) { 147 | return torch::empty(0, row.options()); 148 | } 149 | 150 | auto index = stack({from_vector(rows), from_vector(cols)}, 0); 151 | return coalesce(remove_self_loops(index), (int64_t)set_to_id.size()); 152 | } 153 | }; 154 | -------------------------------------------------------------------------------- /software/k-gnn-master/cpu/graph.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "assignment.h" 5 | #include "connect.h" 6 | #include "utils.h" 7 | 8 | using namespace at; 9 | using namespace std; 10 | 11 | template 12 | vector local(Tensor index, Tensor x, int64_t num_nodes) { 13 | Tensor row, col; 14 | tie(row, col) = to_csr(index, num_nodes); 15 | Tensor assignment, iso_type; 16 | map, int64_t> set_to_id; 17 | tie(set_to_id, iso_type) = Assignment::unconnected(row, col, x, num_nodes); 18 | index = Connect::local(row, col, set_to_id); 19 | assignment = MapToTensor::get(set_to_id); 20 | return {index, assignment, iso_type}; 21 | } 22 | 23 | template 24 | vector connected_local(Tensor index, Tensor x, int64_t num_nodes) { 25 | Tensor row, col; 26 | tie(row, col) = to_csr(index, num_nodes); 27 | Tensor assignment, iso_type; 28 | map, int64_t> set_to_id; 29 | tie(set_to_id, iso_type) = Assignment::connected(row, col, x, num_nodes); 30 | index = Connect::local(row, col, set_to_id); 31 | assignment = MapToTensor::get(set_to_id); 32 | return {index, assignment, iso_type}; 33 | } 34 | 35 | template 36 | vector malkin(Tensor index, Tensor x, int64_t num_nodes) { 37 | Tensor row, col; 38 | tie(row, col) = to_csr(index, num_nodes); 39 | Tensor assignment, iso_type; 40 | map, int64_t> set_to_id; 41 | tie(set_to_id, iso_type) = Assignment::unconnected(row, col, x, num_nodes); 42 | index = Connect::malkin(row, col, set_to_id); 43 | assignment = MapToTensor::get(set_to_id); 44 | return {index, assignment, iso_type}; 45 | } 46 | 47 | template 48 | vector connected_malkin(Tensor index, Tensor x, int64_t num_nodes) { 49 | Tensor row, col; 50 | tie(row, col) = to_csr(index, num_nodes); 51 | Tensor assignment, iso_type; 52 | map, int64_t> set_to_id; 53 | tie(set_to_id, iso_type) = Assignment::connected(row, col, x, num_nodes); 54 | index = Connect::malkin(row, col, set_to_id); 55 | assignment = MapToTensor::get(set_to_id); 56 | return {index, assignment, iso_type}; 57 | } 58 | 59 | Tensor assignment_2to3(Tensor index, int64_t num_nodes) { 60 | Tensor row, col; 61 | tie(row, col) = to_csr(index, num_nodes); 62 | auto one = ones({num_nodes, 1}, index.options()); 63 | map, int64_t> set2_to_id = 64 | get<0>(Assignment<2>::unconnected(row, col, one, num_nodes)); 65 | map, int64_t> set3_to_id = 66 | get<0>(Assignment<3>::connected(row, col, one, num_nodes)); 67 | 68 | vector rows, cols; 69 | for (auto item3 : set3_to_id) { 70 | int64_t u = item3.first[0], v = item3.first[1], w = item3.first[2]; 71 | 72 | auto item2 = set2_to_id.find({u, v}); 73 | rows.push_back(item2->second); 74 | cols.push_back(item3.second); 75 | 76 | item2 = set2_to_id.find({u, w}); 77 | rows.push_back(item2->second); 78 | cols.push_back(item3.second); 79 | 80 | item2 = set2_to_id.find({v, w}); 81 | rows.push_back(item2->second); 82 | cols.push_back(item3.second); 83 | } 84 | 85 | return stack({from_vector(rows), from_vector(cols)}, 0); 86 | } 87 | 88 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 89 | m.def("two_local", &local<2>, "2-Local"); 90 | m.def("connected_two_local", &connected_local<2>, "Connected 2-Local"); 91 | m.def("two_malkin", &malkin<2>, "2-Malkin"); 92 | m.def("connected_two_malkin", &connected_malkin<2>, "Connected 2-Malkin"); 93 | m.def("three_local", &local<3>, "3-Local"); 94 | m.def("connected_three_local", &connected_local<3>, "Connected 3-Local"); 95 | m.def("three_malkin", &malkin<3>, "3-Malkin"); 96 | m.def("connected_three_malkin", &connected_malkin<3>, "Connected 3-Malkin"); 97 | m.def("assignment_2to3", &assignment_2to3, "Assignment Two To Three Graph"); 98 | } 99 | -------------------------------------------------------------------------------- /software/k-gnn-master/cpu/isomorphism.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "adjacency.h" 7 | #include "iterate.h" 8 | 9 | using namespace at; 10 | using namespace std; 11 | 12 | inline int64_t pair(int64_t u, int64_t v) { 13 | return u >= v ? u * u + u + v : u + v * v; 14 | } 15 | 16 | inline Tensor convert(Tensor x) { 17 | auto range = torch::empty(x.size(1), x.options()); 18 | arange_out(range, x.size(1)); 19 | x = x * range.view({1, -1}); 20 | return x.sum(1).toType(kLong); 21 | } 22 | 23 | template struct Isomorphism; 24 | 25 | template <> struct Isomorphism<2, true> { 26 | static int64_t type(vector set, Tensor row, Tensor col, Tensor x, 27 | int64_t num_labels) { 28 | auto x_data = x.data(); 29 | vector labels = {x_data[set[0]], x_data[set[1]]}; 30 | sort(labels.begin(), labels.end()); 31 | 32 | return labels[0] * num_labels + labels[1]; 33 | } 34 | }; 35 | 36 | template <> struct Isomorphism<2, false> { 37 | static int64_t type(vector set, Tensor row, Tensor col, Tensor x, 38 | int64_t num_labels) { 39 | auto x_data = x.data(); 40 | vector labels = {x_data[set[0]], x_data[set[1]]}; 41 | sort(labels.begin(), labels.end()); 42 | 43 | return num_labels * num_labels * is_adjacent(set[0], set[1], row, col) + 44 | labels[0] * num_labels + labels[1]; 45 | } 46 | }; 47 | 48 | template <> struct Isomorphism<3, true> { 49 | static int64_t type(vector set, Tensor row, Tensor col, Tensor x, 50 | int64_t num_labels) { 51 | auto x_data = x.data(); 52 | vector labels = {x_data[set[0]], x_data[set[1]], x_data[set[2]]}; 53 | sort(labels.begin(), labels.end()); 54 | 55 | return num_labels * num_labels * num_labels * 56 | is_adjacent(set[2], set[0], row, col) + 57 | labels[0] * num_labels * num_labels + labels[1] * num_labels + 58 | labels[2]; 59 | } 60 | }; 61 | 62 | template <> struct Isomorphism<3, false> { 63 | static int64_t type(vector set, Tensor row, Tensor col, Tensor x, 64 | int64_t num_labels) { 65 | // TODO 66 | printf("Not yet implemented.\n"); 67 | return -1; 68 | } 69 | }; 70 | -------------------------------------------------------------------------------- /software/k-gnn-master/cpu/iterate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define ITERATE_NODES(START, NAME, END, ...) \ 4 | { \ 5 | for (int64_t NAME = START; NAME < END; NAME++) { \ 6 | __VA_ARGS__; \ 7 | } \ 8 | } 9 | 10 | #define ITERATE_NEIGHBORS(NODE, NAME, ROW, COL, ...) \ 11 | { \ 12 | for (int64_t NAME##_i = ROW[NODE]; NAME##_i < ROW[NODE + 1]; NAME##_i++) { \ 13 | auto NAME = COL[NAME##_i]; \ 14 | __VA_ARGS__; \ 15 | } \ 16 | } 17 | -------------------------------------------------------------------------------- /software/k-gnn-master/cpu/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | using namespace at; 7 | using namespace std; 8 | 9 | inline Tensor remove_self_loops(Tensor index) { 10 | auto row = index[0], col = index[1]; 11 | auto mask = row != col; 12 | row = row.masked_select(mask), col = col.masked_select(mask); 13 | return stack({row, col}, 0); 14 | } 15 | 16 | inline Tensor coalesce(Tensor index, int64_t num_nodes) { 17 | Tensor row = index[0], col = index[1], unique, inv, perm; 18 | tie(unique, inv) = _unique(num_nodes * row + col, true, true); 19 | 20 | perm = torch::empty(inv.size(0), index.options()); 21 | arange_out(perm, inv.size(0)); 22 | perm = torch::empty(unique.size(0), index.options()).scatter_(0, inv, perm); 23 | 24 | row = row.index_select(0, perm); 25 | col = col.index_select(0, perm); 26 | 27 | return stack({row, col}, 0); 28 | } 29 | 30 | inline Tensor sort_by_row(Tensor index) { 31 | Tensor row = index[0], col = index[1], perm; 32 | tie(row, perm) = row.sort(); 33 | col = col.index_select(0, perm); 34 | return stack({row, col}, 0); 35 | } 36 | 37 | inline Tensor degree(Tensor row, int64_t num_nodes) { 38 | auto zero = torch::zeros(num_nodes, row.options()); 39 | auto one = torch::ones(row.size(0), row.options()); 40 | return zero.scatter_add_(0, row, one); 41 | } 42 | 43 | inline tuple to_csr(Tensor index, int64_t num_nodes) { 44 | index = sort_by_row(index); 45 | auto row = degree(index[0], num_nodes).cumsum(0); 46 | row = cat({torch::zeros(1, row.options()), row}, 0); // Prepend zero. 47 | return make_tuple(row, index[1]); 48 | } 49 | 50 | inline Tensor from_vector(vector src) { 51 | auto out = torch::empty((size_t)src.size(), torch::CPU(at::kLong)); 52 | auto out_data = out.data(); 53 | for (ptrdiff_t i = 0; i < out.size(0); i++) { 54 | out_data[i] = src[i]; 55 | } 56 | return out; 57 | } 58 | 59 | template struct MapToTensor; 60 | 61 | template <> struct MapToTensor<2> { 62 | static Tensor get(map, int64_t> set_to_id) { 63 | int64_t size = (int64_t)set_to_id.size(); 64 | Tensor set = torch::empty(2 * size, torch::CPU(at::kLong)); 65 | Tensor id = torch::empty(2 * size, torch::CPU(at::kLong)); 66 | auto set_data = set.data(), id_data = id.data(); 67 | 68 | int64_t i = 0; 69 | for (auto item : set_to_id) { 70 | set_data[2 * i] = item.first[0]; 71 | set_data[2 * i + 1] = item.first[1]; 72 | id_data[2 * i] = item.second; 73 | id_data[2 * i + 1] = item.second; 74 | i++; 75 | } 76 | 77 | return stack({set, id}, 0); 78 | } 79 | }; 80 | 81 | template <> struct MapToTensor<3> { 82 | static Tensor get(map, int64_t> set_to_id) { 83 | int64_t size = (int64_t)set_to_id.size(); 84 | Tensor set = torch::empty(3 * size, torch::CPU(at::kLong)); 85 | Tensor id = torch::empty(3 * size, torch::CPU(at::kLong)); 86 | auto set_data = set.data(), id_data = id.data(); 87 | 88 | int64_t i = 0; 89 | for (auto item : set_to_id) { 90 | set_data[3 * i] = item.first[0]; 91 | set_data[3 * i + 1] = item.first[1]; 92 | set_data[3 * i + 2] = item.first[2]; 93 | id_data[3 * i] = item.second; 94 | id_data[3 * i + 1] = item.second; 95 | id_data[3 * i + 2] = item.second; 96 | i++; 97 | } 98 | 99 | return stack({set, id}, 0); 100 | } 101 | }; 102 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-2-3-imdb.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_scatter import scatter_mean 7 | from torch_geometric.datasets import TUDataset 8 | from torch_geometric.utils import degree 9 | import torch_geometric.transforms as T 10 | from k_gnn import DataLoader, GraphConv, avg_pool 11 | from k_gnn import TwoMalkin, ConnectedThreeMalkin 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--no-train', default=False) 15 | args = parser.parse_args() 16 | 17 | 18 | class MyFilter(object): 19 | def __call__(self, data): 20 | return data.num_nodes <= 70 21 | 22 | 23 | class MyPreTransform(object): 24 | def __call__(self, data): 25 | data.x = torch.zeros((data.num_nodes, 1), dtype=torch.float) 26 | data = TwoMalkin()(data) 27 | data = ConnectedThreeMalkin()(data) 28 | data.x = degree(data.edge_index[0], data.num_nodes, dtype=torch.long) 29 | data.x = F.one_hot(data.x, num_classes=136).to(torch.float) 30 | return data 31 | 32 | 33 | BATCH = 20 34 | path = osp.join( 35 | osp.dirname(osp.realpath(__file__)), '..', 'data', '1-2-3-IMDB-BINARY') 36 | dataset = TUDataset( 37 | path, 38 | name='IMDB-BINARY', 39 | pre_transform=T.Compose([MyPreTransform()]), 40 | pre_filter=MyFilter()) 41 | 42 | perm = torch.randperm(len(dataset), dtype=torch.long) 43 | dataset = dataset[perm] 44 | 45 | dataset.data.iso_type_2 = torch.unique(dataset.data.iso_type_2, True, True)[1] 46 | num_i_2 = dataset.data.iso_type_2.max().item() + 1 47 | dataset.data.iso_type_2 = F.one_hot( 48 | dataset.data.iso_type_2, num_classes=num_i_2).to(torch.float) 49 | 50 | dataset.data.iso_type_3 = torch.unique(dataset.data.iso_type_3, True, True)[1] 51 | num_i_3 = dataset.data.iso_type_3.max().item() + 1 52 | dataset.data.iso_type_3 = F.one_hot( 53 | dataset.data.iso_type_3, num_classes=num_i_3).to(torch.float) 54 | 55 | 56 | class Net(torch.nn.Module): 57 | def __init__(self): 58 | super(Net, self).__init__() 59 | self.conv1 = GraphConv(dataset.num_features, 32) 60 | self.conv2 = GraphConv(32, 64) 61 | self.conv3 = GraphConv(64, 64) 62 | self.conv4 = GraphConv(64 + num_i_2, 64) 63 | self.conv5 = GraphConv(64, 64) 64 | self.conv6 = GraphConv(64 + num_i_3, 64) 65 | self.conv7 = GraphConv(64, 64) 66 | self.fc1 = torch.nn.Linear(3 * 64, 64) 67 | self.fc2 = torch.nn.Linear(64, 32) 68 | self.fc3 = torch.nn.Linear(32, dataset.num_classes) 69 | 70 | def reset_parameters(self): 71 | for (name, module) in self._modules.items(): 72 | module.reset_parameters() 73 | 74 | def forward(self, data): 75 | data.x = F.elu(self.conv1(data.x, data.edge_index)) 76 | data.x = F.elu(self.conv2(data.x, data.edge_index)) 77 | data.x = F.elu(self.conv3(data.x, data.edge_index)) 78 | x = data.x 79 | x_1 = scatter_mean(data.x, data.batch, dim=0) 80 | 81 | data.x = avg_pool(x, data.assignment_index_2) 82 | data.x = torch.cat([data.x, data.iso_type_2], dim=1) 83 | 84 | data.x = F.elu(self.conv4(data.x, data.edge_index_2)) 85 | data.x = F.elu(self.conv5(data.x, data.edge_index_2)) 86 | x_2 = scatter_mean(data.x, data.batch_2, dim=0) 87 | 88 | data.x = avg_pool(x, data.assignment_index_3) 89 | data.x = torch.cat([data.x, data.iso_type_3], dim=1) 90 | 91 | data.x = F.elu(self.conv6(data.x, data.edge_index_3)) 92 | data.x = F.elu(self.conv7(data.x, data.edge_index_3)) 93 | x_3 = scatter_mean(data.x, data.batch_3, dim=0) 94 | 95 | x = torch.cat([x_1, x_2, x_3], dim=1) 96 | 97 | if args.no_train: 98 | x = x.detach() 99 | 100 | x = F.elu(self.fc1(x)) 101 | x = F.dropout(x, p=0.5, training=self.training) 102 | x = F.elu(self.fc2(x)) 103 | x = self.fc3(x) 104 | return F.log_softmax(x, dim=1) 105 | 106 | 107 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 108 | model = Net().to(device) 109 | 110 | 111 | def train(epoch, loader, optimizer): 112 | model.train() 113 | loss_all = 0 114 | 115 | for data in loader: 116 | data = data.to(device) 117 | optimizer.zero_grad() 118 | loss = F.nll_loss(model(data), data.y) 119 | loss.backward() 120 | loss_all += data.num_graphs * loss.item() 121 | optimizer.step() 122 | return loss_all / len(loader.dataset) 123 | 124 | 125 | def val(loader): 126 | model.eval() 127 | loss_all = 0 128 | 129 | for data in loader: 130 | data = data.to(device) 131 | loss_all += F.nll_loss(model(data), data.y, reduction='sum').item() 132 | return loss_all / len(loader.dataset) 133 | 134 | 135 | def test(loader): 136 | model.eval() 137 | correct = 0 138 | 139 | for data in loader: 140 | data = data.to(device) 141 | pred = model(data).max(1)[1] 142 | correct += pred.eq(data.y).sum().item() 143 | return correct / len(loader.dataset) 144 | 145 | 146 | acc = [] 147 | for i in range(10): 148 | model.reset_parameters() 149 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 150 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 151 | optimizer, mode='min', factor=0.7, patience=5, min_lr=0.00001) 152 | 153 | test_mask = torch.zeros(len(dataset), dtype=torch.uint8) 154 | n = len(dataset) // 10 155 | test_mask[i * n:(i + 1) * n] = 1 156 | test_dataset = dataset[test_mask] 157 | train_dataset = dataset[1 - test_mask] 158 | 159 | n = len(train_dataset) // 10 160 | val_mask = torch.zeros(len(train_dataset), dtype=torch.uint8) 161 | val_mask[i * n:(i + 1) * n] = 1 162 | val_dataset = train_dataset[val_mask] 163 | train_dataset = train_dataset[1 - val_mask] 164 | 165 | val_loader = DataLoader(val_dataset, batch_size=BATCH) 166 | test_loader = DataLoader(test_dataset, batch_size=BATCH) 167 | train_loader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True) 168 | 169 | print('---------------- Split {} ----------------'.format(i)) 170 | 171 | best_val_loss, test_acc = 100, 0 172 | for epoch in range(1, 101): 173 | lr = scheduler.optimizer.param_groups[0]['lr'] 174 | train_loss = train(epoch, train_loader, optimizer) 175 | val_loss = val(val_loader) 176 | scheduler.step(val_loss) 177 | if best_val_loss >= val_loss: 178 | test_acc = test(test_loader) 179 | best_val_loss = val_loss 180 | print('Epoch: {:03d}, LR: {:7f}, Train Loss: {:.7f}, ' 181 | 'Val Loss: {:.7f}, Test Acc: {:.7f}'.format( 182 | epoch, lr, train_loss, val_loss, test_acc)) 183 | acc.append(test_acc) 184 | acc = torch.tensor(acc) 185 | print('---------------- Final Result ----------------') 186 | print('Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std())) 187 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-2-3-mutag.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_scatter import scatter_mean, scatter_add 7 | from torch_geometric.datasets import TUDataset 8 | import torch_geometric.transforms as T 9 | from k_gnn import DataLoader, GraphConv, avg_pool 10 | from k_gnn import TwoMalkin, ConnectedThreeLocal 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--no-train', default=False) 14 | args = parser.parse_args() 15 | 16 | 17 | class MyFilter(object): 18 | def __call__(self, data): 19 | return True 20 | 21 | 22 | class MyPreTransform(object): 23 | def __call__(self, data): 24 | return data 25 | 26 | 27 | BATCH = 32 28 | path = osp.join( 29 | osp.dirname(osp.realpath(__file__)), '..', 'data', '1-2-3-MUTAG') 30 | dataset = TUDataset( 31 | path, 32 | name='MUTAG', 33 | pre_transform=T.Compose( 34 | [MyPreTransform(), 35 | TwoMalkin(), ConnectedThreeLocal()]), 36 | pre_filter=MyFilter()) 37 | 38 | perm = torch.randperm(len(dataset), dtype=torch.long) 39 | torch.save(perm, 'mutag_perm.pt') 40 | perm = torch.load('mutag_perm.pt') 41 | dataset = dataset[perm] 42 | 43 | dataset.data.iso_type_2 = torch.unique(dataset.data.iso_type_2, True, True)[1] 44 | num_i_2 = dataset.data.iso_type_2.max().item() + 1 45 | dataset.data.iso_type_2 = F.one_hot( 46 | dataset.data.iso_type_2, num_classes=num_i_2).to(torch.float) 47 | 48 | dataset.data.iso_type_3 = torch.unique(dataset.data.iso_type_3, True, True)[1] 49 | num_i_3 = dataset.data.iso_type_3.max().item() + 1 50 | dataset.data.iso_type_3 = F.one_hot( 51 | dataset.data.iso_type_3, num_classes=num_i_3).to(torch.float) 52 | 53 | 54 | class Net(torch.nn.Module): 55 | def __init__(self): 56 | super(Net, self).__init__() 57 | self.conv1 = GraphConv(dataset.num_features, 64) 58 | self.conv2 = GraphConv(64, 64) 59 | self.conv3 = GraphConv(64, 64) 60 | self.conv4 = GraphConv(64 + num_i_2, 64) 61 | self.conv5 = GraphConv(64, 64) 62 | self.conv6 = GraphConv(64 + num_i_3, 64) 63 | self.conv7 = GraphConv(64, 64) 64 | self.fc1 = torch.nn.Linear(3 * 64, 64) 65 | self.fc2 = torch.nn.Linear(64, 32) 66 | self.fc3 = torch.nn.Linear(32, dataset.num_classes) 67 | 68 | def reset_parameters(self): 69 | for (name, module) in self._modules.items(): 70 | module.reset_parameters() 71 | 72 | def forward(self, data): 73 | data.x = F.elu(self.conv1(data.x, data.edge_index)) 74 | data.x = F.elu(self.conv2(data.x, data.edge_index)) 75 | data.x = F.elu(self.conv3(data.x, data.edge_index)) 76 | x = data.x 77 | x_1 = scatter_add(data.x, data.batch, dim=0) 78 | 79 | data.x = avg_pool(x, data.assignment_index_2) 80 | data.x = torch.cat([data.x, data.iso_type_2], dim=1) 81 | 82 | data.x = F.elu(self.conv4(data.x, data.edge_index_2)) 83 | data.x = F.elu(self.conv5(data.x, data.edge_index_2)) 84 | x_2 = scatter_mean(data.x, data.batch_2, dim=0) 85 | 86 | data.x = avg_pool(x, data.assignment_index_3) 87 | data.x = torch.cat([data.x, data.iso_type_3], dim=1) 88 | 89 | data.x = F.elu(self.conv6(data.x, data.edge_index_3)) 90 | data.x = F.elu(self.conv7(data.x, data.edge_index_3)) 91 | x_3 = scatter_mean(data.x, data.batch_3, dim=0) 92 | 93 | x = torch.cat([x_1, x_2, x_3], dim=1) 94 | 95 | if args.no_train: 96 | x = x.detach() 97 | 98 | x = F.elu(self.fc1(x)) 99 | x = F.dropout(x, p=0.5, training=self.training) 100 | x = F.elu(self.fc2(x)) 101 | x = self.fc3(x) 102 | return F.log_softmax(x, dim=1) 103 | 104 | 105 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 106 | model = Net().to(device) 107 | 108 | 109 | def train(epoch, loader, optimizer): 110 | model.train() 111 | loss_all = 0 112 | 113 | for data in loader: 114 | data = data.to(device) 115 | optimizer.zero_grad() 116 | loss = F.nll_loss(model(data), data.y) 117 | loss.backward() 118 | loss_all += data.num_graphs * loss.item() 119 | optimizer.step() 120 | return loss_all / len(loader.dataset) 121 | 122 | 123 | def val(loader): 124 | model.eval() 125 | loss_all = 0 126 | 127 | for data in loader: 128 | data = data.to(device) 129 | loss_all += F.nll_loss(model(data), data.y, reduction='sum').item() 130 | return loss_all / len(loader.dataset) 131 | 132 | 133 | def test(loader): 134 | model.eval() 135 | correct = 0 136 | 137 | for data in loader: 138 | data = data.to(device) 139 | pred = model(data).max(1)[1] 140 | correct += pred.eq(data.y).sum().item() 141 | return correct / len(loader.dataset) 142 | 143 | 144 | acc = [] 145 | for i in range(10): 146 | model.reset_parameters() 147 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 148 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 149 | optimizer, mode='min', factor=0.1, patience=15, min_lr=0.00001) 150 | 151 | test_mask = torch.zeros(len(dataset), dtype=torch.uint8) 152 | n = len(dataset) // 10 153 | test_mask[i * n:(i + 1) * n] = 1 154 | test_dataset = dataset[test_mask] 155 | train_dataset = dataset[1 - test_mask] 156 | 157 | n = len(train_dataset) // 10 158 | val_mask = torch.zeros(len(train_dataset), dtype=torch.uint8) 159 | val_mask[i * n:(i + 1) * n] = 1 160 | val_dataset = train_dataset[val_mask] 161 | train_dataset = train_dataset[1 - val_mask] 162 | 163 | val_loader = DataLoader(val_dataset, batch_size=BATCH) 164 | test_loader = DataLoader(test_dataset, batch_size=BATCH) 165 | train_loader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True) 166 | 167 | print('---------------- Split {} ----------------'.format(i)) 168 | 169 | best_val_loss, test_acc = 100, 0 170 | for epoch in range(1, 51): 171 | lr = scheduler.optimizer.param_groups[0]['lr'] 172 | train_loss = train(epoch, train_loader, optimizer) 173 | val_loss = val(val_loader) 174 | # scheduler.step(val_loss) 175 | if best_val_loss >= val_loss: 176 | test_acc = test(test_loader) 177 | best_val_loss = val_loss 178 | if epoch % 5 == 0: 179 | print('Epoch: {:03d}, LR: {:7f}, Train Loss: {:.7f}, ' 180 | 'Val Loss: {:.7f}, Test Acc: {:.7f}'.format( 181 | epoch, lr, train_loss, val_loss, test_acc)) 182 | acc.append(test_acc) 183 | acc = torch.tensor(acc) 184 | print('---------------- Final Result ----------------') 185 | print('Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std())) 186 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-2-3-proteins.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_scatter import scatter_mean 7 | from torch_geometric.datasets import TUDataset 8 | import torch_geometric.transforms as T 9 | from k_gnn import DataLoader, GraphConv, avg_pool 10 | from k_gnn import TwoMalkin, ConnectedThreeMalkin 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--no-train', default=False) 14 | args = parser.parse_args() 15 | 16 | 17 | class MyFilter(object): 18 | def __call__(self, data): 19 | return not (data.num_nodes == 7 and data.num_edges == 12) and \ 20 | data.num_nodes < 450 21 | 22 | 23 | BATCH = 20 24 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 25 | '1-2-3-PROTEINS') 26 | dataset = TUDataset( 27 | path, name='PROTEINS', 28 | pre_transform=T.Compose([TwoMalkin(), 29 | ConnectedThreeMalkin()]), pre_filter=MyFilter()) 30 | 31 | perm = torch.randperm(len(dataset), dtype=torch.long) 32 | dataset = dataset[perm] 33 | 34 | dataset.data.iso_type_2 = torch.unique(dataset.data.iso_type_2, True, True)[1] 35 | num_i_2 = dataset.data.iso_type_2.max().item() + 1 36 | dataset.data.iso_type_2 = F.one_hot(dataset.data.iso_type_2, 37 | num_classes=num_i_2).to(torch.float) 38 | 39 | dataset.data.iso_type_3 = torch.unique(dataset.data.iso_type_3, True, True)[1] 40 | num_i_3 = dataset.data.iso_type_3.max().item() + 1 41 | dataset.data.iso_type_3 = F.one_hot(dataset.data.iso_type_3, 42 | num_classes=num_i_3).to(torch.float) 43 | 44 | 45 | class Net(torch.nn.Module): 46 | def __init__(self): 47 | super(Net, self).__init__() 48 | self.conv1 = GraphConv(dataset.num_features, 32) 49 | self.conv2 = GraphConv(32, 64) 50 | self.conv3 = GraphConv(64, 64) 51 | self.conv4 = GraphConv(64 + num_i_2, 64) 52 | self.conv5 = GraphConv(64, 64) 53 | self.conv6 = GraphConv(64 + num_i_3, 64) 54 | self.conv7 = GraphConv(64, 64) 55 | self.fc1 = torch.nn.Linear(3 * 64, 64) 56 | self.fc2 = torch.nn.Linear(64, 32) 57 | self.fc3 = torch.nn.Linear(32, dataset.num_classes) 58 | 59 | def reset_parameters(self): 60 | for (name, module) in self._modules.items(): 61 | module.reset_parameters() 62 | 63 | def forward(self, data): 64 | data.x = F.elu(self.conv1(data.x, data.edge_index)) 65 | data.x = F.elu(self.conv2(data.x, data.edge_index)) 66 | data.x = F.elu(self.conv3(data.x, data.edge_index)) 67 | x = data.x 68 | x_1 = scatter_mean(data.x, data.batch, dim=0) 69 | 70 | data.x = avg_pool(x, data.assignment_index_2) 71 | data.x = torch.cat([data.x, data.iso_type_2], dim=1) 72 | 73 | data.x = F.elu(self.conv4(data.x, data.edge_index_2)) 74 | data.x = F.elu(self.conv5(data.x, data.edge_index_2)) 75 | x_2 = scatter_mean(data.x, data.batch_2, dim=0) 76 | 77 | data.x = avg_pool(x, data.assignment_index_3) 78 | data.x = torch.cat([data.x, data.iso_type_3], dim=1) 79 | 80 | data.x = F.elu(self.conv6(data.x, data.edge_index_3)) 81 | data.x = F.elu(self.conv7(data.x, data.edge_index_3)) 82 | x_3 = scatter_mean(data.x, data.batch_3, dim=0) 83 | 84 | x = torch.cat([x_1, x_2, x_3], dim=1) 85 | 86 | if args.no_train: 87 | x = x.detach() 88 | 89 | x = F.elu(self.fc1(x)) 90 | x = F.dropout(x, p=0.5, training=self.training) 91 | x = F.elu(self.fc2(x)) 92 | x = self.fc3(x) 93 | return F.log_softmax(x, dim=1) 94 | 95 | 96 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 97 | model = Net().to(device) 98 | 99 | 100 | def train(epoch, loader, optimizer): 101 | model.train() 102 | loss_all = 0 103 | 104 | for data in loader: 105 | data = data.to(device) 106 | optimizer.zero_grad() 107 | loss = F.nll_loss(model(data), data.y) 108 | loss.backward() 109 | loss_all += data.num_graphs * loss.item() 110 | optimizer.step() 111 | return loss_all / len(loader.dataset) 112 | 113 | 114 | def val(loader): 115 | model.eval() 116 | loss_all = 0 117 | 118 | for data in loader: 119 | data = data.to(device) 120 | loss_all += F.nll_loss(model(data), data.y, reduction='sum').item() 121 | return loss_all / len(loader.dataset) 122 | 123 | 124 | def test(loader): 125 | model.eval() 126 | correct = 0 127 | 128 | for data in loader: 129 | data = data.to(device) 130 | pred = model(data).max(1)[1] 131 | correct += pred.eq(data.y).sum().item() 132 | return correct / len(loader.dataset) 133 | 134 | 135 | acc = [] 136 | for i in range(10): 137 | model.reset_parameters() 138 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 139 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 140 | optimizer, mode='min', factor=0.7, patience=5, min_lr=0.00001) 141 | 142 | test_mask = torch.zeros(len(dataset), dtype=torch.uint8) 143 | n = len(dataset) // 10 144 | test_mask[i * n:(i + 1) * n] = 1 145 | test_dataset = dataset[test_mask] 146 | train_dataset = dataset[1 - test_mask] 147 | 148 | n = len(train_dataset) // 10 149 | val_mask = torch.zeros(len(train_dataset), dtype=torch.uint8) 150 | val_mask[i * n:(i + 1) * n] = 1 151 | val_dataset = train_dataset[val_mask] 152 | train_dataset = train_dataset[1 - val_mask] 153 | 154 | val_loader = DataLoader(val_dataset, batch_size=BATCH) 155 | test_loader = DataLoader(test_dataset, batch_size=BATCH) 156 | train_loader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True) 157 | 158 | print('---------------- Split {} ----------------'.format(i)) 159 | 160 | best_val_loss, test_acc = 100, 0 161 | for epoch in range(1, 101): 162 | lr = scheduler.optimizer.param_groups[0]['lr'] 163 | train_loss = train(epoch, train_loader, optimizer) 164 | val_loss = val(val_loader) 165 | scheduler.step(val_loss) 166 | if best_val_loss >= val_loss: 167 | test_acc = test(test_loader) 168 | best_val_loss = val_loss 169 | print('Epoch: {:03d}, LR: {:7f}, Train Loss: {:.7f}, ' 170 | 'Val Loss: {:.7f}, Test Acc: {:.7f}'.format( 171 | epoch, lr, train_loss, val_loss, test_acc)) 172 | acc.append(test_acc) 173 | acc = torch.tensor(acc) 174 | print('---------------- Final Result ----------------') 175 | print('Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std())) 176 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-2-3-qm9.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn import Sequential, Linear, ReLU 7 | from torch_scatter import scatter_mean 8 | from torch_geometric.datasets import QM9 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn import NNConv 11 | from k_gnn import GraphConv, DataLoader, avg_pool 12 | from k_gnn import TwoMalkin, ConnectedThreeMalkin 13 | 14 | 15 | class MyFilter(object): 16 | def __call__(self, data): 17 | return data.num_nodes > 6 # Remove graphs with less than 6 nodes. 18 | 19 | 20 | class MyPreTransform(object): 21 | def __call__(self, data): 22 | x = data.x 23 | data.x = data.x[:, :5] 24 | data = TwoMalkin()(data) 25 | data = ConnectedThreeMalkin()(data) 26 | data.x = x 27 | return data 28 | 29 | 30 | class MyTransform(object): 31 | def __call__(self, data): 32 | data.y = data.y[:, int(args.target)] # Specify target: 0 = mu 33 | return data 34 | 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--target', default=0) 38 | args = parser.parse_args() 39 | target = int(args.target) 40 | 41 | print('---- Target: {} ----'.format(target)) 42 | 43 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', '1-2-3-QM9') 44 | dataset = QM9( 45 | path, 46 | transform=T.Compose([MyTransform(), T.Distance()]), 47 | pre_transform=MyPreTransform(), 48 | pre_filter=MyFilter()) 49 | 50 | dataset.data.iso_type_2 = torch.unique(dataset.data.iso_type_2, True, True)[1] 51 | num_i_2 = dataset.data.iso_type_2.max().item() + 1 52 | dataset.data.iso_type_2 = F.one_hot( 53 | dataset.data.iso_type_2, num_classes=num_i_2).to(torch.float) 54 | 55 | dataset.data.iso_type_3 = torch.unique(dataset.data.iso_type_3, True, True)[1] 56 | num_i_3 = dataset.data.iso_type_3.max().item() + 1 57 | dataset.data.iso_type_3 = F.one_hot( 58 | dataset.data.iso_type_3, num_classes=num_i_3).to(torch.float) 59 | 60 | dataset = dataset.shuffle() 61 | 62 | # Normalize targets to mean = 0 and std = 1. 63 | tenpercent = int(len(dataset) * 0.1) 64 | mean = dataset.data.y[tenpercent:].mean(dim=0) 65 | std = dataset.data.y[tenpercent:].std(dim=0) 66 | dataset.data.y = (dataset.data.y - mean) / std 67 | 68 | test_dataset = dataset[:tenpercent] 69 | val_dataset = dataset[tenpercent:2 * tenpercent] 70 | train_dataset = dataset[2 * tenpercent:] 71 | test_loader = DataLoader(test_dataset, batch_size=64) 72 | val_loader = DataLoader(val_dataset, batch_size=64) 73 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 74 | 75 | 76 | class Net(torch.nn.Module): 77 | def __init__(self): 78 | super(Net, self).__init__() 79 | M_in, M_out = dataset.num_features, 32 80 | nn1 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 81 | self.conv1 = NNConv(M_in, M_out, nn1) 82 | 83 | M_in, M_out = M_out, 64 84 | nn2 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 85 | self.conv2 = NNConv(M_in, M_out, nn2) 86 | 87 | M_in, M_out = M_out, 64 88 | nn3 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 89 | self.conv3 = NNConv(M_in, M_out, nn3) 90 | 91 | self.conv4 = GraphConv(64 + num_i_2, 64) 92 | self.conv5 = GraphConv(64, 64) 93 | 94 | self.conv6 = GraphConv(64 + num_i_3, 64) 95 | self.conv7 = GraphConv(64, 64) 96 | 97 | self.fc1 = torch.nn.Linear(3 * 64, 64) 98 | self.fc2 = torch.nn.Linear(64, 32) 99 | self.fc3 = torch.nn.Linear(32, 1) 100 | 101 | def forward(self, data): 102 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) 103 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) 104 | data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) 105 | x = data.x 106 | x_1 = scatter_mean(data.x, data.batch, dim=0) 107 | 108 | data.x = avg_pool(x, data.assignment_index_2) 109 | data.x = torch.cat([data.x, data.iso_type_2], dim=1) 110 | 111 | data.x = F.elu(self.conv4(data.x, data.edge_index_2)) 112 | data.x = F.elu(self.conv5(data.x, data.edge_index_2)) 113 | x_2 = scatter_mean(data.x, data.batch_2, dim=0) 114 | 115 | data.x = avg_pool(x, data.assignment_index_3) 116 | data.x = torch.cat([data.x, data.iso_type_3], dim=1) 117 | 118 | data.x = F.elu(self.conv6(data.x, data.edge_index_3)) 119 | data.x = F.elu(self.conv7(data.x, data.edge_index_3)) 120 | x_3 = scatter_mean(data.x, data.batch_3, dim=0) 121 | 122 | x = torch.cat([x_1, x_2, x_3], dim=1) 123 | 124 | x = F.elu(self.fc1(x)) 125 | x = F.elu(self.fc2(x)) 126 | x = self.fc3(x) 127 | return x.view(-1) 128 | 129 | 130 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 131 | model = Net().to(device) 132 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 133 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 134 | optimizer, factor=0.7, patience=5, min_lr=0.00001) 135 | 136 | 137 | def train(epoch): 138 | model.train() 139 | loss_all = 0 140 | 141 | for data in train_loader: 142 | data = data.to(device) 143 | optimizer.zero_grad() 144 | loss = F.mse_loss(model(data), data.y) 145 | loss.backward() 146 | loss_all += loss * data.num_graphs 147 | optimizer.step() 148 | return loss_all / len(train_loader.dataset) 149 | 150 | 151 | def test(loader): 152 | model.eval() 153 | error = 0 154 | 155 | for data in loader: 156 | data = data.to(device) 157 | error += ((model(data) * std[target].cuda()) - 158 | (data.y * std[target].cuda())).abs().sum().item() # MAE 159 | return error / len(loader.dataset) 160 | 161 | 162 | best_val_error = None 163 | for epoch in range(1, 201): 164 | lr = scheduler.optimizer.param_groups[0]['lr'] 165 | loss = train(epoch) 166 | val_error = test(val_loader) 167 | scheduler.step(val_error) 168 | 169 | if best_val_error is None: 170 | best_val_error = val_error 171 | if val_error <= best_val_error: 172 | test_error = test(test_loader) 173 | best_val_error = val_error 174 | print( 175 | 'Epoch: {:03d}, LR: {:7f}, Loss: {:.7f}, Validation MAE: {:.7f}, ' 176 | 'Test MAE: {:.7f}, ' 177 | 'Test MAE norm: {:.7f}'.format(epoch, lr, loss, val_error, 178 | test_error, 179 | test_error / std[target].cuda())) 180 | else: 181 | print('Epoch: {:03d}'.format(epoch)) 182 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-2-3-qm9_all_targets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn import Sequential, Linear, ReLU 7 | from torch_scatter import scatter_mean 8 | from torch_geometric.datasets import QM9 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn import NNConv 11 | from k_gnn import GraphConv, DataLoader, avg_pool 12 | from k_gnn import TwoMalkin, ConnectedThreeMalkin 13 | import numpy as np 14 | 15 | class MyFilter(object): 16 | def __call__(self, data): 17 | return data.num_nodes > 6 # Remove graphs with less than 6 nodes. 18 | 19 | 20 | class MyPreTransform(object): 21 | def __call__(self, data): 22 | x = data.x 23 | data.x = data.x[:, :5] 24 | data = TwoMalkin()(data) 25 | data = ConnectedThreeMalkin()(data) 26 | data.x = x 27 | return data 28 | 29 | 30 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', '1-27893y-QM9') 31 | dataset = QM9( 32 | path, 33 | transform=T.Compose([T.Distance(norm=False)]), 34 | pre_transform=MyPreTransform(), 35 | pre_filter=MyFilter()) 36 | dataset.data.y = dataset.data.y[:,0:12] 37 | 38 | dataset.data.iso_type_2 = torch.unique(dataset.data.iso_type_2, True, True)[1] 39 | num_i_2 = dataset.data.iso_type_2.max().item() + 1 40 | dataset.data.iso_type_2 = F.one_hot( 41 | dataset.data.iso_type_2, num_classes=num_i_2).to(torch.float) 42 | 43 | dataset.data.iso_type_3 = torch.unique(dataset.data.iso_type_3, True, True)[1] 44 | num_i_3 = dataset.data.iso_type_3.max().item() + 1 45 | dataset.data.iso_type_3 = F.one_hot( 46 | dataset.data.iso_type_3, num_classes=num_i_3).to(torch.float) 47 | 48 | 49 | 50 | class Net(torch.nn.Module): 51 | def __init__(self): 52 | super(Net, self).__init__() 53 | M_in, M_out = dataset.num_features, 32 54 | nn1 = Sequential(Linear(6, 128), ReLU(), Linear(128, M_in * M_out)) 55 | self.conv1 = NNConv(M_in, M_out, nn1) 56 | 57 | M_in, M_out = M_out, 64 58 | nn2 = Sequential(Linear(6, 128), ReLU(), Linear(128, M_in * M_out)) 59 | self.conv2 = NNConv(M_in, M_out, nn2) 60 | 61 | M_in, M_out = M_out, 64 62 | nn3 = Sequential(Linear(6, 128), ReLU(), Linear(128, M_in * M_out)) 63 | self.conv3 = NNConv(M_in, M_out, nn3) 64 | 65 | self.conv4 = GraphConv(64 + num_i_2, 64) 66 | self.conv5 = GraphConv(64, 64) 67 | 68 | self.conv6 = GraphConv(64 + num_i_3, 64) 69 | self.conv7 = GraphConv(64, 64) 70 | 71 | 72 | self.fc1 = torch.nn.Linear(3 * 64, 64) 73 | self.fc2 = torch.nn.Linear(64, 32) 74 | self.fc3 = torch.nn.Linear(32, 12) 75 | 76 | def forward(self, data): 77 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) 78 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) 79 | data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) 80 | x = data.x 81 | x_1 = scatter_mean(data.x, data.batch, dim=0) 82 | 83 | data.x = avg_pool(x, data.assignment_index_2) 84 | data.x = torch.cat([data.x, data.iso_type_2], dim=1) 85 | 86 | data.x = F.elu(self.conv4(data.x, data.edge_index_2)) 87 | data.x = F.elu(self.conv5(data.x, data.edge_index_2)) 88 | x_2 = scatter_mean(data.x, data.batch_2, dim=0) 89 | 90 | data.x = avg_pool(x, data.assignment_index_3) 91 | data.x = torch.cat([data.x, data.iso_type_3], dim=1) 92 | 93 | data.x = F.elu(self.conv6(data.x, data.edge_index_3)) 94 | data.x = F.elu(self.conv7(data.x, data.edge_index_3)) 95 | x_3 = scatter_mean(data.x, data.batch_3, dim=0) 96 | 97 | x = torch.cat([x_1, x_2, x_3], dim=1) 98 | 99 | x = F.elu(self.fc1(x)) 100 | x = F.elu(self.fc2(x)) 101 | x = self.fc3(x) 102 | return x 103 | 104 | 105 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 106 | 107 | 108 | 109 | results = [] 110 | results_log = [] 111 | for _ in range(5): 112 | 113 | dataset = dataset.shuffle() 114 | 115 | tenpercent = int(len(dataset) * 0.1) 116 | print("###") 117 | mean = dataset.data.y.mean(dim=0, keepdim=True) 118 | # mean_abs = dataset.data.y.abs().mean(dim=0, keepdim=True).to(device) # .view(-1) 119 | std = dataset.data.y.std(dim=0, keepdim=True) 120 | dataset.data.y = (dataset.data.y - mean) / std 121 | mean, std = mean.to(device), std.to(device) 122 | 123 | print("###") 124 | test_dataset = dataset[:tenpercent].shuffle() 125 | val_dataset = dataset[tenpercent:2 * tenpercent].shuffle() 126 | train_dataset = dataset[2 * tenpercent:].shuffle() 127 | 128 | print(len(train_dataset), len(val_dataset), len(test_dataset)) 129 | 130 | batch_size = 64 131 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 132 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True) 133 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 134 | 135 | # TODO 136 | model = Net().to(device) 137 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 138 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 139 | factor=0.5, patience=5, 140 | min_lr=0.0000001) 141 | 142 | def train(): 143 | model.train() 144 | loss_all = 0 145 | 146 | lf = torch.nn.L1Loss() 147 | 148 | for data in train_loader: 149 | data = data.to(device) 150 | optimizer.zero_grad() 151 | loss = lf(model(data), data.y) 152 | 153 | loss.backward() 154 | loss_all += loss.item() * data.num_graphs 155 | optimizer.step() 156 | return (loss_all / len(train_loader.dataset)) 157 | 158 | 159 | @torch.no_grad() 160 | def test(loader): 161 | model.eval() 162 | error = torch.zeros([1, 12]).to(device) 163 | 164 | for data in loader: 165 | data = data.to(device) 166 | error += ((data.y * std - model(data) * std).abs() / std).sum(dim=0) 167 | 168 | error = error / len(loader.dataset) 169 | error_log = torch.log(error) 170 | 171 | return error.mean().item(), error_log.mean().item() 172 | 173 | test_error = None 174 | log_test_error = None 175 | best_val_error = None 176 | for epoch in range(1, 1001): 177 | lr = scheduler.optimizer.param_groups[0]['lr'] 178 | loss = train() 179 | val_error, _ = test(val_loader) 180 | scheduler.step(val_error) 181 | 182 | if best_val_error is None or val_error <= best_val_error: 183 | test_error, log_test_error = test(test_loader) 184 | best_val_error = val_error 185 | 186 | print('Epoch: {:03d}, LR: {:.7f}, Loss: {:.7f}, Validation MAE: {:.7f}, ' 187 | 'Test MAE: {:.7f}, Test MAE: {:.7f}'.format(epoch, lr, loss, val_error, test_error, log_test_error)) 188 | 189 | if lr < 0.000001: 190 | print("Converged.") 191 | break 192 | 193 | results.append(test_error) 194 | results_log.append(log_test_error) 195 | 196 | print("########################") 197 | print(results) 198 | results = np.array(results) 199 | print(results.mean(), results.std()) 200 | 201 | print(results_log) 202 | results_log = np.array(results_log) 203 | print(results_log.mean(), results_log.std()) 204 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-2-qm9.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn import Sequential, Linear, ReLU 7 | from torch_scatter import scatter_mean 8 | from torch_geometric.datasets import QM9 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn import NNConv 11 | from k_gnn import GraphConv, DataLoader, avg_pool 12 | from k_gnn import TwoMalkin 13 | 14 | 15 | class MyFilter(object): 16 | def __call__(self, data): 17 | return data.num_nodes > 6 # Remove graphs with less than 6 nodes. 18 | 19 | 20 | class MyPreTransform(object): 21 | def __call__(self, data): 22 | x = data.x 23 | data.x = data.x[:, :5] 24 | data = TwoMalkin()(data) 25 | data.x = x 26 | return data 27 | 28 | 29 | class MyTransform(object): 30 | def __call__(self, data): 31 | data.y = data.y[:, int(args.target)] # Specify target: 0 = mu 32 | return data 33 | 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--target', default=0) 37 | args = parser.parse_args() 38 | target = int(args.target) 39 | 40 | print('---- Target: {} ----'.format(target)) 41 | 42 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', '1-2-QM9') 43 | dataset = QM9( 44 | path, 45 | transform=T.Compose([MyTransform(), T.Distance()]), 46 | pre_transform=MyPreTransform(), 47 | pre_filter=MyFilter()) 48 | 49 | dataset.data.iso_type_2 = torch.unique(dataset.data.iso_type_2, True, True)[1] 50 | num_i_2 = dataset.data.iso_type_2.max().item() + 1 51 | dataset.data.iso_type_2 = F.one_hot( 52 | dataset.data.iso_type_2, num_classes=num_i_2).to(torch.float) 53 | 54 | dataset = dataset.shuffle() 55 | 56 | # Normalize targets to mean = 0 and std = 1. 57 | tenpercent = int(len(dataset) * 0.1) 58 | mean = dataset.data.y[tenpercent:].mean(dim=0) 59 | std = dataset.data.y[tenpercent:].std(dim=0) 60 | dataset.data.y = (dataset.data.y - mean) / std 61 | 62 | test_dataset = dataset[:tenpercent] 63 | val_dataset = dataset[tenpercent:2 * tenpercent] 64 | train_dataset = dataset[2 * tenpercent:] 65 | test_loader = DataLoader(test_dataset, batch_size=64) 66 | val_loader = DataLoader(val_dataset, batch_size=64) 67 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 68 | 69 | 70 | class Net(torch.nn.Module): 71 | def __init__(self): 72 | super(Net, self).__init__() 73 | M_in, M_out = dataset.num_features, 32 74 | nn1 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 75 | self.conv1 = NNConv(M_in, M_out, nn1) 76 | 77 | M_in, M_out = M_out, 64 78 | nn2 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 79 | self.conv2 = NNConv(M_in, M_out, nn2) 80 | 81 | M_in, M_out = M_out, 64 82 | nn3 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 83 | self.conv3 = NNConv(M_in, M_out, nn3) 84 | 85 | self.conv4 = GraphConv(64 + num_i_2, 64) 86 | self.conv5 = GraphConv(64, 64) 87 | 88 | self.fc1 = torch.nn.Linear(2 * 64, 64) 89 | self.fc2 = torch.nn.Linear(64, 32) 90 | self.fc3 = torch.nn.Linear(32, 1) 91 | 92 | def forward(self, data): 93 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) 94 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) 95 | data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) 96 | x_1 = scatter_mean(data.x, data.batch, dim=0) 97 | 98 | data.x = avg_pool(data.x, data.assignment_index_2) 99 | data.x = torch.cat([data.x, data.iso_type_2], dim=1) 100 | 101 | data.x = F.elu(self.conv4(data.x, data.edge_index_2)) 102 | data.x = F.elu(self.conv5(data.x, data.edge_index_2)) 103 | x_2 = scatter_mean(data.x, data.batch_2, dim=0) 104 | 105 | x = torch.cat([x_1, x_2], dim=1) 106 | 107 | x = F.elu(self.fc1(x)) 108 | x = F.elu(self.fc2(x)) 109 | x = self.fc3(x) 110 | return x.view(-1) 111 | 112 | 113 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 114 | model = Net().to(device) 115 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 116 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 117 | optimizer, factor=0.7, patience=5, min_lr=0.00001) 118 | 119 | 120 | def train(epoch): 121 | model.train() 122 | loss_all = 0 123 | 124 | for data in train_loader: 125 | data = data.to(device) 126 | optimizer.zero_grad() 127 | loss = F.mse_loss(model(data), data.y) 128 | loss.backward() 129 | loss_all += loss * data.num_graphs 130 | optimizer.step() 131 | return loss_all / len(train_loader.dataset) 132 | 133 | 134 | def test(loader): 135 | model.eval() 136 | error = 0 137 | 138 | for data in loader: 139 | data = data.to(device) 140 | error += ((model(data) * std[target].cuda()) - 141 | (data.y * std[target].cuda())).abs().sum().item() # MAE 142 | return error / len(loader.dataset) 143 | 144 | 145 | best_val_error = None 146 | for epoch in range(1, 201): 147 | lr = scheduler.optimizer.param_groups[0]['lr'] 148 | loss = train(epoch) 149 | val_error = test(val_loader) 150 | scheduler.step(val_error) 151 | 152 | if best_val_error is None: 153 | best_val_error = val_error 154 | if val_error <= best_val_error: 155 | test_error = test(test_loader) 156 | best_val_error = val_error 157 | print( 158 | 'Epoch: {:03d}, LR: {:7f}, Loss: {:.7f}, Validation MAE: {:.7f}, ' 159 | 'Test MAE: {:.7f}, ' 160 | 'Test MAE norm: {:.7f}'.format(epoch, lr, loss, val_error, 161 | test_error, 162 | test_error / std[target].cuda())) 163 | else: 164 | print('Epoch: {:03d}'.format(epoch)) 165 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-2-qm9_all_targets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn import Sequential, Linear, ReLU 7 | from torch_scatter import scatter_mean 8 | from torch_geometric.datasets import QM9 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn import NNConv 11 | from k_gnn import GraphConv, DataLoader, avg_pool 12 | from k_gnn import TwoMalkin 13 | import numpy as np 14 | 15 | class MyFilter(object): 16 | def __call__(self, data): 17 | return data.num_nodes > 6 # Remove graphs with less than 6 nodes. 18 | 19 | 20 | class MyPreTransform(object): 21 | def __call__(self, data): 22 | x = data.x 23 | data.x = data.x[:, :5] 24 | data = TwoMalkin()(data) 25 | data.x = x 26 | return data 27 | 28 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', '1-2-QM9') 29 | dataset = QM9( 30 | path, 31 | transform=T.Compose([T.Distance()]), 32 | pre_transform=MyPreTransform(), 33 | pre_filter=MyFilter()) 34 | 35 | dataset.data.y = dataset.data.y[:,0:12] 36 | 37 | dataset.data.iso_type_2 = torch.unique(dataset.data.iso_type_2, True, True)[1] 38 | num_i_2 = dataset.data.iso_type_2.max().item() + 1 39 | dataset.data.iso_type_2 = F.one_hot( 40 | dataset.data.iso_type_2, num_classes=num_i_2).to(torch.float) 41 | 42 | 43 | 44 | 45 | class Net(torch.nn.Module): 46 | def __init__(self): 47 | super(Net, self).__init__() 48 | M_in, M_out = dataset.num_features, 32 49 | nn1 = Sequential(Linear(6, 128), ReLU(), Linear(128, M_in * M_out)) 50 | self.conv1 = NNConv(M_in, M_out, nn1) 51 | 52 | M_in, M_out = M_out, 64 53 | nn2 = Sequential(Linear(6, 128), ReLU(), Linear(128, M_in * M_out)) 54 | self.conv2 = NNConv(M_in, M_out, nn2) 55 | 56 | M_in, M_out = M_out, 64 57 | nn3 = Sequential(Linear(6, 128), ReLU(), Linear(128, M_in * M_out)) 58 | self.conv3 = NNConv(M_in, M_out, nn3) 59 | 60 | self.conv4 = GraphConv(64 + num_i_2, 64) 61 | self.conv5 = GraphConv(64, 64) 62 | 63 | self.fc1 = torch.nn.Linear(2 * 64, 64) 64 | self.fc2 = torch.nn.Linear(64, 32) 65 | self.fc3 = torch.nn.Linear(32, 12) 66 | 67 | def forward(self, data): 68 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) 69 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) 70 | data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) 71 | x_1 = scatter_mean(data.x, data.batch, dim=0) 72 | 73 | data.x = avg_pool(data.x, data.assignment_index_2) 74 | data.x = torch.cat([data.x, data.iso_type_2], dim=1) 75 | 76 | data.x = F.elu(self.conv4(data.x, data.edge_index_2)) 77 | data.x = F.elu(self.conv5(data.x, data.edge_index_2)) 78 | x_2 = scatter_mean(data.x, data.batch_2, dim=0) 79 | 80 | x = torch.cat([x_1, x_2], dim=1) 81 | 82 | x = F.elu(self.fc1(x)) 83 | x = F.elu(self.fc2(x)) 84 | x = self.fc3(x) 85 | return x 86 | 87 | 88 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 89 | 90 | results = [] 91 | results_log = [] 92 | for _ in range(5): 93 | 94 | dataset = dataset.shuffle() 95 | 96 | tenpercent = int(len(dataset) * 0.1) 97 | print("###") 98 | mean = dataset.data.y.mean(dim=0, keepdim=True) 99 | # mean_abs = dataset.data.y.abs().mean(dim=0, keepdim=True).to(device) # .view(-1) 100 | std = dataset.data.y.std(dim=0, keepdim=True) 101 | dataset.data.y = (dataset.data.y - mean) / std 102 | mean, std = mean.to(device), std.to(device) 103 | 104 | print("###") 105 | test_dataset = dataset[:tenpercent].shuffle() 106 | val_dataset = dataset[tenpercent:2 * tenpercent].shuffle() 107 | train_dataset = dataset[2 * tenpercent:].shuffle() 108 | 109 | print(len(train_dataset), len(val_dataset), len(test_dataset)) 110 | 111 | batch_size = 64 112 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 113 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True) 114 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 115 | 116 | # TODO 117 | model = Net().to(device) 118 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 119 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 120 | factor=0.5, patience=5, 121 | min_lr=0.0000001) 122 | 123 | def train(): 124 | model.train() 125 | loss_all = 0 126 | 127 | lf = torch.nn.L1Loss() 128 | 129 | for data in train_loader: 130 | data = data.to(device) 131 | optimizer.zero_grad() 132 | loss = lf(model(data), data.y) 133 | 134 | loss.backward() 135 | loss_all += loss.item() * data.num_graphs 136 | optimizer.step() 137 | return (loss_all / len(train_loader.dataset)) 138 | 139 | 140 | @torch.no_grad() 141 | def test(loader): 142 | model.eval() 143 | error = torch.zeros([1, 12]).to(device) 144 | 145 | for data in loader: 146 | data = data.to(device) 147 | error += ((data.y * std - model(data) * std).abs() / std).sum(dim=0) 148 | 149 | error = error / len(loader.dataset) 150 | error_log = torch.log(error) 151 | 152 | return error.mean().item(), error_log.mean().item() 153 | 154 | test_error = None 155 | log_test_error = None 156 | best_val_error = None 157 | for epoch in range(1, 1001): 158 | lr = scheduler.optimizer.param_groups[0]['lr'] 159 | loss = train() 160 | val_error, _ = test(val_loader) 161 | scheduler.step(val_error) 162 | 163 | if best_val_error is None or val_error <= best_val_error: 164 | test_error, log_test_error = test(test_loader) 165 | best_val_error = val_error 166 | 167 | print('Epoch: {:03d}, LR: {:.7f}, Loss: {:.7f}, Validation MAE: {:.7f}, ' 168 | 'Test MAE: {:.7f}, Test MAE: {:.7f}'.format(epoch, lr, loss, val_error, test_error, log_test_error)) 169 | 170 | if lr < 0.000001: 171 | print("Converged.") 172 | break 173 | 174 | results.append(test_error) 175 | results_log.append(log_test_error) 176 | 177 | print("########################") 178 | print(results) 179 | results = np.array(results) 180 | print(results.mean(), results.std()) 181 | 182 | print(results_log) 183 | results_log = np.array(results_log) 184 | print(results_log.mean(), results_log.std()) 185 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-3-qm9.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn import Sequential, Linear, ReLU 7 | from torch_scatter import scatter_mean 8 | from torch_geometric.datasets import QM9 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn import NNConv 11 | from k_gnn import GraphConv, DataLoader, avg_pool 12 | from k_gnn import ConnectedThreeMalkin 13 | 14 | 15 | class MyFilter(object): 16 | def __call__(self, data): 17 | return data.num_nodes > 6 # Remove graphs with less than 6 nodes. 18 | 19 | 20 | class MyPreTransform(object): 21 | def __call__(self, data): 22 | x = data.x 23 | data.x = data.x[:, :5] 24 | data = ConnectedThreeMalkin()(data) 25 | data.x = x 26 | return data 27 | 28 | 29 | class MyTransform(object): 30 | def __call__(self, data): 31 | data.y = data.y[:, int(args.target)] # Specify target: 0 = mu 32 | return data 33 | 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--target', default=0) 37 | args = parser.parse_args() 38 | target = int(args.target) 39 | 40 | print('---- Target: {} ----'.format(target)) 41 | 42 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', '1-3-QM9') 43 | dataset = QM9( 44 | path, 45 | transform=T.Compose([MyTransform(), T.Distance()]), 46 | pre_transform=MyPreTransform(), 47 | pre_filter=MyFilter()) 48 | 49 | dataset.data.iso_type_3 = torch.unique(dataset.data.iso_type_3, True, True)[1] 50 | num_i_3 = dataset.data.iso_type_3.max().item() + 1 51 | dataset.data.iso_type_3 = F.one_hot( 52 | dataset.data.iso_type_3, num_classes=num_i_3).to(torch.float) 53 | 54 | dataset = dataset.shuffle() 55 | 56 | # Normalize targets to mean = 0 and std = 1. 57 | tenpercent = int(len(dataset) * 0.1) 58 | mean = dataset.data.y[tenpercent:].mean(dim=0) 59 | std = dataset.data.y[tenpercent:].std(dim=0) 60 | dataset.data.y = (dataset.data.y - mean) / std 61 | 62 | test_dataset = dataset[:tenpercent] 63 | val_dataset = dataset[tenpercent:2 * tenpercent] 64 | train_dataset = dataset[2 * tenpercent:] 65 | test_loader = DataLoader(test_dataset, batch_size=64) 66 | val_loader = DataLoader(val_dataset, batch_size=64) 67 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 68 | 69 | 70 | class Net(torch.nn.Module): 71 | def __init__(self): 72 | super(Net, self).__init__() 73 | M_in, M_out = dataset.num_features, 32 74 | nn1 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 75 | self.conv1 = NNConv(M_in, M_out, nn1) 76 | 77 | M_in, M_out = M_out, 64 78 | nn2 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 79 | self.conv2 = NNConv(M_in, M_out, nn2) 80 | 81 | M_in, M_out = M_out, 64 82 | nn3 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 83 | self.conv3 = NNConv(M_in, M_out, nn3) 84 | 85 | self.conv6 = GraphConv(64 + num_i_3, 64) 86 | self.conv7 = GraphConv(64, 64) 87 | 88 | self.fc1 = torch.nn.Linear(2 * 64, 64) 89 | self.fc2 = torch.nn.Linear(64, 32) 90 | self.fc3 = torch.nn.Linear(32, 1) 91 | 92 | def forward(self, data): 93 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) 94 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) 95 | data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) 96 | x_1 = scatter_mean(data.x, data.batch, dim=0) 97 | 98 | data.x = avg_pool(data.x, data.assignment_index_3) 99 | data.x = torch.cat([data.x, data.iso_type_3], dim=1) 100 | 101 | data.x = F.elu(self.conv6(data.x, data.edge_index_3)) 102 | data.x = F.elu(self.conv7(data.x, data.edge_index_3)) 103 | x_3 = scatter_mean(data.x, data.batch_3, dim=0) 104 | 105 | x = torch.cat([x_1, x_3], dim=1) 106 | 107 | x = F.elu(self.fc1(x)) 108 | x = F.elu(self.fc2(x)) 109 | x = self.fc3(x) 110 | return x.view(-1) 111 | 112 | 113 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 114 | model = Net().to(device) 115 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 116 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 117 | optimizer, factor=0.7, patience=5, min_lr=0.00001) 118 | 119 | 120 | def train(epoch): 121 | model.train() 122 | loss_all = 0 123 | 124 | for data in train_loader: 125 | data = data.to(device) 126 | optimizer.zero_grad() 127 | loss = F.mse_loss(model(data), data.y) 128 | loss.backward() 129 | loss_all += loss * data.num_graphs 130 | optimizer.step() 131 | return loss_all / len(train_loader.dataset) 132 | 133 | 134 | def test(loader): 135 | model.eval() 136 | error = 0 137 | 138 | for data in loader: 139 | data = data.to(device) 140 | error += ((model(data) * std[target].cuda()) - 141 | (data.y * std[target].cuda())).abs().sum().item() # MAE 142 | return error / len(loader.dataset) 143 | 144 | 145 | best_val_error = None 146 | for epoch in range(1, 201): 147 | lr = scheduler.optimizer.param_groups[0]['lr'] 148 | loss = train(epoch) 149 | val_error = test(val_loader) 150 | scheduler.step(val_error) 151 | 152 | if best_val_error is None: 153 | best_val_error = val_error 154 | if val_error <= best_val_error: 155 | test_error = test(test_loader) 156 | best_val_error = val_error 157 | print( 158 | 'Epoch: {:03d}, LR: {:7f}, Loss: {:.7f}, Validation MAE: {:.7f}, ' 159 | 'Test MAE: {:.7f}, ' 160 | 'Test MAE norm: {:.7f}'.format(epoch, lr, loss, val_error, 161 | test_error, 162 | test_error / std[target].cuda())) 163 | else: 164 | print('Epoch: {:03d}'.format(epoch)) 165 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-3-qm9_all_targets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn import Sequential, Linear, ReLU 7 | from torch_scatter import scatter_mean 8 | from torch_geometric.datasets import QM9 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn import NNConv 11 | from k_gnn import GraphConv, DataLoader, avg_pool 12 | from k_gnn import ConnectedThreeMalkin 13 | import numpy as np 14 | 15 | class MyFilter(object): 16 | def __call__(self, data): 17 | return data.num_nodes > 6 # Remove graphs with less than 6 nodes. 18 | 19 | 20 | class MyPreTransform(object): 21 | def __call__(self, data): 22 | x = data.x 23 | data.x = data.x[:, :5] 24 | data = ConnectedThreeMalkin()(data) 25 | data.x = x 26 | return data 27 | 28 | 29 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', '1-23-QssM9') 30 | dataset = QM9( 31 | path, 32 | transform=T.Compose([T.Distance(norm=False)]), 33 | pre_transform=MyPreTransform(), 34 | pre_filter=MyFilter()) 35 | dataset.data.y = dataset.data.y[:,0:12] 36 | 37 | dataset.data.iso_type_3 = torch.unique(dataset.data.iso_type_3, True, True)[1] 38 | num_i_3 = dataset.data.iso_type_3.max().item() + 1 39 | dataset.data.iso_type_3 = F.one_hot( 40 | dataset.data.iso_type_3, num_classes=num_i_3).to(torch.float) 41 | 42 | #gfggg 43 | 44 | 45 | 46 | class Net(torch.nn.Module): 47 | def __init__(self): 48 | super(Net, self).__init__() 49 | M_in, M_out = dataset.num_features, 32 50 | nn1 = Sequential(Linear(6, 128), ReLU(), Linear(128, M_in * M_out)) 51 | self.conv1 = NNConv(M_in, M_out, nn1) 52 | 53 | M_in, M_out = M_out, 64 54 | nn2 = Sequential(Linear(6, 128), ReLU(), Linear(128, M_in * M_out)) 55 | self.conv2 = NNConv(M_in, M_out, nn2) 56 | 57 | M_in, M_out = M_out, 64 58 | nn3 = Sequential(Linear(6, 128), ReLU(), Linear(128, M_in * M_out)) 59 | self.conv3 = NNConv(M_in, M_out, nn3) 60 | 61 | self.conv6 = GraphConv(64 + num_i_3, 64) 62 | self.conv7 = GraphConv(64, 64) 63 | 64 | self.fc1 = torch.nn.Linear(2 * 64, 64) 65 | self.fc2 = torch.nn.Linear(64, 32) 66 | self.fc3 = torch.nn.Linear(32, 12) 67 | 68 | def forward(self, data): 69 | data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) 70 | data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) 71 | data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) 72 | x_1 = scatter_mean(data.x, data.batch, dim=0) 73 | 74 | data.x = avg_pool(data.x, data.assignment_index_3) 75 | data.x = torch.cat([data.x, data.iso_type_3], dim=1) 76 | 77 | data.x = F.elu(self.conv6(data.x, data.edge_index_3)) 78 | data.x = F.elu(self.conv7(data.x, data.edge_index_3)) 79 | x_3 = scatter_mean(data.x, data.batch_3, dim=0) 80 | 81 | x = torch.cat([x_1, x_3], dim=1) 82 | 83 | x = F.elu(self.fc1(x)) 84 | x = F.elu(self.fc2(x)) 85 | x = self.fc3(x) 86 | 87 | return x 88 | 89 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 90 | # model = Net().to(device) 91 | # optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 92 | # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 93 | # optimizer, factor=0.7, patience=5, min_lr=0.00001) 94 | 95 | 96 | 97 | 98 | 99 | results = [] 100 | results_log = [] 101 | for _ in range(5): 102 | 103 | dataset = dataset.shuffle() 104 | 105 | tenpercent = int(len(dataset) * 0.1) 106 | print("###") 107 | mean = dataset.data.y.mean(dim=0, keepdim=True) 108 | # mean_abs = dataset.data.y.abs().mean(dim=0, keepdim=True).to(device) # .view(-1) 109 | std = dataset.data.y.std(dim=0, keepdim=True) 110 | dataset.data.y = (dataset.data.y - mean) / std 111 | mean, std = mean.to(device), std.to(device) 112 | 113 | print("###") 114 | test_dataset = dataset[:tenpercent].shuffle() 115 | val_dataset = dataset[tenpercent:2 * tenpercent].shuffle() 116 | train_dataset = dataset[2 * tenpercent:].shuffle() 117 | 118 | print(len(train_dataset), len(val_dataset), len(test_dataset)) 119 | 120 | batch_size = 64 121 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 122 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True) 123 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 124 | 125 | # TODO 126 | model = Net().to(device) 127 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 128 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 129 | factor=0.5, patience=5, 130 | min_lr=0.0000001) 131 | 132 | def train(): 133 | model.train() 134 | loss_all = 0 135 | 136 | lf = torch.nn.L1Loss() 137 | 138 | for data in train_loader: 139 | data = data.to(device) 140 | optimizer.zero_grad() 141 | loss = lf(model(data), data.y) 142 | 143 | loss.backward() 144 | loss_all += loss.item() * data.num_graphs 145 | optimizer.step() 146 | return (loss_all / len(train_loader.dataset)) 147 | 148 | 149 | @torch.no_grad() 150 | def test(loader): 151 | model.eval() 152 | error = torch.zeros([1, 12]).to(device) 153 | 154 | for data in loader: 155 | data = data.to(device) 156 | error += ((data.y * std - model(data) * std).abs() / std).sum(dim=0) 157 | 158 | error = error / len(loader.dataset) 159 | error_log = torch.log(error) 160 | 161 | return error.mean().item(), error_log.mean().item() 162 | 163 | test_error = None 164 | log_test_error = None 165 | best_val_error = None 166 | for epoch in range(1, 1001): 167 | lr = scheduler.optimizer.param_groups[0]['lr'] 168 | loss = train() 169 | val_error, _ = test(val_loader) 170 | scheduler.step(val_error) 171 | 172 | if best_val_error is None or val_error <= best_val_error: 173 | test_error, log_test_error = test(test_loader) 174 | best_val_error = val_error 175 | 176 | print('Epoch: {:03d}, LR: {:.7f}, Loss: {:.7f}, Validation MAE: {:.7f}, ' 177 | 'Test MAE: {:.7f}, Test MAE: {:.7f}'.format(epoch, lr, loss, val_error, test_error, log_test_error)) 178 | 179 | if lr < 0.000001: 180 | print("Converged.") 181 | break 182 | 183 | results.append(test_error) 184 | results_log.append(log_test_error) 185 | 186 | print("########################") 187 | print(results) 188 | results = np.array(results) 189 | print(results.mean(), results.std()) 190 | 191 | print(results_log) 192 | results_log = np.array(results_log) 193 | print(results_log.mean(), results_log.std()) 194 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-NCI1.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_scatter import scatter_max 7 | from torch_geometric.datasets import TUDataset 8 | from k_gnn import GraphConv 9 | from torch_geometric.data import DataLoader 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--no-train', default=False) 13 | args = parser.parse_args() 14 | 15 | 16 | class MyFilter(object): 17 | def __call__(self, data): 18 | return True 19 | return data.num_nodes >= 5 20 | 21 | 22 | BATCH = 32 23 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', '1-NCI') 24 | dataset = TUDataset(path, name='NCI1', pre_filter=MyFilter()) 25 | 26 | perm = torch.randperm(len(dataset), dtype=torch.long) 27 | torch.save(perm, 'nci_perm.pt') 28 | perm = torch.load('nci_perm.pt') 29 | dataset = dataset[perm] 30 | 31 | 32 | class Net(torch.nn.Module): 33 | def __init__(self): 34 | super(Net, self).__init__() 35 | self.conv1 = GraphConv(dataset.num_features, 32) 36 | self.conv2 = GraphConv(32, 64) 37 | self.conv3 = GraphConv(64, 64) 38 | self.fc1 = torch.nn.Linear(64, 64) 39 | self.fc2 = torch.nn.Linear(64, 32) 40 | self.fc3 = torch.nn.Linear(32, dataset.num_classes) 41 | 42 | def reset_parameters(self): 43 | for (name, module) in self._modules.items(): 44 | module.reset_parameters() 45 | 46 | def forward(self, data): 47 | data.x = F.elu(self.conv1(data.x, data.edge_index)) 48 | data.x = F.elu(self.conv2(data.x, data.edge_index)) 49 | data.x = F.elu(self.conv3(data.x, data.edge_index)) 50 | x_1 = scatter_max(data.x, data.batch, dim=0)[0] 51 | x = x_1 52 | 53 | if args.no_train: 54 | x = x.detach() 55 | 56 | x = F.elu(self.fc1(x)) 57 | x = F.dropout(x, p=0.5, training=self.training) 58 | x = F.elu(self.fc2(x)) 59 | x = self.fc3(x) 60 | return F.log_softmax(x, dim=1) 61 | 62 | 63 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 64 | model = Net().to(device) 65 | 66 | 67 | def train(epoch, loader, optimizer): 68 | model.train() 69 | loss_all = 0 70 | 71 | for data in loader: 72 | data = data.to(device) 73 | optimizer.zero_grad() 74 | loss = F.nll_loss(model(data), data.y) 75 | loss.backward() 76 | loss_all += data.num_graphs * loss.item() 77 | optimizer.step() 78 | return loss_all / len(loader.dataset) 79 | 80 | 81 | def val(loader): 82 | model.eval() 83 | loss_all = 0 84 | 85 | for data in loader: 86 | data = data.to(device) 87 | loss_all += F.nll_loss(model(data), data.y, reduction='sum').item() 88 | return loss_all / len(loader.dataset) 89 | 90 | 91 | def test(loader): 92 | model.eval() 93 | correct = 0 94 | 95 | for data in loader: 96 | data = data.to(device) 97 | pred = model(data).max(1)[1] 98 | correct += pred.eq(data.y).sum().item() 99 | return correct / len(loader.dataset) 100 | 101 | 102 | acc = [] 103 | for i in range(10): 104 | model.reset_parameters() 105 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 106 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 107 | optimizer, mode='min', factor=0.7, patience=5, min_lr=0.00001) 108 | 109 | test_mask = torch.zeros(len(dataset), dtype=torch.uint8) 110 | n = len(dataset) // 10 111 | test_mask[i * n:(i + 1) * n] = 1 112 | test_dataset = dataset[test_mask] 113 | train_dataset = dataset[1 - test_mask] 114 | 115 | n = len(train_dataset) // 10 116 | val_mask = torch.zeros(len(train_dataset), dtype=torch.uint8) 117 | val_mask[i * n:(i + 1) * n] = 1 118 | val_dataset = train_dataset[val_mask] 119 | train_dataset = train_dataset[1 - val_mask] 120 | 121 | val_loader = DataLoader(val_dataset, batch_size=BATCH) 122 | test_loader = DataLoader(test_dataset, batch_size=BATCH) 123 | train_loader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True) 124 | 125 | print('---------------- Split {} ----------------'.format(i)) 126 | 127 | best_val_loss, test_acc = 100, 0 128 | for epoch in range(1, 101): 129 | lr = scheduler.optimizer.param_groups[0]['lr'] 130 | train_loss = train(epoch, train_loader, optimizer) 131 | val_loss = val(val_loader) 132 | scheduler.step(val_loss) 133 | if best_val_loss >= val_loss: 134 | test_acc = test(test_loader) 135 | best_val_loss = val_loss 136 | print('Epoch: {:03d}, LR: {:7f}, Train Loss: {:.7f}, ' 137 | 'Val Loss: {:.7f}, Test Acc: {:.7f}'.format( 138 | epoch, lr, train_loss, val_loss, test_acc)) 139 | acc.append(test_acc) 140 | acc = torch.tensor(acc) 141 | print('---------------- Final Result ----------------') 142 | print('Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std())) 143 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-imdb.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_scatter import scatter_mean 7 | from torch_geometric.datasets import TUDataset 8 | from torch_geometric.utils import degree 9 | from torch_geometric.data import DataLoader 10 | from k_gnn import GraphConv 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--no-train', default=False) 14 | args = parser.parse_args() 15 | 16 | 17 | class MyFilter(object): 18 | def __call__(self, data): 19 | return data.num_nodes <= 70 20 | 21 | 22 | class MyPreTransform(object): 23 | def __call__(self, data): 24 | data.x = degree(data.edge_index[0], data.num_nodes, dtype=torch.long) 25 | data.x = F.one_hot(data.x, num_classes=136).to(torch.float) 26 | return data 27 | 28 | 29 | BATCH = 32 30 | path = osp.join( 31 | osp.dirname(osp.realpath(__file__)), '..', 'data', '1-IMDB-BINARY') 32 | dataset = TUDataset( 33 | path, 34 | name='IMDB-BINARY', 35 | pre_transform=MyPreTransform(), 36 | pre_filter=MyFilter()) 37 | 38 | perm = torch.randperm(len(dataset), dtype=torch.long) 39 | dataset = dataset[perm] 40 | 41 | 42 | class Net(torch.nn.Module): 43 | def __init__(self): 44 | super(Net, self).__init__() 45 | self.conv1 = GraphConv(dataset.num_features, 32) 46 | self.conv2 = GraphConv(32, 64) 47 | self.conv3 = GraphConv(64, 64) 48 | self.fc1 = torch.nn.Linear(64, 64) 49 | self.fc2 = torch.nn.Linear(64, 32) 50 | self.fc3 = torch.nn.Linear(32, dataset.num_classes) 51 | 52 | def reset_parameters(self): 53 | for (name, module) in self._modules.items(): 54 | module.reset_parameters() 55 | 56 | def forward(self, data): 57 | data.x = F.elu(self.conv1(data.x, data.edge_index)) 58 | data.x = F.elu(self.conv2(data.x, data.edge_index)) 59 | data.x = F.elu(self.conv3(data.x, data.edge_index)) 60 | x_1 = scatter_mean(data.x, data.batch, dim=0) 61 | x = x_1 62 | 63 | if args.no_train: 64 | x = x.detach() 65 | 66 | x = F.elu(self.fc1(x)) 67 | x = F.dropout(x, p=0.5, training=self.training) 68 | x = F.elu(self.fc2(x)) 69 | x = self.fc3(x) 70 | return F.log_softmax(x, dim=1) 71 | 72 | 73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 74 | model = Net().to(device) 75 | 76 | 77 | def train(epoch, loader, optimizer): 78 | model.train() 79 | loss_all = 0 80 | 81 | for data in loader: 82 | data = data.to(device) 83 | optimizer.zero_grad() 84 | loss = F.nll_loss(model(data), data.y) 85 | loss.backward() 86 | loss_all += data.num_graphs * loss.item() 87 | optimizer.step() 88 | return loss_all / len(loader.dataset) 89 | 90 | 91 | def val(loader): 92 | model.eval() 93 | loss_all = 0 94 | 95 | for data in loader: 96 | data = data.to(device) 97 | loss_all += F.nll_loss(model(data), data.y, reduction='sum').item() 98 | return loss_all / len(loader.dataset) 99 | 100 | 101 | def test(loader): 102 | model.eval() 103 | correct = 0 104 | 105 | for data in loader: 106 | data = data.to(device) 107 | pred = model(data).max(1)[1] 108 | correct += pred.eq(data.y).sum().item() 109 | return correct / len(loader.dataset) 110 | 111 | 112 | acc = [] 113 | for i in range(10): 114 | model.reset_parameters() 115 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 116 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 117 | optimizer, mode='min', factor=0.7, patience=5, min_lr=0.00001) 118 | 119 | test_mask = torch.zeros(len(dataset), dtype=torch.uint8) 120 | n = len(dataset) // 10 121 | test_mask[i * n:(i + 1) * n] = 1 122 | test_dataset = dataset[test_mask] 123 | train_dataset = dataset[1 - test_mask] 124 | 125 | n = len(train_dataset) // 10 126 | val_mask = torch.zeros(len(train_dataset), dtype=torch.uint8) 127 | val_mask[i * n:(i + 1) * n] = 1 128 | val_dataset = train_dataset[val_mask] 129 | train_dataset = train_dataset[1 - val_mask] 130 | 131 | val_loader = DataLoader(val_dataset, batch_size=BATCH) 132 | test_loader = DataLoader(test_dataset, batch_size=BATCH) 133 | train_loader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True) 134 | 135 | print('---------------- Split {} ----------------'.format(i)) 136 | 137 | best_val_loss, test_acc = 100, 0 138 | for epoch in range(1, 101): 139 | lr = scheduler.optimizer.param_groups[0]['lr'] 140 | train_loss = train(epoch, train_loader, optimizer) 141 | val_loss = val(val_loader) 142 | scheduler.step(val_loss) 143 | if best_val_loss >= val_loss: 144 | test_acc = test(test_loader) 145 | best_val_loss = val_loss 146 | print('Epoch: {:03d}, LR: {:7f}, Train Loss: {:.7f}, ' 147 | 'Val Loss: {:.7f}, Test Acc: {:.7f}'.format( 148 | epoch, lr, train_loss, val_loss, test_acc)) 149 | acc.append(test_acc) 150 | acc = torch.tensor(acc) 151 | print('---------------- Final Result ----------------') 152 | print('Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std())) 153 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-mutag.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_scatter import scatter_add 7 | from torch_geometric.datasets import TUDataset 8 | from k_gnn import GraphConv 9 | from torch_geometric.data import DataLoader 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--no-train', default=False) 13 | args = parser.parse_args() 14 | 15 | 16 | class MyFilter(object): 17 | def __call__(self, data): 18 | return True 19 | 20 | 21 | BATCH = 32 22 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MUTAG') 23 | dataset = TUDataset(path, name='MUTAG', pre_filter=MyFilter()) 24 | 25 | perm = torch.randperm(len(dataset), dtype=torch.long) 26 | torch.save(perm, 'mutag.pt') 27 | perm = torch.load('mutag.pt') 28 | dataset = dataset[perm] 29 | 30 | 31 | class Net(torch.nn.Module): 32 | def __init__(self): 33 | super(Net, self).__init__() 34 | self.conv1 = GraphConv(dataset.num_features, 32) 35 | self.conv2 = GraphConv(32, 64) 36 | self.conv3 = GraphConv(64, 64) 37 | self.fc1 = torch.nn.Linear(64, 64) 38 | self.fc2 = torch.nn.Linear(64, 32) 39 | self.fc3 = torch.nn.Linear(32, dataset.num_classes) 40 | 41 | def reset_parameters(self): 42 | for (name, module) in self._modules.items(): 43 | module.reset_parameters() 44 | 45 | def forward(self, data): 46 | data.x = F.elu(self.conv1(data.x, data.edge_index)) 47 | data.x = F.elu(self.conv2(data.x, data.edge_index)) 48 | data.x = F.elu(self.conv3(data.x, data.edge_index)) 49 | x_1 = scatter_add(data.x, data.batch, dim=0) 50 | x = x_1 51 | 52 | if args.no_train: 53 | x = x.detach() 54 | 55 | x = F.elu(self.fc1(x)) 56 | x = F.elu(self.fc2(x)) 57 | x = self.fc3(x) 58 | return F.log_softmax(x, dim=1) 59 | 60 | 61 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 62 | model = Net().to(device) 63 | 64 | 65 | def train(epoch, loader, optimizer): 66 | model.train() 67 | loss_all = 0 68 | 69 | for data in loader: 70 | data = data.to(device) 71 | optimizer.zero_grad() 72 | loss = F.nll_loss(model(data), data.y) 73 | loss.backward() 74 | loss_all += data.num_graphs * loss.item() 75 | optimizer.step() 76 | return loss_all / len(loader.dataset) 77 | 78 | 79 | def val(loader): 80 | model.eval() 81 | loss_all = 0 82 | 83 | for data in loader: 84 | data = data.to(device) 85 | loss_all += F.nll_loss(model(data), data.y, reduction='sum').item() 86 | return loss_all / len(loader.dataset) 87 | 88 | 89 | def test(loader): 90 | model.eval() 91 | correct = 0 92 | 93 | for data in loader: 94 | data = data.to(device) 95 | pred = model(data).max(1)[1] 96 | correct += pred.eq(data.y).sum().item() 97 | return correct / len(loader.dataset) 98 | 99 | 100 | acc = [] 101 | for i in range(10): 102 | model.reset_parameters() 103 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 104 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 105 | optimizer, mode='min', factor=0.7, patience=5, min_lr=0.00001) 106 | 107 | test_mask = torch.zeros(len(dataset), dtype=torch.uint8) 108 | n = len(dataset) // 10 109 | test_mask[i * n:(i + 1) * n] = 1 110 | test_dataset = dataset[test_mask] 111 | train_dataset = dataset[1 - test_mask] 112 | 113 | n = len(train_dataset) // 10 114 | val_mask = torch.zeros(len(train_dataset), dtype=torch.uint8) 115 | val_mask[i * n:(i + 1) * n] = 1 116 | val_dataset = train_dataset[val_mask] 117 | train_dataset = train_dataset[1 - val_mask] 118 | 119 | val_loader = DataLoader(val_dataset, batch_size=BATCH) 120 | test_loader = DataLoader(test_dataset, batch_size=BATCH) 121 | train_loader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True) 122 | 123 | print('---------------- Split {} ----------------'.format(i)) 124 | 125 | best_val_loss, test_acc = 100, 0 126 | for epoch in range(1, 101): 127 | lr = scheduler.optimizer.param_groups[0]['lr'] 128 | train_loss = train(epoch, train_loader, optimizer) 129 | val_loss = val(val_loader) 130 | scheduler.step(val_loss) 131 | if best_val_loss >= val_loss: 132 | test_acc = test(test_loader) 133 | best_val_loss = val_loss 134 | print('Epoch: {:03d}, LR: {:7f}, Train Loss: {:.7f}, ' 135 | 'Val Loss: {:.7f}, Test Acc: {:.7f}'.format( 136 | epoch, lr, train_loss, val_loss, test_acc)) 137 | acc.append(test_acc) 138 | acc = torch.tensor(acc) 139 | print('---------------- Final Result ----------------') 140 | print('Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std())) 141 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-proteins.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_scatter import scatter_mean 7 | from torch_geometric.datasets import TUDataset 8 | from torch_geometric.data import DataLoader 9 | from k_gnn import GraphConv 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--no-train', default=False) 13 | args = parser.parse_args() 14 | 15 | 16 | class MyFilter(object): 17 | def __call__(self, data): 18 | return not (data.num_nodes == 7 and data.num_edges == 12) and \ 19 | data.num_nodes < 450 20 | 21 | 22 | BATCH = 32 23 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 24 | '1-PROTEINS') 25 | dataset = TUDataset(path, name='PROTEINS', pre_filter=MyFilter()) 26 | 27 | perm = torch.randperm(len(dataset), dtype=torch.long) 28 | dataset = dataset[perm] 29 | 30 | 31 | class Net(torch.nn.Module): 32 | def __init__(self): 33 | super(Net, self).__init__() 34 | self.conv1 = GraphConv(dataset.num_features, 32) 35 | self.conv2 = GraphConv(32, 64) 36 | self.conv3 = GraphConv(64, 64) 37 | self.fc1 = torch.nn.Linear(64, 64) 38 | self.fc2 = torch.nn.Linear(64, 32) 39 | self.fc3 = torch.nn.Linear(32, dataset.num_classes) 40 | 41 | def reset_parameters(self): 42 | for (name, module) in self._modules.items(): 43 | module.reset_parameters() 44 | 45 | def forward(self, data): 46 | data.x = F.elu(self.conv1(data.x, data.edge_index)) 47 | data.x = F.elu(self.conv2(data.x, data.edge_index)) 48 | data.x = F.elu(self.conv3(data.x, data.edge_index)) 49 | x_1 = scatter_mean(data.x, data.batch, dim=0) 50 | x = x_1 51 | 52 | if args.no_train: 53 | x = x.detach() 54 | 55 | x = F.elu(self.fc1(x)) 56 | x = F.dropout(x, p=0.5, training=self.training) 57 | x = F.elu(self.fc2(x)) 58 | x = self.fc3(x) 59 | return F.log_softmax(x, dim=1) 60 | 61 | 62 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 63 | model = Net().to(device) 64 | 65 | 66 | def train(epoch, loader, optimizer): 67 | model.train() 68 | loss_all = 0 69 | 70 | for data in loader: 71 | data = data.to(device) 72 | optimizer.zero_grad() 73 | loss = F.nll_loss(model(data), data.y) 74 | loss.backward() 75 | loss_all += data.num_graphs * loss.item() 76 | optimizer.step() 77 | return loss_all / len(loader.dataset) 78 | 79 | 80 | def val(loader): 81 | model.eval() 82 | loss_all = 0 83 | 84 | for data in loader: 85 | data = data.to(device) 86 | loss_all += F.nll_loss(model(data), data.y, reduction='sum').item() 87 | return loss_all / len(loader.dataset) 88 | 89 | 90 | def test(loader): 91 | model.eval() 92 | correct = 0 93 | 94 | for data in loader: 95 | data = data.to(device) 96 | pred = model(data).max(1)[1] 97 | correct += pred.eq(data.y).sum().item() 98 | return correct / len(loader.dataset) 99 | 100 | 101 | acc = [] 102 | for i in range(10): 103 | model.reset_parameters() 104 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 105 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 106 | optimizer, mode='min', factor=0.7, patience=5, min_lr=0.00001) 107 | 108 | test_mask = torch.zeros(len(dataset), dtype=torch.uint8) 109 | n = len(dataset) // 10 110 | test_mask[i * n:(i + 1) * n] = 1 111 | test_dataset = dataset[test_mask] 112 | train_dataset = dataset[1 - test_mask] 113 | 114 | n = len(train_dataset) // 10 115 | val_mask = torch.zeros(len(train_dataset), dtype=torch.uint8) 116 | val_mask[i * n:(i + 1) * n] = 1 117 | val_dataset = train_dataset[val_mask] 118 | train_dataset = train_dataset[1 - val_mask] 119 | 120 | val_loader = DataLoader(val_dataset, batch_size=BATCH) 121 | test_loader = DataLoader(test_dataset, batch_size=BATCH) 122 | train_loader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True) 123 | 124 | print('---------------- Split {} ----------------'.format(i)) 125 | 126 | best_val_loss, test_acc = 100, 0 127 | for epoch in range(1, 101): 128 | lr = scheduler.optimizer.param_groups[0]['lr'] 129 | train_loss = train(epoch, train_loader, optimizer) 130 | val_loss = val(val_loader) 131 | scheduler.step(val_loss) 132 | if best_val_loss >= val_loss: 133 | test_acc = test(test_loader) 134 | best_val_loss = val_loss 135 | print('Epoch: {:03d}, LR: {:7f}, Train Loss: {:.7f}, ' 136 | 'Val Loss: {:.7f}, Test Acc: {:.7f}'.format( 137 | epoch, lr, train_loss, val_loss, test_acc)) 138 | acc.append(test_acc) 139 | acc = torch.tensor(acc) 140 | print('---------------- Final Result ----------------') 141 | print('Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std())) 142 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-qm9.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | from torch.nn import Sequential, Linear, ReLU 6 | import torch.nn.functional as F 7 | from torch_scatter import scatter_mean 8 | from torch_geometric.datasets import QM9 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn import NNConv 11 | from torch_geometric.data import DataLoader 12 | 13 | 14 | class MyFilter(object): 15 | def __call__(self, data): 16 | return data.num_nodes > 6 # Remove graphs with less than 6 nodes. 17 | 18 | 19 | class MyTransform(object): 20 | def __call__(self, data): 21 | data.y = data.y[:, int(args.target)] # Specify target: 0 = mu 22 | return data 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--target', default=0) 27 | args = parser.parse_args() 28 | target = int(args.target) 29 | 30 | print('---- Target: {} ----'.format(target)) 31 | 32 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', '1-QM9') 33 | dataset = QM9(path, transform=T.Compose([MyTransform(), T.Distance()])) 34 | 35 | dataset = dataset.shuffle() 36 | 37 | # Normalize targets to mean = 0 and std = 1. 38 | tenpercent = int(len(dataset) * 0.1) 39 | mean = dataset.data.y[tenpercent:].mean(dim=0) 40 | std = dataset.data.y[tenpercent:].std(dim=0) 41 | dataset.data.y = (dataset.data.y - mean) / std 42 | 43 | test_dataset = dataset[:tenpercent] 44 | val_dataset = dataset[tenpercent:2 * tenpercent] 45 | train_dataset = dataset[2 * tenpercent:] 46 | test_loader = DataLoader(test_dataset, batch_size=64) 47 | val_loader = DataLoader(val_dataset, batch_size=64) 48 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 49 | 50 | 51 | class Net(torch.nn.Module): 52 | def __init__(self): 53 | super(Net, self).__init__() 54 | M_in, M_out = dataset.num_features, 32 55 | nn1 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 56 | self.conv1 = NNConv(M_in, M_out, nn1) 57 | 58 | M_in, M_out = M_out, 64 59 | nn2 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 60 | self.conv2 = NNConv(M_in, M_out, nn2) 61 | 62 | M_in, M_out = M_out, 64 63 | nn3 = Sequential(Linear(5, 128), ReLU(), Linear(128, M_in * M_out)) 64 | self.conv3 = NNConv(M_in, M_out, nn3) 65 | 66 | self.fc1 = torch.nn.Linear(64, 32) 67 | self.fc2 = torch.nn.Linear(32, 16) 68 | self.fc3 = torch.nn.Linear(16, 1) 69 | 70 | def forward(self, data): 71 | x = data.x 72 | x = F.elu(self.conv1(x, data.edge_index, data.edge_attr)) 73 | x = F.elu(self.conv2(x, data.edge_index, data.edge_attr)) 74 | x = F.elu(self.conv3(x, data.edge_index, data.edge_attr)) 75 | 76 | x = scatter_mean(x, data.batch, dim=0) 77 | 78 | x = F.elu(self.fc1(x)) 79 | x = F.elu(self.fc2(x)) 80 | x = self.fc3(x) 81 | return x.view(-1) 82 | 83 | 84 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 85 | model = Net().to(device) 86 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 87 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 88 | optimizer, factor=0.7, patience=5, min_lr=0.00001) 89 | 90 | 91 | def train(epoch): 92 | model.train() 93 | loss_all = 0 94 | 95 | for data in train_loader: 96 | data = data.to(device) 97 | optimizer.zero_grad() 98 | loss = F.mse_loss(model(data), data.y) 99 | loss.backward() 100 | loss_all += loss * data.num_graphs 101 | optimizer.step() 102 | return loss_all / len(train_loader.dataset) 103 | 104 | 105 | def test(loader): 106 | model.eval() 107 | error = 0 108 | 109 | for data in loader: 110 | data = data.to(device) 111 | error += ((model(data) * std[target].cuda()) - 112 | (data.y * std[target].cuda())).abs().sum().item() # MAE 113 | return error / len(loader.dataset) 114 | 115 | 116 | best_val_error = None 117 | for epoch in range(1, 301): 118 | lr = scheduler.optimizer.param_groups[0]['lr'] 119 | loss = train(epoch) 120 | val_error = test(val_loader) 121 | scheduler.step(val_error) 122 | 123 | if best_val_error is None: 124 | best_val_error = val_error 125 | if val_error <= best_val_error: 126 | test_error = test(test_loader) 127 | best_val_error = val_error 128 | print( 129 | 'Epoch: {:03d}, LR: {:7f}, Loss: {:.7f}, Validation MAE: {:.7f}, ' 130 | 'Test MAE: {:.7f}, ' 131 | 'Test MAE norm: {:.7f}'.format(epoch, lr, loss, val_error, 132 | test_error, 133 | test_error / std[target].cuda())) 134 | else: 135 | print('Epoch: {:03d}'.format(epoch)) 136 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/1-reddit.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_scatter import scatter_mean 7 | from torch_geometric.datasets import TUDataset 8 | from torch_geometric.utils import degree 9 | from torch_geometric.data import DataLoader 10 | from k_gnn import GraphConv 11 | from torch_geometric.utils import degree 12 | import torch_geometric.transforms as T 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--no-train', default=False) 16 | args = parser.parse_args() 17 | 18 | 19 | class NormalizedDegree(object): 20 | def __init__(self, mean, std): 21 | self.mean = mean 22 | self.std = std 23 | 24 | def __call__(self, data): 25 | deg = degree(data.edge_index[0], dtype=torch.float) 26 | deg = (deg - self.mean) / self.std 27 | data.x = deg.view(-1, 1) 28 | return data 29 | 30 | def get_dataset(name, sparse=True, cleaned=False): 31 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name) 32 | dataset = TUDataset(path, name) 33 | dataset.data.edge_attr = None 34 | 35 | if dataset.data.x is None: 36 | max_degree = 0 37 | degs = [] 38 | for data in dataset: 39 | degs += [degree(data.edge_index[0], dtype=torch.long)] 40 | max_degree = max(max_degree, degs[-1].max().item()) 41 | 42 | deg = torch.cat(degs, dim=0).to(torch.float) 43 | mean, std = deg.mean().item(), deg.std().item() 44 | dataset.transform = NormalizedDegree(mean, std) 45 | 46 | if not sparse: 47 | num_nodes = max_num_nodes = 0 48 | for data in dataset: 49 | num_nodes += data.num_nodes 50 | max_num_nodes = max(data.num_nodes, max_num_nodes) 51 | 52 | # Filter out a few really large graphs in order to apply DiffPool. 53 | if name == 'REDDIT-BINARY': 54 | num_nodes = min(int(num_nodes / len(dataset) * 1.5), max_num_nodes) 55 | else: 56 | num_nodes = min(int(num_nodes / len(dataset) * 5), max_num_nodes) 57 | 58 | indices = [] 59 | for i, data in enumerate(dataset): 60 | if data.num_nodes <= num_nodes: 61 | indices.append(i) 62 | dataset = dataset[torch.tensor(indices)] 63 | 64 | if dataset.transform is None: 65 | dataset.transform = T.ToDense(num_nodes) 66 | else: 67 | dataset.transform = T.Compose( 68 | [dataset.transform, T.ToDense(num_nodes)]) 69 | 70 | return dataset 71 | 72 | class MyFilter(object): 73 | def __call__(self, data): 74 | return data.num_nodes <= 100000 75 | 76 | class MyPreTransform(object): 77 | def __call__(self, data): 78 | data.x = degree(data.edge_index[0], data.num_nodes, dtype=torch.long) 79 | data.x = F.one_hot(data.x).to(torch.float) 80 | return data 81 | 82 | BATCH = 32 83 | dataset = get_dataset('REDDIT-BINARY') 84 | 85 | perm = torch.randperm(len(dataset), dtype=torch.long) 86 | dataset = dataset[perm] 87 | 88 | print(len(dataset)) 89 | 90 | 91 | 92 | 93 | class Net(torch.nn.Module): 94 | def __init__(self): 95 | super(Net, self).__init__() 96 | self.conv1 = GraphConv(dataset.num_features, 32) 97 | self.conv2 = GraphConv(32, 64) 98 | self.conv3 = GraphConv(64, 64) 99 | self.fc1 = torch.nn.Linear(64, 64) 100 | self.fc2 = torch.nn.Linear(64, 32) 101 | self.fc3 = torch.nn.Linear(32, dataset.num_classes) 102 | 103 | def reset_parameters(self): 104 | for (name, module) in self._modules.items(): 105 | module.reset_parameters() 106 | 107 | def forward(self, data): 108 | data.x = F.elu(self.conv1(data.x, data.edge_index)) 109 | data.x = F.elu(self.conv2(data.x, data.edge_index)) 110 | data.x = F.elu(self.conv3(data.x, data.edge_index)) 111 | x_1 = scatter_mean(data.x, data.batch, dim=0) 112 | x = x_1 113 | 114 | if args.no_train: 115 | x = x.detach() 116 | 117 | x = F.elu(self.fc1(x)) 118 | x = F.dropout(x, p=0.5, training=self.training) 119 | x = F.elu(self.fc2(x)) 120 | x = self.fc3(x) 121 | return F.log_softmax(x, dim=1) 122 | 123 | 124 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 125 | model = Net().to(device) 126 | 127 | 128 | def train(epoch, loader, optimizer): 129 | model.train() 130 | loss_all = 0 131 | 132 | for data in loader: 133 | data = data.to(device) 134 | optimizer.zero_grad() 135 | loss = F.nll_loss(model(data), data.y) 136 | loss.backward() 137 | loss_all += data.num_graphs * loss.item() 138 | optimizer.step() 139 | return loss_all / len(loader.dataset) 140 | 141 | 142 | def val(loader): 143 | model.eval() 144 | loss_all = 0 145 | 146 | for data in loader: 147 | data = data.to(device) 148 | loss_all += F.nll_loss(model(data), data.y, reduction='sum').item() 149 | return loss_all / len(loader.dataset) 150 | 151 | 152 | def test(loader): 153 | model.eval() 154 | correct = 0 155 | 156 | for data in loader: 157 | data = data.to(device) 158 | pred = model(data).max(1)[1] 159 | correct += pred.eq(data.y).sum().item() 160 | return correct / len(loader.dataset) 161 | 162 | 163 | acc = [] 164 | for i in range(10): 165 | model.reset_parameters() 166 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 167 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 168 | optimizer, mode='min', factor=0.7, patience=5, min_lr=0.00001) 169 | 170 | test_mask = torch.zeros(len(dataset), dtype=torch.uint8) 171 | n = len(dataset) // 10 172 | test_mask[i * n:(i + 1) * n] = 1 173 | test_dataset = dataset[test_mask] 174 | train_dataset = dataset[1 - test_mask] 175 | 176 | n = len(train_dataset) // 10 177 | val_mask = torch.zeros(len(train_dataset), dtype=torch.uint8) 178 | val_mask[i * n:(i + 1) * n] = 1 179 | val_dataset = train_dataset[val_mask] 180 | train_dataset = train_dataset[1 - val_mask] 181 | 182 | val_loader = DataLoader(val_dataset, batch_size=BATCH) 183 | test_loader = DataLoader(test_dataset, batch_size=BATCH) 184 | train_loader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True) 185 | 186 | print('---------------- Split {} ----------------'.format(i)) 187 | 188 | best_val_loss, test_acc = 100, 0 189 | for epoch in range(1, 101): 190 | lr = scheduler.optimizer.param_groups[0]['lr'] 191 | train_loss = train(epoch, train_loader, optimizer) 192 | val_loss = val(val_loader) 193 | scheduler.step(val_loss) 194 | if best_val_loss >= val_loss: 195 | test_acc = test(test_loader) 196 | best_val_loss = val_loss 197 | print('Epoch: {:03d}, LR: {:7f}, Train Loss: {:.7f}, ' 198 | 'Val Loss: {:.7f}, Test Acc: {:.7f}'.format( 199 | epoch, lr, train_loss, val_loss, test_acc)) 200 | acc.append(test_acc) 201 | acc = torch.tensor(acc) 202 | print('---------------- Final Result ----------------') 203 | print('Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std())) 204 | -------------------------------------------------------------------------------- /software/k-gnn-master/examples/nci_perm.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muhanzhang/NestedGNN/a5adccf62d397ad7f83bc73be34eba3765df73fa/software/k-gnn-master/examples/nci_perm.pt -------------------------------------------------------------------------------- /software/k-gnn-master/k_gnn/__init__.py: -------------------------------------------------------------------------------- 1 | from graph_cpu import two_local, connected_two_local 2 | from graph_cpu import two_malkin, connected_two_malkin 3 | from graph_cpu import three_local, connected_three_local 4 | from graph_cpu import three_malkin, connected_three_malkin 5 | from graph_cpu import assignment_2to3 6 | from .transform import TwoLocal, ConnectedTwoLocal 7 | from .transform import TwoMalkin, ConnectedTwoMalkin 8 | from .transform import ThreeLocal, ConnectedThreeLocal 9 | from .transform import ThreeMalkin, ConnectedThreeMalkin 10 | from .transform import Assignment2To3 11 | from .dataloader import DataLoader 12 | from .graph_conv import GraphConv 13 | from .pool import add_pool, max_pool, avg_pool 14 | from .complete import Complete 15 | 16 | __all__ = [ 17 | 'two_local', 18 | 'connected_two_local', 19 | 'two_malkin', 20 | 'connected_two_malkin', 21 | 'three_local', 22 | 'connected_three_local', 23 | 'three_malkin', 24 | 'connected_three_malkin', 25 | 'assignment_2to3', 26 | 'TwoLocal', 27 | 'ConnectedTwoLocal', 28 | 'TwoMalkin', 29 | 'ConnectedTwoMalkin', 30 | 'ThreeLocal', 31 | 'ConnectedThreeLocal', 32 | 'ThreeMalkin', 33 | 'ConnectedThreeMalkin', 34 | 'Assignment2To3', 35 | 'DataLoader', 36 | 'GraphConv', 37 | 'add_pool', 38 | 'max_pool', 39 | 'avg_pool', 40 | 'Complete', 41 | ] 42 | -------------------------------------------------------------------------------- /software/k-gnn-master/k_gnn/complete.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import remove_self_loops 3 | 4 | 5 | class Complete(object): 6 | def __call__(self, data): 7 | device = data.edge_index.device 8 | 9 | row = torch.arange(data.num_nodes, dtype=torch.long, device=device) 10 | col = torch.arange(data.num_nodes, dtype=torch.long, device=device) 11 | 12 | row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1) 13 | col = col.repeat(data.num_nodes) 14 | edge_index = torch.stack([row, col], dim=0) 15 | 16 | edge_attr = None 17 | if data.edge_attr is not None: 18 | idx = data.edge_index[0] * data.num_nodes + data.edge_index[1] 19 | size = list(data.edge_attr.size()) 20 | size[0] = data.num_nodes * data.num_nodes 21 | edge_attr = data.edge_attr.new_zeros(size) 22 | edge_attr[idx] = data.edge_attr 23 | 24 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) 25 | data.edge_attr = edge_attr 26 | data.edge_index = edge_index 27 | 28 | return data 29 | -------------------------------------------------------------------------------- /software/k-gnn-master/k_gnn/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch_geometric.data import Batch 3 | 4 | 5 | def collate(data_list): 6 | keys = data_list[0].keys 7 | assert 'batch' not in keys 8 | 9 | batch = Batch() 10 | for key in keys: 11 | batch[key] = [] 12 | batch.batch = [] 13 | if 'edge_index_2' in keys: 14 | batch.batch_2 = [] 15 | if 'edge_index_3' in keys: 16 | batch.batch_3 = [] 17 | 18 | keys.remove('edge_index') 19 | props = [ 20 | 'edge_index_2', 'assignment_index_2', 'edge_index_3', 21 | 'assignment_index_3', 'assignment_index_2to3' 22 | ] 23 | keys = [x for x in keys if x not in props] 24 | 25 | cumsum_1 = N_1 = cumsum_2 = N_2 = cumsum_3 = N_3 = 0 26 | 27 | for i, data in enumerate(data_list): 28 | for key in keys: 29 | batch[key].append(data[key]) 30 | 31 | N_1 = data.num_nodes 32 | batch.edge_index.append(data.edge_index + cumsum_1) 33 | batch.batch.append(torch.full((N_1, ), i, dtype=torch.long)) 34 | 35 | if 'edge_index_2' in data: 36 | N_2 = data.assignment_index_2[1].max().item() + 1 37 | batch.edge_index_2.append(data.edge_index_2 + cumsum_2) 38 | batch.assignment_index_2.append( 39 | data.assignment_index_2 + 40 | torch.tensor([[cumsum_1], [cumsum_2]])) 41 | batch.batch_2.append(torch.full((N_2, ), i, dtype=torch.long)) 42 | 43 | if 'edge_index_3' in data: 44 | N_3 = data.assignment_index_3[1].max().item() + 1 45 | batch.edge_index_3.append(data.edge_index_3 + cumsum_3) 46 | batch.assignment_index_3.append( 47 | data.assignment_index_3 + 48 | torch.tensor([[cumsum_1], [cumsum_3]])) 49 | batch.batch_3.append(torch.full((N_3, ), i, dtype=torch.long)) 50 | 51 | if 'assignment_index_2to3' in data: 52 | assert 'edge_index_2' in data and 'edge_index_3' in data 53 | batch.assignment_index_2to3.append( 54 | data.assignment_index_2to3 + 55 | torch.tensor([[cumsum_2], [cumsum_3]])) 56 | 57 | cumsum_1 += N_1 58 | cumsum_2 += N_2 59 | cumsum_3 += N_3 60 | 61 | keys = [x for x in batch.keys if x not in ['batch', 'batch_2', 'batch_3']] 62 | for key in keys: 63 | if torch.is_tensor(batch[key][0]): 64 | batch[key] = torch.cat( 65 | batch[key], dim=data_list[0].__cat_dim__(key, batch[key][0])) 66 | 67 | batch.batch = torch.cat(batch.batch, dim=-1) 68 | 69 | if 'batch_2' in batch: 70 | batch.batch_2 = torch.cat(batch.batch_2, dim=-1) 71 | 72 | if 'batch_3' in batch: 73 | batch.batch_3 = torch.cat(batch.batch_3, dim=-1) 74 | 75 | return batch.contiguous() 76 | 77 | 78 | class DataLoader(torch.utils.data.DataLoader): 79 | def __init__(self, dataset, **kwargs): 80 | super(DataLoader, self).__init__(dataset, collate_fn=collate, **kwargs) 81 | -------------------------------------------------------------------------------- /software/k-gnn-master/k_gnn/graph_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn import Parameter 6 | from torch_scatter import scatter_add 7 | 8 | 9 | class GraphConv(torch.nn.Module): 10 | def __init__(self, 11 | in_channels, 12 | out_channels, 13 | norm=True, 14 | bias=True, 15 | dropout=0): 16 | super(GraphConv, self).__init__() 17 | 18 | self.in_channels = in_channels 19 | self.out_channels = out_channels 20 | self.norm = norm 21 | self.dropout = dropout 22 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 23 | self.root_weight = Parameter(torch.Tensor(in_channels, out_channels)) 24 | 25 | if bias: 26 | self.bias = Parameter(torch.Tensor(out_channels)) 27 | else: 28 | self.register_parameter('bias', None) 29 | 30 | self.reset_parameters() 31 | 32 | def reset_parameters(self): 33 | stdv = 1.0 / math.sqrt(self.in_channels) 34 | self.weight.data.uniform_(-stdv, stdv) 35 | self.root_weight.data.uniform_(-stdv, stdv) 36 | if self.bias is not None: 37 | self.bias.data.uniform_(-stdv, stdv) 38 | 39 | def forward(self, x, edge_index): 40 | if edge_index.numel() > 0: 41 | row, col = edge_index 42 | 43 | out = torch.mm(x, self.weight) 44 | out_col = out[col] 45 | 46 | # mask = out_col.new_ones(out_col.size(0)) 47 | # mask = F.dropout(mask, self.dropout, training=self.training) 48 | # out_col = mask.view(-1, 1) * out_col 49 | 50 | out_col = F.dropout(out_col, self.dropout, training=self.training) 51 | 52 | out = scatter_add(out_col, row, dim=0, dim_size=x.size(0)) 53 | 54 | # Normalize output by node degree. 55 | if self.norm: 56 | deg = scatter_add( 57 | x.new_ones((row.size())), row, dim=0, dim_size=x.size(0)) 58 | out = out / deg.unsqueeze(-1).clamp(min=1) 59 | 60 | # Weight root node separately. 61 | out = out + torch.mm(x, self.root_weight) 62 | else: 63 | out = torch.mm(x, self.root_weight) 64 | 65 | # Add bias (if wished). 66 | if self.bias is not None: 67 | out = out + self.bias 68 | 69 | return out 70 | 71 | def __repr__(self): 72 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 73 | self.out_channels) 74 | -------------------------------------------------------------------------------- /software/k-gnn-master/k_gnn/pool.py: -------------------------------------------------------------------------------- 1 | from torch_scatter import scatter_add, scatter_max, scatter_mean 2 | 3 | 4 | def add_pool(x, assignment): 5 | row, col = assignment 6 | return scatter_add(x[row], col, dim=0) 7 | 8 | 9 | def max_pool(x, assignment): 10 | row, col = assignment 11 | return scatter_max(x[row], col, dim=0)[0] 12 | 13 | 14 | def avg_pool(x, assignment): 15 | row, col = assignment 16 | return scatter_mean(x[row], col, dim=0) 17 | -------------------------------------------------------------------------------- /software/k-gnn-master/k_gnn/transform.py: -------------------------------------------------------------------------------- 1 | import graph_cpu 2 | 3 | 4 | class TwoLocal(object): 5 | def __call__(self, data): 6 | out = graph_cpu.two_local(data.edge_index, data.x, data.num_nodes) 7 | data.edge_index_2, data.assignment_index_2, data.iso_type_2 = out 8 | return data 9 | 10 | def __repr__(self): 11 | return '{}()'.format(self.__class__.__name__) 12 | 13 | 14 | class ConnectedTwoLocal(object): 15 | def __call__(self, data): 16 | out = graph_cpu.connected_two_local(data.edge_index, data.x, 17 | data.num_nodes) 18 | data.edge_index_2, data.assignment_index_2, data.iso_type_2 = out 19 | return data 20 | 21 | def __repr__(self): 22 | return '{}()'.format(self.__class__.__name__) 23 | 24 | 25 | class TwoMalkin(object): 26 | def __call__(self, data): 27 | out = graph_cpu.two_malkin(data.edge_index, data.x, data.num_nodes) 28 | data.edge_index_2, data.assignment_index_2, data.iso_type_2 = out 29 | return data 30 | 31 | def __repr__(self): 32 | return '{}()'.format(self.__class__.__name__) 33 | 34 | 35 | class ConnectedTwoMalkin(object): 36 | def __call__(self, data): 37 | out = graph_cpu.connected_two_malkin(data.edge_index, data.x, 38 | data.num_nodes) 39 | data.edge_index_2, data.assignment_index_2, data.iso_type_2 = out 40 | return data 41 | 42 | def __repr__(self): 43 | return '{}()'.format(self.__class__.__name__) 44 | 45 | 46 | class ThreeLocal(object): 47 | def __call__(self, data): 48 | out = graph_cpu.three_local(data.edge_index, data.x, data.num_nodes) 49 | data.edge_index_3, data.assignment_index_3, data.iso_type_3 = out 50 | return data 51 | 52 | def __repr__(self): 53 | return '{}()'.format(self.__class__.__name__) 54 | 55 | 56 | class ConnectedThreeLocal(object): 57 | def __call__(self, data): 58 | out = graph_cpu.connected_three_local(data.edge_index, data.x, 59 | data.num_nodes) 60 | data.edge_index_3, data.assignment_index_3, data.iso_type_3 = out 61 | return data 62 | 63 | def __repr__(self): 64 | return '{}()'.format(self.__class__.__name__) 65 | 66 | 67 | class ThreeMalkin(object): 68 | def __call__(self, data): 69 | out = graph_cpu.three_malkin(data.edge_index, data.x, data.num_nodes) 70 | data.edge_index_3, data.assignment_index_3, data.iso_type_3 = out 71 | return data 72 | 73 | def __repr__(self): 74 | return '{}()'.format(self.__class__.__name__) 75 | 76 | 77 | class ConnectedThreeMalkin(object): 78 | def __call__(self, data): 79 | out = graph_cpu.connected_three_malkin(data.edge_index, data.x, 80 | data.num_nodes) 81 | data.edge_index_3, data.assignment_index_3, data.iso_type_3 = out 82 | return data 83 | 84 | def __repr__(self): 85 | return '{}()'.format(self.__class__.__name__) 86 | 87 | 88 | class Assignment2To3(object): 89 | def __call__(self, data): 90 | out = graph_cpu.assignment_2to3(data.edge_index, data.num_nodes) 91 | data.assignment_index_2to3 = out 92 | return data 93 | 94 | def __repr__(self): 95 | return '{}()'.format(self.__class__.__name__) 96 | -------------------------------------------------------------------------------- /software/k-gnn-master/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | 4 | [aliases] 5 | test = pytest 6 | 7 | [tool:pytest] 8 | addopts = --capture=no 9 | -------------------------------------------------------------------------------- /software/k-gnn-master/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension 3 | 4 | __version__ = '0.0.0' 5 | url = 'https://github.com/k-gnn/k-gnn' 6 | 7 | install_requires = [] 8 | setup_requires = ['pytest-runner'] 9 | tests_require = ['pytest', 'pytest-cov'] 10 | 11 | ext_modules = [CppExtension('graph_cpu', ['cpu/graph.cpp'])] 12 | cmdclass = {'build_ext': BuildExtension} 13 | 14 | setup( 15 | name='k_gnn', 16 | version=__version__, 17 | description='', 18 | author='XXX', 19 | author_email='XXX', 20 | url=url, 21 | download_url='{}/archive/{}.tar.gz'.format(url, __version__), 22 | keywords=[], 23 | install_requires=install_requires, 24 | setup_requires=setup_requires, 25 | tests_require=tests_require, 26 | ext_modules=ext_modules, 27 | cmdclass=cmdclass, 28 | packages=find_packages(), 29 | ) 30 | --------------------------------------------------------------------------------