├── .gitignore ├── LICENSE ├── README.md ├── model.py ├── train.py ├── train.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | checkpoint 4 | dataset -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Mozilla Public License Version 2.0 2 | ================================== 3 | 4 | 1. Definitions 5 | -------------- 6 | 7 | 1.1. "Contributor" 8 | means each individual or legal entity that creates, contributes to 9 | the creation of, or owns Covered Software. 10 | 11 | 1.2. "Contributor Version" 12 | means the combination of the Contributions of others (if any) used 13 | by a Contributor and that particular Contributor's Contribution. 14 | 15 | 1.3. "Contribution" 16 | means Covered Software of a particular Contributor. 17 | 18 | 1.4. "Covered Software" 19 | means Source Code Form to which the initial Contributor has attached 20 | the notice in Exhibit A, the Executable Form of such Source Code 21 | Form, and Modifications of such Source Code Form, in each case 22 | including portions thereof. 23 | 24 | 1.5. "Incompatible With Secondary Licenses" 25 | means 26 | 27 | (a) that the initial Contributor has attached the notice described 28 | in Exhibit B to the Covered Software; or 29 | 30 | (b) that the Covered Software was made available under the terms of 31 | version 1.1 or earlier of the License, but not also under the 32 | terms of a Secondary License. 33 | 34 | 1.6. "Executable Form" 35 | means any form of the work other than Source Code Form. 36 | 37 | 1.7. "Larger Work" 38 | means a work that combines Covered Software with other material, in 39 | a separate file or files, that is not Covered Software. 40 | 41 | 1.8. "License" 42 | means this document. 43 | 44 | 1.9. "Licensable" 45 | means having the right to grant, to the maximum extent possible, 46 | whether at the time of the initial grant or subsequently, any and 47 | all of the rights conveyed by this License. 48 | 49 | 1.10. "Modifications" 50 | means any of the following: 51 | 52 | (a) any file in Source Code Form that results from an addition to, 53 | deletion from, or modification of the contents of Covered 54 | Software; or 55 | 56 | (b) any new file in Source Code Form that contains any Covered 57 | Software. 58 | 59 | 1.11. "Patent Claims" of a Contributor 60 | means any patent claim(s), including without limitation, method, 61 | process, and apparatus claims, in any patent Licensable by such 62 | Contributor that would be infringed, but for the grant of the 63 | License, by the making, using, selling, offering for sale, having 64 | made, import, or transfer of either its Contributions or its 65 | Contributor Version. 66 | 67 | 1.12. "Secondary License" 68 | means either the GNU General Public License, Version 2.0, the GNU 69 | Lesser General Public License, Version 2.1, the GNU Affero General 70 | Public License, Version 3.0, or any later versions of those 71 | licenses. 72 | 73 | 1.13. "Source Code Form" 74 | means the form of the work preferred for making modifications. 75 | 76 | 1.14. "You" (or "Your") 77 | means an individual or a legal entity exercising rights under this 78 | License. For legal entities, "You" includes any entity that 79 | controls, is controlled by, or is under common control with You. For 80 | purposes of this definition, "control" means (a) the power, direct 81 | or indirect, to cause the direction or management of such entity, 82 | whether by contract or otherwise, or (b) ownership of more than 83 | fifty percent (50%) of the outstanding shares or beneficial 84 | ownership of such entity. 85 | 86 | 2. License Grants and Conditions 87 | -------------------------------- 88 | 89 | 2.1. Grants 90 | 91 | Each Contributor hereby grants You a world-wide, royalty-free, 92 | non-exclusive license: 93 | 94 | (a) under intellectual property rights (other than patent or trademark) 95 | Licensable by such Contributor to use, reproduce, make available, 96 | modify, display, perform, distribute, and otherwise exploit its 97 | Contributions, either on an unmodified basis, with Modifications, or 98 | as part of a Larger Work; and 99 | 100 | (b) under Patent Claims of such Contributor to make, use, sell, offer 101 | for sale, have made, import, and otherwise transfer either its 102 | Contributions or its Contributor Version. 103 | 104 | 2.2. Effective Date 105 | 106 | The licenses granted in Section 2.1 with respect to any Contribution 107 | become effective for each Contribution on the date the Contributor first 108 | distributes such Contribution. 109 | 110 | 2.3. Limitations on Grant Scope 111 | 112 | The licenses granted in this Section 2 are the only rights granted under 113 | this License. No additional rights or licenses will be implied from the 114 | distribution or licensing of Covered Software under this License. 115 | Notwithstanding Section 2.1(b) above, no patent license is granted by a 116 | Contributor: 117 | 118 | (a) for any code that a Contributor has removed from Covered Software; 119 | or 120 | 121 | (b) for infringements caused by: (i) Your and any other third party's 122 | modifications of Covered Software, or (ii) the combination of its 123 | Contributions with other software (except as part of its Contributor 124 | Version); or 125 | 126 | (c) under Patent Claims infringed by Covered Software in the absence of 127 | its Contributions. 128 | 129 | This License does not grant any rights in the trademarks, service marks, 130 | or logos of any Contributor (except as may be necessary to comply with 131 | the notice requirements in Section 3.4). 132 | 133 | 2.4. Subsequent Licenses 134 | 135 | No Contributor makes additional grants as a result of Your choice to 136 | distribute the Covered Software under a subsequent version of this 137 | License (see Section 10.2) or under the terms of a Secondary License (if 138 | permitted under the terms of Section 3.3). 139 | 140 | 2.5. Representation 141 | 142 | Each Contributor represents that the Contributor believes its 143 | Contributions are its original creation(s) or it has sufficient rights 144 | to grant the rights to its Contributions conveyed by this License. 145 | 146 | 2.6. Fair Use 147 | 148 | This License is not intended to limit any rights You have under 149 | applicable copyright doctrines of fair use, fair dealing, or other 150 | equivalents. 151 | 152 | 2.7. Conditions 153 | 154 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted 155 | in Section 2.1. 156 | 157 | 3. Responsibilities 158 | ------------------- 159 | 160 | 3.1. Distribution of Source Form 161 | 162 | All distribution of Covered Software in Source Code Form, including any 163 | Modifications that You create or to which You contribute, must be under 164 | the terms of this License. You must inform recipients that the Source 165 | Code Form of the Covered Software is governed by the terms of this 166 | License, and how they can obtain a copy of this License. You may not 167 | attempt to alter or restrict the recipients' rights in the Source Code 168 | Form. 169 | 170 | 3.2. Distribution of Executable Form 171 | 172 | If You distribute Covered Software in Executable Form then: 173 | 174 | (a) such Covered Software must also be made available in Source Code 175 | Form, as described in Section 3.1, and You must inform recipients of 176 | the Executable Form how they can obtain a copy of such Source Code 177 | Form by reasonable means in a timely manner, at a charge no more 178 | than the cost of distribution to the recipient; and 179 | 180 | (b) You may distribute such Executable Form under the terms of this 181 | License, or sublicense it under different terms, provided that the 182 | license for the Executable Form does not attempt to limit or alter 183 | the recipients' rights in the Source Code Form under this License. 184 | 185 | 3.3. Distribution of a Larger Work 186 | 187 | You may create and distribute a Larger Work under terms of Your choice, 188 | provided that You also comply with the requirements of this License for 189 | the Covered Software. If the Larger Work is a combination of Covered 190 | Software with a work governed by one or more Secondary Licenses, and the 191 | Covered Software is not Incompatible With Secondary Licenses, this 192 | License permits You to additionally distribute such Covered Software 193 | under the terms of such Secondary License(s), so that the recipient of 194 | the Larger Work may, at their option, further distribute the Covered 195 | Software under the terms of either this License or such Secondary 196 | License(s). 197 | 198 | 3.4. Notices 199 | 200 | You may not remove or alter the substance of any license notices 201 | (including copyright notices, patent notices, disclaimers of warranty, 202 | or limitations of liability) contained within the Source Code Form of 203 | the Covered Software, except that You may alter any license notices to 204 | the extent required to remedy known factual inaccuracies. 205 | 206 | 3.5. Application of Additional Terms 207 | 208 | You may choose to offer, and to charge a fee for, warranty, support, 209 | indemnity or liability obligations to one or more recipients of Covered 210 | Software. However, You may do so only on Your own behalf, and not on 211 | behalf of any Contributor. You must make it absolutely clear that any 212 | such warranty, support, indemnity, or liability obligation is offered by 213 | You alone, and You hereby agree to indemnify every Contributor for any 214 | liability incurred by such Contributor as a result of warranty, support, 215 | indemnity or liability terms You offer. You may include additional 216 | disclaimers of warranty and limitations of liability specific to any 217 | jurisdiction. 218 | 219 | 4. Inability to Comply Due to Statute or Regulation 220 | --------------------------------------------------- 221 | 222 | If it is impossible for You to comply with any of the terms of this 223 | License with respect to some or all of the Covered Software due to 224 | statute, judicial order, or regulation then You must: (a) comply with 225 | the terms of this License to the maximum extent possible; and (b) 226 | describe the limitations and the code they affect. Such description must 227 | be placed in a text file included with all distributions of the Covered 228 | Software under this License. Except to the extent prohibited by statute 229 | or regulation, such description must be sufficiently detailed for a 230 | recipient of ordinary skill to be able to understand it. 231 | 232 | 5. Termination 233 | -------------- 234 | 235 | 5.1. The rights granted under this License will terminate automatically 236 | if You fail to comply with any of its terms. However, if You become 237 | compliant, then the rights granted under this License from a particular 238 | Contributor are reinstated (a) provisionally, unless and until such 239 | Contributor explicitly and finally terminates Your grants, and (b) on an 240 | ongoing basis, if such Contributor fails to notify You of the 241 | non-compliance by some reasonable means prior to 60 days after You have 242 | come back into compliance. Moreover, Your grants from a particular 243 | Contributor are reinstated on an ongoing basis if such Contributor 244 | notifies You of the non-compliance by some reasonable means, this is the 245 | first time You have received notice of non-compliance with this License 246 | from such Contributor, and You become compliant prior to 30 days after 247 | Your receipt of the notice. 248 | 249 | 5.2. If You initiate litigation against any entity by asserting a patent 250 | infringement claim (excluding declaratory judgment actions, 251 | counter-claims, and cross-claims) alleging that a Contributor Version 252 | directly or indirectly infringes any patent, then the rights granted to 253 | You by any and all Contributors for the Covered Software under Section 254 | 2.1 of this License shall terminate. 255 | 256 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all 257 | end user license agreements (excluding distributors and resellers) which 258 | have been validly granted by You or Your distributors under this License 259 | prior to termination shall survive termination. 260 | 261 | ************************************************************************ 262 | * * 263 | * 6. Disclaimer of Warranty * 264 | * ------------------------- * 265 | * * 266 | * Covered Software is provided under this License on an "as is" * 267 | * basis, without warranty of any kind, either expressed, implied, or * 268 | * statutory, including, without limitation, warranties that the * 269 | * Covered Software is free of defects, merchantable, fit for a * 270 | * particular purpose or non-infringing. The entire risk as to the * 271 | * quality and performance of the Covered Software is with You. * 272 | * Should any Covered Software prove defective in any respect, You * 273 | * (not any Contributor) assume the cost of any necessary servicing, * 274 | * repair, or correction. This disclaimer of warranty constitutes an * 275 | * essential part of this License. No use of any Covered Software is * 276 | * authorized under this License except under this disclaimer. * 277 | * * 278 | ************************************************************************ 279 | 280 | ************************************************************************ 281 | * * 282 | * 7. Limitation of Liability * 283 | * -------------------------- * 284 | * * 285 | * Under no circumstances and under no legal theory, whether tort * 286 | * (including negligence), contract, or otherwise, shall any * 287 | * Contributor, or anyone who distributes Covered Software as * 288 | * permitted above, be liable to You for any direct, indirect, * 289 | * special, incidental, or consequential damages of any character * 290 | * including, without limitation, damages for lost profits, loss of * 291 | * goodwill, work stoppage, computer failure or malfunction, or any * 292 | * and all other commercial damages or losses, even if such party * 293 | * shall have been informed of the possibility of such damages. This * 294 | * limitation of liability shall not apply to liability for death or * 295 | * personal injury resulting from such party's negligence to the * 296 | * extent applicable law prohibits such limitation. Some * 297 | * jurisdictions do not allow the exclusion or limitation of * 298 | * incidental or consequential damages, so this exclusion and * 299 | * limitation may not apply to You. * 300 | * * 301 | ************************************************************************ 302 | 303 | 8. Litigation 304 | ------------- 305 | 306 | Any litigation relating to this License may be brought only in the 307 | courts of a jurisdiction where the defendant maintains its principal 308 | place of business and such litigation shall be governed by laws of that 309 | jurisdiction, without reference to its conflict-of-law provisions. 310 | Nothing in this Section shall prevent a party's ability to bring 311 | cross-claims or counter-claims. 312 | 313 | 9. Miscellaneous 314 | ---------------- 315 | 316 | This License represents the complete agreement concerning the subject 317 | matter hereof. If any provision of this License is held to be 318 | unenforceable, such provision shall be reformed only to the extent 319 | necessary to make it enforceable. Any law or regulation which provides 320 | that the language of a contract shall be construed against the drafter 321 | shall not be used to construe this License against a Contributor. 322 | 323 | 10. Versions of the License 324 | --------------------------- 325 | 326 | 10.1. New Versions 327 | 328 | Mozilla Foundation is the license steward. Except as provided in Section 329 | 10.3, no one other than the license steward has the right to modify or 330 | publish new versions of this License. Each version will be given a 331 | distinguishing version number. 332 | 333 | 10.2. Effect of New Versions 334 | 335 | You may distribute the Covered Software under the terms of the version 336 | of the License under which You originally received the Covered Software, 337 | or under the terms of any subsequent version published by the license 338 | steward. 339 | 340 | 10.3. Modified Versions 341 | 342 | If you create software not governed by this License, and you want to 343 | create a new license for such software, you may create and use a 344 | modified version of this License if you rename the license and remove 345 | any references to the name of the license steward (except to note that 346 | such modified license differs from this License). 347 | 348 | 10.4. Distributing Source Code Form that is Incompatible With Secondary 349 | Licenses 350 | 351 | If You choose to distribute Source Code Form that is Incompatible With 352 | Secondary Licenses under the terms of this version of the License, the 353 | notice described in Exhibit B of this License must be attached. 354 | 355 | Exhibit A - Source Code Form License Notice 356 | ------------------------------------------- 357 | 358 | This Source Code Form is subject to the terms of the Mozilla Public 359 | License, v. 2.0. If a copy of the MPL was not distributed with this 360 | file, You can obtain one at http://mozilla.org/MPL/2.0/. 361 | 362 | If it is not possible or desirable to put the notice in a particular 363 | file, then You may include the notice in a location (such as a LICENSE 364 | file in a relevant directory) where a recipient would be likely to look 365 | for such a notice. 366 | 367 | You may add additional accurate notices of copyright ownership. 368 | 369 | Exhibit B - "Incompatible With Secondary Licenses" Notice 370 | --------------------------------------------------------- 371 | 372 | This Source Code Form is "Incompatible With Secondary Licenses", as 373 | defined by the Mozilla Public License, v. 2.0. 374 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # H2GCN-Pytorch 2 | 3 | This repo is a pytorch implementation of H2GCN raised in the 4 | paper ["Beyond Homophily in Graph Neural Networks: Current Limitations and Effective Designs"](https://arxiv.org/abs/2006.11468) 5 | . Original tensorflow implementation can be found [here](https://github.com/GemsLab/H2GCN). 6 | 7 | ## Requirement 8 | 9 | This project should be able to run without any modification after following packages installed. 10 | 11 | ``` 12 | pytorch 13 | networkx 14 | torch-sparse 15 | torch-geometric 16 | ``` 17 | 18 | ## Tutorial 19 | 20 | ### Run train.py 21 | 22 | ``` 23 | usage: train.py [-h] [--seed SEED] [--without-relu] [--epochs EPOCHS] [--lr LR] [--k K] 24 | [--wd WD] [--hidden HIDDEN] [--dropout DROPOUT] 25 | [--patience PATIENCE] [--dataset DATASET] [--gpu GPU] 26 | [--split-id SPLIT_ID] 27 | 28 | optional arguments: 29 | -h, --help show this help message and exit 30 | --seed SEED seed 31 | --without-relu disable relu for all H2GCN layer 32 | --epochs EPOCHS number of epochs to train 33 | --lr LR learning rate 34 | --k K number of embedding rounds 35 | --wd WD weight decay value 36 | --hidden HIDDEN embedding output dim 37 | --dropout DROPOUT dropout rate 38 | --patience PATIENCE patience for early stop 39 | --dataset DATASET dateset name 40 | --gpu GPU gpu id to use while training, set -1 to use cpu 41 | --split-id SPLIT_ID the data split to use 42 | ``` 43 | 44 | ### Custom dataset 45 | 46 | All dataset used in this repo were forked from repo geom-gcn. 47 | Custom dataset should fit following format : 48 | 49 | ``` 50 | PROJECT_ROOT/new_data/DATASET_NAME/ 51 | out1_graph_edges.txt # format for each lines : SRC_NODE DST_NODE 52 | out1_node_feature_label.txt # format for each rows : NODE_ID f0,f1,··· 53 | ``` 54 | 55 | ### Use model.py 56 | 57 | If you only want to use model.py separately, you need to pass two matrix to forward function while training. 58 | 59 | ``` 60 | adj : torch.sparse.Tensor. 61 | x : torch.FloatTensor. 62 | ``` 63 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import torch_sparse 5 | from torch import FloatTensor 6 | 7 | 8 | class H2GCN(nn.Module): 9 | def __init__( 10 | self, 11 | feat_dim: int, 12 | hidden_dim: int, 13 | class_dim: int, 14 | k: int = 2, 15 | dropout: float = 0.5, 16 | use_relu: bool = True 17 | ): 18 | super(H2GCN, self).__init__() 19 | self.dropout = dropout 20 | self.k = k 21 | self.act = F.relu if use_relu else lambda x: x 22 | self.use_relu = use_relu 23 | self.w_embed = nn.Parameter( 24 | torch.zeros(size=(feat_dim, hidden_dim)), 25 | requires_grad=True 26 | ) 27 | self.w_classify = nn.Parameter( 28 | torch.zeros(size=((2 ** (self.k + 1) - 1) * hidden_dim, class_dim)), 29 | requires_grad=True 30 | ) 31 | self.params = [self.w_embed, self.w_classify] 32 | self.initialized = False 33 | self.a1 = None 34 | self.a2 = None 35 | self.reset_parameter() 36 | 37 | def reset_parameter(self): 38 | nn.init.xavier_uniform_(self.w_embed) 39 | nn.init.xavier_uniform_(self.w_classify) 40 | 41 | @staticmethod 42 | def _indicator(sp_tensor: torch.sparse.Tensor) -> torch.sparse.Tensor: 43 | csp = sp_tensor.coalesce() 44 | return torch.sparse_coo_tensor( 45 | indices=csp.indices(), 46 | values=torch.where(csp.values() > 0, 1, 0), 47 | size=csp.size(), 48 | dtype=torch.float 49 | ) 50 | 51 | @staticmethod 52 | def _spspmm(sp1: torch.sparse.Tensor, sp2: torch.sparse.Tensor) -> torch.sparse.Tensor: 53 | assert sp1.shape[1] == sp2.shape[0], 'Cannot multiply size %s with %s' % (sp1.shape, sp2.shape) 54 | sp1, sp2 = sp1.coalesce(), sp2.coalesce() 55 | index1, value1 = sp1.indices(), sp1.values() 56 | index2, value2 = sp2.indices(), sp2.values() 57 | m, n, k = sp1.shape[0], sp1.shape[1], sp2.shape[1] 58 | indices, values = torch_sparse.spspmm(index1, value1, index2, value2, m, n, k) 59 | return torch.sparse_coo_tensor( 60 | indices=indices, 61 | values=values, 62 | size=(m, k), 63 | dtype=torch.float 64 | ) 65 | 66 | @classmethod 67 | def _adj_norm(cls, adj: torch.sparse.Tensor) -> torch.sparse.Tensor: 68 | n = adj.size(0) 69 | d_diag = torch.pow(torch.sparse.sum(adj, dim=1).values(), -0.5) 70 | d_diag = torch.where(torch.isinf(d_diag), torch.full_like(d_diag, 0), d_diag) 71 | d_tiled = torch.sparse_coo_tensor( 72 | indices=[list(range(n)), list(range(n))], 73 | values=d_diag, 74 | size=(n, n) 75 | ) 76 | return cls._spspmm(cls._spspmm(d_tiled, adj), d_tiled) 77 | 78 | def _prepare_prop(self, adj): 79 | n = adj.size(0) 80 | device = adj.device 81 | self.initialized = True 82 | sp_eye = torch.sparse_coo_tensor( 83 | indices=[list(range(n)), list(range(n))], 84 | values=[1.0] * n, 85 | size=(n, n), 86 | dtype=torch.float 87 | ).to(device) 88 | # initialize A1, A2 89 | a1 = self._indicator(adj - sp_eye) 90 | a2 = self._indicator(self._spspmm(adj, adj) - adj - sp_eye) 91 | # norm A1 A2 92 | self.a1 = self._adj_norm(a1) 93 | self.a2 = self._adj_norm(a2) 94 | 95 | def forward(self, adj: torch.sparse.Tensor, x: FloatTensor) -> FloatTensor: 96 | if not self.initialized: 97 | self._prepare_prop(adj) 98 | # H2GCN propagation 99 | rs = [self.act(torch.mm(x, self.w_embed))] 100 | for i in range(self.k): 101 | r_last = rs[-1] 102 | r1 = torch.spmm(self.a1, r_last) 103 | r2 = torch.spmm(self.a2, r_last) 104 | rs.append(self.act(torch.cat([r1, r2], dim=1))) 105 | r_final = torch.cat(rs, dim=1) 106 | r_final = F.dropout(r_final, self.dropout, training=self.training) 107 | return torch.softmax(torch.mm(r_final, self.w_classify), dim=1) 108 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import utils 8 | from utils import accuracy, set_seed, select_mask, load_dataset 9 | from model import H2GCN 10 | 11 | # Training settings 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--seed', type=int, default=0, help='seed') 14 | parser.add_argument('--without-relu', action="store_true", help="disable relu for all H2GCN layer") 15 | parser.add_argument('--epochs', type=int, default=500, help='number of epochs to train') 16 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate') 17 | parser.add_argument('--k', type=int, default=2, help='number of embedding rounds') 18 | parser.add_argument('--wd', type=float, default=5e-4, help='weight decay value') 19 | parser.add_argument('--hidden', type=int, default=64, help='embedding output dim') 20 | parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate') 21 | parser.add_argument('--patience', type=int, default=50, help='patience for early stop') 22 | parser.add_argument('--dataset', default='cora', help='dateset name') 23 | parser.add_argument('--gpu', type=int, default=0, help='gpu id to use while training, set -1 to use cpu') 24 | parser.add_argument('--split-id', type=int, default=0, help='the data split to use') 25 | args = parser.parse_args() 26 | 27 | 28 | def train(): 29 | model.train() 30 | optimizer.zero_grad() 31 | output = model(adj, features) 32 | acc_train = accuracy(output[idx_train], labels[idx_train].to(device)) 33 | loss_train = F.nll_loss(output[idx_train], labels[idx_train].to(device)) 34 | loss_train.backward() 35 | optimizer.step() 36 | return loss_train.item(), acc_train.item() 37 | 38 | 39 | def validate(): 40 | model.eval() 41 | with torch.no_grad(): 42 | output = model(adj, features) 43 | loss_val = F.nll_loss(output[idx_val], labels[idx_val].to(device)) 44 | acc_val = accuracy(output[idx_val], labels[idx_val].to(device)) 45 | return loss_val.item(), acc_val.item() 46 | 47 | 48 | def test(): 49 | model.load_state_dict(torch.load(checkpoint_path)) 50 | model.eval() 51 | with torch.no_grad(): 52 | output = model(adj, features) 53 | loss_test = F.nll_loss(output[idx_test], labels[idx_test].to(device)) 54 | acc_test = accuracy(output[idx_test], labels[idx_test].to(device)) 55 | return loss_test.item(), acc_test.item() 56 | 57 | 58 | def main(): 59 | begin_time = time.time() 60 | tolerate = 0 61 | best_loss = 1000 62 | for epoch in range(args.epochs): 63 | loss_train, acc_train = train() 64 | loss_validate, acc_validate = validate() 65 | if (epoch + 1) % 1 == 0: 66 | print( 67 | 'Epoch {:03d}'.format(epoch + 1), 68 | '|| train', 69 | 'loss : {:.3f}'.format(loss_train), 70 | ', accuracy : {:.2f}%'.format(acc_train * 100), 71 | '|| val', 72 | 'loss : {:.3f}'.format(loss_validate), 73 | ', accuracy : {:.2f}%'.format(acc_validate * 100) 74 | ) 75 | if loss_validate < best_loss: 76 | best_loss = loss_validate 77 | torch.save(model.state_dict(), checkpoint_path) 78 | tolerate = 0 79 | else: 80 | tolerate += 1 81 | if tolerate == args.patience: 82 | break 83 | print("Train cost : {:.2f}s".format(time.time() - begin_time)) 84 | print("Test accuracy : {:.2f}%".format(test()[1] * 100), "on dataset", args.dataset) 85 | 86 | 87 | if __name__ == '__main__': 88 | set_seed(args.seed) 89 | device = torch.device('cpu' if args.gpu == -1 else "cuda:%s" % args.gpu) 90 | features, labels, feat_dim, class_dim, adj, train_mask, val_mask, test_mask = load_dataset( 91 | args.dataset, 92 | device 93 | ) 94 | checkpoint_path = utils.root + '/checkpoint/%s.pt' % args.dataset 95 | idx_train, idx_val, idx_test = select_mask(args.split_id, train_mask, val_mask, test_mask) 96 | if not os.path.exists(utils.root + '/checkpoint'): 97 | os.makedirs(utils.root + '/checkpoint') 98 | model = H2GCN( 99 | feat_dim=feat_dim, 100 | hidden_dim=args.hidden, 101 | class_dim=class_dim, 102 | use_relu=not args.without_relu 103 | ).to(device) 104 | optimizer = optim.Adam([{'params': model.params, 'weight_decay': args.wd}], lr=args.lr) 105 | main() 106 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | echo "Running on dataset Chameleon" 2 | python train.py --dataset chameleon --wd 0 --dropout 0.1 --lr 0.1 --without-relu 3 | echo "Running on dataset Squirrel" 4 | python train.py --dataset squirrel --wd 5e-4 --dropout 0.1 --lr 0.1 --without-relu 5 | echo "Running on dataset Texas" 6 | python train.py --dataset texas --wd 0 --dropout 0 --lr 0.1 --without-relu 7 | echo "Running on dataset Cornell" 8 | python train.py --dataset cornell --wd 5e-4 --dropout 0.5 --lr 0.01 --without-relu 9 | echo "Running on dataset Wisconsin" 10 | python train.py --dataset wisconsin --wd 5e-4 --dropout 0.5 --lr 0.05 --without-relu 11 | echo "Running on dataset Actor" 12 | python train.py --dataset actor --wd 5e-4 --dropout 0.5 --lr 0.1 --hidden 128 --without-relu -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch_geometric.datasets import Planetoid, WikipediaNetwork, WebKB, Actor 6 | 7 | root = os.path.split(__file__)[0] 8 | 9 | 10 | def set_seed(seed): 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | 16 | 17 | def accuracy(output, labels): 18 | preds = output.max(1)[1].type_as(labels) 19 | correct = preds.eq(labels).double() 20 | correct = correct.sum() 21 | return correct / len(labels) 22 | 23 | 24 | def load_dataset(name: str, device=None): 25 | if device is None: 26 | device = torch.device('cpu') 27 | name = name.lower() 28 | if name in ["cora", "pubmed", "citeseer"]: 29 | dataset = Planetoid(root=root + "/dataset/Planetoid", name=name) 30 | elif name in ["chameleon", "squirrel"]: 31 | dataset = WikipediaNetwork(root=root + "/dataset/WikipediaNetwork", name=name) 32 | elif name in ["cornell", "texas", "wisconsin"]: 33 | dataset = WebKB(root=root + "/dataset/WebKB", name=name) 34 | elif name in ["actor"]: 35 | dataset = Actor(root=root + "/dataset/Actor") 36 | else: 37 | raise "Please implement support for this dataset in function load_dataset()." 38 | data = dataset[0].to(device) 39 | x, y = data.x, data.y 40 | n = len(x) 41 | edge_index = data.edge_index 42 | nfeat = data.num_node_features 43 | nclass = len(torch.unique(y)) 44 | return x, y, nfeat, nclass, eidx_to_sp(n, edge_index), data.train_mask, data.val_mask, data.test_mask 45 | 46 | 47 | def eidx_to_sp(n: int, edge_index: torch.Tensor, device=None) -> torch.sparse.Tensor: 48 | indices = edge_index 49 | values = torch.FloatTensor([1.0] * len(edge_index[0])).to(edge_index.device) 50 | coo = torch.sparse_coo_tensor(indices=indices, values=values, size=[n, n]) 51 | if device is None: 52 | device = edge_index.device 53 | return coo.to(device) 54 | 55 | 56 | def select_mask(i: int, train: torch.Tensor, val: torch.Tensor, test: torch.Tensor) -> torch.Tensor: 57 | if train.dim() == 1: 58 | return train, val, test 59 | else: 60 | indices = torch.tensor([i]).to(train.device) 61 | train_idx = torch.index_select(train, 1, indices).reshape(-1) 62 | val_idx = torch.index_select(val, 1, indices).reshape(-1) 63 | test_idx = torch.index_select(test, 1, indices).reshape(-1) 64 | return train_idx, val_idx, test_idx 65 | --------------------------------------------------------------------------------