├── .gitignore ├── LICENSE ├── README.md ├── SGDD_agent.py ├── configs.py ├── graph.py ├── images └── yang2023does.png ├── models ├── IGNR.py ├── gat.py ├── gcn.py ├── myappnp.py ├── myappnp1.py ├── mycheby.py ├── mygatconv.py ├── mygraphsage.py ├── parametrized_adj.py ├── sgc.py └── sgc_multi.py ├── modules.py ├── requirements.txt ├── train_SGDD.py ├── utils.py ├── utils_copt.py └── utils_graphsaint.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | 141 | data/ 142 | wandb/ 143 | .vscode/ 144 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SGDD: Does Graph Distillation See Like Vision Dataset Counterpart? 2 | 3 | Official implementation of "[Does Graph Distillation See Like Vision Dataset Counterpart](http://arxiv.org/abs/2310.09192)", published as a conference paper at NeurIPS 2023. 4 | 5 | The authors of this paper are: Beining Yang*, Kai Wang*, Qingyun Sun, Cheng Ji, Xingcheng Fu, Hao Tang, Yang You, Jianxin Li 6 | ![Does Graph Distillation See Like Vision Dataset Counterpart?](./images/yang2023does.png) 7 | 8 | # Abstract 9 | Training on large-scale graphs has achieved remarkable results in graph representation learning, but its cost and storage have attracted increasing concerns. Existing graph condensation methods primarily focus on optimizing the feature matrices of condensed graphs while overlooking the impact of the structure information from the original graphs. To investigate the impact of the structure information, we conduct analysis from the spectral domain and empirically identify substantial Laplacian Energy Distribution (LED) shifts in previous works. Such shifts lead to poor performance in cross-architecture generalization and specific tasks, including anomaly detection and link prediction. In this paper, we propose a novel Structure-broadcasting Graph Dataset Distillation (\textbf{SGDD}) scheme for broadcasting the original structure information to the generation of the synthetic one, which explicitly prevents overlooking the original structure information. 10 | Theoretically, the synthetic graphs by SGDD are expected to have smaller LED shifts than previous works, leading to superior performance in both cross-architecture settings and specific tasks. 11 | We validate the proposed SGDD~across 9 datasets and achieve state-of-the-art results on all of them: for example, on YelpChi dataset, our approach maintains 98.6\% test accuracy of training on the original graph dataset with 1,000 times saving on the scale of the graph. Moreover, we empirically evaluate there exist 17.6\% $\sim$ 31.4\% reductions in LED shift crossing 9 datasets. Extensive experiments and analysis verify the effectiveness and necessity of the proposed designs. 12 | 13 | # OS Requirements 14 | * Linux OS 15 | * Python 3.7 16 | 17 | # Requirements 18 | ```code 19 | torch==1.7.0 20 | torch_geometric==1.6.3 21 | scipy==1.6.2 22 | numpy==1.19.2 23 | ogb==1.3.0 24 | tqdm==4.59.0 25 | torch_sparse==0.6.9 26 | torchvision==0.8.0 27 | configs==3.0.3 28 | deeprobust==0.2.4 29 | scikit_learn==1.0.2 30 | ``` 31 | 32 | # Download Datasets 33 | Cora, Citeseer: [Pyg](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.Planetoid.html#torch_geometric.datasets.Planetoid) 34 | Reddit, Ogbn-arxiv, Flick: [GraphSAINT](https://github.com/GraphSAINT/GraphSAINT) [GCond](https://github.com/ChandlerBang/GCond) 35 | YelpChi: [DGL](https://docs.dgl.ai/en/latest/generated/dgl.data.FraudYelpDataset.html#dgl.data.FraudYelpDataset) 36 | Amazon: [DGL](https://docs.dgl.ai/en/latest/generated/dgl.data.FraudAmazonDataset.html#dgl.data.FraudAmazonDataset) 37 | DBLP, Citeseer: [Pyg](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.DBLP.html#torch_geometric.datasets.DBLP) 38 | 39 | 40 | # Getting started 41 | * Clone this repo 42 | ``` 43 | git clone ... 44 | cd SGDD/ 45 | ``` 46 | * Install the required packages 47 | ``` 48 | pip install -r ./requirements.txt 49 | ``` 50 | * Dwonload the datasets from the above links and put them in the `./data` folder 51 | 52 | * Train the model (setting dataset to your dataset name) 53 | ``` 54 | python train_SGDD.py --dataset ${dataset} --nlayers=2 -beta 0.1 --r=0.5 --gpu_id=0 55 | ``` 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /SGDD_agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.nn import Parameter 6 | import torch.nn.functional as F 7 | from utils import match_loss, regularization, row_normalize_tensor 8 | import deeprobust.graph.utils as utils 9 | from copy import deepcopy 10 | import numpy as np 11 | from tqdm import tqdm 12 | from models.gcn import GCN 13 | from models.sgc import SGC 14 | from models.sgc_multi import SGC as SGC1 15 | 16 | from models.IGNR import GraphonLearner as IGNR 17 | import scipy.sparse as sp 18 | from torch_sparse import SparseTensor 19 | from tqdm import trange 20 | 21 | 22 | class SGDD: 23 | 24 | def __init__(self, data, args, device='cuda', **kwargs): 25 | self.data = data 26 | self.args = args 27 | self.device = device 28 | 29 | 30 | 31 | n = int(data.feat_train.shape[0] * args.reduction_rate) 32 | 33 | 34 | d = data.feat_train.shape[1] 35 | self.nnodes_syn = n 36 | self.feat_syn = nn.Parameter(torch.FloatTensor(n, d).to(device)) 37 | 38 | 39 | self.IGNR = IGNR(node_feature=d, nfeat=128, nnodes=n, device=device, args=args).to(device) 40 | self.graphon = 1 41 | 42 | self.labels_syn = torch.LongTensor(self.generate_labels_syn(data)).to(device) 43 | 44 | self.reset_parameters() 45 | self.optimizer_feat = torch.optim.Adam([self.feat_syn], lr=args.lr_feat) 46 | 47 | 48 | self.optimizer_IGNR = torch.optim.Adam(self.IGNR.parameters(), lr=args.lr_adj) 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | print('adj_syn:', (n,n), 'feat_syn:', self.feat_syn.shape) 62 | 63 | def reset_parameters(self): 64 | self.feat_syn.data.copy_(torch.randn(self.feat_syn.size())) 65 | 66 | def generate_labels_syn(self, data): 67 | from collections import Counter 68 | counter = Counter(data.labels_train) 69 | num_class_dict = {} 70 | n = len(data.labels_train) 71 | 72 | sorted_counter = sorted(counter.items(), key=lambda x:x[1]) 73 | sum_ = 0 74 | labels_syn = [] 75 | self.syn_class_indices = {} 76 | for ix, (c, num) in enumerate(sorted_counter): 77 | if ix == len(sorted_counter) - 1: 78 | num_class_dict[c] = int(n * self.args.reduction_rate) - sum_ 79 | self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]] 80 | labels_syn += [c] * num_class_dict[c] 81 | else: 82 | num_class_dict[c] = max(int(num * self.args.reduction_rate), 1) 83 | sum_ += num_class_dict[c] 84 | self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]] 85 | labels_syn += [c] * num_class_dict[c] 86 | 87 | self.num_class_dict = num_class_dict 88 | return labels_syn 89 | 90 | 91 | def test_with_val(self, verbose=True): 92 | res = [] 93 | 94 | data, device = self.data, self.device 95 | feat_syn, IGNR, labels_syn = self.feat_syn.detach(), \ 96 | self.IGNR, self.labels_syn 97 | 98 | 99 | model = GCN(nfeat=feat_syn.shape[1], nhid=self.args.hidden, dropout=0.5, 100 | weight_decay=5e-4, nlayers=2, 101 | nclass=data.nclass, device=device).to(device) 102 | 103 | if self.args.dataset in ['ogbn-arxiv']: 104 | model = GCN(nfeat=feat_syn.shape[1], nhid=self.args.hidden, dropout=0.5, 105 | weight_decay=0e-4, nlayers=2, with_bn=False, 106 | nclass=data.nclass, device=device).to(device) 107 | 108 | adj_syn = IGNR.inference(feat_syn) 109 | args = self.args 110 | 111 | import os 112 | if not os.path.exists('saved_ours'): 113 | os.makedirs('saved_ours') 114 | if self.args.save: 115 | torch.save(adj_syn, f'saved_ours/adj_{args.dataset}_{args.reduction_rate}_{args.seed}.pt') 116 | torch.save(feat_syn, f'saved_ours/feat_{args.dataset}_{args.reduction_rate}_{args.seed}.pt') 117 | torch.save(labels_syn, f'saved_ours/label_{args.dataset}_{args.reduction_rate}_{args.seed}.pt') 118 | 119 | if self.args.lr_adj == 0: 120 | n = len(labels_syn) 121 | adj_syn = torch.zeros((n, n)) 122 | 123 | 124 | 125 | model.fit_with_val(feat_syn, adj_syn, labels_syn, data, 126 | train_iters=2000, normalize=False, verbose=True) 127 | 128 | model.eval() 129 | labels_test = torch.LongTensor(data.labels_test).cuda() 130 | 131 | labels_train = torch.LongTensor(data.labels_train).cuda() 132 | output = model.predict(data.feat_train, data.adj_train) 133 | loss_train = F.nll_loss(output, labels_train) 134 | acc_train = utils.accuracy(output, labels_train) 135 | if verbose: 136 | print("Train set results:", 137 | "loss= {:.4f}".format(loss_train.item()), 138 | "accuracy= {:.4f}".format(acc_train.item())) 139 | res.append(acc_train.item()) 140 | 141 | 142 | output = model.predict(data.feat_full, data.adj_full) 143 | loss_test = F.nll_loss(output[data.idx_test], labels_test) 144 | acc_test = utils.accuracy(output[data.idx_test], labels_test) 145 | res.append(acc_test.item()) 146 | if verbose: 147 | print("Test set results:", 148 | "loss= {:.4f}".format(loss_test.item()), 149 | "accuracy= {:.4f}".format(acc_test.item())) 150 | return res 151 | 152 | def train(self, verbose=True): 153 | args = self.args 154 | data = self.data 155 | feat_syn, IGNR, labels_syn = self.feat_syn, self.IGNR, self.labels_syn 156 | features, adj, labels = data.feat_full, data.adj_full, data.labels_full 157 | idx_train = data.idx_train 158 | 159 | syn_class_indices = self.syn_class_indices 160 | 161 | features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device) 162 | 163 | feat_sub, adj_sub = self.get_sub_adj_feat(features) 164 | self.feat_syn.data.copy_(feat_sub) 165 | 166 | if utils.is_sparse_tensor(adj): 167 | adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 168 | else: 169 | adj_norm = utils.normalize_adj_tensor(adj) 170 | 171 | adj = adj_norm 172 | adj = SparseTensor(row=adj._indices()[0], col=adj._indices()[1], 173 | value=adj._values(), sparse_sizes=adj.size()).t() 174 | 175 | 176 | outer_loop, inner_loop = get_loops(args) 177 | loss_avg = 0 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | for it in trange(args.epochs+1): 188 | if args.dataset in ['ogbn-arxiv']: 189 | model = SGC1(nfeat=feat_syn.shape[1], nhid=self.args.hidden, 190 | dropout=0.0, with_bn=False, 191 | weight_decay=0e-4, nlayers=2, 192 | nclass=data.nclass, 193 | device=self.device).to(self.device) 194 | else: 195 | if args.sgc == 1: 196 | model = SGC(nfeat=data.feat_train.shape[1], nhid=args.hidden, 197 | nclass=data.nclass, dropout=args.dropout, 198 | nlayers=args.nlayers, with_bn=False, 199 | device=self.device).to(self.device) 200 | else: 201 | model = GCN(nfeat=data.feat_train.shape[1], nhid=args.hidden, 202 | nclass=data.nclass, dropout=args.dropout, nlayers=args.nlayers, 203 | device=self.device).to(self.device) 204 | 205 | 206 | model.initialize() 207 | 208 | model_parameters = list(model.parameters()) 209 | 210 | optimizer_model = torch.optim.Adam(model_parameters, lr=args.lr_model) 211 | model.train() 212 | 213 | 214 | for ol in range(outer_loop): 215 | 216 | if adj.size(0) > 5000: 217 | random_nodes = np.random.choice(list(range(adj.size(0))), 5000, replace=False) 218 | adj_syn, opt_loss = IGNR(self.feat_syn, Lx=adj[random_nodes].to_dense()[:, random_nodes]) 219 | else: 220 | adj_syn, opt_loss = IGNR(self.feat_syn, Lx=adj) 221 | 222 | 223 | adj_syn_norm = utils.normalize_adj_tensor(adj_syn, sparse=False) 224 | feat_syn_norm = feat_syn 225 | 226 | 227 | BN_flag = False 228 | for module in model.modules(): 229 | if 'BatchNorm' in module._get_name(): 230 | BN_flag = True 231 | if BN_flag: 232 | model.train() 233 | output_real = model.forward(features, adj_norm) 234 | for module in model.modules(): 235 | if 'BatchNorm' in module._get_name(): 236 | module.eval() 237 | 238 | loss = torch.tensor(0.0).to(self.device) 239 | for c in range(data.nclass): 240 | batch_size, n_id, adjs = data.retrieve_class_sampler( 241 | c, adj, transductive=True, args=args) 242 | if args.nlayers == 1: 243 | adjs = [adjs] 244 | 245 | adjs = [adj.to(self.device) for adj in adjs] 246 | output = model.forward_sampler(features[n_id], adjs) 247 | loss_real = F.nll_loss(output, labels[n_id[:batch_size]]) 248 | 249 | gw_real = torch.autograd.grad(loss_real, model_parameters) 250 | gw_real = list((_.detach().clone() for _ in gw_real)) 251 | output_syn = model.forward(feat_syn, adj_syn_norm) 252 | 253 | ind = syn_class_indices[c] 254 | loss_syn = F.nll_loss( 255 | output_syn[ind[0]: ind[1]], 256 | labels_syn[ind[0]: ind[1]]) 257 | gw_syn = torch.autograd.grad(loss_syn, model_parameters, create_graph=True) 258 | coeff = self.num_class_dict[c] / max(self.num_class_dict.values()) 259 | loss += coeff * match_loss(gw_syn, gw_real, args, device=self.device) 260 | 261 | loss_avg += loss.item() 262 | 263 | if args.beta > 0: 264 | loss_reg = args.beta * regularization(adj_syn, utils.tensor2onehot(labels_syn).to(adj_syn.device)) 265 | else: 266 | loss_reg = torch.tensor(0) 267 | 268 | loss = loss + loss_reg 269 | 270 | 271 | self.optimizer_feat.zero_grad() 272 | self.optimizer_IGNR.zero_grad() 273 | if it % 50 < 10: 274 | 275 | 276 | loss = loss + args.opt_scale* opt_loss 277 | 278 | 279 | loss.backward() 280 | self.optimizer_IGNR.step() 281 | else: 282 | loss.backward() 283 | self.optimizer_feat.step() 284 | 285 | if args.debug and ol % 5 ==0: 286 | print('Gradient matching loss:', loss.item()) 287 | 288 | if ol == outer_loop - 1: 289 | 290 | 291 | break 292 | 293 | feat_syn_inner = feat_syn.detach() 294 | adj_syn_inner = IGNR.inference(feat_syn_inner) 295 | adj_syn_inner_norm = utils.normalize_adj_tensor(adj_syn_inner, sparse=False) 296 | feat_syn_inner_norm = feat_syn_inner 297 | for j in range(inner_loop): 298 | optimizer_model.zero_grad() 299 | output_syn_inner = model.forward(feat_syn_inner_norm, adj_syn_inner_norm) 300 | loss_syn_inner = F.nll_loss(output_syn_inner, labels_syn) 301 | loss_syn_inner.backward() 302 | 303 | optimizer_model.step() 304 | 305 | 306 | loss_avg /= (data.nclass*outer_loop) 307 | if it % 50 == 0: 308 | print('Epoch {}, loss_avg: {}'.format(it, loss_avg)) 309 | 310 | 311 | 312 | eval_epochs = list(range(0, 5000, 50)) 313 | if verbose and it in eval_epochs: 314 | 315 | res = [] 316 | runs = 1 if args.dataset in ['ogbn-arxiv'] else 3 317 | for i in range(runs): 318 | if args.dataset in ['ogbn-arxiv']: 319 | res.append(self.test_with_val()) 320 | else: 321 | res.append(self.test_with_val()) 322 | 323 | res = np.array(res) 324 | print('Train/Test Mean Accuracy:', 325 | repr([res.mean(0), res.std(0)])) 326 | 327 | def get_sub_adj_feat(self, features): 328 | data = self.data 329 | args = self.args 330 | idx_selected = [] 331 | 332 | from collections import Counter; 333 | counter = Counter(self.labels_syn.cpu().numpy()) 334 | 335 | for c in range(data.nclass): 336 | tmp = data.retrieve_class(c, num=counter[c]) 337 | tmp = list(tmp) 338 | idx_selected = idx_selected + tmp 339 | idx_selected = np.array(idx_selected).reshape(-1) 340 | features = features[self.data.idx_train][idx_selected] 341 | 342 | 343 | from sklearn.metrics.pairwise import cosine_similarity 344 | 345 | k = 2 346 | sims = cosine_similarity(features.cpu().numpy()) 347 | sims[(np.arange(len(sims)), np.arange(len(sims)))] = 0 348 | for i in range(len(sims)): 349 | indices_argsort = np.argsort(sims[i]) 350 | sims[i, indices_argsort[: -k]] = 0 351 | adj_knn = torch.FloatTensor(sims).to(self.device) 352 | return features, adj_knn 353 | 354 | 355 | def get_loops(args): 356 | if args.one_step: 357 | if args.dataset =='ogbn-arxiv': 358 | return 5, 0 359 | return 1, 0 360 | if args.dataset in ['ogbn-arxiv']: 361 | return args.outer, args.inner 362 | if args.dataset in ['cora']: 363 | return 20, 15 364 | if args.dataset in ['citeseer']: 365 | return 20, 15 366 | if args.dataset in ['physics']: 367 | return 20, 10 368 | else: 369 | return 20, 10 370 | 371 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | '''Configuration''' 2 | 3 | def load_config(args): 4 | dataset = args.dataset 5 | if dataset in ['flickr']: 6 | args.nlayers = 2 7 | args.hidden = 256 8 | args.weight_decay = 5e-3 9 | args.dropout = 0.0 10 | 11 | if dataset in ['reddit']: 12 | args.nlayers = 2 13 | args.hidden = 256 14 | args.weight_decay = 0e-4 15 | args.dropout = 0 16 | 17 | if dataset in ['ogbn-arxiv']: 18 | args.hidden = 256 19 | args.weight_decay = 0 20 | args.dropout = 0 21 | 22 | return args 23 | 24 | 25 | -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Class for core routine in coordinated optimal transport, can be applied to graph sketching, graph comparison, etc. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch_sparse 9 | import utils_copt as utils 10 | from torch.nn.parameter import Parameter 11 | import sys 12 | import time 13 | 14 | import pdb 15 | 16 | device = utils.device 17 | 18 | class GraphDist(nn.Module): 19 | # def __init__(self, Lx, m, n, args, Ly=None, take_ly_exp=True): 20 | def __init__(self, Lx, m, n, args, Ly=None, take_ly_exp=False): 21 | """ 22 | Input: Lx and Ly are graph Laplacians, Lx being the input, origin graph, 23 | Ly the Laplacian of the target graph to be compared with. 24 | If sketching graph X, Ly should be None. 25 | Lx and Ly are *log* of upper triangular part of Laplacian of X and Y. 26 | """ 27 | super(GraphDist, self).__init__() 28 | 29 | #P is m x n prob matrix, m is number of nodes in X, n is # of nodes in Y 30 | 31 | if args.fix_seed: 32 | torch.manual_seed(0) 33 | self.P = torch.empty(m, n).uniform_(1,2) 34 | 35 | #scale with sinkhorn iterations 36 | for _ in range(args.sinkhorn_iter): 37 | self.P /= (self.P.sum(1, keepdim=True)/n) #*m 38 | self.P /= (self.P.sum(0, keepdim=True)/m) #*n 39 | self.P = Parameter(self.P) 40 | #n x n mx, symmetric, off diag < 0, row & col sums are 0 41 | self.Lx_inv = mx_inv(Lx) 42 | 43 | #upper triangular part of Ly 44 | if Ly is None: 45 | if args.fix_seed: 46 | torch.manual_seed(0) 47 | 48 | self.Ly = Parameter(torch.randn(n*(n-1)//2)) 49 | self.optim = optim.Adam([self.P, self.Ly], lr=.4) 50 | self.fix_ly = False 51 | else: 52 | Ly = Ly.clone() 53 | if take_ly_exp: 54 | assert len(Ly.shape) == 1 55 | self.Ly = realize_upper(Ly, args.n, take_ly_exp=take_ly_exp) #Parameter(Ly) 56 | else: 57 | self.Ly = Ly 58 | self.Ly = self.Ly.to(device=device) 59 | self.Ly_inv_rt, self.Ly_inv = mx_inv_sqrt(self.Ly) 60 | 61 | self.optim = optim.Adam([self.P], lr=.35) #.4 62 | self.fix_ly = True 63 | 64 | #this is replaced by built-in scheduler in forward pass 65 | #milestones = [100*i for i in range(1, 4)] #[100, 200, 300] 66 | #self.scheduler = optim.lr_scheduler.MultiStepLR(self.optim, milestones=milestones, gamma=0.51) 67 | 68 | self.Lx = Lx 69 | self.m, self.n = m, n 70 | 71 | self.args = args 72 | self.y_labels = None 73 | 74 | def compute_graph_dist(self): 75 | Lx = self.Lx 76 | loss0 = sys.maxsize 77 | delta_l = [] 78 | 79 | for i in range(self.args.n_epochs): 80 | loss, P, Ly = self.ot_dist(Lx, epoch=i) 81 | self.optim.zero_grad() 82 | loss.backward() 83 | 84 | if self.args.verbose and i % 20 == 0: 85 | cur_lr = self.optim.param_groups[0]['lr'] 86 | sys.stdout.write('{} lr {}'.format(str(loss.cpu().item()), str(cur_lr)) + ', ') 87 | #torch.nn.utils.clip_grad_norm_(self.parameters(), 5.) 88 | self.optim.step() 89 | 90 | if (i % 101 == 0 and i < 2000): 91 | cur_lr = -1 92 | 93 | for p_group in self.optim.param_groups: 94 | p_group['lr'] *= .7 95 | cur_lr = p_group['lr'] 96 | print_lr = False #True 97 | if print_lr: 98 | print(i, cur_lr, i, loss.item()) 99 | if self.args.lr_hike and loss0 - loss < .002 and i < self.args.n_epochs-200: 100 | delta_l.append(loss) 101 | 102 | if len(delta_l) > self.args.hike_interval: 103 | # can enable lr hiking to help escape local minima, described in supplement. 104 | cur_lr = -1 105 | for p_group in self.optim.param_groups: 106 | p_group['lr'] = min(p_group['lr']*5, 4) 107 | cur_lr = p_group['lr'] 108 | 109 | #print('hike ', i) 110 | #print(i, p_group['lr'], i, loss.item()) 111 | delta_l = [] 112 | if self.args.early_stopping: 113 | break #can break here for early stopping 114 | 115 | loss0 = loss 116 | 117 | ''' 118 | avg_dur = total_dur / total_iter 119 | print('Avg iter timing {}'.format(avg_dur )) 120 | with open('time.txt', 'a') as f: 121 | f.write('{}\n'.format(avg_dur)) 122 | ''' 123 | P = P.clone().detach() 124 | if self.args.plot: 125 | labels = torch.topk(P, k=2, dim=0)[1].t().cpu().numpy() 126 | #self.y_labels = {k:str(label[0])+' '+str(label[1])+' '+str(label[2]) for k, label in enumerate(labels)} 127 | self.y_labels = {k:str(label[0])+' '+str(label[1]) for k, label in enumerate(labels)} 128 | 129 | return loss, P, Ly.clone().detach() 130 | 131 | def ot_dist(self, Lx, epoch=0): 132 | """ 133 | Distance between two graphs. Evolve Ly and P simultaneouly 134 | 1/|X| tr(L_X). self.Ly is the log of the actual laplacian. 135 | """ 136 | if not self.fix_ly: 137 | ones = torch.ones(self.n, self.n, dtype=torch.uint8, device=device) 138 | Ly = torch.zeros(self.n, self.n, device=device) 139 | Ly[torch.triu(ones, diagonal=1)] = -self.Ly**2 140 | ''' 141 | #can also use Huber function to enforce positivity. 142 | Ly_val = torch.abs(self.Ly.clone()) 143 | Ly_val[Ly_val < 1] = -Ly_val[Ly_val < 1]**2/2 144 | Ly_val[torch.abs(self.Ly) >= 1] = -(torch.abs(self.Ly[torch.abs(self.Ly) >= 1])-.5) 145 | Ly[torch.triu(ones, diagonal=1)] = Ly_val 146 | ''' 147 | #Ly[torch.tril(ones, diagonal=-1)] = Ly[torch.triu(ones, diagonal=1)].t() 148 | #ensure laplacian 149 | Ly += Ly.clone().t() 150 | Ly[torch.eye(self.n, dtype=torch.uint8, device=device)] = -Ly.sum(0) 151 | 152 | Ly_inv_rt, Ly_inv = mx_inv_sqrt(Ly) 153 | #regularization 154 | #Ly += .1*torch.eye(self.n) 155 | else: 156 | Ly = self.Ly 157 | Ly_inv_rt, Ly_inv = self.Ly_inv_rt, self.Ly_inv 158 | 159 | ''' 160 | Ly = utils.symmetrize(Ly, inplace=False) 161 | Ly_diag = Ly.diag() 162 | #Ly[(1-torch.eye(n)) > 0] = min(-Ly[(1-torch.eye(n)) > 0], Ly[(1-torch.eye(n)) > 0]) #off diag terms neg 163 | #Ly = torch.min(Ly, -Ly) 164 | Ly *= -1 165 | Ly[torch.eye(self.n) > 0] = Ly_diag 166 | ''' 167 | #Ly[torch.eye(self.n) > 0] = 0 168 | 169 | P = self.P.abs() 170 | for _ in range(self.args.sinkhorn_iter): 171 | P = P / (P.sum(1, keepdim=True)/self.n) #*m 172 | P = P / (P.sum(0, keepdim=True)/self.m) #*n 173 | 174 | #approximate Ly^{-1/2} 175 | #when transport plan becomes uniform, mixed term goes to 0 176 | use_symeig = True 177 | if use_symeig: 178 | sqrt = torch.symeig(Ly_inv_rt @ P.t() @ self.Lx_inv @ P @ Ly_inv_rt, eigenvectors=True ) 179 | #.clamp(min=0) here due to inconsistency between pytorch svd and symeig, ie svd gives PSD but symeig gives neg eval. 180 | ##self.w_dist = mx_tr(mx_inv(Lx))/self.m + mx_tr(mx_inv(Ly))/self.n - 2*torch.sqrt(sqrt[0].clamp(min=0)).sum() # 181 | self.w_dist = mx_tr(self.Lx_inv)*self.n + mx_tr(Ly_inv)*self.m - 2*torch.sqrt(sqrt[0].clamp(min=2e-20)).sum() # 182 | else: 183 | sqrt = Ly_inv_rt @ P.t() @ self.Lx_inv @ P @ Ly_inv_rt 184 | self.w_dist = mx_tr(self.Lx_inv)**2*self.n**2 + mx_tr(Ly_inv)**2*self.m**2 - 2*mx_tr(sqrt) 185 | loss = self.w_dist 186 | 187 | #conditions on Ly, symmetric, off diag non-positive, row & col sums 0 188 | #row and col sums, can force to be 0 189 | ##ly_loss = torch.abs(Ly.sum(dim=1)).sum() + torch.abs(Ly.sum(dim=0)).sum() 190 | #Ly_off_diag = Ly[1-torch.eye(self.n) > 0] 191 | #off diag should be non-positive 192 | #ly_loss += Ly_off_diag[Ly_off_diag > 0].sum() 193 | #ly_loss += torch.abs(torch.triu(Ly, diagonal=1) - torch.tril(Ly, diagonal=-1).t()).sum() #### #symmetry 194 | 195 | return loss.clamp(min=0), P, Ly 196 | 197 | def realize_upper(upper, sz, take_ly_exp=False): 198 | ones = torch.ones(sz, sz, dtype=torch.uint8) 199 | Ly = torch.zeros(sz, sz) 200 | if take_ly_exp: 201 | Ly[torch.triu(ones, diagonal=1)] = -torch.exp(upper) 202 | else: 203 | Ly[torch.triu(ones, diagonal=1)] = upper 204 | Ly += Ly.t() 205 | Ly[torch.eye(sz, dtype=torch.uint8)] = -Ly.sum(0) 206 | return Ly 207 | 208 | def mx_inv(mx): 209 | if isinstance(mx, torch_sparse.tensor.SparseTensor): 210 | mx = mx.to_dense() 211 | U, D, V = torch.svd(mx) 212 | eps = 0.009 213 | D_min = torch.min(D) 214 | if D_min < eps: 215 | D_1 = torch.zeros_like(D) 216 | D_1[D>D_min] = 1/D[D>D_min] 217 | else: 218 | D_1 = 1/D 219 | #D_1 = 1 / D #.clamp(min=0.005) 220 | 221 | return U @ D_1.diag() @ V.t() 222 | 223 | def mx_inv_sqrt(mx): 224 | # singular values need to be distinct for backprop 225 | U, D, V = torch.svd(mx) 226 | D_min = torch.min(D) 227 | eps = 0.009 228 | if D_min < eps: 229 | D_1 = torch.zeros_like(D) 230 | D_1[D>D_min] = 1/D[D>D_min] #.sqrt() 231 | else: 232 | D_1 = 1/D #.sqrt() 233 | #D_1 = 1 / D.clamp(min=0.005).sqrt() 234 | return U @ D_1.sqrt().diag() @ V.t(), U @ D_1.diag() @ V.t() 235 | 236 | def mx_tr(mx): 237 | return mx.diag().sum() 238 | 239 | def mx_svd(mx, topk): 240 | U, D, V = torch.svd(mx) 241 | #topk x m 242 | evecs = U.t()[-topk-1:-1] 243 | row1 = evecs[-1, :] 244 | idx = torch.argsort(row1) 245 | evecs = evecs[:, idx] 246 | return evecs 247 | 248 | def graph_dist(args, plot=True, Ly=None, take_ly_exp=True): 249 | args.Lx = args.Lx.to(device) 250 | args.plot = plot 251 | model = GraphDist(args.Lx, args.m, args.n, args, Ly=Ly, take_ly_exp=take_ly_exp) 252 | model = model.to(device) 253 | loss, P, Ly = model.compute_graph_dist() 254 | Ly = Ly.cpu() 255 | #view graphs 256 | if plot: 257 | utils.view_graph(args.Lx, soft_edge=True, name='x') 258 | #Ly = realize_upper(model.Ly.detach(), args.n) 259 | utils.view_graph(Ly, soft_edge=True, name='y', labels=model.y_labels) 260 | pdb.set_trace() 261 | return loss, P, Ly 262 | 263 | def visualize_graph(graph_type, args): 264 | g = utils.create_graph(args.m, graph_type) 265 | args.Lx = utils.graph_to_lap(g).to(device) 266 | args.m = len(args.Lx) 267 | 268 | model = GraphDist(args.Lx, args.m, args.n, args) 269 | model = model.to(device) 270 | loss, P, Ly = model.compute_graph_dist() 271 | Ly = Ly.cpu() 272 | #view graphs 273 | 274 | utils.view_graph(args.Lx, soft_edge=True, name='{}x'.format(graph_type)) 275 | #Ly = realize_upper(model.Ly.detach(), args.n) 276 | utils.view_graph(Ly, soft_edge=True, name='{}y'.format(graph_type), labels=model.y_labels) 277 | pdb.set_trace() 278 | return loss, P, Ly 279 | 280 | if __name__ == '__main__': 281 | args = utils.parse_args() 282 | args.m = 11 #5 6 5 12 (barbell) 283 | args.n = 9 #5 3 6 6 284 | if args.fix_seed: 285 | torch.manual_seed(0) 286 | test = torch.load('test.pt') 287 | args.Lx = test['q5'] 288 | args.Ly = test['data19'] 289 | args.m = len(args.Lx) 290 | args.n = len(args.Ly) 291 | graph_dist(args, plot=False, Ly=args.Ly, take_ly_exp=False) 292 | pdb.set_trace() 293 | args.Lx = torch.randn(args.m*(args.m-1)//2) #torch.FloatTensor([[1, -1], [-1, 2]]) 294 | args.Lx = realize_upper(args.Lx, args.m) 295 | args.n_epochs = 300 #1200 #100 296 | args.plot = False 297 | if False:#True: #False: #True: 298 | #graphs, labels = utils.load_data('data/graphs50.pkl') 299 | #args.Lx = utils.graph_to_lap(graphs[2]) 300 | params = {'n_blocks':2} 301 | g2 = utils.create_graph(1000, gtype='block', params=params) 302 | args.Lx = utils.graph_to_lap(g2) 303 | args.m = len(args.Lx) 304 | ''' 305 | params = {'n_blocks':3} 306 | g3 = utils.create_graph(50, gtype='block', params=params) 307 | args.Ly = utils.graph_to_lap(g3) 308 | args.Ly = args.Lx.clone() 309 | args.n = len(args.Ly) 310 | args.n = 15 311 | ''' 312 | args.n = 200 313 | ##graph_dist(args, Ly=args.Ly, plot=False, take_ly_exp=False) 314 | #50 -> 7 => loss 12.3 315 | 316 | #pdb.set_trace() 317 | check_graph = False 318 | if check_graph: 319 | try: 320 | graphs, labels = utils.load_data('./data/graphs50.pkl') 321 | except FileNotFoundError: 322 | graphs, labels = utils.load_data('./copt/data/graphs.pkl') 323 | args.Lx = utils.graph_to_lap(graphs[2]) 324 | ##g = utils.create_graph(12) 325 | ##args.Lx = utils.graph_to_lap(g) 326 | #args.Lx[torch.eye(len(args.Lx), dtype=torch.uint8)] = args.Lx.sum(0) 327 | args.m = args.n = len(args.Lx) #for 30 to 10 the graphs can get to 9 loss for 280 epochs, .4 lr with schedule 328 | #50 -> 7 => loss 12.3 329 | args.n = 7 330 | vis = True #False #True 331 | if vis: 332 | graph_type = 'wheel' #'lollipop' #'barbell' # # #'wheel' #'cycle' #'ladder' #'hypercube' #'pappus' #'grid' #'hypercube'#'grid'#'ladder' #'barbell' 333 | visualize_graph(graph_type, args) 334 | 335 | pdb.set_trace() 336 | #args.Lx = torch.exp(torch.FloatTensor([[2, -2], [-2, 1]])) #good initializations?! checks & stability 337 | print('args ', args) 338 | #graph_dist(args, plot=False, Ly=args.Lx, take_ly_exp=False) 339 | graph_dist(args, plot=False, take_ly_exp=False) 340 | 341 | 342 | -------------------------------------------------------------------------------- /images/yang2023does.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/SGDD/5ac77c1dc4e725e468b41f8f09a6c530adc3af58/images/yang2023does.png -------------------------------------------------------------------------------- /models/IGNR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from itertools import product 6 | from graph import mx_inv, mx_inv_sqrt, mx_tr 7 | 8 | def get_mgrid(sidelen, dim=2): 9 | if isinstance(sidelen, int): 10 | sidelen = dim * (sidelen,) 11 | 12 | if dim == 2: 13 | pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) 14 | pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1) 15 | pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1) 16 | elif dim == 3: 17 | pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) 18 | pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1) 19 | pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1) 20 | pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1) 21 | else: 22 | raise NotImplementedError('Not implemented for dim=%d' % dim) 23 | 24 | pixel_coords -= 0.5 25 | pixel_coords *= 2. 26 | pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) 27 | return pixel_coords.to(torch.float32) 28 | 29 | class Sine(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def forward(self, input): 34 | return torch.sin(30 * input) 35 | 36 | class EdgeBlock(nn.Module): 37 | def __init__(self, in_, out_, dtype=torch.float32) -> None: 38 | super(EdgeBlock, self).__init__() 39 | self.net = nn.Sequential( 40 | nn.Linear(in_, out_, dtype=dtype), 41 | nn.BatchNorm1d(out_, dtype=dtype), 42 | nn.ReLU()) 43 | 44 | def forward(self, x): 45 | return self.net(x) 46 | 47 | 48 | class GraphonLearner(nn.Module): 49 | def __init__(self, node_feature, nfeat=256, nnodes=50, device="cuda", args={}, num_hidden_layers=3, **kwargs): 50 | super().__init__() 51 | self.num_hidden_layers = num_hidden_layers 52 | self.step_size = nnodes 53 | self.ep_ratio = args.ep_ratio 54 | self.sinkhorn_iter = args.sinkhorn_iter 55 | self.mx_size = args.mx_size 56 | 57 | self.edge_index = np.array(list(product(range(self.step_size), range(self.step_size)))).T 58 | 59 | self.net0 = nn.ModuleList([ 60 | EdgeBlock(node_feature*2, nfeat), 61 | EdgeBlock(nfeat, nfeat), 62 | nn.Linear(nfeat, 1, dtype=torch.float32) 63 | ]) 64 | 65 | self.net1 = nn.ModuleList([ 66 | EdgeBlock(2, nfeat), 67 | EdgeBlock(nfeat, nfeat), 68 | nn.Linear(nfeat, 1, dtype=torch.float32) 69 | ]) 70 | 71 | self.P = nn.Parameter(torch.Tensor(self.mx_size, self.step_size).to(torch.float32).uniform_(0, 1)) # transport plan 72 | self.Lx_inv = None 73 | 74 | self.output = nn.Linear(nfeat, 1) 75 | self.act = F.relu 76 | self.device = device 77 | self.reset_parameters() 78 | 79 | def reset_parameters(self): 80 | def weight_reset(m): 81 | if isinstance(m, nn.Linear): 82 | m.reset_parameters() 83 | if isinstance(m, nn.BatchNorm1d): 84 | m.reset_parameters() 85 | self.apply(weight_reset) 86 | 87 | def forward(self, c, inference=False, Lx=None): 88 | if inference == True: 89 | self.eval() 90 | else: 91 | self.train() 92 | x0 = get_mgrid(c.shape[0]).to(self.device) 93 | c = torch.cat([c[self.edge_index[0]], 94 | c[self.edge_index[1]]], axis=1) 95 | for layer in range(len(self.net0)): 96 | c = self.net0[layer](c) 97 | if layer == 0: 98 | x = self.net1[layer](x0) 99 | else: 100 | x = self.net1[layer](x) 101 | 102 | if layer != (len(self.net0) - 1): 103 | # use node feature to guide the graphon generating process 104 | x = x*c 105 | else: 106 | x = (1 - self.ep_ratio) * x + self.ep_ratio * c 107 | 108 | # x = self.output(x) 109 | # adj = self.output(x).reshape(self.step_size, self.step_size) 110 | adj = x.reshape(self.step_size, self.step_size) 111 | 112 | adj = (adj + adj.T)/2 113 | adj = torch.sigmoid(adj) 114 | adj = adj - torch.diag(torch.diag(adj, 0)) 115 | 116 | if inference == True: 117 | return adj 118 | if Lx is not None and self.Lx_inv is None: 119 | self.Lx_inv = mx_inv(Lx) 120 | try: 121 | opt_loss = self.opt_loss(adj) 122 | except: 123 | opt_loss = torch.tensor(0).to(self.device) 124 | return adj, opt_loss 125 | 126 | 127 | def opt_loss(self, adj): 128 | Ly_inv_rt, Ly_inv = mx_inv_sqrt(adj) 129 | m = self.step_size 130 | P = self.P.abs() 131 | 132 | for _ in range(self.sinkhorn_iter): 133 | P = P / P.sum(dim=1, keepdim=True) 134 | P = P / P.sum(dim=0, keepdim=True) 135 | 136 | # if self.args.use_symeig: 137 | sqrt = torch.symeig(Ly_inv_rt @ self.P.t() @ self.Lx_inv @ self.P @ Ly_inv_rt, eigenvectors=True) 138 | loss = torch.abs(mx_tr(Ly_inv)*m - 2*torch.sqrt(sqrt[0].clamp(min=2e-20)).sum()) 139 | return loss 140 | 141 | @torch.no_grad() 142 | def inference(self, c): 143 | return self.forward(c, inference=True) 144 | -------------------------------------------------------------------------------- /models/gat.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | import torch 5 | import torch.optim as optim 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from deeprobust.graph import utils 9 | from copy import deepcopy 10 | from torch_geometric.nn import SGConv 11 | from torch_geometric.nn import APPNP as ModuleAPPNP 12 | # from torch_geometric.nn import GATConv 13 | from .mygatconv import GATConv 14 | import numpy as np 15 | import scipy.sparse as sp 16 | 17 | from torch.nn import Linear 18 | from itertools import repeat 19 | 20 | 21 | class GAT(torch.nn.Module): 22 | 23 | def __init__(self, nfeat, nhid, nclass, heads=8, output_heads=1, dropout=0.5, lr=0.01, 24 | weight_decay=5e-4, with_bias=True, device=None, **kwargs): 25 | 26 | super(GAT, self).__init__() 27 | 28 | assert device is not None, "Please specify 'device'!" 29 | self.device = device 30 | self.dropout = dropout 31 | self.lr = lr 32 | self.weight_decay = weight_decay 33 | 34 | if 'dataset' in kwargs: 35 | if kwargs['dataset'] in ['ogbn-arxiv']: 36 | dropout = 0.7 # arxiv 37 | elif kwargs['dataset'] in ['reddit']: 38 | dropout = 0.05; self.dropout = 0.1; self.weight_decay = 5e-4 39 | # self.weight_decay = 5e-2; dropout=0.05; self.dropout=0.1 40 | elif kwargs['dataset'] in ['citeseer']: 41 | dropout = 0.7 42 | self.weight_decay = 5e-4 43 | elif kwargs['dataset'] in ['flickr']: 44 | dropout = 0.8 45 | # nhid=8; heads=8 46 | # self.dropout=0.1 47 | else: 48 | dropout = 0.7 # cora, citeseer, reddit 49 | else: 50 | dropout = 0.7 51 | self.conv1 = GATConv( 52 | nfeat, 53 | nhid, 54 | heads=heads, 55 | dropout=dropout, 56 | bias=with_bias) 57 | 58 | self.conv2 = GATConv( 59 | nhid * heads, 60 | nclass, 61 | heads=output_heads, 62 | concat=False, 63 | dropout=dropout, 64 | bias=with_bias) 65 | 66 | self.output = None 67 | self.best_model = None 68 | self.best_output = None 69 | 70 | # def forward(self, data): 71 | # x, edge_index = data.x, data.edge_index 72 | # x = F.dropout(x, p=self.dropout, training=self.training) 73 | # x = F.elu(self.conv1(x, edge_index)) 74 | # x = F.dropout(x, p=self.dropout, training=self.training) 75 | # x = self.conv2(x, edge_index) 76 | # return F.log_softmax(x, dim=1) 77 | 78 | def forward(self, data): 79 | # x, edge_index = data.x, data.edge_index 80 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight 81 | x = F.dropout(x, p=self.dropout, training=self.training) 82 | x = F.elu(self.conv1(x, edge_index, edge_weight=edge_weight)) 83 | # print(self.conv1.att_l.sum()) 84 | x = F.dropout(x, p=self.dropout, training=self.training) 85 | x = self.conv2(x, edge_index, edge_weight=edge_weight) 86 | return F.log_softmax(x, dim=1) 87 | 88 | 89 | def initialize(self): 90 | self.conv1.reset_parameters() 91 | self.conv2.reset_parameters() 92 | 93 | 94 | def fit(self, feat, adj, labels, idx, data=None, train_iters=600, initialize=True, verbose=False, patience=None, noval=False, **kwargs): 95 | 96 | data_train = GraphData(feat, adj, labels) 97 | data_train = Dpr2Pyg(data_train)[0] 98 | 99 | data_test = Dpr2Pyg(GraphData(data.feat_test, data.adj_test, None))[0] 100 | 101 | if noval: 102 | data_val = GraphData(data.feat_val, data.adj_val, None) 103 | data_val = Dpr2Pyg(data_val)[0] 104 | else: 105 | data_val = GraphData(data.feat_full, data.adj_full, None) 106 | data_val = Dpr2Pyg(data_val)[0] 107 | 108 | labels_val = torch.LongTensor(data.labels_val).to(self.device) 109 | 110 | if initialize: 111 | self.initialize() 112 | 113 | if len(data_train.y.shape) > 1: 114 | self.multi_label = True 115 | self.loss = torch.nn.BCELoss() 116 | else: 117 | self.multi_label = False 118 | self.loss = F.nll_loss 119 | 120 | 121 | data_train.y = data_train.y.float() if self.multi_label else data_train.y 122 | # data_val.y = data_val.y.float() if self.multi_label else data_val.y 123 | 124 | if verbose: 125 | print('=== training gat model ===') 126 | 127 | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 128 | best_acc_val = 0 129 | best_loss_val = 100 130 | for i in range(train_iters): 131 | # if i == train_iters // 2: 132 | if i in [1500]: 133 | lr = self.lr*0.1 134 | optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) 135 | 136 | self.train() 137 | optimizer.zero_grad() 138 | output = self.forward(data_train) 139 | loss_train = self.loss(output, data_train.y) 140 | loss_train.backward() 141 | optimizer.step() 142 | 143 | with torch.no_grad(): 144 | self.eval() 145 | 146 | output = self.forward(data_val) 147 | if noval: 148 | loss_val = F.nll_loss(output, labels_val) 149 | acc_val = utils.accuracy(output, labels_val) 150 | else: 151 | loss_val = F.nll_loss(output[data.idx_val], labels_val) 152 | acc_val = utils.accuracy(output[data.idx_val], labels_val) 153 | 154 | 155 | if loss_val < best_loss_val: 156 | best_loss_val = loss_val 157 | self.output = output 158 | weights = deepcopy(self.state_dict()) 159 | 160 | if acc_val > best_acc_val: 161 | best_acc_val = acc_val 162 | self.output = output 163 | weights = deepcopy(self.state_dict()) 164 | # print(best_acc_val) 165 | # output = self.forward(data_test) 166 | # labels_test = torch.LongTensor(data.labels_test).to(self.device) 167 | # loss_test = F.nll_loss(output, labels_test) 168 | # acc_test = utils.accuracy(output, labels_test) 169 | # print('acc_test:', acc_test.item()) 170 | 171 | 172 | 173 | if verbose and i % 100 == 0: 174 | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) 175 | 176 | if verbose: 177 | print('=== picking the best model according to the performance on validation ===') 178 | self.load_state_dict(weights) 179 | 180 | def test(self, data_test): 181 | self.eval() 182 | with torch.no_grad(): 183 | output = self.forward(data_test) 184 | evaluate(output, data_test.y, self.args) 185 | 186 | # @torch.no_grad() 187 | # def predict(self, data): 188 | # self.eval() 189 | # return self.forward(data) 190 | @torch.no_grad() 191 | def predict(self, feat, adj): 192 | self.eval() 193 | data = GraphData(feat, adj, None) 194 | data = Dpr2Pyg(data)[0] 195 | return self.forward(data) 196 | 197 | @torch.no_grad() 198 | def predict_unnorm(self, feat, adj): 199 | self.eval() 200 | data = GraphData(feat, adj, None) 201 | data = Dpr2Pyg(data)[0] 202 | 203 | return self.forward(data) 204 | 205 | 206 | class GraphData: 207 | 208 | def __init__(self, features, adj, labels, idx_train=None, idx_val=None, idx_test=None): 209 | self.adj = adj 210 | self.features = features 211 | self.labels = labels 212 | self.idx_train = idx_train 213 | self.idx_val = idx_val 214 | self.idx_test = idx_test 215 | 216 | 217 | from torch_geometric.data import InMemoryDataset, Data 218 | import scipy.sparse as sp 219 | 220 | class Dpr2Pyg(InMemoryDataset): 221 | 222 | def __init__(self, dpr_data, transform=None, **kwargs): 223 | root = 'data/' # dummy root; does not mean anything 224 | self.dpr_data = dpr_data 225 | super(Dpr2Pyg, self).__init__(root, transform) 226 | pyg_data = self.process() 227 | self.data, self.slices = self.collate([pyg_data]) 228 | self.transform = transform 229 | 230 | def process____(self): 231 | dpr_data = self.dpr_data 232 | try: 233 | edge_index = torch.LongTensor(dpr_data.adj.nonzero().cpu()).cuda().T 234 | except: 235 | edge_index = torch.LongTensor(dpr_data.adj.nonzero()).cuda() 236 | # by default, the features in pyg data is dense 237 | try: 238 | x = torch.FloatTensor(dpr_data.features.cpu()).float().cuda() 239 | except: 240 | x = torch.FloatTensor(dpr_data.features).float().cuda() 241 | try: 242 | y = torch.LongTensor(dpr_data.labels.cpu()).cuda() 243 | except: 244 | y = dpr_data.labels 245 | 246 | data = Data(x=x, edge_index=edge_index, y=y) 247 | data.train_mask = None 248 | data.val_mask = None 249 | data.test_mask = None 250 | return data 251 | 252 | def process(self): 253 | dpr_data = self.dpr_data 254 | if type(dpr_data.adj) == torch.Tensor: 255 | adj_selfloop = dpr_data.adj + torch.eye(dpr_data.adj.shape[0]).cuda() 256 | edge_index_selfloop = adj_selfloop.nonzero().T 257 | edge_index = edge_index_selfloop 258 | edge_weight = adj_selfloop[edge_index_selfloop[0], edge_index_selfloop[1]] 259 | else: 260 | adj_selfloop = dpr_data.adj + sp.eye(dpr_data.adj.shape[0]) 261 | edge_index = torch.LongTensor(adj_selfloop.nonzero()).cuda() 262 | edge_weight = torch.FloatTensor(adj_selfloop[adj_selfloop.nonzero()]).cuda() 263 | 264 | # by default, the features in pyg data is dense 265 | try: 266 | x = torch.FloatTensor(dpr_data.features.cpu()).float().cuda() 267 | except: 268 | x = torch.FloatTensor(dpr_data.features).float().cuda() 269 | try: 270 | y = torch.LongTensor(dpr_data.labels.cpu()).cuda() 271 | except: 272 | y = dpr_data.labels 273 | 274 | 275 | data = Data(x=x, edge_index=edge_index, y=y, edge_weight=edge_weight) 276 | data.train_mask = None 277 | data.val_mask = None 278 | data.test_mask = None 279 | return data 280 | 281 | def get(self, idx): 282 | data = self.data.__class__() 283 | 284 | if hasattr(self.data, '__num_nodes__'): 285 | data.num_nodes = self.data.__num_nodes__[idx] 286 | 287 | for key in self.data.keys: 288 | item, slices = self.data[key], self.slices[key] 289 | s = list(repeat(slice(None), item.dim())) 290 | s[self.data.__cat_dim__(key, item)] = slice(slices[idx], 291 | slices[idx + 1]) 292 | data[key] = item[s] 293 | return data 294 | 295 | @property 296 | def raw_file_names(self): 297 | return ['some_file_1', 'some_file_2', ...] 298 | 299 | @property 300 | def processed_file_names(self): 301 | return ['data.pt'] 302 | 303 | def _download(self): 304 | pass 305 | 306 | -------------------------------------------------------------------------------- /models/gcn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | import torch 5 | import torch.optim as optim 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from deeprobust.graph import utils 9 | from copy import deepcopy 10 | from sklearn.metrics import f1_score 11 | from torch.nn import init 12 | import torch_sparse 13 | 14 | 15 | class GraphConvolution(Module): 16 | 17 | def __init__(self, in_features, out_features, with_bias=True): 18 | super(GraphConvolution, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 22 | self.bias = Parameter(torch.FloatTensor(out_features)) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | stdv = 1. / math.sqrt(self.weight.T.size(1)) 27 | self.weight.data.uniform_(-stdv, stdv) 28 | if self.bias is not None: 29 | self.bias.data.uniform_(-stdv, stdv) 30 | 31 | def forward(self, input, adj): 32 | if input.data.is_sparse: 33 | support = torch.spmm(input, self.weight) 34 | else: 35 | support = torch.mm(input, self.weight) 36 | if isinstance(adj, torch_sparse.SparseTensor): 37 | output = torch_sparse.matmul(adj, support) 38 | else: 39 | output = torch.spmm(adj, support) 40 | if self.bias is not None: 41 | return output + self.bias 42 | else: 43 | return output 44 | 45 | def __repr__(self): 46 | return self.__class__.__name__ + ' (' \ 47 | + str(self.in_features) + ' -> ' \ 48 | + str(self.out_features) + ')' 49 | 50 | 51 | class GCN(nn.Module): 52 | 53 | def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4, 54 | with_relu=True, with_bias=True, with_bn=False, device=None): 55 | 56 | super(GCN, self).__init__() 57 | 58 | assert device is not None, "Please specify 'device'!" 59 | self.device = device 60 | self.nfeat = nfeat 61 | self.nclass = nclass 62 | 63 | self.layers = nn.ModuleList([]) 64 | 65 | if nlayers == 1: 66 | self.layers.append(GraphConvolution(nfeat, nclass, with_bias=with_bias)) 67 | else: 68 | if with_bn: 69 | self.bns = torch.nn.ModuleList() 70 | self.bns.append(nn.BatchNorm1d(nhid)) 71 | self.layers.append(GraphConvolution(nfeat, nhid, with_bias=with_bias)) 72 | for i in range(nlayers-2): 73 | self.layers.append(GraphConvolution(nhid, nhid, with_bias=with_bias)) 74 | if with_bn: 75 | self.bns.append(nn.BatchNorm1d(nhid)) 76 | self.layers.append(GraphConvolution(nhid, nclass, with_bias=with_bias)) 77 | 78 | self.dropout = dropout 79 | self.lr = lr 80 | if not with_relu: 81 | self.weight_decay = 0 82 | else: 83 | self.weight_decay = weight_decay 84 | self.with_relu = with_relu 85 | self.with_bn = with_bn 86 | self.with_bias = with_bias 87 | self.output = None 88 | self.best_model = None 89 | self.best_output = None 90 | self.adj_norm = None 91 | self.features = None 92 | self.multi_label = None 93 | 94 | def forward(self, x, adj): 95 | for ix, layer in enumerate(self.layers): 96 | x = layer(x, adj) 97 | if ix != len(self.layers) - 1: 98 | x = self.bns[ix](x) if self.with_bn else x 99 | if self.with_relu: 100 | x = F.relu(x) 101 | x = F.dropout(x, self.dropout, training=self.training) 102 | 103 | if self.multi_label: 104 | return torch.sigmoid(x) 105 | else: 106 | return F.log_softmax(x, dim=1) 107 | 108 | def forward_sampler(self, x, adjs): 109 | # for ix, layer in enumerate(self.layers): 110 | for ix, (adj, _, size) in enumerate(adjs): 111 | x = self.layers[ix](x, adj) 112 | if ix != len(self.layers) - 1: 113 | x = self.bns[ix](x) if self.with_bn else x 114 | if self.with_relu: 115 | x = F.relu(x) 116 | x = F.dropout(x, self.dropout, training=self.training) 117 | 118 | if self.multi_label: 119 | return torch.sigmoid(x) 120 | else: 121 | return F.log_softmax(x, dim=1) 122 | 123 | def forward_sampler_syn(self, x, adjs): 124 | for ix, (adj) in enumerate(adjs): 125 | x = self.layers[ix](x, adj) 126 | if ix != len(self.layers) - 1: 127 | x = self.bns[ix](x) if self.with_bn else x 128 | if self.with_relu: 129 | x = F.relu(x) 130 | x = F.dropout(x, self.dropout, training=self.training) 131 | 132 | if self.multi_label: 133 | return torch.sigmoid(x) 134 | else: 135 | return F.log_softmax(x, dim=1) 136 | 137 | 138 | def initialize(self): 139 | for layer in self.layers: 140 | layer.reset_parameters() 141 | if self.with_bn: 142 | for bn in self.bns: 143 | bn.reset_parameters() 144 | 145 | def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs): 146 | if initialize: 147 | self.initialize() 148 | 149 | if type(adj) is not torch.Tensor: 150 | features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device) 151 | else: 152 | features = features.to(self.device) 153 | adj = adj.to(self.device) 154 | labels = labels.to(self.device) 155 | 156 | if normalize: 157 | if utils.is_sparse_tensor(adj): 158 | adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 159 | else: 160 | adj_norm = utils.normalize_adj_tensor(adj) 161 | else: 162 | adj_norm = adj 163 | 164 | if 'feat_norm' in kwargs and kwargs['feat_norm']: 165 | from utils import row_normalize_tensor 166 | features = row_normalize_tensor(features-features.min()) 167 | 168 | self.adj_norm = adj_norm 169 | self.features = features 170 | 171 | if len(labels.shape) > 1: 172 | self.multi_label = True 173 | self.loss = torch.nn.BCELoss() 174 | else: 175 | self.multi_label = False 176 | self.loss = F.nll_loss 177 | 178 | labels = labels.float() if self.multi_label else labels 179 | self.labels = labels 180 | 181 | if noval: 182 | self._train_with_val(labels, data, train_iters, verbose, adj_val=True) 183 | else: 184 | self._train_with_val(labels, data, train_iters, verbose) 185 | 186 | def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False): 187 | if adj_val: 188 | feat_full, adj_full = data.feat_val, data.adj_val 189 | else: 190 | feat_full, adj_full = data.feat_full, data.adj_full 191 | feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device) 192 | adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True) 193 | labels_val = torch.LongTensor(data.labels_val).to(self.device) 194 | 195 | if verbose: 196 | print('=== training gcn model ===') 197 | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 198 | 199 | best_acc_val = 0 200 | 201 | for i in range(train_iters): 202 | if i == train_iters // 2: 203 | lr = self.lr*0.1 204 | optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) 205 | 206 | self.train() 207 | optimizer.zero_grad() 208 | output = self.forward(self.features, self.adj_norm) 209 | loss_train = self.loss(output, labels) 210 | loss_train.backward() 211 | optimizer.step() 212 | 213 | if verbose and i % 100 == 0: 214 | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) 215 | 216 | with torch.no_grad(): 217 | self.eval() 218 | output = self.forward(feat_full, adj_full_norm) 219 | 220 | if adj_val: 221 | loss_val = F.nll_loss(output, labels_val) 222 | acc_val = utils.accuracy(output, labels_val) 223 | else: 224 | loss_val = F.nll_loss(output[data.idx_val], labels_val) 225 | acc_val = utils.accuracy(output[data.idx_val], labels_val) 226 | 227 | if acc_val > best_acc_val: 228 | best_acc_val = acc_val 229 | self.output = output 230 | weights = deepcopy(self.state_dict()) 231 | 232 | if verbose: 233 | print('=== picking the best model according to the performance on validation ===') 234 | self.load_state_dict(weights) 235 | 236 | 237 | def test(self, idx_test): 238 | self.eval() 239 | output = self.predict() 240 | # output = self.output 241 | loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) 242 | acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) 243 | print("Test set results:", 244 | "loss= {:.4f}".format(loss_test.item()), 245 | "accuracy= {:.4f}".format(acc_test.item())) 246 | return acc_test.item() 247 | 248 | 249 | @torch.no_grad() 250 | def predict(self, features=None, adj=None): 251 | 252 | self.eval() 253 | if features is None and adj is None: 254 | return self.forward(self.features, self.adj_norm) 255 | else: 256 | if type(adj) is not torch.Tensor: 257 | features, adj = utils.to_tensor(features, adj, device=self.device) 258 | 259 | self.features = features 260 | if utils.is_sparse_tensor(adj): 261 | self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 262 | else: 263 | self.adj_norm = utils.normalize_adj_tensor(adj) 264 | return self.forward(self.features, self.adj_norm) 265 | 266 | @torch.no_grad() 267 | def predict_unnorm(self, features=None, adj=None): 268 | self.eval() 269 | if features is None and adj is None: 270 | return self.forward(self.features, self.adj_norm) 271 | else: 272 | if type(adj) is not torch.Tensor: 273 | features, adj = utils.to_tensor(features, adj, device=self.device) 274 | 275 | self.features = features 276 | self.adj_norm = adj 277 | return self.forward(self.features, self.adj_norm) 278 | 279 | 280 | def _train_with_val2(self, labels, idx_train, idx_val, train_iters, verbose): 281 | if verbose: 282 | print('=== training gcn model ===') 283 | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 284 | 285 | best_loss_val = 100 286 | best_acc_val = 0 287 | 288 | for i in range(train_iters): 289 | if i == train_iters // 2: 290 | lr = self.lr*0.1 291 | optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) 292 | 293 | self.train() 294 | optimizer.zero_grad() 295 | output = self.forward(self.features, self.adj_norm) 296 | loss_train = F.nll_loss(output[idx_train], labels[idx_train]) 297 | loss_train.backward() 298 | optimizer.step() 299 | 300 | if verbose and i % 10 == 0: 301 | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) 302 | 303 | self.eval() 304 | output = self.forward(self.features, self.adj_norm) 305 | loss_val = F.nll_loss(output[idx_val], labels[idx_val]) 306 | acc_val = utils.accuracy(output[idx_val], labels[idx_val]) 307 | 308 | if acc_val > best_acc_val: 309 | best_acc_val = acc_val 310 | self.output = output 311 | weights = deepcopy(self.state_dict()) 312 | 313 | if verbose: 314 | print('=== picking the best model according to the performance on validation ===') 315 | self.load_state_dict(weights) 316 | -------------------------------------------------------------------------------- /models/myappnp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | import torch 5 | import torch.optim as optim 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from deeprobust.graph import utils 9 | from copy import deepcopy 10 | from sklearn.metrics import f1_score 11 | from torch.nn import init 12 | import torch_sparse 13 | 14 | 15 | class APPNP(nn.Module): 16 | 17 | def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4, 18 | ntrans=1, with_bias=True, with_bn=False, device=None): 19 | 20 | super(APPNP, self).__init__() 21 | 22 | assert device is not None, "Please specify 'device'!" 23 | self.device = device 24 | self.nfeat = nfeat 25 | self.nclass = nclass 26 | self.alpha = 0.1 27 | 28 | with_bn = False 29 | 30 | self.layers = nn.ModuleList([]) 31 | if ntrans == 1: 32 | self.layers.append(MyLinear(nfeat, nclass)) 33 | else: 34 | self.layers.append(MyLinear(nfeat, nhid)) 35 | if with_bn: 36 | self.bns = torch.nn.ModuleList() 37 | self.bns.append(nn.BatchNorm1d(nhid)) 38 | for i in range(ntrans-2): 39 | if with_bn: 40 | self.bns.append(nn.BatchNorm1d(nhid)) 41 | self.layers.append(MyLinear(nhid, nhid)) 42 | self.layers.append(MyLinear(nhid, nclass)) 43 | 44 | self.nlayers = nlayers 45 | self.weight_decay = weight_decay 46 | self.dropout = dropout 47 | self.lr = lr 48 | self.with_bn = with_bn 49 | self.with_bias = with_bias 50 | self.output = None 51 | self.best_model = None 52 | self.best_output = None 53 | self.adj_norm = None 54 | self.features = None 55 | self.multi_label = None 56 | self.sparse_dropout = SparseDropout(dprob=0) 57 | 58 | def forward(self, x, adj): 59 | for ix, layer in enumerate(self.layers): 60 | x = layer(x) 61 | if ix != len(self.layers) - 1: 62 | x = self.bns[ix](x) if self.with_bn else x 63 | x = F.relu(x) 64 | x = F.dropout(x, self.dropout, training=self.training) 65 | 66 | h = x 67 | # here nlayers means K 68 | for i in range(self.nlayers): 69 | # adj_drop = self.sparse_dropout(adj, training=self.training) 70 | adj_drop = adj 71 | x = torch.spmm(adj_drop, x) 72 | x = x * (1 - self.alpha) 73 | x = x + self.alpha * h 74 | 75 | 76 | if self.multi_label: 77 | return torch.sigmoid(x) 78 | else: 79 | return F.log_softmax(x, dim=1) 80 | 81 | def forward_sampler(self, x, adjs): 82 | for ix, layer in enumerate(self.layers): 83 | x = layer(x) 84 | if ix != len(self.layers) - 1: 85 | x = self.bns[ix](x) if self.with_bn else x 86 | x = F.relu(x) 87 | x = F.dropout(x, self.dropout, training=self.training) 88 | 89 | h = x 90 | for ix, (adj, _, size) in enumerate(adjs): 91 | # x_target = x[: size[1]] 92 | # x = self.layers[ix]((x, x_target), edge_index) 93 | # adj = adj.to(self.device) 94 | # adj_drop = F.dropout(adj, p=self.dropout) 95 | adj_drop = adj 96 | h = h[: size[1]] 97 | x = torch_sparse.matmul(adj_drop, x) 98 | x = x * (1 - self.alpha) 99 | x = x + self.alpha * h 100 | 101 | if self.multi_label: 102 | return torch.sigmoid(x) 103 | else: 104 | return F.log_softmax(x, dim=1) 105 | 106 | def forward_sampler_syn(self, x, adjs): 107 | for ix, layer in enumerate(self.layers): 108 | x = layer(x) 109 | if ix != len(self.layers) - 1: 110 | x = self.bns[ix](x) if self.with_bn else x 111 | x = F.relu(x) 112 | x = F.dropout(x, self.dropout, training=self.training) 113 | 114 | for ix, (adj) in enumerate(adjs): 115 | # x_target = x[: size[1]] 116 | # x = self.layers[ix]((x, x_target), edge_index) 117 | # adj = adj.to(self.device) 118 | x = torch_sparse.matmul(adj, x) 119 | 120 | if self.multi_label: 121 | return torch.sigmoid(x) 122 | else: 123 | return F.log_softmax(x, dim=1) 124 | 125 | 126 | def initialize(self): 127 | for layer in self.layers: 128 | layer.reset_parameters() 129 | if self.with_bn: 130 | for bn in self.bns: 131 | bn.reset_parameters() 132 | 133 | def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs): 134 | '''data: full data class''' 135 | if initialize: 136 | self.initialize() 137 | 138 | # features, adj, labels = data.feat_train, data.adj_train, data.labels_train 139 | if type(adj) is not torch.Tensor: 140 | features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device) 141 | else: 142 | features = features.to(self.device) 143 | adj = adj.to(self.device) 144 | labels = labels.to(self.device) 145 | 146 | if normalize: 147 | if utils.is_sparse_tensor(adj): 148 | adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 149 | else: 150 | adj_norm = utils.normalize_adj_tensor(adj) 151 | else: 152 | adj_norm = adj 153 | 154 | if 'feat_norm' in kwargs and kwargs['feat_norm']: 155 | from utils import row_normalize_tensor 156 | features = row_normalize_tensor(features-features.min()) 157 | 158 | self.adj_norm = adj_norm 159 | self.features = features 160 | 161 | if len(labels.shape) > 1: 162 | self.multi_label = True 163 | self.loss = torch.nn.BCELoss() 164 | else: 165 | self.multi_label = False 166 | self.loss = F.nll_loss 167 | 168 | labels = labels.float() if self.multi_label else labels 169 | self.labels = labels 170 | 171 | if noval: 172 | # self._train_without_val(labels, data, train_iters, verbose) 173 | # self._train_without_val(labels, data, train_iters, verbose) 174 | self._train_with_val(labels, data, train_iters, verbose, adj_val=True) 175 | else: 176 | self._train_with_val(labels, data, train_iters, verbose) 177 | 178 | def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False): 179 | if adj_val: 180 | feat_full, adj_full = data.feat_val, data.adj_val 181 | else: 182 | feat_full, adj_full = data.feat_full, data.adj_full 183 | 184 | feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device) 185 | adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True) 186 | labels_val = torch.LongTensor(data.labels_val).to(self.device) 187 | 188 | if verbose: 189 | print('=== training gcn model ===') 190 | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 191 | 192 | best_acc_val = 0 193 | 194 | for i in range(train_iters): 195 | if i == train_iters // 2: 196 | lr = self.lr*0.1 197 | optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) 198 | 199 | self.train() 200 | optimizer.zero_grad() 201 | output = self.forward(self.features, self.adj_norm) 202 | loss_train = self.loss(output, labels) 203 | loss_train.backward() 204 | optimizer.step() 205 | 206 | if verbose and i % 100 == 0: 207 | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) 208 | 209 | with torch.no_grad(): 210 | self.eval() 211 | output = self.forward(feat_full, adj_full_norm) 212 | if adj_val: 213 | loss_val = F.nll_loss(output, labels_val) 214 | acc_val = utils.accuracy(output, labels_val) 215 | else: 216 | loss_val = F.nll_loss(output[data.idx_val], labels_val) 217 | acc_val = utils.accuracy(output[data.idx_val], labels_val) 218 | 219 | if acc_val > best_acc_val: 220 | best_acc_val = acc_val 221 | self.output = output 222 | weights = deepcopy(self.state_dict()) 223 | 224 | if verbose: 225 | print('=== picking the best model according to the performance on validation ===') 226 | self.load_state_dict(weights) 227 | 228 | 229 | def test(self, idx_test): 230 | self.eval() 231 | output = self.predict() 232 | # output = self.output 233 | loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) 234 | acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) 235 | print("Test set results:", 236 | "loss= {:.4f}".format(loss_test.item()), 237 | "accuracy= {:.4f}".format(acc_test.item())) 238 | return acc_test.item() 239 | 240 | 241 | @torch.no_grad() 242 | def predict(self, features=None, adj=None): 243 | 244 | self.eval() 245 | if features is None and adj is None: 246 | return self.forward(self.features, self.adj_norm) 247 | else: 248 | if type(adj) is not torch.Tensor: 249 | features, adj = utils.to_tensor(features, adj, device=self.device) 250 | 251 | self.features = features 252 | if utils.is_sparse_tensor(adj): 253 | self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 254 | else: 255 | self.adj_norm = utils.normalize_adj_tensor(adj) 256 | return self.forward(self.features, self.adj_norm) 257 | 258 | @torch.no_grad() 259 | def predict_unnorm(self, features=None, adj=None): 260 | self.eval() 261 | if features is None and adj is None: 262 | return self.forward(self.features, self.adj_norm) 263 | else: 264 | if type(adj) is not torch.Tensor: 265 | features, adj = utils.to_tensor(features, adj, device=self.device) 266 | 267 | self.features = features 268 | self.adj_norm = adj 269 | return self.forward(self.features, self.adj_norm) 270 | 271 | 272 | 273 | class MyLinear(Module): 274 | 275 | def __init__(self, in_features, out_features, with_bias=True): 276 | super(MyLinear, self).__init__() 277 | self.in_features = in_features 278 | self.out_features = out_features 279 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 280 | if with_bias: 281 | self.bias = Parameter(torch.FloatTensor(out_features)) 282 | else: 283 | self.register_parameter('bias', None) 284 | self.reset_parameters() 285 | 286 | def reset_parameters(self): 287 | # stdv = 1. / math.sqrt(self.weight.size(1)) 288 | stdv = 1. / math.sqrt(self.weight.T.size(1)) 289 | self.weight.data.uniform_(-stdv, stdv) 290 | if self.bias is not None: 291 | self.bias.data.uniform_(-stdv, stdv) 292 | 293 | def forward(self, input): 294 | if input.data.is_sparse: 295 | support = torch.spmm(input, self.weight) 296 | else: 297 | support = torch.mm(input, self.weight) 298 | output = support 299 | if self.bias is not None: 300 | return output + self.bias 301 | else: 302 | return output 303 | 304 | def __repr__(self): 305 | return self.__class__.__name__ + ' (' \ 306 | + str(self.in_features) + ' -> ' \ 307 | + str(self.out_features) + ')' 308 | 309 | class SparseDropout(torch.nn.Module): 310 | def __init__(self, dprob=0.5): 311 | super(SparseDropout, self).__init__() 312 | self.kprob=1-dprob 313 | 314 | def forward(self, x, training): 315 | if training: 316 | mask=((torch.rand(x._values().size())+(self.kprob)).floor()).type(torch.bool) 317 | rc=x._indices()[:,mask] 318 | val=x._values()[mask]*(1.0/self.kprob) 319 | return torch.sparse.FloatTensor(rc, val, x.size()) 320 | else: 321 | return x 322 | -------------------------------------------------------------------------------- /models/myappnp1.py: -------------------------------------------------------------------------------- 1 | """multiple transformaiton and multiple propagation""" 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import torch 6 | import torch.optim as optim 7 | from torch.nn.parameter import Parameter 8 | from torch.nn.modules.module import Module 9 | from deeprobust.graph import utils 10 | from copy import deepcopy 11 | from sklearn.metrics import f1_score 12 | from torch.nn import init 13 | import torch_sparse 14 | 15 | 16 | class APPNP1(nn.Module): 17 | 18 | def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4, 19 | with_relu=True, with_bias=True, with_bn=False, device=None): 20 | 21 | super(APPNP1, self).__init__() 22 | 23 | assert device is not None, "Please specify 'device'!" 24 | self.device = device 25 | self.nfeat = nfeat 26 | self.nclass = nclass 27 | self.alpha = 0.1 28 | 29 | if with_bn: 30 | self.bns = torch.nn.ModuleList() 31 | self.bns.append(nn.BatchNorm1d(nhid)) 32 | 33 | self.layers = nn.ModuleList([]) 34 | # self.layers.append(MyLinear(nfeat, nclass)) 35 | self.layers.append(MyLinear(nfeat, nhid)) 36 | # self.layers.append(MyLinear(nhid, nhid)) 37 | self.layers.append(MyLinear(nhid, nclass)) 38 | 39 | # if nlayers == 1: 40 | # self.layers.append(nn.Linear(nfeat, nclass)) 41 | # else: 42 | # self.layers.append(nn.Linear(nfeat, nhid)) 43 | # for i in range(nlayers-2): 44 | # self.layers.append(nn.Linear(nhid, nhid)) 45 | # self.layers.append(nn.Linear(nhid, nclass)) 46 | 47 | self.nlayers = nlayers 48 | self.dropout = dropout 49 | self.lr = lr 50 | if not with_relu: 51 | self.weight_decay = 0 52 | else: 53 | self.weight_decay = weight_decay 54 | self.with_relu = with_relu 55 | self.with_bn = with_bn 56 | self.with_bias = with_bias 57 | self.output = None 58 | self.best_model = None 59 | self.best_output = None 60 | self.adj_norm = None 61 | self.features = None 62 | self.multi_label = None 63 | self.sparse_dropout = SparseDropout(dprob=0) 64 | 65 | def forward(self, x, adj): 66 | for ix, layer in enumerate(self.layers): 67 | x = layer(x) 68 | if ix != len(self.layers) - 1: 69 | x = self.bns[ix](x) if self.with_bn else x 70 | x = F.relu(x) 71 | x = F.dropout(x, self.dropout, training=self.training) 72 | 73 | h = x 74 | # here nlayers means K 75 | for i in range(self.nlayers): 76 | # adj_drop = self.sparse_dropout(adj, training=self.training) 77 | adj_drop = adj 78 | x = torch.spmm(adj_drop, x) 79 | x = x * (1 - self.alpha) 80 | x = x + self.alpha * h 81 | 82 | 83 | if self.multi_label: 84 | return torch.sigmoid(x) 85 | else: 86 | return F.log_softmax(x, dim=1) 87 | 88 | def forward_sampler(self, x, adjs): 89 | for ix, layer in enumerate(self.layers): 90 | x = layer(x) 91 | if ix != len(self.layers) - 1: 92 | x = self.bns[ix](x) if self.with_bn else x 93 | x = F.relu(x) 94 | x = F.dropout(x, self.dropout, training=self.training) 95 | 96 | h = x 97 | for ix, (adj, _, size) in enumerate(adjs): 98 | # x_target = x[: size[1]] 99 | # x = self.layers[ix]((x, x_target), edge_index) 100 | # adj = adj.to(self.device) 101 | # adj_drop = F.dropout(adj, p=self.dropout) 102 | adj_drop = adj 103 | h = h[: size[1]] 104 | x = torch_sparse.matmul(adj_drop, x) 105 | x = x * (1 - self.alpha) 106 | x = x + self.alpha * h 107 | 108 | if self.multi_label: 109 | return torch.sigmoid(x) 110 | else: 111 | return F.log_softmax(x, dim=1) 112 | 113 | def forward_sampler_syn(self, x, adjs): 114 | for ix, layer in enumerate(self.layers): 115 | x = layer(x) 116 | if ix != len(self.layers) - 1: 117 | x = self.bns[ix](x) if self.with_bn else x 118 | x = F.relu(x) 119 | x = F.dropout(x, self.dropout, training=self.training) 120 | 121 | for ix, (adj) in enumerate(adjs): 122 | # x_target = x[: size[1]] 123 | # x = self.layers[ix]((x, x_target), edge_index) 124 | # adj = adj.to(self.device) 125 | x = torch_sparse.matmul(adj, x) 126 | 127 | if self.multi_label: 128 | return torch.sigmoid(x) 129 | else: 130 | return F.log_softmax(x, dim=1) 131 | 132 | 133 | def initialize(self): 134 | """Initialize parameters of GCN. 135 | """ 136 | for layer in self.layers: 137 | layer.reset_parameters() 138 | if self.with_bn: 139 | for bn in self.bns: 140 | bn.reset_parameters() 141 | 142 | def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs): 143 | '''data: full data class''' 144 | if initialize: 145 | self.initialize() 146 | 147 | # features, adj, labels = data.feat_train, data.adj_train, data.labels_train 148 | if type(adj) is not torch.Tensor: 149 | features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device) 150 | else: 151 | features = features.to(self.device) 152 | adj = adj.to(self.device) 153 | labels = labels.to(self.device) 154 | 155 | if normalize: 156 | if utils.is_sparse_tensor(adj): 157 | adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 158 | else: 159 | adj_norm = utils.normalize_adj_tensor(adj) 160 | else: 161 | adj_norm = adj 162 | 163 | if 'feat_norm' in kwargs and kwargs['feat_norm']: 164 | from utils import row_normalize_tensor 165 | features = row_normalize_tensor(features-features.min()) 166 | 167 | self.adj_norm = adj_norm 168 | self.features = features 169 | 170 | if len(labels.shape) > 1: 171 | self.multi_label = True 172 | self.loss = torch.nn.BCELoss() 173 | else: 174 | self.multi_label = False 175 | self.loss = F.nll_loss 176 | 177 | labels = labels.float() if self.multi_label else labels 178 | self.labels = labels 179 | 180 | 181 | if noval: 182 | self._train_with_val(labels, data, train_iters, verbose, adj_val=True) 183 | else: 184 | self._train_with_val(labels, data, train_iters, verbose) 185 | 186 | def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False): 187 | if adj_val: 188 | feat_full, adj_full = data.feat_val, data.adj_val 189 | else: 190 | feat_full, adj_full = data.feat_full, data.adj_full 191 | feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device) 192 | adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True) 193 | labels_val = torch.LongTensor(data.labels_val).to(self.device) 194 | 195 | if verbose: 196 | print('=== training gcn model ===') 197 | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 198 | 199 | best_acc_val = 0 200 | 201 | for i in range(train_iters): 202 | if i == train_iters // 2: 203 | lr = self.lr*0.1 204 | optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) 205 | 206 | self.train() 207 | optimizer.zero_grad() 208 | output = self.forward(self.features, self.adj_norm) 209 | loss_train = self.loss(output, labels) 210 | loss_train.backward() 211 | optimizer.step() 212 | 213 | if verbose and i % 100 == 0: 214 | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) 215 | 216 | with torch.no_grad(): 217 | self.eval() 218 | output = self.forward(feat_full, adj_full_norm) 219 | if adj_val: 220 | loss_val = F.nll_loss(output, labels_val) 221 | acc_val = utils.accuracy(output, labels_val) 222 | else: 223 | loss_val = F.nll_loss(output[data.idx_val], labels_val) 224 | acc_val = utils.accuracy(output[data.idx_val], labels_val) 225 | 226 | if acc_val > best_acc_val: 227 | best_acc_val = acc_val 228 | self.output = output 229 | weights = deepcopy(self.state_dict()) 230 | 231 | if verbose: 232 | print('=== picking the best model according to the performance on validation ===') 233 | self.load_state_dict(weights) 234 | 235 | 236 | def test(self, idx_test): 237 | """Evaluate GCN performance on test set. 238 | Parameters 239 | ---------- 240 | idx_test : 241 | node testing indices 242 | """ 243 | self.eval() 244 | output = self.predict() 245 | # output = self.output 246 | loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) 247 | acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) 248 | print("Test set results:", 249 | "loss= {:.4f}".format(loss_test.item()), 250 | "accuracy= {:.4f}".format(acc_test.item())) 251 | return acc_test.item() 252 | 253 | 254 | @torch.no_grad() 255 | def predict(self, features=None, adj=None): 256 | """By default, the inputs should be unnormalized adjacency 257 | Parameters 258 | ---------- 259 | features : 260 | node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions. 261 | adj : 262 | adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions. 263 | Returns 264 | ------- 265 | torch.FloatTensor 266 | output (log probabilities) of GCN 267 | """ 268 | 269 | self.eval() 270 | if features is None and adj is None: 271 | return self.forward(self.features, self.adj_norm) 272 | else: 273 | if type(adj) is not torch.Tensor: 274 | features, adj = utils.to_tensor(features, adj, device=self.device) 275 | 276 | self.features = features 277 | if utils.is_sparse_tensor(adj): 278 | self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 279 | else: 280 | self.adj_norm = utils.normalize_adj_tensor(adj) 281 | return self.forward(self.features, self.adj_norm) 282 | 283 | @torch.no_grad() 284 | def predict_unnorm(self, features=None, adj=None): 285 | self.eval() 286 | if features is None and adj is None: 287 | return self.forward(self.features, self.adj_norm) 288 | else: 289 | if type(adj) is not torch.Tensor: 290 | features, adj = utils.to_tensor(features, adj, device=self.device) 291 | 292 | self.features = features 293 | self.adj_norm = adj 294 | return self.forward(self.features, self.adj_norm) 295 | 296 | 297 | 298 | class MyLinear(Module): 299 | """Simple Linear layer, modified from https://github.com/tkipf/pygcn 300 | """ 301 | 302 | def __init__(self, in_features, out_features, with_bias=True): 303 | super(MyLinear, self).__init__() 304 | self.in_features = in_features 305 | self.out_features = out_features 306 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 307 | if with_bias: 308 | self.bias = Parameter(torch.FloatTensor(out_features)) 309 | else: 310 | self.register_parameter('bias', None) 311 | self.reset_parameters() 312 | 313 | def reset_parameters(self): 314 | # stdv = 1. / math.sqrt(self.weight.size(1)) 315 | stdv = 1. / math.sqrt(self.weight.T.size(1)) 316 | self.weight.data.uniform_(-stdv, stdv) 317 | if self.bias is not None: 318 | self.bias.data.uniform_(-stdv, stdv) 319 | 320 | def forward(self, input): 321 | if input.data.is_sparse: 322 | support = torch.spmm(input, self.weight) 323 | else: 324 | support = torch.mm(input, self.weight) 325 | output = support 326 | if self.bias is not None: 327 | return output + self.bias 328 | else: 329 | return output 330 | 331 | def __repr__(self): 332 | return self.__class__.__name__ + ' (' \ 333 | + str(self.in_features) + ' -> ' \ 334 | + str(self.out_features) + ')' 335 | 336 | class SparseDropout(torch.nn.Module): 337 | def __init__(self, dprob=0.5): 338 | super(SparseDropout, self).__init__() 339 | self.kprob=1-dprob 340 | 341 | def forward(self, x, training): 342 | if training: 343 | mask=((torch.rand(x._values().size())+(self.kprob)).floor()).type(torch.bool) 344 | rc=x._indices()[:,mask] 345 | val=x._values()[mask]*(1.0/self.kprob) 346 | return torch.sparse.FloatTensor(rc, val, x.size()) 347 | else: 348 | return x 349 | -------------------------------------------------------------------------------- /models/mycheby.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | import torch 5 | import torch.optim as optim 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from deeprobust.graph import utils 9 | from copy import deepcopy 10 | from sklearn.metrics import f1_score 11 | from torch.nn import init 12 | import torch_sparse 13 | from torch_geometric.nn.inits import zeros 14 | import scipy.sparse as sp 15 | import numpy as np 16 | 17 | 18 | class ChebConvolution(Module): 19 | 20 | def __init__(self, in_features, out_features, with_bias=True, single_param=True, K=2): 21 | super(ChebConvolution, self).__init__() 22 | self.in_features = in_features 23 | self.out_features = out_features 24 | self.lins = torch.nn.ModuleList([ 25 | MyLinear(in_features, out_features, with_bias=False) for _ in range(K)]) 26 | # self.lins = torch.nn.ModuleList([ 27 | # MyLinear(in_features, out_features, with_bias=True) for _ in range(K)]) 28 | if with_bias: 29 | self.bias = Parameter(torch.Tensor(out_features)) 30 | else: 31 | self.register_parameter('bias', None) 32 | self.single_param = single_param 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | for lin in self.lins: 37 | lin.reset_parameters() 38 | zeros(self.bias) 39 | 40 | def forward(self, input, adj, size=None): 41 | # support = torch.mm(input, self.weight_l) 42 | x = input 43 | Tx_0 = x[:size[1]] if size is not None else x 44 | Tx_1 = x # dummy 45 | output = self.lins[0](Tx_0) 46 | 47 | if len(self.lins) > 1: 48 | if isinstance(adj, torch_sparse.SparseTensor): 49 | Tx_1 = torch_sparse.matmul(adj, x) 50 | else: 51 | Tx_1 = torch.spmm(adj, x) 52 | 53 | if self.single_param: 54 | output = output + self.lins[0](Tx_1) 55 | else: 56 | output = output + self.lins[1](Tx_1) 57 | 58 | for lin in self.lins[2:]: 59 | if self.single_param: 60 | lin = self.lins[0] 61 | if isinstance(adj, torch_sparse.SparseTensor): 62 | Tx_2 = torch_sparse.matmul(adj, Tx_1) 63 | else: 64 | Tx_2 = torch.spmm(adj, Tx_1) 65 | Tx_2 = 2. * Tx_2 - Tx_0 66 | output = output + lin.forward(Tx_2) 67 | Tx_0, Tx_1 = Tx_1, Tx_2 68 | 69 | if self.bias is not None: 70 | return output + self.bias 71 | else: 72 | return output 73 | 74 | def __repr__(self): 75 | return self.__class__.__name__ + ' (' \ 76 | + str(self.in_features) + ' -> ' \ 77 | + str(self.out_features) + ')' 78 | 79 | 80 | class Cheby(nn.Module): 81 | 82 | def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4, 83 | with_relu=True, with_bias=True, with_bn=False, device=None): 84 | 85 | super(Cheby, self).__init__() 86 | 87 | assert device is not None, "Please specify 'device'!" 88 | self.device = device 89 | self.nfeat = nfeat 90 | self.nclass = nclass 91 | 92 | self.layers = nn.ModuleList([]) 93 | 94 | if nlayers == 1: 95 | self.layers.append(ChebConvolution(nfeat, nclass, with_bias=with_bias)) 96 | else: 97 | if with_bn: 98 | self.bns = torch.nn.ModuleList() 99 | self.bns.append(nn.BatchNorm1d(nhid)) 100 | self.layers.append(ChebConvolution(nfeat, nhid, with_bias=with_bias)) 101 | for i in range(nlayers-2): 102 | self.layers.append(ChebConvolution(nhid, nhid, with_bias=with_bias)) 103 | if with_bn: 104 | self.bns.append(nn.BatchNorm1d(nhid)) 105 | self.layers.append(ChebConvolution(nhid, nclass, with_bias=with_bias)) 106 | 107 | # self.lin = MyLinear(nhid, nclass, with_bias=True) 108 | 109 | # dropout = 0.5 110 | self.dropout = dropout 111 | self.lr = lr 112 | self.weight_decay = weight_decay 113 | self.with_relu = with_relu 114 | self.with_bn = with_bn 115 | self.with_bias = with_bias 116 | self.output = None 117 | self.best_model = None 118 | self.best_output = None 119 | self.adj_norm = None 120 | self.features = None 121 | self.multi_label = None 122 | 123 | def forward(self, x, adj): 124 | for ix, layer in enumerate(self.layers): 125 | # x = F.dropout(x, 0.2, training=self.training) 126 | x = layer(x, adj) 127 | if ix != len(self.layers) - 1: 128 | x = self.bns[ix](x) if self.with_bn else x 129 | if self.with_relu: 130 | x = F.relu(x) 131 | x = F.dropout(x, self.dropout, training=self.training) 132 | # x = F.dropout(x, 0.5, training=self.training) 133 | 134 | if self.multi_label: 135 | return torch.sigmoid(x) 136 | else: 137 | return F.log_softmax(x, dim=1) 138 | 139 | def forward_sampler(self, x, adjs): 140 | # TODO: do we need normalization? 141 | # for ix, layer in enumerate(self.layers): 142 | for ix, (adj, _, size) in enumerate(adjs): 143 | # x_target = x[: size[1]] 144 | # x = self.layers[ix]((x, x_target), edge_index) 145 | # adj = adj.to(self.device) 146 | x = self.layers[ix](x, adj, size=size) 147 | if ix != len(self.layers) - 1: 148 | x = self.bns[ix](x) if self.with_bn else x 149 | if self.with_relu: 150 | x = F.relu(x) 151 | x = F.dropout(x, self.dropout, training=self.training) 152 | 153 | if self.multi_label: 154 | return torch.sigmoid(x) 155 | else: 156 | return F.log_softmax(x, dim=1) 157 | 158 | def forward_sampler_syn(self, x, adjs): 159 | for ix, (adj) in enumerate(adjs): 160 | x = self.layers[ix](x, adj) 161 | if ix != len(self.layers) - 1: 162 | x = self.bns[ix](x) if self.with_bn else x 163 | if self.with_relu: 164 | x = F.relu(x) 165 | x = F.dropout(x, self.dropout, training=self.training) 166 | 167 | if self.multi_label: 168 | return torch.sigmoid(x) 169 | else: 170 | return F.log_softmax(x, dim=1) 171 | 172 | 173 | def initialize(self): 174 | for layer in self.layers: 175 | layer.reset_parameters() 176 | if self.with_bn: 177 | for bn in self.bns: 178 | bn.reset_parameters() 179 | 180 | def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs): 181 | if initialize: 182 | self.initialize() 183 | 184 | # features, adj, labels = data.feat_train, data.adj_train, data.labels_train 185 | 186 | if type(adj) is not torch.Tensor: 187 | features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device) 188 | else: 189 | features = features.to(self.device) 190 | adj = adj.to(self.device) 191 | labels = labels.to(self.device) 192 | 193 | adj = adj - torch.eye(adj.shape[0]).to(self.device) # cheby 194 | if normalize: 195 | adj_norm = utils.normalize_adj_tensor(adj) 196 | else: 197 | adj_norm = adj 198 | 199 | 200 | if 'feat_norm' in kwargs and kwargs['feat_norm']: 201 | from utils import row_normalize_tensor 202 | features = row_normalize_tensor(features-features.min()) 203 | 204 | self.adj_norm = adj_norm 205 | self.features = features 206 | 207 | if len(labels.shape) > 1: 208 | self.multi_label = True 209 | self.loss = torch.nn.BCELoss() 210 | else: 211 | self.multi_label = False 212 | self.loss = F.nll_loss 213 | 214 | labels = labels.float() if self.multi_label else labels 215 | self.labels = labels 216 | 217 | if noval: 218 | self._train_with_val(labels, data, train_iters, verbose, adj_val=True) 219 | else: 220 | self._train_with_val(labels, data, train_iters, verbose) 221 | 222 | def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False): 223 | if adj_val: 224 | feat_full, adj_full = data.feat_val, data.adj_val 225 | else: 226 | feat_full, adj_full = data.feat_full, data.adj_full 227 | # adj_full = adj_full - sp.eye(adj_full.shape[0]) 228 | feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device) 229 | adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True) 230 | labels_val = torch.LongTensor(data.labels_val).to(self.device) 231 | 232 | if verbose: 233 | print('=== training gcn model ===') 234 | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 235 | 236 | best_acc_val = 0 237 | best_loss_val = 100 238 | 239 | for i in range(train_iters): 240 | if i == train_iters // 2: 241 | lr = self.lr*0.1 242 | optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) 243 | 244 | self.train() 245 | optimizer.zero_grad() 246 | output = self.forward(self.features, self.adj_norm) 247 | loss_train = self.loss(output, labels) 248 | loss_train.backward() 249 | optimizer.step() 250 | 251 | if verbose and i % 100 == 0: 252 | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) 253 | 254 | with torch.no_grad(): 255 | self.eval() 256 | output = self.forward(feat_full, adj_full_norm) 257 | if adj_val: 258 | loss_val = F.nll_loss(output, labels_val) 259 | acc_val = utils.accuracy(output, labels_val) 260 | else: 261 | loss_val = F.nll_loss(output[data.idx_val], labels_val) 262 | acc_val = utils.accuracy(output[data.idx_val], labels_val) 263 | 264 | # if loss_val < best_loss_val: 265 | # best_loss_val = loss_val 266 | # self.output = output 267 | # weights = deepcopy(self.state_dict()) 268 | # print(best_loss_val) 269 | 270 | if acc_val > best_acc_val: 271 | best_acc_val = acc_val 272 | self.output = output 273 | weights = deepcopy(self.state_dict()) 274 | # print(best_acc_val) 275 | 276 | if verbose: 277 | print('=== picking the best model according to the performance on validation ===') 278 | self.load_state_dict(weights) 279 | 280 | def test(self, idx_test): 281 | self.eval() 282 | output = self.predict() 283 | # output = self.output 284 | loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) 285 | acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) 286 | print("Test set results:", 287 | "loss= {:.4f}".format(loss_test.item()), 288 | "accuracy= {:.4f}".format(acc_test.item())) 289 | return acc_test.item() 290 | 291 | 292 | @torch.no_grad() 293 | def predict(self, features=None, adj=None): 294 | 295 | self.eval() 296 | if features is None and adj is None: 297 | return self.forward(self.features, self.adj_norm) 298 | else: 299 | # adj = adj-sp.eye(adj.shape[0]) 300 | # adj[0,0]=0 301 | 302 | if type(adj) is not torch.Tensor: 303 | features, adj = utils.to_tensor(features, adj, device=self.device) 304 | 305 | self.features = features 306 | self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 307 | adj = utils.to_scipy(adj) 308 | 309 | adj = adj-sp.eye(adj.shape[0]) 310 | mx = normalize_adj(adj) 311 | adj = utils.sparse_mx_to_torch_sparse_tensor(mx).to(self.device) 312 | return self.forward(self.features, self.adj_norm) 313 | 314 | @torch.no_grad() 315 | def predict_unnorm(self, features=None, adj=None): 316 | self.eval() 317 | if features is None and adj is None: 318 | return self.forward(self.features, self.adj_norm) 319 | else: 320 | if type(adj) is not torch.Tensor: 321 | features, adj = utils.to_tensor(features, adj, device=self.device) 322 | 323 | self.features = features 324 | self.adj_norm = adj 325 | return self.forward(self.features, self.adj_norm) 326 | 327 | class MyLinear(Module): 328 | 329 | def __init__(self, in_features, out_features, with_bias=True): 330 | super(MyLinear, self).__init__() 331 | self.in_features = in_features 332 | self.out_features = out_features 333 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 334 | if with_bias: 335 | self.bias = Parameter(torch.FloatTensor(out_features)) 336 | else: 337 | self.register_parameter('bias', None) 338 | self.reset_parameters() 339 | 340 | def reset_parameters(self): 341 | # stdv = 1. / math.sqrt(self.weight.size(1)) 342 | stdv = 1. / math.sqrt(self.weight.T.size(1)) 343 | self.weight.data.uniform_(-stdv, stdv) 344 | if self.bias is not None: 345 | self.bias.data.uniform_(-stdv, stdv) 346 | 347 | def forward(self, input): 348 | if input.data.is_sparse: 349 | support = torch.spmm(input, self.weight) 350 | else: 351 | support = torch.mm(input, self.weight) 352 | output = support 353 | if self.bias is not None: 354 | return output + self.bias 355 | else: 356 | return output 357 | 358 | def __repr__(self): 359 | return self.__class__.__name__ + ' (' \ 360 | + str(self.in_features) + ' -> ' \ 361 | + str(self.out_features) + ')' 362 | 363 | 364 | 365 | def normalize_adj(mx): 366 | if type(mx) is not sp.lil.lil_matrix: 367 | mx = mx.tolil() 368 | mx = mx + sp.eye(mx.shape[0]) 369 | rowsum = np.array(mx.sum(1)) 370 | r_inv = np.power(rowsum, -1/2).flatten() 371 | r_inv[np.isinf(r_inv)] = 0. 372 | r_mat_inv = sp.diags(r_inv) 373 | mx = r_mat_inv.dot(mx) 374 | mx = mx.dot(r_mat_inv) 375 | return mx 376 | -------------------------------------------------------------------------------- /models/mygatconv.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, Optional 2 | from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, 3 | OptTensor) 4 | 5 | import torch 6 | from torch import Tensor 7 | import torch.nn.functional as F 8 | from torch.nn import Parameter, Linear 9 | from torch_sparse import SparseTensor, set_diag 10 | from torch_geometric.nn.conv import MessagePassing 11 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax 12 | 13 | from torch_geometric.nn.inits import glorot, zeros 14 | 15 | 16 | class GATConv(MessagePassing): 17 | 18 | def __init__(self, in_channels: Union[int, Tuple[int, int]], 19 | out_channels: int, heads: int = 1, concat: bool = True, 20 | negative_slope: float = 0.2, dropout: float = 0.0, 21 | add_self_loops: bool = True, bias: bool = True, **kwargs): 22 | kwargs.setdefault('aggr', 'add') 23 | super(GATConv, self).__init__(node_dim=0, **kwargs) 24 | 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.heads = heads 28 | self.concat = concat 29 | self.negative_slope = negative_slope 30 | self.dropout = dropout 31 | self.add_self_loops = add_self_loops 32 | 33 | if isinstance(in_channels, int): 34 | self.lin_l = Linear(in_channels, heads * out_channels, bias=False) 35 | self.lin_r = self.lin_l 36 | else: 37 | self.lin_l = Linear(in_channels[0], heads * out_channels, False) 38 | self.lin_r = Linear(in_channels[1], heads * out_channels, False) 39 | 40 | self.att_l = Parameter(torch.Tensor(1, heads, out_channels)) 41 | self.att_r = Parameter(torch.Tensor(1, heads, out_channels)) 42 | 43 | if bias and concat: 44 | self.bias = Parameter(torch.Tensor(heads * out_channels)) 45 | elif bias and not concat: 46 | self.bias = Parameter(torch.Tensor(out_channels)) 47 | else: 48 | self.register_parameter('bias', None) 49 | 50 | self._alpha = None 51 | 52 | self.reset_parameters() 53 | self.edge_weight = None 54 | 55 | def reset_parameters(self): 56 | glorot(self.lin_l.weight) 57 | glorot(self.lin_r.weight) 58 | glorot(self.att_l) 59 | glorot(self.att_r) 60 | zeros(self.bias) 61 | 62 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 63 | size: Size = None, return_attention_weights=None, edge_weight=None): 64 | H, C = self.heads, self.out_channels 65 | 66 | x_l: OptTensor = None 67 | x_r: OptTensor = None 68 | alpha_l: OptTensor = None 69 | alpha_r: OptTensor = None 70 | if isinstance(x, Tensor): 71 | assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' 72 | x_l = x_r = self.lin_l(x).view(-1, H, C) 73 | alpha_l = (x_l * self.att_l).sum(dim=-1) 74 | alpha_r = (x_r * self.att_r).sum(dim=-1) 75 | else: 76 | x_l, x_r = x[0], x[1] 77 | assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' 78 | x_l = self.lin_l(x_l).view(-1, H, C) 79 | alpha_l = (x_l * self.att_l).sum(dim=-1) 80 | if x_r is not None: 81 | x_r = self.lin_r(x_r).view(-1, H, C) 82 | alpha_r = (x_r * self.att_r).sum(dim=-1) 83 | 84 | 85 | assert x_l is not None 86 | assert alpha_l is not None 87 | 88 | if self.add_self_loops: 89 | if isinstance(edge_index, Tensor): 90 | num_nodes = x_l.size(0) 91 | if x_r is not None: 92 | num_nodes = min(num_nodes, x_r.size(0)) 93 | if size is not None: 94 | num_nodes = min(size[0], size[1]) 95 | edge_index, _ = remove_self_loops(edge_index) 96 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) 97 | 98 | if edge_weight is not None: 99 | if self.edge_weight is None: 100 | self.edge_weight = edge_weight 101 | 102 | if edge_index.size(1) != self.edge_weight.shape[0]: 103 | self.edge_weight = edge_weight 104 | 105 | elif isinstance(edge_index, SparseTensor): 106 | edge_index = set_diag(edge_index) 107 | 108 | # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) 109 | 110 | out = self.propagate(edge_index, x=(x_l, x_r), 111 | alpha=(alpha_l, alpha_r), size=size) 112 | 113 | alpha = self._alpha 114 | self._alpha = None 115 | 116 | if self.concat: 117 | out = out.view(-1, self.heads * self.out_channels) 118 | else: 119 | out = out.mean(dim=1) 120 | 121 | if self.bias is not None: 122 | out += self.bias 123 | 124 | if isinstance(return_attention_weights, bool): 125 | assert alpha is not None 126 | if isinstance(edge_index, Tensor): 127 | return out, (edge_index, alpha) 128 | elif isinstance(edge_index, SparseTensor): 129 | return out, edge_index.set_value(alpha, layout='coo') 130 | else: 131 | return out 132 | 133 | def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, 134 | index: Tensor, ptr: OptTensor, 135 | size_i: Optional[int]) -> Tensor: 136 | alpha = alpha_j if alpha_i is None else alpha_j + alpha_i 137 | alpha = F.leaky_relu(alpha, self.negative_slope) 138 | alpha = softmax(alpha, index, ptr, size_i) 139 | self._alpha = alpha 140 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 141 | 142 | if self.edge_weight is not None: 143 | x_j = self.edge_weight.view(-1, 1, 1) * x_j 144 | return x_j * alpha.unsqueeze(-1) 145 | 146 | def __repr__(self): 147 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 148 | self.in_channels, 149 | self.out_channels, self.heads) 150 | 151 | -------------------------------------------------------------------------------- /models/mygraphsage.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | import torch 5 | import torch.optim as optim 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from deeprobust.graph import utils 9 | from copy import deepcopy 10 | from sklearn.metrics import f1_score 11 | from torch.nn import init 12 | import torch_sparse 13 | from torch_geometric.data import NeighborSampler 14 | from torch_sparse import SparseTensor 15 | 16 | 17 | class SageConvolution(Module): 18 | def __init__(self, in_features, out_features, with_bias=True, root_weight=False): 19 | super(SageConvolution, self).__init__() 20 | self.in_features = in_features 21 | self.out_features = out_features 22 | self.weight_l = Parameter(torch.FloatTensor(in_features, out_features)) 23 | self.bias_l = Parameter(torch.FloatTensor(out_features)) 24 | self.weight_r = Parameter(torch.FloatTensor(in_features, out_features)) 25 | self.bias_r = Parameter(torch.FloatTensor(out_features)) 26 | self.reset_parameters() 27 | self.root_weight = root_weight 28 | # self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 29 | # self.linear = torch.nn.Linear(self.in_features, self.out_features) 30 | 31 | def reset_parameters(self): 32 | # stdv = 1. / math.sqrt(self.weight.size(1)) 33 | stdv = 1. / math.sqrt(self.weight_l.T.size(1)) 34 | self.weight_l.data.uniform_(-stdv, stdv) 35 | self.bias_l.data.uniform_(-stdv, stdv) 36 | 37 | stdv = 1. / math.sqrt(self.weight_r.T.size(1)) 38 | self.weight_r.data.uniform_(-stdv, stdv) 39 | self.bias_r.data.uniform_(-stdv, stdv) 40 | 41 | def forward(self, input, adj, size=None): 42 | if input.data.is_sparse: 43 | support = torch.spmm(input, self.weight_l) 44 | else: 45 | support = torch.mm(input, self.weight_l) 46 | if isinstance(adj, torch_sparse.SparseTensor): 47 | output = torch_sparse.matmul(adj, support) 48 | else: 49 | output = torch.spmm(adj, support) 50 | output = output + self.bias_l 51 | 52 | if self.root_weight: 53 | if size is not None: 54 | output = output + input[:size[1]] @ self.weight_r + self.bias_r 55 | else: 56 | output = output + input @ self.weight_r + self.bias_r 57 | else: 58 | output = output 59 | 60 | return output 61 | 62 | def __repr__(self): 63 | return self.__class__.__name__ + ' (' \ 64 | + str(self.in_features) + ' -> ' \ 65 | + str(self.out_features) + ')' 66 | 67 | 68 | class GraphSage(nn.Module): 69 | 70 | def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4, 71 | with_relu=True, with_bias=True, with_bn=False, device=None): 72 | 73 | super(GraphSage, self).__init__() 74 | 75 | assert device is not None, "Please specify 'device'!" 76 | self.device = device 77 | self.nfeat = nfeat 78 | self.nclass = nclass 79 | 80 | self.layers = nn.ModuleList([]) 81 | 82 | if nlayers == 1: 83 | self.layers.append(SageConvolution(nfeat, nclass, with_bias=with_bias)) 84 | else: 85 | if with_bn: 86 | self.bns = torch.nn.ModuleList() 87 | self.bns.append(nn.BatchNorm1d(nhid)) 88 | self.layers.append(SageConvolution(nfeat, nhid, with_bias=with_bias)) 89 | for i in range(nlayers-2): 90 | self.layers.append(SageConvolution(nhid, nhid, with_bias=with_bias)) 91 | if with_bn: 92 | self.bns.append(nn.BatchNorm1d(nhid)) 93 | self.layers.append(SageConvolution(nhid, nclass, with_bias=with_bias)) 94 | 95 | self.dropout = dropout 96 | self.lr = lr 97 | if not with_relu: 98 | self.weight_decay = 0 99 | else: 100 | self.weight_decay = weight_decay 101 | self.with_relu = with_relu 102 | self.with_bn = with_bn 103 | self.with_bias = with_bias 104 | self.output = None 105 | self.best_model = None 106 | self.best_output = None 107 | self.adj_norm = None 108 | self.features = None 109 | self.multi_label = None 110 | 111 | def forward(self, x, adj): 112 | for ix, layer in enumerate(self.layers): 113 | x = layer(x, adj) 114 | if ix != len(self.layers) - 1: 115 | x = self.bns[ix](x) if self.with_bn else x 116 | if self.with_relu: 117 | x = F.relu(x) 118 | x = F.dropout(x, self.dropout, training=self.training) 119 | 120 | if self.multi_label: 121 | return torch.sigmoid(x) 122 | else: 123 | return F.log_softmax(x, dim=1) 124 | 125 | def forward_sampler(self, x, adjs): 126 | # TODO: do we need normalization? 127 | # for ix, layer in enumerate(self.layers): 128 | for ix, (adj, _, size) in enumerate(adjs): 129 | # x_target = x[: size[1]] 130 | # x = self.layers[ix]((x, x_target), edge_index) 131 | # adj = adj.to(self.device) 132 | x = self.layers[ix](x, adj, size=size) 133 | if ix != len(self.layers) - 1: 134 | x = self.bns[ix](x) if self.with_bn else x 135 | if self.with_relu: 136 | x = F.relu(x) 137 | x = F.dropout(x, self.dropout, training=self.training) 138 | 139 | if self.multi_label: 140 | return torch.sigmoid(x) 141 | else: 142 | return F.log_softmax(x, dim=1) 143 | 144 | def forward_sampler_syn(self, x, adjs): 145 | for ix, (adj) in enumerate(adjs): 146 | x = self.layers[ix](x, adj) 147 | if ix != len(self.layers) - 1: 148 | x = self.bns[ix](x) if self.with_bn else x 149 | if self.with_relu: 150 | x = F.relu(x) 151 | x = F.dropout(x, self.dropout, training=self.training) 152 | 153 | if self.multi_label: 154 | return torch.sigmoid(x) 155 | else: 156 | return F.log_softmax(x, dim=1) 157 | 158 | 159 | def initialize(self): 160 | for layer in self.layers: 161 | layer.reset_parameters() 162 | if self.with_bn: 163 | for bn in self.bns: 164 | bn.reset_parameters() 165 | 166 | def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs): 167 | '''data: full data class''' 168 | if initialize: 169 | self.initialize() 170 | 171 | # features, adj, labels = data.feat_train, data.adj_train, data.labels_train 172 | if type(adj) is not torch.Tensor: 173 | features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device) 174 | else: 175 | features = features.to(self.device) 176 | adj = adj.to(self.device) 177 | labels = labels.to(self.device) 178 | 179 | if normalize: 180 | if utils.is_sparse_tensor(adj): 181 | adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 182 | else: 183 | adj_norm = utils.normalize_adj_tensor(adj) 184 | else: 185 | adj_norm = adj 186 | 187 | if 'feat_norm' in kwargs and kwargs['feat_norm']: 188 | from utils import row_normalize_tensor 189 | features = row_normalize_tensor(features-features.min()) 190 | 191 | self.adj_norm = adj_norm 192 | self.features = features 193 | 194 | if len(labels.shape) > 1: 195 | self.multi_label = True 196 | self.loss = torch.nn.BCELoss() 197 | else: 198 | self.multi_label = False 199 | self.loss = F.nll_loss 200 | 201 | labels = labels.float() if self.multi_label else labels 202 | self.labels = labels 203 | 204 | if noval: 205 | self._train_with_val(labels, data, train_iters, verbose, adj_val=True) 206 | else: 207 | self._train_with_val(labels, data, train_iters, verbose) 208 | 209 | def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False): 210 | if adj_val: 211 | feat_full, adj_full = data.feat_val, data.adj_val 212 | else: 213 | feat_full, adj_full = data.feat_full, data.adj_full 214 | feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device) 215 | adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True) 216 | labels_val = torch.LongTensor(data.labels_val).to(self.device) 217 | 218 | if verbose: 219 | print('=== training gcn model ===') 220 | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 221 | 222 | adj_norm = self.adj_norm 223 | node_idx = torch.arange(adj_norm.shape[0]).long() 224 | 225 | edge_index = adj_norm.nonzero().T 226 | adj_norm = SparseTensor(row=edge_index[0], col=edge_index[1], 227 | value=adj_norm[edge_index[0], edge_index[1]], sparse_sizes=adj_norm.size()).t() 228 | # edge_index = adj_norm._indices() 229 | # adj_norm = SparseTensor(row=edge_index[0], col=edge_index[1], 230 | # value=adj_norm._values(), sparse_sizes=adj_norm.size()).t() 231 | 232 | if adj_norm.density() > 0.5: # if the weighted graph is too dense, we need a larger neighborhood size 233 | sizes = [30, 20] 234 | else: 235 | sizes = [5, 5] 236 | train_loader = NeighborSampler(adj_norm, 237 | node_idx=node_idx, 238 | sizes=sizes, batch_size=len(node_idx), 239 | num_workers=0, return_e_id=False, 240 | num_nodes=adj_norm.size(0), 241 | shuffle=True) 242 | 243 | best_acc_val = 0 244 | for i in range(train_iters): 245 | if i == train_iters // 2: 246 | lr = self.lr*0.1 247 | optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) 248 | 249 | self.train() 250 | # optimizer.zero_grad() 251 | # output = self.forward(self.features, self.adj_norm) 252 | # loss_train = self.loss(output, labels) 253 | # loss_train.backward() 254 | # optimizer.step() 255 | 256 | for batch_size, n_id, adjs in train_loader: 257 | adjs = [adj.to(self.device) for adj in adjs] 258 | optimizer.zero_grad() 259 | out = self.forward_sampler(self.features[n_id], adjs) 260 | loss_train = F.nll_loss(out, labels[n_id[:batch_size]]) 261 | loss_train.backward() 262 | optimizer.step() 263 | 264 | 265 | if verbose and i % 100 == 0: 266 | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) 267 | 268 | with torch.no_grad(): 269 | self.eval() 270 | output = self.forward(feat_full, adj_full_norm) 271 | if adj_val: 272 | loss_val = F.nll_loss(output, labels_val) 273 | acc_val = utils.accuracy(output, labels_val) 274 | else: 275 | loss_val = F.nll_loss(output[data.idx_val], labels_val) 276 | acc_val = utils.accuracy(output[data.idx_val], labels_val) 277 | 278 | if acc_val > best_acc_val: 279 | best_acc_val = acc_val 280 | self.output = output 281 | weights = deepcopy(self.state_dict()) 282 | 283 | if verbose: 284 | print('=== picking the best model according to the performance on validation ===') 285 | self.load_state_dict(weights) 286 | 287 | def test(self, idx_test): 288 | self.eval() 289 | output = self.predict() 290 | # output = self.output 291 | loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) 292 | acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) 293 | print("Test set results:", 294 | "loss= {:.4f}".format(loss_test.item()), 295 | "accuracy= {:.4f}".format(acc_test.item())) 296 | return acc_test.item() 297 | 298 | 299 | @torch.no_grad() 300 | def predict(self, features=None, adj=None): 301 | 302 | self.eval() 303 | if features is None and adj is None: 304 | return self.forward(self.features, self.adj_norm) 305 | else: 306 | if type(adj) is not torch.Tensor: 307 | features, adj = utils.to_tensor(features, adj, device=self.device) 308 | 309 | self.features = features 310 | if utils.is_sparse_tensor(adj): 311 | self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 312 | else: 313 | self.adj_norm = utils.normalize_adj_tensor(adj) 314 | return self.forward(self.features, self.adj_norm) 315 | 316 | @torch.no_grad() 317 | def predict_unnorm(self, features=None, adj=None): 318 | self.eval() 319 | if features is None and adj is None: 320 | return self.forward(self.features, self.adj_norm) 321 | else: 322 | if type(adj) is not torch.Tensor: 323 | features, adj = utils.to_tensor(features, adj, device=self.device) 324 | 325 | self.features = features 326 | self.adj_norm = adj 327 | return self.forward(self.features, self.adj_norm) 328 | 329 | -------------------------------------------------------------------------------- /models/parametrized_adj.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | import torch 5 | import torch.optim as optim 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from itertools import product 9 | import numpy as np 10 | 11 | class PGE(nn.Module): 12 | 13 | def __init__(self, nfeat, nnodes, nhid=128, nlayers=3, device=None, args=None): 14 | super(PGE, self).__init__() 15 | if args.dataset in ['ogbn-arxiv', 'arxiv', 'flickr']: 16 | nhid = 256 17 | if args.dataset in ['reddit']: 18 | nhid = 256 19 | if args.reduction_rate==0.01: 20 | nhid = 128 21 | nlayers = 3 22 | # nhid = 128 23 | 24 | self.layers = nn.ModuleList([]) 25 | self.layers.append(nn.Linear(nfeat*2, nhid)) 26 | self.bns = torch.nn.ModuleList() 27 | self.bns.append(nn.BatchNorm1d(nhid)) 28 | for i in range(nlayers-2): 29 | self.layers.append(nn.Linear(nhid, nhid)) 30 | self.bns.append(nn.BatchNorm1d(nhid)) 31 | self.layers.append(nn.Linear(nhid, 1)) 32 | 33 | edge_index = np.array(list(product(range(nnodes), range(nnodes)))) 34 | self.edge_index = edge_index.T 35 | self.nnodes = nnodes 36 | self.device = device 37 | self.reset_parameters() 38 | self.cnt = 0 39 | self.args = args 40 | self.nnodes = nnodes 41 | 42 | def forward(self, x, inference=False): 43 | if self.args.dataset == 'reddit' and self.args.reduction_rate >= 0.01: 44 | edge_index = self.edge_index 45 | n_part = 5 46 | splits = np.array_split(np.arange(edge_index.shape[1]), n_part) 47 | edge_embed = [] 48 | for idx in splits: 49 | tmp_edge_embed = torch.cat([x[edge_index[0][idx]], 50 | x[edge_index[1][idx]]], axis=1) 51 | for ix, layer in enumerate(self.layers): 52 | tmp_edge_embed = layer(tmp_edge_embed) 53 | if ix != len(self.layers) - 1: 54 | tmp_edge_embed = self.bns[ix](tmp_edge_embed) 55 | tmp_edge_embed = F.relu(tmp_edge_embed) 56 | edge_embed.append(tmp_edge_embed) 57 | edge_embed = torch.cat(edge_embed) 58 | else: 59 | edge_index = self.edge_index 60 | edge_embed = torch.cat([x[edge_index[0]], 61 | x[edge_index[1]]], axis=1) 62 | for ix, layer in enumerate(self.layers): 63 | edge_embed = layer(edge_embed) 64 | if ix != len(self.layers) - 1: 65 | edge_embed = self.bns[ix](edge_embed) 66 | edge_embed = F.relu(edge_embed) 67 | 68 | adj = edge_embed.reshape(self.nnodes, self.nnodes) 69 | 70 | adj = (adj + adj.T)/2 71 | adj = torch.sigmoid(adj) 72 | adj = adj - torch.diag(torch.diag(adj, 0)) 73 | return adj 74 | 75 | @torch.no_grad() 76 | def inference(self, x): 77 | # self.eval() 78 | adj_syn = self.forward(x, inference=True) 79 | return adj_syn 80 | 81 | def reset_parameters(self): 82 | def weight_reset(m): 83 | if isinstance(m, nn.Linear): 84 | m.reset_parameters() 85 | if isinstance(m, nn.BatchNorm1d): 86 | m.reset_parameters() 87 | self.apply(weight_reset) 88 | 89 | -------------------------------------------------------------------------------- /models/sgc.py: -------------------------------------------------------------------------------- 1 | '''one transformation with multiple propagation''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import torch 6 | import torch.optim as optim 7 | from torch.nn.parameter import Parameter 8 | from torch.nn.modules.module import Module 9 | from deeprobust.graph import utils 10 | from copy import deepcopy 11 | from sklearn.metrics import f1_score 12 | from torch.nn import init 13 | import torch_sparse 14 | 15 | class GraphConvolution(Module): 16 | 17 | def __init__(self, in_features, out_features, with_bias=True): 18 | super(GraphConvolution, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 22 | self.bias = Parameter(torch.FloatTensor(out_features)) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | # stdv = 1. / math.sqrt(self.weight.size(1)) 27 | stdv = 1. / math.sqrt(self.weight.T.size(1)) 28 | self.weight.data.uniform_(-stdv, stdv) 29 | if self.bias is not None: 30 | self.bias.data.uniform_(-stdv, stdv) 31 | 32 | def forward(self, input, adj): 33 | if input.data.is_sparse: 34 | support = torch.spmm(input, self.weight) 35 | else: 36 | support = torch.mm(input, self.weight) 37 | if isinstance(adj, torch_sparse.SparseTensor): 38 | output = torch_sparse.matmul(adj, support) 39 | else: 40 | output = torch.spmm(adj, support) 41 | if self.bias is not None: 42 | return output + self.bias 43 | else: 44 | return output 45 | 46 | def __repr__(self): 47 | return self.__class__.__name__ + ' (' \ 48 | + str(self.in_features) + ' -> ' \ 49 | + str(self.out_features) + ')' 50 | 51 | 52 | class SGC(nn.Module): 53 | 54 | def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4, 55 | with_relu=True, with_bias=True, with_bn=False, device=None): 56 | 57 | super(SGC, self).__init__() 58 | 59 | assert device is not None, "Please specify 'device'!" 60 | self.device = device 61 | self.nfeat = nfeat 62 | self.nclass = nclass 63 | 64 | self.conv = GraphConvolution(nfeat, nclass, with_bias=with_bias) 65 | 66 | self.nlayers = nlayers 67 | self.dropout = dropout 68 | self.lr = lr 69 | if not with_relu: 70 | self.weight_decay = 0 71 | else: 72 | self.weight_decay = weight_decay 73 | self.with_relu = with_relu 74 | if with_bn: 75 | print('Warning: SGC does not have bn!!!') 76 | self.with_bn = False 77 | self.with_bias = with_bias 78 | self.output = None 79 | self.best_model = None 80 | self.best_output = None 81 | self.adj_norm = None 82 | self.features = None 83 | self.multi_label = None 84 | 85 | def forward(self, x, adj): 86 | weight = self.conv.weight 87 | bias = self.conv.bias 88 | x = torch.mm(x, weight) 89 | for i in range(self.nlayers): 90 | x = torch.spmm(adj, x) 91 | x = x + bias 92 | if self.multi_label: 93 | return torch.sigmoid(x) 94 | else: 95 | return F.log_softmax(x, dim=1) 96 | 97 | def forward_sampler(self, x, adjs): 98 | weight = self.conv.weight 99 | bias = self.conv.bias 100 | x = torch.mm(x, weight) 101 | for ix, (adj, _, size) in enumerate(adjs): 102 | x = torch_sparse.matmul(adj, x) 103 | x = x + bias 104 | if self.multi_label: 105 | return torch.sigmoid(x) 106 | else: 107 | return F.log_softmax(x, dim=1) 108 | 109 | def forward_sampler_syn(self, x, adjs): 110 | weight = self.conv.weight 111 | bias = self.conv.bias 112 | x = torch.mm(x, weight) 113 | for ix, (adj) in enumerate(adjs): 114 | if type(adj) == torch.Tensor: 115 | x = adj @ x 116 | else: 117 | x = torch_sparse.matmul(adj, x) 118 | x = x + bias 119 | if self.multi_label: 120 | return torch.sigmoid(x) 121 | else: 122 | return F.log_softmax(x, dim=1) 123 | 124 | def initialize(self): 125 | self.conv.reset_parameters() 126 | if self.with_bn: 127 | for bn in self.bns: 128 | bn.reset_parameters() 129 | 130 | def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs): 131 | if initialize: 132 | self.initialize() 133 | 134 | # features, adj, labels = data.feat_train, data.adj_train, data.labels_train 135 | if type(adj) is not torch.Tensor: 136 | features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device) 137 | else: 138 | features = features.to(self.device) 139 | adj = adj.to(self.device) 140 | labels = labels.to(self.device) 141 | 142 | if normalize: 143 | if utils.is_sparse_tensor(adj): 144 | adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 145 | else: 146 | adj_norm = utils.normalize_adj_tensor(adj) 147 | else: 148 | adj_norm = adj 149 | 150 | if 'feat_norm' in kwargs and kwargs['feat_norm']: 151 | from utils import row_normalize_tensor 152 | features = row_normalize_tensor(features-features.min()) 153 | 154 | self.adj_norm = adj_norm 155 | self.features = features 156 | 157 | if len(labels.shape) > 1: 158 | self.multi_label = True 159 | self.loss = torch.nn.BCELoss() 160 | else: 161 | self.multi_label = False 162 | self.loss = F.nll_loss 163 | 164 | labels = labels.float() if self.multi_label else labels 165 | self.labels = labels 166 | 167 | if noval: 168 | self._train_with_val(labels, data, train_iters, verbose, adj_val=True) 169 | else: 170 | self._train_with_val(labels, data, train_iters, verbose) 171 | 172 | def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False): 173 | if adj_val: 174 | feat_full, adj_full = data.feat_val, data.adj_val 175 | else: 176 | feat_full, adj_full = data.feat_full, data.adj_full 177 | 178 | feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device) 179 | adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True) 180 | labels_val = torch.LongTensor(data.labels_val).to(self.device) 181 | 182 | if verbose: 183 | print('=== training gcn model ===') 184 | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 185 | 186 | best_acc_val = 0 187 | 188 | for i in range(train_iters): 189 | if i == train_iters // 2: 190 | lr = self.lr*0.1 191 | optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) 192 | 193 | self.train() 194 | optimizer.zero_grad() 195 | output = self.forward(self.features, self.adj_norm) 196 | loss_train = self.loss(output, labels) 197 | loss_train.backward() 198 | optimizer.step() 199 | 200 | if verbose and i % 100 == 0: 201 | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) 202 | 203 | with torch.no_grad(): 204 | self.eval() 205 | output = self.forward(feat_full, adj_full_norm) 206 | if adj_val: 207 | loss_val = F.nll_loss(output, labels_val) 208 | acc_val = utils.accuracy(output, labels_val) 209 | else: 210 | loss_val = F.nll_loss(output[data.idx_val], labels_val) 211 | acc_val = utils.accuracy(output[data.idx_val], labels_val) 212 | 213 | if acc_val > best_acc_val: 214 | best_acc_val = acc_val 215 | self.output = output 216 | weights = deepcopy(self.state_dict()) 217 | 218 | if verbose: 219 | print('=== picking the best model according to the performance on validation ===') 220 | self.load_state_dict(weights) 221 | 222 | 223 | def test(self, idx_test): 224 | self.eval() 225 | output = self.predict() 226 | # output = self.output 227 | loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) 228 | acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) 229 | print("Test set results:", 230 | "loss= {:.4f}".format(loss_test.item()), 231 | "accuracy= {:.4f}".format(acc_test.item())) 232 | return acc_test.item() 233 | 234 | 235 | @torch.no_grad() 236 | def predict(self, features=None, adj=None): 237 | 238 | self.eval() 239 | if features is None and adj is None: 240 | return self.forward(self.features, self.adj_norm) 241 | else: 242 | if type(adj) is not torch.Tensor: 243 | features, adj = utils.to_tensor(features, adj, device=self.device) 244 | 245 | self.features = features 246 | if utils.is_sparse_tensor(adj): 247 | self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 248 | else: 249 | self.adj_norm = utils.normalize_adj_tensor(adj) 250 | return self.forward(self.features, self.adj_norm) 251 | 252 | @torch.no_grad() 253 | def predict_unnorm(self, features=None, adj=None): 254 | self.eval() 255 | if features is None and adj is None: 256 | return self.forward(self.features, self.adj_norm) 257 | else: 258 | if type(adj) is not torch.Tensor: 259 | features, adj = utils.to_tensor(features, adj, device=self.device) 260 | 261 | self.features = features 262 | self.adj_norm = adj 263 | return self.forward(self.features, self.adj_norm) 264 | 265 | 266 | -------------------------------------------------------------------------------- /models/sgc_multi.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | import torch 5 | import torch.optim as optim 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from deeprobust.graph import utils 9 | from copy import deepcopy 10 | from sklearn.metrics import f1_score 11 | from torch.nn import init 12 | import torch_sparse 13 | 14 | 15 | class SGC(nn.Module): 16 | 17 | def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4, 18 | ntrans=2, with_bias=True, with_bn=False, device=None): 19 | 20 | super(SGC, self).__init__() 21 | 22 | assert device is not None, "Please specify 'device'!" 23 | self.device = device 24 | self.nfeat = nfeat 25 | self.nclass = nclass 26 | 27 | 28 | self.layers = nn.ModuleList([]) 29 | if ntrans == 1: 30 | self.layers.append(MyLinear(nfeat, nclass)) 31 | else: 32 | self.layers.append(MyLinear(nfeat, nhid)) 33 | if with_bn: 34 | self.bns = torch.nn.ModuleList() 35 | self.bns.append(nn.BatchNorm1d(nhid)) 36 | for i in range(ntrans-2): 37 | if with_bn: 38 | self.bns.append(nn.BatchNorm1d(nhid)) 39 | self.layers.append(MyLinear(nhid, nhid)) 40 | self.layers.append(MyLinear(nhid, nclass)) 41 | 42 | self.nlayers = nlayers 43 | self.dropout = dropout 44 | self.lr = lr 45 | self.with_bn = with_bn 46 | self.with_bias = with_bias 47 | self.weight_decay = weight_decay 48 | self.output = None 49 | self.best_model = None 50 | self.best_output = None 51 | self.adj_norm = None 52 | self.features = None 53 | self.multi_label = None 54 | 55 | def forward(self, x, adj): 56 | for ix, layer in enumerate(self.layers): 57 | x = layer(x) 58 | if ix != len(self.layers) - 1: 59 | x = self.bns[ix](x) if self.with_bn else x 60 | x = F.relu(x) 61 | x = F.dropout(x, self.dropout, training=self.training) 62 | 63 | for i in range(self.nlayers): 64 | x = torch.spmm(adj, x) 65 | 66 | if self.multi_label: 67 | return torch.sigmoid(x) 68 | else: 69 | return F.log_softmax(x, dim=1) 70 | 71 | def forward_sampler(self, x, adjs): 72 | for ix, layer in enumerate(self.layers): 73 | x = layer(x) 74 | if ix != len(self.layers) - 1: 75 | x = self.bns[ix](x) if self.with_bn else x 76 | x = F.relu(x) 77 | x = F.dropout(x, self.dropout, training=self.training) 78 | 79 | for ix, (adj, _, size) in enumerate(adjs): 80 | # x_target = x[: size[1]] 81 | # x = self.layers[ix]((x, x_target), edge_index) 82 | # adj = adj.to(self.device) 83 | x = torch_sparse.matmul(adj, x) 84 | 85 | if self.multi_label: 86 | return torch.sigmoid(x) 87 | else: 88 | return F.log_softmax(x, dim=1) 89 | 90 | def forward_sampler_syn(self, x, adjs): 91 | for ix, layer in enumerate(self.layers): 92 | x = layer(x) 93 | if ix != len(self.layers) - 1: 94 | x = self.bns[ix](x) if self.with_bn else x 95 | x = F.relu(x) 96 | x = F.dropout(x, self.dropout, training=self.training) 97 | 98 | for ix, (adj) in enumerate(adjs): 99 | if type(adj) == torch.Tensor: 100 | x = adj @ x 101 | else: 102 | x = torch_sparse.matmul(adj, x) 103 | 104 | if self.multi_label: 105 | return torch.sigmoid(x) 106 | else: 107 | return F.log_softmax(x, dim=1) 108 | 109 | 110 | def initialize(self): 111 | for layer in self.layers: 112 | layer.reset_parameters() 113 | if self.with_bn: 114 | for bn in self.bns: 115 | bn.reset_parameters() 116 | 117 | def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs): 118 | if initialize: 119 | self.initialize() 120 | 121 | # features, adj, labels = data.feat_train, data.adj_train, data.labels_train 122 | if type(adj) is not torch.Tensor: 123 | features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device) 124 | else: 125 | features = features.to(self.device) 126 | adj = adj.to(self.device) 127 | labels = labels.to(self.device) 128 | 129 | if normalize: 130 | if utils.is_sparse_tensor(adj): 131 | adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 132 | else: 133 | adj_norm = utils.normalize_adj_tensor(adj) 134 | else: 135 | adj_norm = adj 136 | 137 | if 'feat_norm' in kwargs and kwargs['feat_norm']: 138 | from utils import row_normalize_tensor 139 | features = row_normalize_tensor(features-features.min()) 140 | 141 | self.adj_norm = adj_norm 142 | self.features = features 143 | 144 | if len(labels.shape) > 1: 145 | self.multi_label = True 146 | self.loss = torch.nn.BCELoss() 147 | else: 148 | self.multi_label = False 149 | self.loss = F.nll_loss 150 | 151 | labels = labels.float() if self.multi_label else labels 152 | self.labels = labels 153 | 154 | if noval: 155 | self._train_with_val(labels, data, train_iters, verbose, adj_val=True) 156 | else: 157 | self._train_with_val(labels, data, train_iters, verbose) 158 | 159 | def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False): 160 | if adj_val: 161 | feat_full, adj_full = data.feat_val, data.adj_val 162 | else: 163 | feat_full, adj_full = data.feat_full, data.adj_full 164 | 165 | feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device) 166 | adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True) 167 | labels_val = torch.LongTensor(data.labels_val).to(self.device) 168 | 169 | if verbose: 170 | print('=== training gcn model ===') 171 | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 172 | 173 | best_acc_val = 0 174 | 175 | for i in range(train_iters): 176 | if i == train_iters // 2: 177 | lr = self.lr*0.1 178 | optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) 179 | 180 | self.train() 181 | optimizer.zero_grad() 182 | output = self.forward(self.features, self.adj_norm) 183 | loss_train = self.loss(output, labels) 184 | loss_train.backward() 185 | optimizer.step() 186 | 187 | if verbose and i % 100 == 0: 188 | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) 189 | 190 | with torch.no_grad(): 191 | self.eval() 192 | output = self.forward(feat_full, adj_full_norm) 193 | if adj_val: 194 | loss_val = F.nll_loss(output, labels_val) 195 | acc_val = utils.accuracy(output, labels_val) 196 | else: 197 | loss_val = F.nll_loss(output[data.idx_val], labels_val) 198 | acc_val = utils.accuracy(output[data.idx_val], labels_val) 199 | 200 | if acc_val > best_acc_val: 201 | best_acc_val = acc_val 202 | self.output = output 203 | weights = deepcopy(self.state_dict()) 204 | 205 | if verbose: 206 | print('=== picking the best model according to the performance on validation ===') 207 | self.load_state_dict(weights) 208 | 209 | 210 | def test(self, idx_test): 211 | self.eval() 212 | output = self.predict() 213 | # output = self.output 214 | loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) 215 | acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) 216 | print("Test set results:", 217 | "loss= {:.4f}".format(loss_test.item()), 218 | "accuracy= {:.4f}".format(acc_test.item())) 219 | return acc_test.item() 220 | 221 | 222 | @torch.no_grad() 223 | def predict(self, features=None, adj=None): 224 | 225 | self.eval() 226 | if features is None and adj is None: 227 | return self.forward(self.features, self.adj_norm) 228 | else: 229 | if type(adj) is not torch.Tensor: 230 | features, adj = utils.to_tensor(features, adj, device=self.device) 231 | 232 | self.features = features 233 | if utils.is_sparse_tensor(adj): 234 | self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True) 235 | else: 236 | self.adj_norm = utils.normalize_adj_tensor(adj) 237 | return self.forward(self.features, self.adj_norm) 238 | 239 | @torch.no_grad() 240 | def predict_unnorm(self, features=None, adj=None): 241 | self.eval() 242 | if features is None and adj is None: 243 | return self.forward(self.features, self.adj_norm) 244 | else: 245 | if type(adj) is not torch.Tensor: 246 | features, adj = utils.to_tensor(features, adj, device=self.device) 247 | 248 | self.features = features 249 | self.adj_norm = adj 250 | return self.forward(self.features, self.adj_norm) 251 | 252 | 253 | 254 | class MyLinear(Module): 255 | 256 | def __init__(self, in_features, out_features, with_bias=True): 257 | super(MyLinear, self).__init__() 258 | self.in_features = in_features 259 | self.out_features = out_features 260 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 261 | if with_bias: 262 | self.bias = Parameter(torch.FloatTensor(out_features)) 263 | else: 264 | self.register_parameter('bias', None) 265 | self.reset_parameters() 266 | 267 | def reset_parameters(self): 268 | # stdv = 1. / math.sqrt(self.weight.size(1)) 269 | stdv = 1. / math.sqrt(self.weight.T.size(1)) 270 | self.weight.data.uniform_(-stdv, stdv) 271 | if self.bias is not None: 272 | self.bias.data.uniform_(-stdv, stdv) 273 | 274 | def forward(self, input): 275 | if input.data.is_sparse: 276 | support = torch.spmm(input, self.weight) 277 | else: 278 | support = torch.mm(input, self.weight) 279 | output = support 280 | if self.bias is not None: 281 | return output + self.bias 282 | else: 283 | return output 284 | 285 | def __repr__(self): 286 | return self.__class__.__name__ + ' (' \ 287 | + str(self.in_features) + ' -> ' \ 288 | + str(self.out_features) + ')' 289 | 290 | 291 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | torch_geometric==1.6.3 3 | scipy==1.6.2 4 | numpy==1.19.2 5 | ogb==1.3.0 6 | tqdm==4.59.0 7 | torch_sparse==0.6.9 8 | torchvision==0.8.0 9 | configs==3.0.3 10 | deeprobust==0.2.4 11 | scikit_learn==1.0.2 12 | -------------------------------------------------------------------------------- /train_SGDD.py: -------------------------------------------------------------------------------- 1 | from deeprobust.graph.data import Dataset 2 | import numpy as np 3 | import random 4 | import time 5 | import argparse 6 | import torch 7 | from utils import * 8 | import torch.nn.functional as F 9 | from SGDD_agent import SGDD 10 | from utils_graphsaint import DataGraphSAINT 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--gpu_id', type=int, default=7, help='gpu id') 14 | # parser.add_argument('--dataset', type=str, default='cora') 15 | # parser.add_argument('--dataset', type=str, default='citeseer') 16 | # parser.add_argument('--dataset', type=str, default='flickr') 17 | parser.add_argument('--dataset', type=str, default='ogbn-arxiv') 18 | # parser.add_argument('--dataset', type=str, default='yelpchi') 19 | # parser.add_argument('--dataset', type=str, default='sbm') 20 | parser.add_argument('--dis_metric', type=str, default='ours') 21 | parser.add_argument('--epochs', type=int, default=2000) 22 | # parser.add_argument('--nlayers', type=int, default=3) 23 | parser.add_argument('--nlayers', type=int, default=2) 24 | parser.add_argument('--hidden', type=int, default=256) 25 | parser.add_argument('--lr_adj', type=float, default=1e-4) 26 | parser.add_argument('--lr_feat', type=float, default=1e-4) 27 | # parser.add_argument('--lr_adj', type=float, default=0.01) 28 | # parser.add_argument('--lr_feat', type=float, default=0.01) 29 | parser.add_argument('--lr_model', type=float, default=0.01) 30 | parser.add_argument('--weight_decay', type=float, default=0.0) 31 | parser.add_argument('--dropout', type=float, default=0.0) 32 | parser.add_argument('--normalize_features', type=bool, default=True) 33 | parser.add_argument('--keep_ratio', type=float, default=1.0) 34 | parser.add_argument('--reduction_rate', type=float, default=0.1) 35 | parser.add_argument('--seed', type=int, default=15, help='Random seed.') # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] 36 | parser.add_argument('--beta', type=float, default=0.5, help='regularization term.') 37 | parser.add_argument('--ep_ratio', type=float, default=0.5, help='control the ratio of direct \ 38 | edges predict term in the graph.') 39 | parser.add_argument('--sinkhorn_iter', type=int, default=5, help='use sinkhorn iteration to \ 40 | warm-up the transport plan.') 41 | parser.add_argument('--opt_scale', type=float, default=1e-10, help='control the scale of the opt loss') 42 | parser.add_argument("--ignr_epochs", type=int, default=400, help="use the few epochs to warm-up structure learning") 43 | # parser.add_argument('--mx_size', type=int, default=2708, help='max size of the matrix to') 44 | parser.add_argument('--debug', type=int, default=0) 45 | parser.add_argument('--option', type=int, default=0) 46 | parser.add_argument('--sgc', type=int, default=1) 47 | parser.add_argument('--inner', type=int, default=0) 48 | parser.add_argument('--outer', type=int, default=20) 49 | parser.add_argument('--save', type=int, default=1) 50 | parser.add_argument('--one_step', type=int, default=0) 51 | parser.add_argument('--mode', type=str, default='disabled', help='whether to use the wandb') 52 | args = parser.parse_args() 53 | 54 | torch.cuda.set_device(args.gpu_id) 55 | 56 | # random seed setting 57 | random.seed(args.seed) 58 | np.random.seed(args.seed) 59 | torch.manual_seed(args.seed) 60 | torch.cuda.manual_seed(args.seed) 61 | 62 | print(args) 63 | 64 | data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv'] 65 | if args.dataset in data_graphsaint: 66 | data = DataGraphSAINT(args.dataset) 67 | data_full = data.data_full 68 | else: 69 | data_full = get_dataset(args.dataset, args.normalize_features) 70 | data = Transd2Ind(data_full, keep_ratio=args.keep_ratio) 71 | 72 | if data_full.adj.shape[0] < 5000: 73 | args.mx_size = data_full.adj.shape[0] 74 | else: 75 | args.mx_size = 5000 76 | data_full.adj_mx = data_full.adj[:args.mx_size, :args.mx_size] 77 | agent = SGDD(data, args, device='cuda') 78 | 79 | agent.train() 80 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import scipy.sparse as sp 4 | import torch 5 | import torch_geometric.transforms as T 6 | from ogb.nodeproppred import PygNodePropPredDataset, Evaluator 7 | from deeprobust.graph.data import Dataset 8 | from deeprobust.graph.utils import get_train_val_test 9 | from torch_geometric.utils import train_test_split_edges 10 | from sklearn.model_selection import train_test_split 11 | from sklearn import metrics 12 | import numpy as np 13 | import torch.nn.functional as F 14 | from sklearn.preprocessing import StandardScaler 15 | from deeprobust.graph.utils import * 16 | from torch_geometric.data import NeighborSampler 17 | from torch_geometric.utils import add_remaining_self_loops, to_undirected, remove_self_loops 18 | from torch_geometric.datasets import Planetoid, StochasticBlockModelDataset 19 | 20 | from typing import Any 21 | from dgl.data import FraudDataset 22 | 23 | try: 24 | from oolongTool.PostMessage import Wechat 25 | P = Wechat.P 26 | __builtins__["P"] = P 27 | except: 28 | P = "P" 29 | 30 | 31 | def get_dataset(name: str, normalize_features=False, transform=None, if_dpr=True): 32 | path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', name) 33 | if name in ['cora', 'citeseer', 'pubmed']: 34 | dataset = Planetoid(path, name) 35 | elif name in ['ogbn-arxiv']: 36 | dataset = PygNodePropPredDataset(name='ogbn-arxiv') 37 | elif name in ['yelpchi', 'amazon']: 38 | name = 'yelp' if name == 'yelpchi' else 'amazon' 39 | dataset = FraudDataset(name, raw_dir=path) 40 | dataset = from_dgl(dataset[0], name=name, hetero=False) 41 | elif name.lower() == "sbm": 42 | num_nodes = [2000] * 3 43 | edge_probs = [[0.1, 0.05, 0.02], 44 | [0.05, 0.1, 0.02], 45 | [0.02, 0.02, 0.1]] 46 | dataset = StochasticBlockModelDataset(path, num_nodes, edge_probs, num_channels=32) 47 | dataset.name = "SBM" 48 | else: 49 | raise NotImplementedError 50 | 51 | if transform is not None and normalize_features: 52 | dataset.transform = T.Compose([T.NormalizeFeatures(), transform]) 53 | elif normalize_features: 54 | dataset.transform = T.NormalizeFeatures() 55 | elif transform is not None: 56 | dataset.transform = transform 57 | 58 | dpr_data = Pyg2Dpr(dataset) 59 | if name in ['ogbn-arxiv', 'sbm']: 60 | 61 | 62 | feat, idx_train = dpr_data.features, dpr_data.idx_train 63 | feat_train = feat[idx_train] 64 | scaler = StandardScaler() 65 | scaler.fit(feat_train) 66 | feat = scaler.transform(feat) 67 | dpr_data.features = feat 68 | 69 | return dpr_data 70 | 71 | 72 | class Pyg2Dpr(Dataset): 73 | def __init__(self, pyg_data, **kwargs): 74 | try: 75 | splits = pyg_data.get_idx_split() 76 | except: 77 | pass 78 | 79 | dataset_name = pyg_data.name 80 | try: 81 | pyg_data = pyg_data[0] 82 | except TypeError: 83 | pyg_data = pyg_data 84 | n = pyg_data.num_nodes 85 | 86 | if not n: 87 | n = pyg_data.x.shape[0] 88 | 89 | if dataset_name == 'ogbn-arxiv': 90 | pyg_data.edge_index = to_undirected(pyg_data.edge_index, pyg_data.num_nodes) 91 | from torch_geometric.data import HeteroData 92 | 93 | self.adj = sp.csr_matrix((np.ones(pyg_data.edge_index.shape[1]), 94 | (pyg_data.edge_index[0], pyg_data.edge_index[1])), shape=(n, n)) 95 | 96 | self.features = pyg_data.x.numpy() 97 | self.labels = pyg_data.y.numpy() 98 | 99 | if len(self.labels.shape) == 2 and self.labels.shape[1] == 1: 100 | self.labels = self.labels.reshape(-1) 101 | 102 | if hasattr(pyg_data, 'train_mask'): 103 | 104 | self.idx_train = mask_to_index(pyg_data.train_mask, n) 105 | self.idx_val = mask_to_index(pyg_data.val_mask, n) 106 | self.idx_test = mask_to_index(pyg_data.test_mask, n) 107 | self.name = 'Pyg2Dpr' 108 | else: 109 | try: 110 | 111 | self.idx_train = splits['train'] 112 | self.idx_val = splits['valid'] 113 | self.idx_test = splits['test'] 114 | self.name = 'Pyg2Dpr' 115 | except: 116 | 117 | self.idx_train, self.idx_val, self.idx_test = get_train_val_test( 118 | nnodes=n, val_size=0.1, test_size=0.8, stratify=self.labels) 119 | 120 | 121 | def mask_to_index(index, size): 122 | all_idx = np.arange(size) 123 | return all_idx[index] 124 | 125 | def index_to_mask(index, size): 126 | mask = torch.zeros((size, ), dtype=torch.bool) 127 | mask[index] = 1 128 | return mask 129 | 130 | 131 | 132 | class Transd2Ind: 133 | 134 | 135 | def __init__(self, dpr_data, keep_ratio): 136 | idx_train, idx_val, idx_test = dpr_data.idx_train, dpr_data.idx_val, dpr_data.idx_test 137 | adj, features, labels = dpr_data.adj, dpr_data.features, dpr_data.labels 138 | self.nclass = labels.max()+1 139 | self.adj_full, self.feat_full, self.labels_full = adj, features, labels 140 | self.idx_train = np.array(idx_train) 141 | self.idx_val = np.array(idx_val) 142 | self.idx_test = np.array(idx_test) 143 | 144 | if keep_ratio < 1: 145 | idx_train, _ = train_test_split(idx_train, 146 | random_state=None, 147 | train_size=keep_ratio, 148 | test_size=1-keep_ratio, 149 | stratify=labels[idx_train]) 150 | 151 | self.adj_train = adj[np.ix_(idx_train, idx_train)] 152 | self.adj_val = adj[np.ix_(idx_val, idx_val)] 153 | self.adj_test = adj[np.ix_(idx_test, idx_test)] 154 | print('size of adj_train:', self.adj_train.shape) 155 | print('#edges in adj_train:', self.adj_train.sum()) 156 | 157 | self.labels_train = labels[idx_train] 158 | self.labels_val = labels[idx_val] 159 | self.labels_test = labels[idx_test] 160 | 161 | self.feat_train = features[idx_train] 162 | self.feat_val = features[idx_val] 163 | self.feat_test = features[idx_test] 164 | 165 | self.class_dict = None 166 | self.samplers = None 167 | self.class_dict2 = None 168 | 169 | def retrieve_class(self, c, num=256): 170 | if self.class_dict is None: 171 | self.class_dict = {} 172 | for i in range(self.nclass): 173 | self.class_dict['class_%s'%i] = (self.labels_train == i) 174 | idx = np.arange(len(self.labels_train)) 175 | idx = idx[self.class_dict['class_%s'%c]] 176 | return np.random.permutation(idx)[:num] 177 | 178 | def retrieve_class_sampler(self, c, adj, transductive, num=256, args=None): 179 | if self.class_dict2 is None: 180 | self.class_dict2 = {} 181 | for i in range(self.nclass): 182 | if transductive: 183 | idx = self.idx_train[self.labels_train == i] 184 | else: 185 | idx = np.arange(len(self.labels_train))[self.labels_train==i] 186 | self.class_dict2[i] = idx 187 | 188 | if args.nlayers == 1: 189 | sizes = [15] 190 | if args.nlayers == 2: 191 | sizes = [10, 5] 192 | 193 | if args.nlayers == 3: 194 | sizes = [15, 10, 5] 195 | if args.nlayers == 4: 196 | sizes = [15, 10, 5, 5] 197 | if args.nlayers == 5: 198 | sizes = [15, 10, 5, 5, 5] 199 | 200 | 201 | if self.samplers is None: 202 | self.samplers = [] 203 | for i in range(self.nclass): 204 | node_idx = torch.LongTensor(self.class_dict2[i]) 205 | self.samplers.append(NeighborSampler(adj, 206 | node_idx=node_idx, 207 | sizes=sizes, batch_size=num, 208 | num_workers=12, return_e_id=False, 209 | num_nodes=adj.size(0), 210 | shuffle=True)) 211 | batch = np.random.permutation(self.class_dict2[c])[:num] 212 | out = self.samplers[c].sample(batch) 213 | return out 214 | 215 | def retrieve_class_multi_sampler(self, c, adj, transductive, num=256, args=None): 216 | if self.class_dict2 is None: 217 | self.class_dict2 = {} 218 | for i in range(self.nclass): 219 | if transductive: 220 | idx = self.idx_train[self.labels_train == i] 221 | else: 222 | idx = np.arange(len(self.labels_train))[self.labels_train==i] 223 | self.class_dict2[i] = idx 224 | 225 | 226 | if self.samplers is None: 227 | self.samplers = [] 228 | for l in range(2): 229 | layer_samplers = [] 230 | sizes = [15] if l == 0 else [10, 5] 231 | for i in range(self.nclass): 232 | node_idx = torch.LongTensor(self.class_dict2[i]) 233 | layer_samplers.append(NeighborSampler(adj, 234 | node_idx=node_idx, 235 | sizes=sizes, batch_size=num, 236 | num_workers=12, return_e_id=False, 237 | num_nodes=adj.size(0), 238 | shuffle=True)) 239 | self.samplers.append(layer_samplers) 240 | batch = np.random.permutation(self.class_dict2[c])[:num] 241 | out = self.samplers[args.nlayers-1][c].sample(batch) 242 | return out 243 | 244 | 245 | 246 | def match_loss(gw_syn, gw_real, args, device): 247 | dis = torch.tensor(0.0).to(device) 248 | 249 | if args.dis_metric == 'ours': 250 | 251 | for ig in range(len(gw_real)): 252 | gwr = gw_real[ig] 253 | gws = gw_syn[ig] 254 | dis += distance_wb(gwr, gws) 255 | 256 | elif args.dis_metric == 'mse': 257 | gw_real_vec = [] 258 | gw_syn_vec = [] 259 | for ig in range(len(gw_real)): 260 | gw_real_vec.append(gw_real[ig].reshape((-1))) 261 | gw_syn_vec.append(gw_syn[ig].reshape((-1))) 262 | gw_real_vec = torch.cat(gw_real_vec, dim=0) 263 | gw_syn_vec = torch.cat(gw_syn_vec, dim=0) 264 | dis = torch.sum((gw_syn_vec - gw_real_vec)**2) 265 | 266 | elif args.dis_metric == 'cos': 267 | gw_real_vec = [] 268 | gw_syn_vec = [] 269 | for ig in range(len(gw_real)): 270 | gw_real_vec.append(gw_real[ig].reshape((-1))) 271 | gw_syn_vec.append(gw_syn[ig].reshape((-1))) 272 | gw_real_vec = torch.cat(gw_real_vec, dim=0) 273 | gw_syn_vec = torch.cat(gw_syn_vec, dim=0) 274 | dis = 1 - torch.sum(gw_real_vec * gw_syn_vec, dim=-1) / (torch.norm(gw_real_vec, dim=-1) * torch.norm(gw_syn_vec, dim=-1) + 0.000001) 275 | 276 | else: 277 | exit('DC error: unknown distance function') 278 | 279 | return dis 280 | 281 | def distance_wb(gwr, gws): 282 | shape = gwr.shape 283 | 284 | 285 | if len(gwr.shape) == 2: 286 | gwr = gwr.T 287 | gws = gws.T 288 | 289 | if len(shape) == 4: 290 | gwr = gwr.reshape(shape[0], shape[1] * shape[2] * shape[3]) 291 | gws = gws.reshape(shape[0], shape[1] * shape[2] * shape[3]) 292 | elif len(shape) == 3: 293 | gwr = gwr.reshape(shape[0], shape[1] * shape[2]) 294 | gws = gws.reshape(shape[0], shape[1] * shape[2]) 295 | elif len(shape) == 2: 296 | tmp = 'do nothing' 297 | elif len(shape) == 1: 298 | gwr = gwr.reshape(1, shape[0]) 299 | gws = gws.reshape(1, shape[0]) 300 | return 0 301 | 302 | dis_weight = torch.sum(1 - torch.sum(gwr * gws, dim=-1) / (torch.norm(gwr, dim=-1) * torch.norm(gws, dim=-1) + 0.000001)) 303 | dis = dis_weight 304 | return dis 305 | 306 | 307 | 308 | def calc_f1(y_true, y_pred,is_sigmoid): 309 | if not is_sigmoid: 310 | y_pred = np.argmax(y_pred, axis=1) 311 | else: 312 | y_pred[y_pred > 0.5] = 1 313 | y_pred[y_pred <= 0.5] = 0 314 | return metrics.f1_score(y_true, y_pred, average="micro"), metrics.f1_score(y_true, y_pred, average="macro") 315 | 316 | def evaluate(output, labels, args): 317 | data_graphsaint = ['yelp', 'ppi', 'ppi-large', 'flickr', 'reddit', 'amazon'] 318 | if args.dataset in data_graphsaint: 319 | labels = labels.cpu().numpy() 320 | output = output.cpu().numpy() 321 | if len(labels.shape) > 1: 322 | micro, macro = calc_f1(labels, output, is_sigmoid=True) 323 | else: 324 | micro, macro = calc_f1(labels, output, is_sigmoid=False) 325 | print("Test set results:", "F1-micro= {:.4f}".format(micro), 326 | "F1-macro= {:.4f}".format(macro)) 327 | else: 328 | loss_test = F.nll_loss(output, labels) 329 | acc_test = accuracy(output, labels) 330 | print("Test set results:", 331 | "loss= {:.4f}".format(loss_test.item()), 332 | "accuracy= {:.4f}".format(acc_test.item())) 333 | return 334 | 335 | 336 | from torchvision import datasets, transforms 337 | def get_mnist(data_path): 338 | channel = 1 339 | im_size = (28, 28) 340 | num_classes = 10 341 | mean = [0.1307] 342 | std = [0.3081] 343 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 344 | dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) 345 | dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform) 346 | class_names = [str(c) for c in range(num_classes)] 347 | 348 | labels = [] 349 | feat = [] 350 | for x, y in dst_train: 351 | feat.append(x.view(1, -1)) 352 | labels.append(y) 353 | feat = torch.cat(feat, axis=0).numpy() 354 | from utils_graphsaint import GraphData 355 | adj = sp.eye(len(feat)) 356 | idx = np.arange(len(feat)) 357 | dpr_data = GraphData(adj-adj, feat, labels, idx, idx, idx) 358 | from deeprobust.graph.data import Dpr2Pyg 359 | return Dpr2Pyg(dpr_data) 360 | 361 | def regularization(adj, x, eig_real=None): 362 | 363 | loss = 0 364 | 365 | loss += feature_smoothing(adj, x) 366 | return loss 367 | 368 | def maxdegree(adj): 369 | n = adj.shape[0] 370 | return F.relu(max(adj.sum(1))/n - 0.5) 371 | 372 | def sparsity2(adj): 373 | n = adj.shape[0] 374 | loss_degree = - torch.log(adj.sum(1)).sum() / n 375 | loss_fro = torch.norm(adj) / n 376 | return 0 * loss_degree + loss_fro 377 | 378 | def sparsity(adj): 379 | n = adj.shape[0] 380 | thresh = n * n * 0.01 381 | return F.relu(adj.sum()-thresh) 382 | 383 | 384 | def feature_smoothing(adj, X): 385 | adj = (adj.t() + adj)/2 386 | rowsum = adj.sum(1) 387 | r_inv = rowsum.flatten() 388 | D = torch.diag(r_inv) 389 | L = D - adj 390 | 391 | r_inv = r_inv + 1e-8 392 | r_inv = r_inv.pow(-1/2).flatten() 393 | r_inv[torch.isinf(r_inv)] = 0. 394 | r_mat_inv = torch.diag(r_inv) 395 | 396 | L = r_mat_inv @ L @ r_mat_inv 397 | 398 | XLXT = torch.matmul(torch.matmul(X.t(), L), X) 399 | loss_smooth_feat = torch.trace(XLXT) 400 | 401 | return loss_smooth_feat 402 | 403 | def row_normalize_tensor(mx): 404 | rowsum = mx.sum(1) 405 | r_inv = rowsum.pow(-1).flatten() 406 | 407 | r_mat_inv = torch.diag(r_inv) 408 | mx = r_mat_inv @ mx 409 | return mx 410 | 411 | 412 | def from_dgl(g: Any, name: str, hetero=True): 413 | import dgl, torch 414 | 415 | from torch_geometric.data import Data, HeteroData 416 | 417 | 418 | 419 | 420 | if g.is_homogeneous: 421 | data = Data() 422 | data.edge_index = torch.stack(g.edges(), dim=0) 423 | 424 | for attr, value in g.ndata.items(): 425 | data[attr] = value 426 | for attr, value in g.edata.items(): 427 | data[attr] = value 428 | 429 | return data 430 | 431 | data = HeteroData() 432 | data.name = name 433 | data.num_nodes = g.number_of_nodes() 434 | 435 | for node_type in g.ntypes: 436 | for attr, value in g.nodes[node_type].data.items(): 437 | data[node_type][attr] = value 438 | 439 | for edge_type in g.canonical_etypes: 440 | row, col = g.edges(form="uv", etype=edge_type) 441 | data[edge_type].edge_index = torch.stack([row, col], dim=0) 442 | for attr, value in g.edge_attr_schemes(edge_type).items(): 443 | data[edge_type][attr] = value 444 | 445 | if not hetero: 446 | edge_index_list = [] 447 | for edge_type in g.canonical_etypes: 448 | edge_index_list.append(data[edge_type].edge_index) 449 | data.edge_index = remove_self_loops(torch.cat(edge_index_list, dim=1))[0] 450 | 451 | data.x = data.node_stores[0]['feature'] 452 | 453 | data.y = data.node_stores[0]['label'] 454 | 455 | return data 456 | 457 | -------------------------------------------------------------------------------- /utils_copt.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | import numpy as np 4 | import os 5 | import argparse 6 | import torch 7 | import math 8 | import numpy.linalg as linalg 9 | import matplotlib.pyplot as plt 10 | import networkx as nx 11 | import pickle 12 | import warnings 13 | import sklearn.metrics 14 | warnings.filterwarnings('ignore') 15 | 16 | import pdb 17 | 18 | 19 | 20 | 21 | res_dir = 'results' 22 | data_dir = 'data' 23 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser() 28 | 29 | parser.add_argument("--verbose", dest='verbose', action='store_const', default=False, const=True, help='Print out verbose info during optimization') 30 | parser.add_argument("--seed", dest='fix_seed', action='store_const', default=False, const=True, help='Fix seed for reproducibility and fair comparison.') 31 | parser.add_argument("--n_epochs", dest='n_epochs', type=int, default=300, help='Number of COPT iterations during training.') 32 | parser.add_argument("--hike", dest='lr_hike', action='store_const', default=False, const=True, help='Use learning rate hiking. This is recommended for most applications.') 33 | parser.add_argument("--hike_interval", dest='hike_interval', type=int, default=15, help='Number of iterations having *low loss* between lr hikes.') 34 | parser.add_argument("--fast", dest='fast', action='store_const', default=False, const=True, help='Use fast kernel, in particular multiscale laplacian') 35 | parser.add_argument("--early_stopping", action='store_true', default=False, help='stop training early if loss repeatedly reaches below some threshold, useful for sketching.') 36 | 37 | parser.add_argument("--dataset_type", type=str, default='synthetic', help='dataset type, can be "synthetic" or "real"') 38 | parser.add_argument("--dataset_name", type=str, default='PROTEINS', help='dataset name, will be ignored if dataset_type is synthetic') 39 | parser.add_argument("--m", default=30, type=int, help='Number of vertices in graph X, e.g. query graph, useful especially when testing') 40 | parser.add_argument("--n", default=30, type=int, help='Number of vertices in graph Y, e.g. dataset graph, useful especially when testing') 41 | 42 | parser.add_argument("--sinkhorn_iter", type=int, default=10, help='Number of Sinkhorn scaling iterations during optimization') 43 | parser.add_argument("--grid_search", dest='grid_search', action='store_const', default=False, const=True, help='grid search for learning SVC classifier') 44 | parser.add_argument("--compress_fac", default=-1, type=int, help='Factor of compression, e.g. 2 means reduce to half as many vertices') 45 | parser.add_argument("--got_it", dest='st_it', type=int, default=5, help='Number of Sinkhorn iterations for GOT') 46 | parser.add_argument("--got_tau", dest='st_tau', type=float, default=1, help='Number of Sinkhorn iterations for GOT') 47 | parser.add_argument("--got_n_sample", dest='st_n_samples', type=int, default=10, help='Number of samples in stochastic sampling for GOT') 48 | parser.add_argument("--got_n_epochs", dest='st_epochs', type=int, default=1000, help='Number of Sinkhorn iterations for GOT') 49 | parser.add_argument("--got_lr", dest='st_lr', type=float, default=.5, help='Number of Sinkhorn iterations for GOT') 50 | parser.add_argument("--gw_alpha", dest='gw_alpha', type=float, default=.8, help='alpha parameter for GW') 51 | parser.add_argument("--gw_metric", dest='gw_metric', type=str, default='sqeuclidean', help='features metric for GW') 52 | 53 | opt = parser.parse_args() 54 | return opt 55 | 56 | def create_graph_lap(n): 57 | """ 58 | Create graph laplacian of given size. 59 | """ 60 | g = nx.random_geometric_graph(n, .5) 61 | 62 | 63 | Lx = nx.laplacian_matrix(g, range(n)) 64 | Lx = np.array(Lx.todense()) 65 | Lx = np.array([[0, -.3, -.9],[-.3, 0, 0],[-.9, 0,0]]) 66 | return Lx 67 | 68 | def graph_to_lap(g): 69 | """ 70 | Get Laplacian from nx graph g 71 | """ 72 | if not isinstance(g, nx.Graph): 73 | g = nx.from_numpy_array(g) 74 | Lx = nx.laplacian_matrix(g).todense() 75 | Lx = torch.from_numpy(Lx).to(dtype=torch_dtype) 76 | return Lx 77 | 78 | def lap_to_graph(L): 79 | 80 | if isinstance(L, torch.Tensor): 81 | L = L.cpu().numpy() 82 | L = L.copy() 83 | np.fill_diagonal(L, 0) 84 | return nx.from_numpy_array(-L) 85 | 86 | def canonicalize_mx(mx): 87 | mx1 = mx.clone() 88 | diag = mx.diag() 89 | idx = diag.argsort(dim=0) 90 | n_mx = len(mx) 91 | for i in range(n_mx): 92 | mx[i] = mx1[idx[i]] 93 | mx1 = mx.clone() 94 | for i in range(n_mx): 95 | mx[:, i] = mx1[:, idx[i]] 96 | 97 | return mx 98 | 99 | def create_graph(n, gtype=None, seed=None, params={}): 100 | total_iter = 100 101 | cnt = 0 102 | while True: 103 | if gtype == 'block': 104 | 105 | if params['n_blocks'] == 3: 106 | m = n // 3 107 | g = nx.stochastic_block_model([m, m, n-2*m],[[0.98,0.01,.01],[0.01,0.98,.01],[0.01,.01,.98]], seed=seed) 108 | elif params['n_blocks'] == 4: 109 | m = n // 4 110 | g = nx.stochastic_block_model([m, m, m, n-3*m],[[.97,0.01,0.01,.01],[.01,0.97,0.01,.01],[.01,0.01,0.97,.01],[.01,0.01,0.01,.97] ], seed=seed) 111 | else: 112 | m = n // 2 113 | g = nx.stochastic_block_model([m, n-m],[[0.99,0.01],[0.01,0.99]], seed=seed) 114 | elif gtype == 'strogatz': 115 | g = nx.connected_watts_strogatz_graph(n, max(n//4, 3), p=.05, seed=seed) 116 | elif gtype == 'random_regular': 117 | 118 | d = max(n//8, 2) 119 | if n*d % 2 == 1: 120 | n += 1 121 | g = nx.random_regular_graph(d, n, seed=seed) 122 | elif gtype == 'binomial': 123 | 124 | prob = params['prob'] 125 | g = nx.binomial_graph(n, prob, seed=seed) 126 | elif gtype == 'barabasi': 127 | d = max(4, n//6) 128 | g = nx.barabasi_albert_graph(n, d, seed=seed) 129 | elif gtype == 'powerlaw_tree': 130 | g = nx.random_powerlaw_tree(n, gamma=3, tries=1300, seed=seed) 131 | elif gtype == 'caveman': 132 | n_cliques = params['n_cliques'] 133 | clique_sz = params['clique_sz'] 134 | assert n_cliques * clique_sz == n 135 | g = nx.connected_caveman_graph(n_cliques, clique_sz) 136 | 137 | 138 | 139 | elif gtype == 'random_geometric': 140 | radius = params['radius'] 141 | g = nx.random_geometric_graph(n, radius, seed=seed) 142 | elif gtype == 'barbell': 143 | 144 | g = nx.barbell_graph(n//2, 1) 145 | elif gtype == 'ladder': 146 | g = nx.ladder_graph(n) 147 | elif gtype == 'grid': 148 | g = nx.grid_graph([n,n]) 149 | elif gtype == 'hypercube': 150 | g = nx.hypercube_graph(n) 151 | elif gtype == 'pappus': 152 | g = nx.pappus_graph() 153 | elif gtype == 'star': 154 | g = nx.star_graph(n) 155 | elif gtype == 'cycle': 156 | g = nx.cycle_graph(n) 157 | elif gtype == 'wheel': 158 | g = nx.wheel_graph(n) 159 | elif gtype == 'lollipop': 160 | g = nx.lollipop_graph(n//2, 1) 161 | else: 162 | raise Exception('graph type not supported ', gtype) 163 | 164 | remove_isolates(g) 165 | cnt += 1 166 | if nx.is_connected(g) or cnt > total_iter: 167 | if cnt > total_iter: 168 | g = g.subgraph(sorted(nx.connected_components(g), key=len)[-1]).copy() 169 | break 170 | return g 171 | 172 | def fetch_data(dataset_name): 173 | 174 | 175 | data = torch.load('{}_lap.pt'.format(dataset_name)) 176 | return data['lap'], data['labels'], data['target'] 177 | 178 | def fetch_data_graphs(dataset_name): 179 | 180 | try: 181 | data = torch.load('data/{}_lap.pt'.format(dataset_name)) 182 | except Exception: 183 | raise Exception('Dataset {} graph data not created yet. More data can be created using the generateData.py script as in README.'.format(dataset_name)) 184 | graphs = [] 185 | for g in data['lap']: 186 | graph = lap_to_graph(g) 187 | ''' 188 | for i in range(len(g)): 189 | graph.add_node(i) 190 | pdb.set_trace() 191 | ''' 192 | graphs.append(graph) 193 | 194 | return graphs, data['labels'], np.array(data['target']) 195 | 196 | def view_graph(L, soft_edge=False, labels=None, name=''): 197 | plt.clf() 198 | if isinstance(L, torch.Tensor): 199 | L = L.cpu().numpy() 200 | L = L.copy() 201 | 202 | np.fill_diagonal(L, 0) 203 | L *= -1 204 | 205 | g = nx.from_numpy_array(L) 206 | fig = plt.figure() 207 | plt.axis('off') 208 | ax = plt.gca() 209 | 210 | 211 | layout = nx.spring_layout(g) 212 | nx.draw_networkx_nodes(g, layout, node_size=500, alpha=0.5, cmap=plt.cm.RdYlGn, node_color='r', ax=ax) 213 | if labels is None: 214 | nx.draw_networkx_labels(g, layout, font_color='w', font_weight='bold', font_size=15, ax=ax) 215 | else: 216 | 217 | nx.draw_networkx_labels(g, layout, labels=labels, font_color='k', font_size=12, ax=ax) 218 | if soft_edge: 219 | 220 | 221 | 222 | elarge = [(u, v) for (u, v, d) in g.edges(data=True) if d['weight'] > .5] 223 | esmall = [(u, v) for (u, v, d) in g.edges(data=True) if d['weight'] <= .5 and d['weight'] > .19] 224 | nx.draw_networkx_edges(g, layout, edgelist=elarge, width=3.5, ax=ax) 225 | nx.draw_networkx_edges(g, layout, edgelist=esmall, wifth=3, ax=ax) 226 | else: 227 | nx.draw_networkx_edges(g, layout, ax=ax) 228 | 229 | fig.savefig('data/view_graph_{}.jpg'.format(name)) 230 | print('plot saved to {}'.format('data/view_graph_{}.jpg'.format(name))) 231 | plt.show() 232 | 233 | def plot_confusion(tgt, pred, labels=None, name=''): 234 | """ 235 | Input: 236 | tgt, pred: target and predicted classes. 237 | labels: node labels. 238 | """ 239 | plt.clf() 240 | fig = plt.figure() 241 | 242 | 243 | ax = plt.gca() 244 | mx = sklearn.metrics.confusion_matrix(tgt, pred) 245 | 246 | img = plt.matshow(mx) 247 | path = 'data/confusion_mx_{}.jpg'.format(name) 248 | 249 | ax.legend() 250 | name2label = {'gw_cls':'GW', 'ot_cls':'COPT', 'combine_cls':'[COPT + GW]'} 251 | plt.title('{} Predictions'.format(name2label[name]), fontsize=20) 252 | plt.savefig(path) 253 | print('fig saved to ', path) 254 | 255 | 256 | def plot_search_acc(): 257 | plt.clf() 258 | x_l = [1, 3, 5, 10, 15] 259 | ot_acc = [0.9721962,0.9866296,0.9941481333,0.998,0.998] 260 | svd_acc = [0.814787,0.894257037,0.9349997,0.9774814073,0.9888888777] 261 | ot_std = [0.005516745247,0.003174584773,0.0002565744596,0.003464101615,0.003464101615] 262 | svd_std = [0.01152346983,0.02409426815,0.02220130725,0.009109368462,0.009622514205] 263 | 264 | fig = plt.figure() 265 | 266 | 267 | plt.errorbar(x_l, ot_acc, yerr=ot_std, marker='o', label='COPT sketches') 268 | 269 | plt.errorbar(x_l, svd_acc, yerr=svd_std, marker='+', label='Spectral projections') 270 | 271 | plt.title('Classification acc of [COPT, GW] vs [spectral projections, GW] pipelines') 272 | plt.legend() 273 | plt.xlabel('Number of candidates allowed to 2nd stage') 274 | plt.ylabel('Classification accuracy') 275 | path = 'data/search_acc.jpg' 276 | fig.savefig(path) 277 | print('fig saved to ', path) 278 | 279 | 280 | 281 | def create_dir(path): 282 | if not os.path.exists(path): 283 | os.mkdir(path) 284 | 285 | 286 | 287 | def normalizedMI(ar1, ar2): 288 | score = sklearn.metrics.normalized_mutual_info_score(ar1, ar2) 289 | return score 290 | 291 | 292 | def remove_edges(L, n_remove=1, seed=None): 293 | rng = np.random.RandomState(seed) 294 | 295 | G = lap_to_graph(L) 296 | edges = np.triu(L, k=1).nonzero() 297 | 298 | removed = 0 299 | for idx in rng.permutation(edges[0].size): 300 | u, v = edges[0][idx], edges[1][idx] 301 | 302 | G.remove_edge(u,v) 303 | if nx.is_connected(G): 304 | removed += 1 305 | else: 306 | G.add_edge(u,v) 307 | if removed == n_remove: 308 | break 309 | 310 | return graph_to_lap(G) 311 | 312 | def permute_nodes(l1, seed=None): 313 | ''' 314 | Adapted from GOT. 315 | ''' 316 | np.random.seed(seed) 317 | n = len(l1) 318 | idx = np.random.permutation(n) 319 | P_true = np.eye(n) 320 | P_true = P_true[idx] 321 | l2 = np.array(P_true @ l1 @ P_true.T) 322 | 323 | return np.double(l2), idx 324 | 325 | def symmetrize(mx, inplace=True): 326 | m, n = mx.size() 327 | assert m == n 328 | mask = torch.ones(m, m) 329 | mask = torch.tril(mask, diagonal=-1) 330 | 331 | if not inplace: 332 | mx = mx.clone() 333 | mx[mask > 0] = torch.triu(mx, diagonal=1).t()[mask > 0] 334 | return mx 335 | 336 | def remove_isolates(g): 337 | g.remove_nodes_from(list(nx.isolates(g))) 338 | 339 | def load_data(fname): 340 | with open(fname, 'rb') as f: 341 | data = pickle.load(f) 342 | graphs = data['graphs'] 343 | labels = data['labels'] 344 | return graphs, labels 345 | 346 | def read_lines(path): 347 | with open(path, 'r') as file: 348 | return file.readlines() 349 | 350 | def parse_cls(st): 351 | ar = st.split('., ') 352 | return [int(i) for i in ar] 353 | 354 | def plot_confusions(): 355 | 356 | ot_cls = '0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 4., 1., 1., 4., 1., 1., 4., 1., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 9., 6., 9., 9., 9., 6., 6., 6., 9' 357 | gw_cls = '0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 6., 1., 1., 1., 1., 6., 6., 1., 6., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 6., 5., 6., 6., 6., 6., 5., 6., 6., 6., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9' 358 | combine_cls = '0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 4., 1., 1., 4., 1., 1., 1., 1., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 6., 5., 6., 6., 6., 6., 5., 6., 6., 6., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9' 359 | ot_cls = parse_cls(ot_cls) 360 | gw_cls = parse_cls(gw_cls) 361 | combine_cls = parse_cls(combine_cls) 362 | 363 | tgt_cls = [] 364 | for i in [0, 1, 4, 5, 6, 9]: 365 | tgt_cls.extend([i]*10) 366 | names = ['ot_cls', 'gw_cls', 'combine_cls'] 367 | for i, pred in enumerate([ot_cls, gw_cls, combine_cls]): 368 | plot_confusion(tgt_cls, pred, name=names[i]) 369 | 370 | def plot_convergence(): 371 | ''' 372 | plot Convergence 373 | ''' 374 | conv = '99.77512221820409, 47.0570797352225, 45.454712183248795, 43.46666639511483, 40.9052209134619, 34.46097733377226, 24.431260224317242, 23.51762172579995, 22.485459983772515, 22.106446259828203, 21.534230884767375, 21.463743674621867' 375 | ar = conv.split(', ') 376 | conv_ar = [float(f) for f in ar] 377 | x_l = [20*i for i in list(range(len(conv_ar)))] 378 | fig = plt.figure() 379 | 380 | plt.plot(x_l, conv_ar, '-o', label='COPT distance') 381 | 382 | 383 | plt.title('COPT distance convergence sketching a 50-node graph to 15 nodes') 384 | plt.legend() 385 | plt.xlabel('Number of iterations') 386 | plt.ylabel('COPT distance') 387 | path = 'data/convergence.jpg' 388 | fig.savefig(path) 389 | print('fig saved to ', path) 390 | 391 | if __name__ == '__main__': 392 | """ 393 | For testing utils functions. 394 | """ 395 | ''' 396 | n = 2 397 | L = create_graph_lap(n) 398 | view_graph(L, soft_edge=True) 399 | ''' 400 | 401 | plot_cls_acc() 402 | 403 | 404 | 405 | -------------------------------------------------------------------------------- /utils_graphsaint.py: -------------------------------------------------------------------------------- 1 | import scipy.sparse as sp 2 | import numpy as np 3 | import sys 4 | import json 5 | import os 6 | from sklearn.preprocessing import StandardScaler 7 | from torch_geometric.data import InMemoryDataset, Data 8 | import torch 9 | from itertools import repeat 10 | from torch_geometric.data import NeighborSampler 11 | 12 | class DataGraphSAINT: 13 | 14 | def __init__(self, dataset, **kwargs): 15 | dataset_str='data/'+dataset+'/' 16 | adj_full = sp.load_npz(dataset_str+'adj_full.npz') 17 | self.nnodes = adj_full.shape[0] 18 | if dataset == 'ogbn-arxiv': 19 | adj_full = adj_full + adj_full.T 20 | adj_full[adj_full > 1] = 1 21 | 22 | role = json.load(open(dataset_str+'role.json','r')) 23 | idx_train = role['tr'] 24 | idx_test = role['te'] 25 | idx_val = role['va'] 26 | 27 | if 'label_rate' in kwargs: 28 | label_rate = kwargs['label_rate'] 29 | if label_rate < 1: 30 | idx_train = idx_train[:int(label_rate*len(idx_train))] 31 | 32 | self.adj_train = adj_full[np.ix_(idx_train, idx_train)] 33 | self.adj_val = adj_full[np.ix_(idx_val, idx_val)] 34 | self.adj_test = adj_full[np.ix_(idx_test, idx_test)] 35 | 36 | feat = np.load(dataset_str+'feats.npy') 37 | 38 | feat_train = feat[idx_train] 39 | scaler = StandardScaler() 40 | scaler.fit(feat_train) 41 | feat = scaler.transform(feat) 42 | 43 | self.feat_train = feat[idx_train] 44 | self.feat_val = feat[idx_val] 45 | self.feat_test = feat[idx_test] 46 | 47 | class_map = json.load(open(dataset_str + 'class_map.json','r')) 48 | labels = self.process_labels(class_map) 49 | 50 | self.labels_train = labels[idx_train] 51 | self.labels_val = labels[idx_val] 52 | self.labels_test = labels[idx_test] 53 | 54 | self.data_full = GraphData(adj_full, feat, labels, idx_train, idx_val, idx_test) 55 | self.class_dict = None 56 | self.class_dict2 = None 57 | 58 | self.adj_full = adj_full 59 | self.feat_full = feat 60 | self.labels_full = labels 61 | self.idx_train = np.array(idx_train) 62 | self.idx_val = np.array(idx_val) 63 | self.idx_test = np.array(idx_test) 64 | self.samplers = None 65 | 66 | def process_labels(self, class_map): 67 | """ 68 | setup vertex property map for output classests 69 | """ 70 | num_vertices = self.nnodes 71 | if isinstance(list(class_map.values())[0], list): 72 | num_classes = len(list(class_map.values())[0]) 73 | self.nclass = num_classes 74 | class_arr = np.zeros((num_vertices, num_classes)) 75 | for k,v in class_map.items(): 76 | class_arr[int(k)] = v 77 | else: 78 | class_arr = np.zeros(num_vertices, dtype=np.int) 79 | for k, v in class_map.items(): 80 | class_arr[int(k)] = v 81 | class_arr = class_arr - class_arr.min() 82 | self.nclass = max(class_arr) + 1 83 | return class_arr 84 | 85 | def retrieve_class(self, c, num=256): 86 | if self.class_dict is None: 87 | self.class_dict = {} 88 | for i in range(self.nclass): 89 | self.class_dict['class_%s'%i] = (self.labels_train == i) 90 | idx = np.arange(len(self.labels_train)) 91 | idx = idx[self.class_dict['class_%s'%c]] 92 | return np.random.permutation(idx)[:num] 93 | 94 | def retrieve_class_sampler(self, c, adj, transductive, num=256, args=None): 95 | if args.nlayers == 1: 96 | sizes = [30] 97 | if args.nlayers == 2: 98 | if args.dataset in ['reddit', 'flickr']: 99 | if args.option == 0: 100 | sizes = [15, 8] 101 | if args.option == 1: 102 | sizes = [20, 10] 103 | if args.option == 2: 104 | sizes = [25, 10] 105 | else: 106 | sizes = [10, 5] 107 | 108 | if self.class_dict2 is None: 109 | print(sizes) 110 | self.class_dict2 = {} 111 | for i in range(self.nclass): 112 | if transductive: 113 | idx_train = np.array(self.idx_train) 114 | idx = idx_train[self.labels_train == i] 115 | else: 116 | idx = np.arange(len(self.labels_train))[self.labels_train==i] 117 | self.class_dict2[i] = idx 118 | 119 | if self.samplers is None: 120 | self.samplers = [] 121 | for i in range(self.nclass): 122 | node_idx = torch.LongTensor(self.class_dict2[i]) 123 | if len(node_idx) == 0: 124 | continue 125 | 126 | self.samplers.append(NeighborSampler(adj, 127 | node_idx=node_idx, 128 | sizes=sizes, batch_size=num, 129 | num_workers=8, return_e_id=False, 130 | num_nodes=adj.size(0), 131 | shuffle=True)) 132 | batch = np.random.permutation(self.class_dict2[c])[:num] 133 | out = self.samplers[c].sample(batch) 134 | return out 135 | 136 | 137 | class GraphData: 138 | 139 | def __init__(self, adj, features, labels, idx_train, idx_val, idx_test): 140 | self.adj = adj 141 | self.features = features 142 | self.labels = labels 143 | self.idx_train = idx_train 144 | self.idx_val = idx_val 145 | self.idx_test = idx_test 146 | 147 | 148 | class Data2Pyg: 149 | 150 | def __init__(self, data, device='cuda', transform=None, **kwargs): 151 | self.data_train = Dpr2Pyg(data.data_train, transform=transform)[0].to(device) 152 | self.data_val = Dpr2Pyg(data.data_val, transform=transform)[0].to(device) 153 | self.data_test = Dpr2Pyg(data.data_test, transform=transform)[0].to(device) 154 | self.nclass = data.nclass 155 | self.nfeat = data.nfeat 156 | self.class_dict = None 157 | 158 | def retrieve_class(self, c, num=256): 159 | if self.class_dict is None: 160 | self.class_dict = {} 161 | for i in range(self.nclass): 162 | self.class_dict['class_%s'%i] = (self.data_train.y == i).cpu().numpy() 163 | idx = np.arange(len(self.data_train.y)) 164 | idx = idx[self.class_dict['class_%s'%c]] 165 | return np.random.permutation(idx)[:num] 166 | 167 | 168 | class Dpr2Pyg(InMemoryDataset): 169 | 170 | def __init__(self, dpr_data, transform=None, **kwargs): 171 | root = 'data/' 172 | self.dpr_data = dpr_data 173 | super(Dpr2Pyg, self).__init__(root, transform) 174 | pyg_data = self.process() 175 | self.data, self.slices = self.collate([pyg_data]) 176 | self.transform = transform 177 | 178 | def process(self): 179 | dpr_data = self.dpr_data 180 | edge_index = torch.LongTensor(dpr_data.adj.nonzero()) 181 | 182 | if sp.issparse(dpr_data.features): 183 | x = torch.FloatTensor(dpr_data.features.todense()).float() 184 | else: 185 | x = torch.FloatTensor(dpr_data.features).float() 186 | y = torch.LongTensor(dpr_data.labels) 187 | data = Data(x=x, edge_index=edge_index, y=y) 188 | data.train_mask = None 189 | data.val_mask = None 190 | data.test_mask = None 191 | return data 192 | 193 | 194 | def get(self, idx): 195 | data = self.data.__class__() 196 | 197 | if hasattr(self.data, '__num_nodes__'): 198 | data.num_nodes = self.data.__num_nodes__[idx] 199 | 200 | for key in self.data.keys: 201 | item, slices = self.data[key], self.slices[key] 202 | s = list(repeat(slice(None), item.dim())) 203 | s[self.data.__cat_dim__(key, item)] = slice(slices[idx], 204 | slices[idx + 1]) 205 | data[key] = item[s] 206 | return data 207 | 208 | @property 209 | def raw_file_names(self): 210 | return ['some_file_1', 'some_file_2', ...] 211 | 212 | @property 213 | def processed_file_names(self): 214 | return ['data.pt'] 215 | 216 | def _download(self): 217 | pass 218 | 219 | 220 | --------------------------------------------------------------------------------