├── src ├── config.json ├── layers.py ├── models.py ├── utils.py ├── main.py └── datasets │ └── node_classification.py ├── README.md ├── LICENSE └── .gitignore /src/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "stats_per_batch" : 3, 3 | "dataset" : "cora", 4 | "dataset_path" : "/Users/raunak/Documents/Datasets/Cora", 5 | "mode" : "train", 6 | "task" : "node_classification", 7 | "cuda" : "True", 8 | "hidden_dims" : [8], 9 | "num_heads" : [8, 1], 10 | "dropout" : 0.6, 11 | "batch_size" : 140, 12 | "epochs" : 200, 13 | "lr" : 5e-2, 14 | "weight_decay" : 5e-4, 15 | "transductive" : "True", 16 | "self_loop" : "True" 17 | } 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Attention Networks 2 | This is a PyTorch implementation of Graph Attention Networks (GAT) from the paper [Graph Attention Networks](https://arxiv.org/pdf/1710.10903.pdf). 3 | 4 | ## Usage 5 | 6 | In the `src` directory, edit the `config.json` file to specify arguments and 7 | flags. Then run `python main.py`. 8 | 9 | ## Limitations 10 | * Currently, only supports the Cora dataset. However, for a new dataset it should be fairly straightforward to write a Dataset class similar to `datasets.Cora`. 11 | 12 | ## References 13 | * [Graph Attention Networks](https://arxiv.org/pdf/1710.10903.pdf), Velickovic et al., ICLR 2018. 14 | * [Collective Classification in Network Data](https://www.aaai.org/ojs/index.php/aimagazine/article/view/2157), Sen et al., AI Magazine 2008. 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Raunak Kumar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | class GraphAttention(nn.Module): 8 | 9 | def __init__(self, input_dim, output_dim, num_heads, dropout=0.5): 10 | """ 11 | Parameters 12 | ---------- 13 | input_dim : int 14 | Dimension of input node features. 15 | output_dim : int 16 | Dimension of output features after each attention head. 17 | num_heads : int 18 | Number of attention heads. 19 | dropout : float 20 | Dropout rate. Default: 0.5. 21 | """ 22 | super().__init__() 23 | 24 | self.input_dim = input_dim 25 | self.output_dim = output_dim 26 | self.num_heads = num_heads 27 | 28 | self.fcs = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)]) 29 | self.a = nn.ModuleList([nn.Linear(2*output_dim, 1) for _ in range(num_heads)]) 30 | 31 | self.dropout = nn.Dropout(dropout) 32 | self.softmax = nn.Softmax(dim=0) 33 | self.leakyrelu = nn.LeakyReLU() 34 | 35 | def forward(self, features, nodes, mapping, rows): 36 | """ 37 | Parameters 38 | ---------- 39 | features : torch.Tensor 40 | An (n' x input_dim) tensor of input node features. 41 | nodes : numpy array 42 | nodes is a numpy array of nodes in the current layer of the computation graph. 43 | mapping : dict 44 | mapping is a dictionary mapping node v (labelled 0 to |V|-1) to 45 | its position in the layer of nodes in the computation graph 46 | before nodes. For example, if the layer before nodes is [2,5], 47 | then mapping[2] = 0 and mapping[5] = 1. 48 | rows : numpy array 49 | rows[i] is an array of neighbors of node i which is present in nodes. 50 | 51 | Returns 52 | ------- 53 | out : list of torch.Tensor 54 | A list of (len(nodes) x input_dim) tensor of output node features. 55 | """ 56 | 57 | nprime = features.shape[0] 58 | rows = [np.array([mapping[v] for v in row], dtype=np.int64) for row in rows] 59 | sum_degs = np.hstack(([0], np.cumsum([len(row) for row in rows]))) 60 | mapped_nodes = [mapping[v] for v in nodes] 61 | indices = torch.LongTensor([[v, c] for (v, row) in zip(mapped_nodes, rows) for c in row]).t() 62 | # indices = torch.LongTensor([[mapping[nodes[i]], c] for i in range(len(rows)) for c in rows[i]]).t() 63 | 64 | out = [] 65 | for k in range(self.num_heads): 66 | h = self.fcs[k](features) 67 | 68 | nbr_h = torch.cat(tuple([h[row] for row in rows]), dim=0) 69 | self_h = torch.cat(tuple([h[mapping[nodes[i]]].repeat(len(row), 1) for (i, row) in enumerate(rows)]), dim=0) 70 | cat_h = torch.cat((self_h, nbr_h), dim=1) 71 | 72 | e = self.leakyrelu(self.a[k](cat_h)) 73 | 74 | alpha = [self.softmax(e[lo : hi]) for (lo, hi) in zip(sum_degs, sum_degs[1:])] 75 | alpha = torch.cat(tuple(alpha), dim=0) 76 | alpha = alpha.squeeze(1) 77 | alpha = self.dropout(alpha) 78 | 79 | adj = torch.sparse.FloatTensor(indices, alpha, torch.Size([nprime, nprime])) 80 | out.append(torch.sparse.mm(adj, h)[mapped_nodes]) 81 | 82 | return out -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | import layers 6 | 7 | class GAT(nn.Module): 8 | 9 | def __init__(self, input_dim, hidden_dims, output_dim, num_heads, 10 | dropout=0.5, device='cpu'): 11 | """ 12 | Parameters 13 | ---------- 14 | input_dim : int 15 | Dimension of input node features. 16 | hidden_dims : list of ints 17 | Dimension of hidden layers. Must be non empty. 18 | output_dim : int 19 | Dimension of output node features. 20 | num_heads : list of ints 21 | Number of attention heads in each hidden layer and output layer. Must be non empty. Note that len(num_heads) = len(hidden_dims)+1. 22 | dropout : float 23 | Dropout rate. Default: 0.5. 24 | device : str 25 | 'cpu' or 'cuda:0'. Default: 'cpu'. 26 | """ 27 | super().__init__() 28 | 29 | self.input_dim = input_dim 30 | self.hidden_dims = hidden_dims 31 | self.output_dim = output_dim 32 | self.num_heads = num_heads 33 | self.device = device 34 | self.num_layers = len(hidden_dims) + 1 35 | 36 | dims = [input_dim] + [d*nh for (d, nh) in zip(hidden_dims, num_heads[:-1])] + [output_dim*num_heads[-1]] 37 | in_dims = dims[:-1] 38 | out_dims = [d // nh for (d, nh) in zip(dims[1:], num_heads)] 39 | 40 | self.attn = nn.ModuleList([layers.GraphAttention(i, o, nh, dropout) for (i, o, nh) in zip(in_dims, out_dims, num_heads)]) 41 | 42 | self.bns = nn.ModuleList([nn.BatchNorm1d(dim) for dim in dims[1:-1]]) 43 | 44 | self.dropout = nn.Dropout(dropout) 45 | self.elu = nn.ELU() 46 | 47 | def forward(self, features, node_layers, mappings, rows): 48 | """ 49 | Parameters 50 | ---------- 51 | features : torch.Tensor 52 | An (n' x input_dim) tensor of input node features. 53 | node_layers : list of numpy array 54 | node_layers[i] is an array of the nodes in the ith layer of the 55 | computation graph. 56 | mappings : list of dictionary 57 | mappings[i] is a dictionary mapping node v (labelled 0 to |V|-1) 58 | in node_layers[i] to its position in node_layers[i]. For example, 59 | if node_layers[i] = [2,5], then mappings[i][2] = 0 and 60 | mappings[i][5] = 1. 61 | rows : numpy array 62 | rows[i] is an array of neighbors of node i. 63 | 64 | Returns 65 | ------- 66 | out : torch.Tensor 67 | An (len(node_layers[-1]) x output_dim) tensor of output node features. 68 | """ 69 | out = features 70 | for k in range(self.num_layers): 71 | nodes = node_layers[k+1] 72 | mapping = mappings[k] 73 | init_mapped_nodes = np.array([mappings[0][v] for v in nodes], dtype=np.int64) 74 | cur_rows = rows[init_mapped_nodes] 75 | out = self.dropout(out) 76 | out = self.attn[k](out, nodes, mapping, cur_rows) 77 | if k+1 < self.num_layers: 78 | out = [self.elu(o) for o in out] 79 | out = torch.cat(tuple(out), dim=1) 80 | out = self.bns[k](out) 81 | else: 82 | out = torch.cat(tuple([x.flatten().unsqueeze(0) for x in out]), dim=0) 83 | out = out.mean(dim=0).reshape(len(nodes), self.output_dim) 84 | 85 | return out 86 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import sys 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from datasets import node_classification 9 | 10 | def get_criterion(task): 11 | """ 12 | Parameters 13 | ---------- 14 | task : str 15 | Name of the task. 16 | 17 | Returns 18 | ------- 19 | criterion : torch.nn.modules._Loss 20 | Loss function for the task. 21 | """ 22 | if task == 'node_classification': 23 | criterion = nn.CrossEntropyLoss() 24 | 25 | return criterion 26 | 27 | def get_dataset(args): 28 | """ 29 | Parameters 30 | ---------- 31 | args : tuple 32 | Tuple of task, dataset name and other arguments required by the dataset constructor. 33 | 34 | Returns 35 | ------- 36 | dataset : torch.utils.data.Dataset 37 | The dataset. 38 | """ 39 | task, dataset_name, *dataset_args = args 40 | if task == 'node_classification': 41 | if dataset_name == 'cora': 42 | dataset = node_classification.Cora(*dataset_args) 43 | 44 | return dataset 45 | 46 | def get_fname(config): 47 | """ 48 | Parameters 49 | ---------- 50 | config : dict 51 | A dictionary with all the arguments and flags. 52 | 53 | Returns 54 | ------- 55 | fname : str 56 | The filename for the saved model. 57 | """ 58 | hidden_dims_str = '_'.join([str(x) for x in config['hidden_dims']]) 59 | num_heads_str = '_'.join([str(x) for x in config['num_heads']]) 60 | batch_size = config['batch_size'] 61 | epochs = config['epochs'] 62 | lr = config['lr'] 63 | weight_decay = config['weight_decay'] 64 | dropout = config['dropout'] 65 | transductive = str(config['transductive']) 66 | fname = 'gat_hidden_dims_{}_num_heads_{}_batch_size_{}_epochs_{}_lr_{}_weight_decay_{}_dropout_{}_transductive_{}.pth'.format( 67 | hidden_dims_str, num_heads_str, batch_size, epochs, lr, 68 | weight_decay, dropout, transductive) 69 | 70 | return fname 71 | 72 | def parse_args(): 73 | """ 74 | Returns 75 | ------- 76 | config : dict 77 | A dictionary with the required arguments and flags. 78 | """ 79 | parser = argparse.ArgumentParser() 80 | 81 | parser.add_argument('--json', type=str, default='config.json', 82 | help='path to json file with arguments, default: config.json') 83 | 84 | parser.add_argument('--print_every', type=int, default=16, 85 | help='print loss and accuracy after how many batches, default: 16') 86 | 87 | parser.add_argument('--dataset', type=str, choices=['cora'], default='cora', 88 | help='name of the dataset, default=cora') 89 | parser.add_argument('--dataset_path', type=str, 90 | default='/Users/raunak/Documents/Datasets/Cora', 91 | help='path to dataset') 92 | parser.add_argument('--self_loop', action='store_true', 93 | help='whether to add self loops to adjacency matrix, default=False') 94 | parser.add_argument('--normalize_adj', action='store_true', 95 | help='whether to normalize adj like in gcn, default=False') 96 | parser.add_argument('--transductive', action='store_true', 97 | help='whether to use all nodes while training, default=False') 98 | 99 | parser.add_argument('--task', type=str, 100 | choices=['unsupervised', 'node_classification'], 101 | default='node_classification', 102 | help='type of task, default=node_classification') 103 | 104 | parser.add_argument('--dropout', type=float, default=0.5, 105 | help='dropout parameter, default=0.5.') 106 | parser.add_argument('--cuda', action='store_true', 107 | help='whether to use GPU, default: False') 108 | parser.add_argument('--hidden_dims', type=int, nargs="*", 109 | help='dimensions of hidden layers, specify through config.json') 110 | parser.add_argument('--num_heads', type=int, nargs="*", 111 | help='number of attention heads in each layer, length should be equal to len(hidden_dims)+1, specify through config.json') 112 | 113 | parser.add_argument('--batch_size', type=int, default=8, 114 | help='training batch size, default=8') 115 | parser.add_argument('--epochs', type=int, default=10, 116 | help='number of training epochs, default=10') 117 | parser.add_argument('--lr', type=float, default=1e-3, 118 | help='learning rate, default=1e-3') 119 | parser.add_argument('--weight_decay', type=float, default=5e-4, 120 | help='weight decay, default=5e-4') 121 | 122 | parser.add_argument('--save', action='store_true', 123 | help='whether to save model in trained_models/ directory, default: False') 124 | parser.add_argument('--load', action='store_true', 125 | help='whether to load model in trained_models/ directory') 126 | 127 | args = parser.parse_args() 128 | config = vars(args) 129 | if config['json']: 130 | with open(config['json']) as f: 131 | json_dict = json.load(f) 132 | config.update(json_dict) 133 | 134 | config['num_layers'] = len(config['hidden_dims']) + 1 135 | 136 | print('--------------------------------') 137 | print('Config:') 138 | for (k, v) in config.items(): 139 | print(" '{}': '{}'".format(k, v)) 140 | print('--------------------------------') 141 | 142 | return config -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import os 3 | import sys 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | 10 | from datasets import node_classification 11 | import models 12 | import utils 13 | 14 | def main(): 15 | config = utils.parse_args() 16 | 17 | if config['cuda'] and torch.cuda.is_available(): 18 | device = 'cuda:0' 19 | else: 20 | device = 'cpu' 21 | 22 | dataset_args = (config['task'], config['dataset'], config['dataset_path'], 23 | 'train', config['num_layers'], config['self_loop'], 24 | config['normalize_adj'], config['transductive']) 25 | dataset = utils.get_dataset(dataset_args) 26 | loader = DataLoader(dataset=dataset, batch_size=config['batch_size'], 27 | shuffle=True, collate_fn=dataset.collate_wrapper) 28 | input_dim, output_dim = dataset.get_dims() 29 | 30 | model = models.GAT(input_dim, config['hidden_dims'], output_dim, 31 | config['num_heads'], config['dropout'], device) 32 | model.to(device) 33 | 34 | if not config['load']: 35 | criterion = utils.get_criterion(config['task']) 36 | optimizer = optim.Adam(model.parameters(), lr=config['lr'], 37 | weight_decay=config['weight_decay']) 38 | epochs = config['epochs'] 39 | stats_per_batch = config['stats_per_batch'] 40 | num_batches = int(ceil(len(dataset) / config['batch_size'])) 41 | model.train() 42 | print('--------------------------------') 43 | print('Training.') 44 | for epoch in range(epochs): 45 | print('Epoch {} / {}'.format(epoch+1, epochs)) 46 | running_loss = 0.0 47 | num_correct, num_examples = 0, 0 48 | for (idx, batch) in enumerate(loader): 49 | features, node_layers, mappings, rows, labels = batch 50 | features, labels = features.to(device), labels.to(device) 51 | optimizer.zero_grad() 52 | out = model(features, node_layers, mappings, rows) 53 | loss = criterion(out, labels) 54 | loss.backward() 55 | optimizer.step() 56 | with torch.no_grad(): 57 | running_loss += loss.item() 58 | predictions = torch.max(out, dim=1)[1] 59 | num_correct += torch.sum(predictions == labels).item() 60 | num_examples += len(labels) 61 | if (idx + 1) % stats_per_batch == 0: 62 | running_loss /= stats_per_batch 63 | accuracy = num_correct / num_examples 64 | print(' Batch {} / {}: loss {}, accuracy {}'.format( 65 | idx+1, num_batches, running_loss, accuracy)) 66 | running_loss = 0.0 67 | num_correct, num_examples = 0, 0 68 | print('Finished training.') 69 | print('--------------------------------') 70 | 71 | if config['save']: 72 | print('--------------------------------') 73 | directory = os.path.join(os.path.dirname(os.getcwd()), 74 | 'trained_models') 75 | if not os.path.exists(directory): 76 | os.makedirs(directory) 77 | fname = utils.get_fname(config) 78 | path = os.path.join(directory, fname) 79 | print('Saving model at {}'.format(path)) 80 | torch.save(model.state_dict(), path) 81 | print('Finished saving model.') 82 | print('--------------------------------') 83 | 84 | if config['load']: 85 | directory = os.path.join(os.path.dirname(os.getcwd()), 86 | 'trained_models') 87 | fname = utils.get_fname(config) 88 | path = os.path.join(directory, fname) 89 | model.load_state_dict(torch.load(path)) 90 | dataset_args = (config['task'], config['dataset'], config['dataset_path'], 91 | 'test', config['num_layers'], config['self_loop'], 92 | config['normalize_adj'], config['transductive']) 93 | dataset = utils.get_dataset(dataset_args) 94 | loader = DataLoader(dataset=dataset, batch_size=config['batch_size'], 95 | shuffle=False, collate_fn=dataset.collate_wrapper) 96 | criterion = utils.get_criterion(config['task']) 97 | stats_per_batch = config['stats_per_batch'] 98 | num_batches = int(ceil(len(dataset) / config['batch_size'])) 99 | model.eval() 100 | print('--------------------------------') 101 | print('Testing.') 102 | running_loss, total_loss = 0.0, 0.0 103 | num_correct, num_examples = 0, 0 104 | total_correct, total_examples = 0, 0 105 | for (idx, batch) in enumerate(loader): 106 | features, node_layers, mappings, rows, labels = batch 107 | features, labels = features.to(device), labels.to(device) 108 | out = model(features, node_layers, mappings, rows) 109 | loss = criterion(out, labels) 110 | running_loss += loss.item() 111 | total_loss += loss.item() 112 | predictions = torch.max(out, dim=1)[1] 113 | num_correct += torch.sum(predictions == labels).item() 114 | total_correct += torch.sum(predictions == labels).item() 115 | num_examples += len(labels) 116 | total_examples += len(labels) 117 | if (idx + 1) % stats_per_batch == 0: 118 | running_loss /= stats_per_batch 119 | accuracy = num_correct / num_examples 120 | print(' Batch {} / {}: loss {}, accuracy {}'.format( 121 | idx+1, num_batches, running_loss, accuracy)) 122 | running_loss = 0.0 123 | num_correct, num_examples = 0, 0 124 | total_loss /= num_batches 125 | total_accuracy = total_correct / total_examples 126 | print('Loss {}, accuracy {}'.format(total_loss, total_accuracy)) 127 | print('Finished testing.') 128 | print('--------------------------------') 129 | 130 | if __name__ == '__main__': 131 | main() -------------------------------------------------------------------------------- /src/datasets/node_classification.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import scipy.sparse as sp 5 | import torch 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | class Cora(Dataset): 9 | 10 | def __init__(self, path, mode, num_layers, 11 | self_loop=False, normalize_adj=False, transductive=False): 12 | """ 13 | Parameters 14 | ---------- 15 | path : str 16 | Path to the cora dataset with cora.cites and cora.content files. 17 | mode : str 18 | train / val / test. 19 | num_layers : int 20 | Depth of the model. 21 | self_loop : Boolean 22 | Whether to add self loops, default: False. 23 | normalize_adj : Boolean 24 | Whether to use symmetric normalization on the adjacency matrix, default: False. 25 | transductive : Boolean 26 | Whether to use all node features while training, as in a transductive setting, default: False. 27 | """ 28 | super(Cora, self).__init__() 29 | 30 | self.path = path 31 | self.mode = mode 32 | self.num_layers = num_layers 33 | self.self_loop = self_loop 34 | self.normalize_adj = normalize_adj 35 | self.transductive = transductive 36 | self.idx = { 37 | 'train' : np.array(range(140)), 38 | 'val' : np.array(range(200, 500)), 39 | 'test' : np.array(range(500, 1500)) 40 | } 41 | 42 | print('--------------------------------') 43 | print('Reading cora dataset from {}'.format(path)) 44 | citations = np.loadtxt(os.path.join(path, 'cora.cites'), dtype=np.int64) 45 | content = np.loadtxt(os.path.join(path, 'cora.content'), dtype=str) 46 | print('Finished reading data.') 47 | 48 | print('Setting up data structures.') 49 | if transductive: 50 | idx = np.arange(content.shape[0]) 51 | else: 52 | if mode == 'train': 53 | idx = self.idx['train'] 54 | elif mode == 'val': 55 | idx = np.hstack((self.idx['train'], self.idx['val'])) 56 | elif mode == 'test': 57 | idx = np.hstack((self.idx['train'], self.idx['test'])) 58 | features, labels = content[idx, 1:-1].astype(np.float32), content[idx, -1] 59 | d = {j : i for (i,j) in enumerate(sorted(set(labels)))} 60 | labels = np.array([d[l] for l in labels]) 61 | 62 | vertices = np.array(content[idx, 0], dtype=np.int64) 63 | d = {j : i for (i,j) in enumerate(vertices)} 64 | edges = np.array([e for e in citations if e[0] in d.keys() and e[1] in d.keys()]) 65 | edges = np.array([d[v] for v in edges.flatten()]).reshape(edges.shape) 66 | n, m = labels.shape[0], edges.shape[0] 67 | u, v = edges[:, 0], edges[:, 1] 68 | adj = sp.coo_matrix((np.ones(m), (u, v)), 69 | shape=(n, n), 70 | dtype=np.float32) 71 | adj += adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 72 | if self_loop: 73 | adj += sp.eye(n) 74 | if normalize_adj: 75 | degrees = np.power(np.array(np.sum(adj, axis=1)), -0.5).flatten() 76 | degrees = sp.diags(degrees) 77 | adj = (degrees.dot(adj.dot(degrees))) 78 | print('Finished setting up data structures.') 79 | print('--------------------------------') 80 | 81 | self.features = features 82 | self.labels = labels 83 | self.adj = adj.tolil() 84 | 85 | def __len__(self): 86 | return len(self.idx[self.mode]) 87 | 88 | def __getitem__(self, idx): 89 | if self.transductive: 90 | idx += int(self.idx[self.mode][0]) 91 | else: 92 | if self.mode != 'train': 93 | idx += len(self.idx['train']) 94 | node_layers, mappings = self._form_computation_graph(idx) 95 | rows = self.adj.rows[node_layers[0]] 96 | features = self.features[node_layers[0], :] 97 | labels = self.labels[node_layers[-1]] 98 | features = torch.FloatTensor(features) 99 | labels = torch.LongTensor(labels) 100 | 101 | return features, node_layers, mappings, rows, labels 102 | 103 | def collate_wrapper(self, batch): 104 | """ 105 | Parameters 106 | ---------- 107 | batch : list 108 | A list of examples from this dataset. 109 | 110 | Returns 111 | ------- 112 | features : torch.FloatTensor 113 | An (n' x input_dim) tensor of input node features. 114 | node_layers : list of numpy array 115 | node_layers[i] is an array of the nodes in the ith layer of the 116 | computation graph. 117 | mappings : list of dictionary 118 | mappings[i] is a dictionary mapping node v (labelled 0 to |V|-1) 119 | in node_layers[i] to its position in node_layers[i]. For example, 120 | if node_layers[i] = [2,5], then mappings[i][2] = 0 and 121 | mappings[i][5] = 1. 122 | rows : list 123 | labels : torch.LongTensor 124 | An (n') length tensor of node labels. 125 | """ 126 | idx = [node_layers[-1][0] for node_layers in [sample[1] for sample in batch]] 127 | 128 | node_layers, mappings = self._form_computation_graph(idx) 129 | rows = self.adj.rows[node_layers[0]] 130 | features = self.features[node_layers[0], :] 131 | labels = self.labels[node_layers[-1]] 132 | features = torch.FloatTensor(features) 133 | labels = torch.LongTensor(labels) 134 | 135 | return features, node_layers, mappings, rows, labels 136 | 137 | def get_dims(self): 138 | """ 139 | Returns 140 | ------- 141 | dimension of input features, dimension of output features 142 | """ 143 | return self.features.shape[1], len(set(self.labels)) 144 | 145 | def _form_computation_graph(self, idx): 146 | """ 147 | Parameters 148 | ---------- 149 | idx : int 150 | Index of the node for which the forward pass needs to be computed. 151 | 152 | Returns 153 | ------- 154 | node_layers : list of numpy array 155 | node_layers[i] is an array of the nodes in the ith layer of the 156 | computation graph. 157 | mappings : list of dictionary 158 | mappings[i] is a dictionary mapping node v (labelled 0 to |V|-1) 159 | in node_layers[i] to its position in node_layers[i]. For example, 160 | if node_layers[i] = [2,5], then mappings[i][2] = 0 and 161 | mappings[i][5] = 1. 162 | """ 163 | _list, _set = list, set 164 | rows = self.adj.rows 165 | if type(idx) is int: 166 | node_layers = [np.array([idx], dtype=np.int64)] 167 | elif type(idx) is list: 168 | node_layers = [np.array(idx, dtype=np.int64)] 169 | for _ in range(self.num_layers): 170 | prev = node_layers[-1] 171 | arr = [node for node in prev] 172 | arr.extend([v for node in arr for v in rows[node]]) 173 | arr = np.array(_list(_set(arr)), dtype=np.int64) 174 | node_layers.append(arr) 175 | node_layers.reverse() 176 | 177 | mappings = [{j : i for (i,j) in enumerate(arr)} for arr in node_layers] 178 | 179 | return node_layers, mappings --------------------------------------------------------------------------------