├── README.md ├── W&B Chart 6_16_2020, 5_13_19 PM.png ├── W&B Chart 6_16_2020, 5_13_29 PM.png ├── amazon2M.sh ├── layers.py ├── main.py ├── models.py ├── requirements.txt ├── top3runs.csv └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Cluster-GCN in PyTorch 2 | [![Arxiv](https://img.shields.io/badge/ArXiv-1905.07953-orange.svg?color=blue&style=plastic)](https://arxiv.org/abs/1905.07953) 3 | [![Download](https://img.shields.io/badge/Download-amazon2M-brightgreen.svg?color=black&style=plastic)](https://drive.google.com/drive/folders/1Tfn-yABlW5JheyYItyRyrMGtmQdYN7wm?usp=sharing) 4 | > Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks 5 | > Wei-Lin Chiang, Xuanqing Liu, Si Si, Yang Li, Samy Bengio, Cho-Jui Hsieh. 6 | > KDD, 2019. 7 | > [[Paper]](https://arxiv.org/abs/1905.07953) 8 | Raw data files used to curate this dataset can be downloaded from http://manikvarma.org/downloads/XC/XMLRepository.html while the processed data files used in this implementation can be downloaded by clicking on the above Download amazon2M badge. 9 | 10 | ## GraphConv Layer module usage 11 | ``` 12 | from layers import GraphConv 13 | gconv = GraphConv(in_features: int, out_features: int) 14 | out = gconv(A: torch.sparse.Floattensor, X: torch.FloatTensor) # out.shape(N, out_features) 15 | ``` 16 | where, N = number of nodes in the graph \ 17 | F = number of features per node \ 18 | in_features = number of input features in X, \ 19 | out_features = number of output features, \ 20 | X = a dense FloatTensor containing input features, \ 21 | A = a sparse FloatTensor representing the graph adjacency matrix created as follows 22 | ``` 23 | i = torch.LongTensor(indices) 24 | v = torch.FloatTensor(values) 25 | A = torch.sparse.FloatTensor(i.t(), v, shape) 26 | ``` 27 | 28 | 29 | ## Requirements: 30 | * install the clustering toolkit metis and other required Python packages. 31 | ``` 32 | 1) Download metis-5.1.0.tar.gz from http://glaros.dtc.umn.edu/gkhome/metis/metis/download and unpack it 33 | 2) cd metis-5.1.0 34 | 3) make config shared=1 prefix=~/.local/ 35 | 4) make install 36 | 5) export METIS_DLL=~/.local/lib/libmetis.so 37 | 6) pip install -r requirements.txt 38 | ``` 39 | ## Usage: 40 | * Run the below shell script to perform experiments on amazon2M dataset. 41 | ``` 42 | ./amazon2M.sh 43 | tensorboard --logdir=runs --bind_all (optional to visualize training) 44 | 45 | NOTE: The Scipt assumes that the data files are stored in the following structure. 46 | ./datasets/amazon2M/amazon2M-{G.json, feats.npy, class_map.json, id_map.json} 47 | ``` 48 | ## Results: 49 | * Test F1 **0.8880** (vs Cluster-GCN paper - 0.9041) 50 | * Training Loss **0.3096** 51 |
52 | 53 |
54 | 55 | * Training Accuracy **0.9021** 56 |
57 | 58 |
59 | -------------------------------------------------------------------------------- /W&B Chart 6_16_2020, 5_13_19 PM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyyush/GraphML/73e237e3d02346ccef948814c9eab7c649888530/W&B Chart 6_16_2020, 5_13_19 PM.png -------------------------------------------------------------------------------- /W&B Chart 6_16_2020, 5_13_29 PM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyyush/GraphML/73e237e3d02346ccef948814c9eab7c649888530/W&B Chart 6_16_2020, 5_13_29 PM.png -------------------------------------------------------------------------------- /amazon2M.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export METIS_DLL=~/.local/lib/libmetis.so 4 | 5 | python main.py --dataset amazon2M --exp_num 1 --batch_size 20 --num_clusters_train 10000 --num_clusters_test 1 --layers 4 --epochs 200 --lr 0.01 --hidden 2048 --dropout 0.5 --lr_scheduler -1 --test 1 6 | 7 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | # Author: Piyush Vyas 2 | import math 3 | import torch 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class GraphConv(torch.nn.Module): 8 | """ Applies the Graph Convolution operation to the incoming data: math: `X' = \hat{A}XW` 9 | 10 | Args: 11 | in_features: size of each input sample 12 | out_features: size of each output sample 13 | dropout: dropout probability 14 | Default: 0.2 15 | bias: If set to ``True ``, the layer will learn an additive bias. 16 | Default: ``False`` 17 | normalize: If set to ``False``, the layer will not apply Layer Normalization to the features. 18 | Default: ``True`` 19 | last: If set to ``True``, the layer will act as the final/classification layer and return logits. 20 | Default: ``False`` 21 | Shape: 22 | - Input: :math:`(N, H_{in})` where :math:`H_{in} = \text{in\_features}` 23 | - Output: :math:`(N, H_{out})` where :math:`H_{out} = \text{out\_features}`. 24 | 25 | Attributes: 26 | weight: the learnable weights of the module of shape 27 | :math:`(\text{out\_features}, \text{in\_features})`. The values are 28 | initialized from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where 29 | :math: `\text{bound} = \sqrt{\frac{6}{\text{in\_features + out\_features}}}` 30 | bias: the learnable bias of the module of shape 31 | :math:`(\text{out\_features})`. 32 | If :attr:`bias` is ``True``, the values are initialized with the scalar value `0`. 33 | 34 | """ 35 | 36 | __constants__ = ['in_features, out_features'] 37 | 38 | def __init__(self, in_features, out_features, dropout=0.2, bias=False, normalize=True, last=False): 39 | super(GraphConv, self).__init__() 40 | self.in_features = in_features 41 | self.out_features = out_features 42 | self.normalize = normalize 43 | self.p = dropout 44 | self.last = last 45 | self.layer_norm = torch.nn.LayerNorm(normalized_shape=out_features) 46 | self.weight = Parameter(torch.Tensor(self.out_features, self.in_features)) 47 | if bias: 48 | self.bias = Parameter(torch.Tensor(self.out_features)) 49 | else: 50 | self.register_parameter('bias', None) 51 | self.reset_parameters() 52 | 53 | def reset_parameters(self): 54 | # Xavier Glorot Uniform 55 | bound = math.sqrt(6.0/float(self.out_features + self.in_features)) 56 | self.weight.data.uniform_(-bound, bound) 57 | 58 | # Kaiming He Uniform 59 | #torch.nn.init.kaiming_uniform_(self.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu') 60 | 61 | if self.bias is not None: 62 | torch.nn.init.zeros_(self.bias) 63 | 64 | def forward(self, A, X): 65 | 66 | input = torch.sparse.mm(A, X) # (N, N) x (N, F) -> (N, F) 67 | #output = input.matmul(self.weight.t()) # (N, F) x (H, N).t() -> (N, H) 68 | #if self.bias is not None: 69 | #output += self.bias 70 | output = torch.nn.functional.linear(input, self.weight, self.bias) 71 | 72 | if self.last: 73 | return output 74 | 75 | if self.normalize: 76 | output = self.layer_norm(output) 77 | output = torch.nn.functional.leaky_relu(output) 78 | return torch.nn.functional.dropout(output, p=self.p, training=self.training) 79 | 80 | def extra_repr(self): 81 | return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias is not None) 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Author: Piyush Vyas 2 | import time 3 | import torch 4 | import argparse 5 | import numpy as np 6 | import utils as utils 7 | import sklearn.metrics 8 | from models import GCN 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | 12 | # Training settings 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--dataset', type=str, default='amazon2M', help='options - amazon2M') 15 | parser.add_argument('--exp_num', type=str, help='experiment number for tensorboard') 16 | parser.add_argument('--test', type=int, default=1, help='True if 1, else False') 17 | parser.add_argument('--batch_size', type=int, default=20) 18 | parser.add_argument('--num_clusters_train', type=int, default=10000) 19 | parser.add_argument('--num_clusters_test', type=int, default=1) 20 | parser.add_argument('--layers', type=int, default=4, help='Number of layers in the network.') 21 | parser.add_argument('--epochs', type=int, default=2048, help='Number of epochs to train.') 22 | parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.') 23 | parser.add_argument('--lr_scheduler', type=int, default=-1, help='True if 1, else False') 24 | parser.add_argument('--hidden', type=int, default=200, help='Number of hidden units.') 25 | parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (ratio of units to drop).') 26 | args = parser.parse_args() 27 | 28 | 29 | # Reproducibility 30 | np.random.seed(0) 31 | torch.manual_seed(0) 32 | 33 | # Tensorboard Writer 34 | writer = SummaryWriter('runs/' + args.dataset + '/' + args.exp_num) 35 | 36 | # Settings based on ClusterGCN's Table 4. of section Experiments 37 | if args.dataset == 'amazon2M' and args.layers == 4: 38 | _in = [100, args.hidden, args.hidden, args.hidden] 39 | _out = [args.hidden, args.hidden, args.hidden, 47] 40 | 41 | 42 | def train(model, criterion, optimizer, features, adj, labels, dataset): 43 | optimizer.zero_grad() 44 | 45 | features = torch.from_numpy(features).cuda() 46 | labels = torch.LongTensor(labels).cuda() 47 | 48 | # Adj -> Torch Sparse Tensor 49 | i = torch.LongTensor(adj[0]) # indices 50 | v = torch.FloatTensor(adj[1]) # values 51 | adj = torch.sparse.FloatTensor(i.t(), v, adj[2]).cuda() 52 | 53 | output = model(adj, features) 54 | loss = criterion(output, torch.max(labels, 1)[1]) 55 | 56 | loss.backward() 57 | optimizer.step() 58 | 59 | return loss 60 | 61 | 62 | @torch.no_grad() 63 | def test(model, features, adj, labels, mask, device): 64 | model.eval() 65 | 66 | features = torch.FloatTensor(features).to(device) 67 | labels = torch.LongTensor(labels).to(device) 68 | 69 | if device == 'cpu': 70 | adj = adj[0] 71 | features = features[0] 72 | labels = labels[0] 73 | mask = mask[0] 74 | 75 | # Adj -> Torch Sparse Tensor 76 | i = torch.LongTensor(adj[0]) # indices 77 | v = torch.FloatTensor(adj[1]) # values 78 | adj = torch.sparse.FloatTensor(i.t(), v, adj[2]).to(device) 79 | 80 | output = model(adj, features) 81 | 82 | pred = output[mask].argmax(dim=1, keepdim=True) 83 | labels = torch.max(labels[mask], 1)[1] 84 | return sklearn.metrics.f1_score(labels.cpu().numpy(), pred.cpu().numpy(), average='micro') 85 | 86 | 87 | 88 | def main(): 89 | 90 | # Load data 91 | start = time.time() 92 | N, _adj, _feats, _labels, train_adj, train_feats, train_nodes, val_nodes, test_nodes, y_train, y_val, y_test, val_mask, test_mask = utils.load_data(args.dataset) 93 | print('Loaded data in {:.2f} seconds!'.format(time.time() - start)) 94 | 95 | # Prepare Train Data 96 | start = time.time() 97 | _, parts = utils.partition_graph(train_adj, train_nodes, args.num_clusters_train) 98 | parts = [np.array(pt) for pt in parts] 99 | train_features, train_support, y_train = utils.preprocess_multicluster(train_adj, parts, train_feats, y_train, args.num_clusters_train, args.batch_size) 100 | print('Train Data pre-processed in {:.2f} seconds!'.format(time.time() - start)) 101 | 102 | # Prepare Test Data 103 | if args.test == 1: 104 | y_test, test_mask = y_val, val_mask 105 | start = time.time() 106 | _, test_features, test_support, y_test, test_mask = utils.preprocess(_adj, _feats, y_test, np.arange(N), args.num_clusters_test, test_mask) 107 | print('Test Data pre-processed in {:.2f} seconds!'.format(time.time() - start)) 108 | 109 | # Shuffle Batches 110 | batch_idxs = list(range(len(train_features))) 111 | 112 | # model 113 | model = GCN(fan_in=_in, fan_out=_out, layers=args.layers, dropout=args.dropout, normalize=True, bias=False).float() 114 | model.cuda() 115 | 116 | # Loss Function 117 | criterion = torch.nn.CrossEntropyLoss() 118 | 119 | # Optimization Algorithm 120 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 121 | 122 | # Learning Rate Schedule 123 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=int(args.num_clusters_train/args.batch_size), epochs=args.epochs+1, anneal_strategy='linear') 124 | model.train() 125 | 126 | 127 | # Train 128 | for epoch in range(args.epochs + 1): 129 | np.random.shuffle(batch_idxs) 130 | avg_loss = 0 131 | start = time.time() 132 | for batch in batch_idxs: 133 | loss = train(model.cuda(), criterion, optimizer, train_features[batch], train_support[batch], y_train[batch], dataset=args.dataset) 134 | if args.lr_scheduler == 1: 135 | scheduler.step() 136 | avg_loss += loss.item() 137 | 138 | # Write Train stats to tensorboard 139 | writer.add_scalar('time/train', time.time() - start, epoch) 140 | writer.add_scalar('loss/train', avg_loss/len(train_features), epoch) 141 | 142 | if args.test == 1: 143 | # Test on cpu 144 | f1 = test(model.cpu(), test_features, test_support, y_test, test_mask, device='cpu') 145 | print('f1: {:.4f}'.format(f1)) 146 | 147 | 148 | if __name__ == '__main__': 149 | main() 150 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Author: Piyush Vyas 2 | import torch 3 | from layers import GraphConv 4 | 5 | 6 | class GCN(torch.nn.Module): 7 | def __init__(self, fan_in, fan_out, layers, dropout=0.2, normalize=True, bias=False): 8 | super(GCN, self).__init__() 9 | self.fan_in = fan_in 10 | self.fan_out = fan_out 11 | self.num_layers = layers 12 | self.dropout = dropout 13 | self.normalize = normalize 14 | self.bias = bias 15 | self.network = torch.nn.ModuleList([GraphConv(in_features=fan_in[i], out_features=fan_out[i], dropout=self.dropout, bias=self.bias, normalize=self.normalize) for i in range(self.num_layers - 1)]) 16 | self.network.append(GraphConv(in_features=fan_in[-1], out_features=fan_out[-1], bias=self.bias, last=True)) 17 | 18 | 19 | def forward(self, sparse_adj, feats): 20 | for idx, layer in enumerate(self.network): 21 | feats = layer(sparse_adj, feats) 22 | return feats 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | metis 2 | torch==1.4.0 3 | tensorflow==2.0.1 4 | tensorboard==2.0.2 5 | networkx==2.4 6 | scipy>=1.4.1 7 | scikit-learn>=0.22.2 8 | -------------------------------------------------------------------------------- /top3runs.csv: -------------------------------------------------------------------------------- 1 | Name,State,User,Created,act,batch_size,dropout,epochs,hidden,learning_rate,num_clusters_train,accuracy,test-F1,loss 2 | logical-sweep-31,finished,pyyush,2020-06-16T12:58:55.000Z,leaky_relu,20,0.5,200,2048,0.01,10000,0.902107909,0.888054784,0.30960577 3 | sunny-haze-1,finished,pyyush,2020-06-14T03:13:01.000Z,leaky_relu,20,0.2,200,2048,0.01,10000,0.920543903,0.886562422,0.244526007 4 | good-sweep-8,finished,pyyush,2020-06-15T08:44:33.000Z,leaky_relu,10,0.2,200,2048,0.01,10000,0.918845557,0.886371632,0.252288322 5 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import time 4 | import metis 5 | import numpy as np 6 | import sklearn.metrics 7 | import scipy.sparse as sp 8 | from networkx.readwrite import json_graph 9 | import sklearn.preprocessing 10 | import tensorflow as tf 11 | 12 | np.random.seed(0) 13 | 14 | 15 | def load_data(dataset, path='./datasets'): 16 | 17 | # Load data files 18 | feats = np.load(tf.io.gfile.GFile('{}/{}/{}-feats.npy'.format(path, dataset, dataset), 'rb')).astype(np.float32) 19 | G = json_graph.node_link_graph(json.load(tf.io.gfile.GFile('{}/{}/{}-G.json'.format(path, dataset, dataset)))) 20 | 21 | id_map = json.load(tf.io.gfile.GFile('{}/{}/{}-id_map.json'.format(path, dataset, dataset))) 22 | is_digit = list(id_map.keys())[0].isdigit() 23 | id_map = {(int(k) if is_digit else k): int(v) for k, v in id_map.items()} 24 | 25 | class_map = json.load(tf.io.gfile.GFile('{}/{}/{}-class_map.json'.format(path, dataset, dataset))) 26 | is_instance = isinstance(list(class_map.values())[0], list) 27 | class_map = {(int(k) if is_digit else k): (v if is_instance else int(v)) for k, v in class_map.items()} 28 | 29 | # Generate edge list 30 | edges = [] 31 | for edge in G.edges(): 32 | if edge[0] in id_map and edge[1] in id_map: 33 | edges.append((id_map[edge[0]], id_map[edge[1]])) 34 | 35 | # Total Number of Nodes in the Graph 36 | _nodes = len(id_map) 37 | 38 | # Seperate Train, Val, and Test nodes 39 | val_nodes = np.array([id_map[n] for n in G.nodes() if G.nodes[n]['val']], dtype=np.int32) 40 | test_nodes = np.array([id_map[n] for n in G.nodes() if G.nodes[n]['test']], dtype=np.int32) 41 | is_train = np.ones((_nodes), dtype=np.bool) 42 | is_train[test_nodes] = False 43 | is_train[val_nodes] = False 44 | train_nodes = np.array([n for n in range(_nodes) if is_train[n]], dtype=np.int32) 45 | 46 | # Train Edges 47 | train_edges = [(e[0], e[1]) for e in edges if is_train[e[0]] and is_train[e[1]]] 48 | train_edges = np.array(train_edges, dtype=np.int32) 49 | 50 | # All Edges in the Graph 51 | _edges = np.array(edges, dtype=np.int32) 52 | 53 | # Generate Labels 54 | if isinstance(list(class_map.values())[0], list): 55 | num_classes = len(list(class_map.values())[0]) 56 | _labels = np.zeros((_nodes, num_classes), dtype=np.float32) 57 | for k in class_map.keys(): 58 | _labels[id_map[k], :] = np.array(class_map[k]) 59 | else: 60 | num_classes = len(set(class_map.values())) 61 | _labels = np.zeros((_nodes, num_classes), dtype=np.float32) 62 | for k in class_map.keys(): 63 | _labels[id_map[k], class_map[k]] = 1 64 | 65 | train_ids = np.array([id_map[n] for n in G.nodes() if not G.nodes[n]['val'] and not G.nodes[n]['test']]) 66 | 67 | train_feats = feats[train_ids] 68 | scaler = sklearn.preprocessing.StandardScaler() 69 | scaler.fit(train_feats) 70 | _feats = scaler.transform(feats) 71 | 72 | def _construct_adj(e, shape): 73 | adj = sp.csr_matrix((np.ones((e.shape[0]), dtype=np.float32), (e[:, 0], e[:, 1])), shape=shape) 74 | adj += adj.transpose() 75 | return adj 76 | 77 | train_adj = _construct_adj(train_edges, (len(train_nodes), len(train_nodes))) 78 | _adj = _construct_adj(_edges, (_nodes, _nodes)) 79 | 80 | train_feats = _feats[train_nodes] 81 | 82 | # Generate Labels 83 | y_train = _labels[train_nodes] 84 | y_val = np.zeros(_labels.shape) 85 | y_test = np.zeros(_labels.shape) 86 | y_val[val_nodes, :] = _labels[val_nodes, :] 87 | y_test[test_nodes, :] = _labels[test_nodes, :] 88 | 89 | # Generate Masks for Validtion & Testing Data 90 | val_mask = sample_mask(val_nodes, _labels.shape[0]) 91 | test_mask = sample_mask(test_nodes, _labels.shape[0]) 92 | 93 | return _nodes, _adj, _feats, _labels, train_adj, train_feats, train_nodes, val_nodes, test_nodes, y_train, y_val, y_test, val_mask, test_mask 94 | 95 | 96 | 97 | def partition_graph(adj, idx_nodes, num_clusters): 98 | 99 | num_nodes = len(idx_nodes) 100 | num_all_nodes = adj.shape[0] 101 | 102 | neighbor_intervals = [] 103 | neighbors = [] 104 | edge_cnt = 0 105 | neighbor_intervals.append(0) 106 | train_adj_lil = adj[idx_nodes, :][:, idx_nodes].tolil() 107 | train_ord_map = dict() 108 | train_adj_lists = [[] for _ in range(num_nodes)] 109 | 110 | for i in range(num_nodes): 111 | rows = train_adj_lil[i].rows[0] 112 | # self-edge needs to be removed for valid format of METIS 113 | if i in rows: 114 | rows.remove(i) 115 | train_adj_lists[i] = rows 116 | neighbors += rows 117 | edge_cnt += len(rows) 118 | neighbor_intervals.append(edge_cnt) 119 | train_ord_map[idx_nodes[i]] = i 120 | 121 | if num_clusters > 1: 122 | _, groups = metis.part_graph(train_adj_lists, num_clusters, seed=1) 123 | else: 124 | groups = [0] * num_nodes 125 | 126 | part_row = [] 127 | part_col = [] 128 | part_data = [] 129 | parts = [[] for _ in range(num_clusters)] 130 | 131 | for nd_idx in range(num_nodes): 132 | gp_idx = groups[nd_idx] 133 | nd_orig_idx = idx_nodes[nd_idx] 134 | parts[gp_idx].append(nd_orig_idx) 135 | 136 | for nb_orig_idx in adj[nd_orig_idx].indices: 137 | nb_idx = train_ord_map[nb_orig_idx] 138 | if groups[nb_idx] == gp_idx: 139 | part_data.append(1) 140 | part_row.append(nd_orig_idx) 141 | part_col.append(nb_orig_idx) 142 | part_data.append(0) 143 | part_row.append(num_all_nodes - 1) 144 | part_col.append(num_all_nodes - 1) 145 | part_adj = sp.coo_matrix((part_data, (part_row, part_col))).tocsr() 146 | 147 | return part_adj, parts 148 | 149 | 150 | def preprocess(adj, features, y_train, visible_data, num_clusters, train_mask=None): 151 | 152 | # graph partitioning 153 | part_adj, parts = partition_graph(adj, visible_data, num_clusters) 154 | part_adj = normalize_adj_diag_enhance(part_adj) 155 | parts = [np.array(pt) for pt in parts] 156 | 157 | features_batches = [] 158 | support_batches = [] 159 | y_train_batches = [] 160 | train_mask_batches = [] 161 | total_nnz = 0 162 | 163 | for pt in parts: 164 | features_batches.append(features[pt, :]) 165 | now_part = part_adj[pt, :][:, pt] 166 | total_nnz += now_part.count_nonzero() 167 | support_batches.append(sparse_to_tuple(now_part)) 168 | y_train_batches.append(y_train[pt, :])## 169 | 170 | if train_mask is not None: 171 | train_pt = [] 172 | for newidx, idx in enumerate(pt): 173 | if train_mask[idx]: 174 | train_pt.append(newidx) 175 | train_mask_batches.append(sample_mask(train_pt, len(pt))) 176 | return parts, features_batches, support_batches, y_train_batches, train_mask_batches 177 | else: 178 | return parts, features_batches, support_batches, y_train_batches, train_mask 179 | 180 | 181 | def preprocess_multicluster(adj, parts, features, y_train, num_clusters, block_size): 182 | """ Generate batches for multiple clusters.""" 183 | features_batches = [] 184 | support_batches = [] 185 | y_train_batches = [] 186 | total_nnz = 0 187 | np.random.shuffle(parts) 188 | 189 | for _, st in enumerate(range(0, num_clusters, block_size)): 190 | pt = parts[st] 191 | for pt_idx in range(st + 1, min(st + block_size, num_clusters)): 192 | pt = np.concatenate((pt, parts[pt_idx]), axis=0) 193 | features_batches.append(features[pt, :]) 194 | y_train_batches.append(y_train[pt, :]) 195 | support_now = adj[pt, :][:, pt] 196 | support_batches.append(sparse_to_tuple(normalize_adj_diag_enhance(support_now, diag_lambda=1))) 197 | total_nnz += support_now.count_nonzero() 198 | 199 | return features_batches, support_batches, y_train_batches 200 | 201 | 202 | def normalize_adj_diag_enhance(adj, diag_lambda=1): 203 | 204 | """ A'=(D+I)^{-1}(A+I), A'=A'+lambda*diag(A') """ 205 | 206 | adj = adj + sp.eye(adj.shape[0]) 207 | rowsum = np.array(adj.sum(1)).flatten() 208 | d_inv = 1.0 / (rowsum + 1e-20) 209 | d_mat_inv = sp.diags(d_inv, 0) 210 | adj = d_mat_inv.dot(adj) 211 | adj = adj + diag_lambda * sp.diags(adj.diagonal(), 0) 212 | return adj 213 | 214 | 215 | def sparse_to_tuple(sparse_mx): 216 | 217 | def to_tuple(mx): 218 | if not sp.isspmatrix_coo(mx): 219 | mx = mx.tocoo() 220 | coords = np.vstack((mx.row, mx.col)).transpose() 221 | values = mx.data 222 | shape = mx.shape 223 | return coords, values, shape 224 | 225 | if isinstance(sparse_mx, list): 226 | for i in range(len(sparse_mx)): 227 | sparse_mx[i] = to_tuple(sparse_mx[i]) 228 | else: 229 | sparse_mx = to_tuple(sparse_mx) 230 | 231 | return sparse_mx 232 | 233 | 234 | def sample_mask(idx, mat): 235 | mask = np.zeros(mat) 236 | mask[idx] = 1 237 | return np.array(mask, dtype=np.bool) 238 | --------------------------------------------------------------------------------