├── README.md ├── main.py ├── models.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # GraphU-Net 2 | pytorch geometric version 3 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import os.path as osp 4 | import argparse 5 | 6 | import torch 7 | import numpy as np 8 | import torch.nn.functional as F 9 | import torch.optim.lr_scheduler as lr_scheduler 10 | import torch_geometric.transforms as T 11 | from torch_geometric.datasets import TUDataset 12 | from torch_geometric.data import DataLoader 13 | 14 | import utils 15 | from models import Classifier, DownClassifier 16 | 17 | from tensorboardX import SummaryWriter 18 | 19 | print('[INFO] Using torch', torch.__version__) 20 | 21 | def set_parser(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--seed', type=int, default=1, 24 | help='seed') 25 | parser.add_argument('--lr_decay_steps', type=int, default=50) 26 | parser.add_argument('--batch_size', type=int, default=20, 27 | help='batch size') 28 | parser.add_argument('--lr', type=float, default=0.0005, 29 | help='learning rate') 30 | parser.add_argument('--weight_decay', type=float, default=0.0008, 31 | help='weight decay') 32 | parser.add_argument('--momentum', type=float, default=0.9) 33 | parser.add_argument('--nn_hid', type=int, default=48, 34 | help='hidden size') 35 | parser.add_argument('--nn_out', type=int, default=97, 36 | help='hidden size') 37 | parser.add_argument('--hidden', type=int, default=128, 38 | help='hidden size') 39 | parser.add_argument('--final_pool', type=float, default=0.6, 40 | help='pooling ratio') 41 | 42 | parser.add_argument('--datadir', type=str, default='/home/zwq/data/') 43 | parser.add_argument('--dataset', type=str, default='DD', 44 | help='DD/PROTEINS/NCI1/NCI109/Mutagenicity') 45 | parser.add_argument('--epochs', type=int, default=200, 46 | help='maximum number of epochs') 47 | 48 | parser.add_argument('--gpu', type=int, default=1, 49 | help='gpu device id') 50 | args = parser.parse_args() 51 | return args 52 | 53 | args = set_parser() 54 | writer = SummaryWriter(args.dataset) 55 | ##################################### Setting: Dataset ###################################### 56 | random.seed(args.seed) 57 | np.random.seed(args.seed) 58 | torch.manual_seed(args.seed) 59 | 60 | path = osp.join(args.datadir, args.dataset) 61 | dataset = TUDataset(path, name=args.dataset, use_node_attr=True, transform=utils.Indegree()) 62 | 63 | num_feat = dataset.num_features 64 | num_classes = dataset.num_classes 65 | 66 | num_node_list = sorted([data.num_nodes for data in dataset]) 67 | s_k = num_node_list[int(math.ceil(args.final_pool * len(num_node_list))) - 1] 68 | sk = max(10, s_k) 69 | 70 | print("##################################################") 71 | print("[INFO] Dataset name:", args.dataset) 72 | print("[INFO] K use in dataset is", sk) 73 | print("[INFO] All data exists ", len(num_node_list), "Graphs") 74 | print("[INFO] Max node:", max(num_node_list)) 75 | print("[INFO] Min node:", min(num_node_list)) 76 | print("[INFO] num_feat| num_classes:", num_feat, num_classes) 77 | 78 | torch.cuda.set_device(args.gpu) 79 | device = torch.device('cuda:'+str(args.gpu)) 80 | 81 | ############################### train & val #################################### 82 | def train(model, train_loader, epoch, optimizer): 83 | model.train() 84 | total_iters = len(train_loader) 85 | for data in train_loader: 86 | data = data.to(device) 87 | optimizer.zero_grad() 88 | out = model(data.x, data.edge_index, data.batch) 89 | loss = F.nll_loss(out, data.y) 90 | loss.backward() 91 | optimizer.step() 92 | 93 | 94 | def evalute(model, loader): 95 | model.eval() 96 | total_correct=0. 97 | total_loss=0. 98 | total = len(loader.dataset) 99 | for data in loader: 100 | data = data.to(device) 101 | out = model(data.x, data.edge_index, data.batch) 102 | pred = out.max(dim=1)[1] 103 | total_correct += pred.eq(data.y).sum().item() 104 | loss = F.nll_loss(out, data.y) 105 | # crossentropy(reduce=True) for default 106 | total_loss += loss.item() * len(data.y) 107 | loss, acc = 1.0*total_loss / total, 1.0*total_correct / total 108 | model.train() 109 | return loss, acc 110 | 111 | 112 | def main(): 113 | for fold in range(10): 114 | print("="*50) 115 | print("[INFO] Fold:", fold) 116 | print("="*50) 117 | train_idx, val_idx = utils.sep_data(dataset, seed=args.seed, fold_idx=fold) 118 | train_idx = torch.LongTensor(train_idx) 119 | val_idx = torch.LongTensor(val_idx) 120 | train_dataset = dataset[train_idx] 121 | val_dataset = dataset[val_idx] 122 | print("[INFO] Train:", len(train_idx), "| Val:", len(val_idx)) 123 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 124 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) 125 | 126 | model = Classifier( 127 | node_feat=num_feat, 128 | nn_hid=args.nn_hid, 129 | nn_out=args.nn_out, 130 | k=sk, 131 | hidden=args.hidden, 132 | num_class=num_classes 133 | ).to(device) 134 | 135 | print(model) 136 | 137 | optimizer = torch.optim.SGD([ 138 | {'params': model.feature_extractor.parameters()}, 139 | {'params': model.mlp.parameters(), 'lr':0.001}, 140 | {'params': model.readout.parameters(), 'lr': 0.05} 141 | ], lr=args.lr, momentum=0.9) 142 | scheduler = lr_scheduler.MultiStepLR(optimizer, [50,100,150], gamma=0.1) 143 | best = 0. 144 | ######################### training ##################### 145 | for epoch in range(args.epochs): 146 | scheduler.step() 147 | train(model, train_loader, epoch, optimizer) 148 | train_loss, train_acc = evalute(model, train_loader) 149 | val_loss, val_acc = evalute(model, val_loader) 150 | writer.add_scalars('%s/loss' % args.dataset, {'train_loss':train_loss, 'val_loss': val_loss}, epoch) 151 | writer.add_scalars('%s/acc' % args.dataset, {'train_acc': train_acc, 'val_acc': val_acc}, epoch) 152 | print("Epoch:{:05d} | TrainLoss:{:.4f} | TrainAcc:{:.4f} |" 153 | " ValLoss:{:.4f} ValAcc:{:.4f}". 154 | format(epoch, train_loss, train_acc, 155 | val_loss, val_acc)) 156 | best = best if best > val_acc else val_acc 157 | print("[INFO] Best Acc:", best) 158 | with open('results/ACC_0904_%s.txt' % args.dataset, 'a+') as f: 159 | f.write(str(best) + '\n') 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_sparse import spspmm 5 | from torch_geometric.nn import global_sort_pool 6 | from torch_geometric.nn import TopKPooling, GCNConv, GINConv, GraphConv 7 | from torch_geometric.utils import (add_self_loops, sort_edge_index, 8 | remove_self_loops) 9 | 10 | import utils 11 | 12 | 13 | class Conv1dReadout(nn.Module): 14 | def __init__( 15 | self, 16 | input_dim, 17 | k, 18 | conv1d_channels=[16, 32], 19 | conv1d_kws=[0, 5] 20 | ): 21 | super(Conv1dReadout, self).__init__() 22 | self.k = k 23 | 24 | conv1d_kws[0] = input_dim 25 | self.input_dim = input_dim 26 | 27 | self.conv1d_p1 = nn.Conv1d( 28 | 1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0]) 29 | self.pool = nn.MaxPool1d(2, 2) 30 | self.conv1d_p2 = nn.Conv1d( 31 | conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1) 32 | 33 | dense_dim = int((k - 2) / 2 + 1) 34 | self.dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1] 35 | 36 | def forward(self, x, batch): 37 | batch_size = batch.max() + 1 38 | re = global_sort_pool(x, batch, self.k) 39 | # re shape bs * (k * out_dim) 2910 40 | re = re.unsqueeze(2).transpose(2, 1) 41 | conv1 = F.relu(self.conv1d_p1(re)) 42 | conv1 = self.pool(conv1) 43 | conv2 = F.relu(self.conv1d_p2(conv1)) 44 | to_dense = conv2.view(batch_size, -1) 45 | out = F.relu(to_dense) 46 | return out 47 | 48 | 49 | class GraphUNet(nn.Module): 50 | def __init__(self, in_channels, hid_channels, out_channels, ks, 51 | sum_res=True, dropout=0.5 52 | ): 53 | super(GraphUNet, self).__init__() 54 | self.in_channels = in_channels 55 | self.hid_channels = hid_channels 56 | self.out_channels = out_channels 57 | self.pool_ratios = ks 58 | self.depth = len(ks) 59 | self.act = F.relu 60 | self.sum_res=sum_res 61 | self.down_convs = torch.nn.ModuleList() 62 | self.pools = torch.nn.ModuleList() 63 | self.bn_layers = torch.nn.ModuleList() 64 | self.down_convs.append(GCNConv(in_channels, self.hid_channels)) 65 | self.bn_layers.append(nn.BatchNorm1d(self.hid_channels)) 66 | 67 | for i in range(len(ks)): 68 | self.pools.append(TopKPooling(self.hid_channels, self.pool_ratios[i])) 69 | self.down_convs.append(GCNConv(self.hid_channels, self.hid_channels)) 70 | self.bn_layers.append(nn.BatchNorm1d(self.hid_channels)) 71 | 72 | in_channels = self.hid_channels if sum_res else 2 * self.hid_channels 73 | self.up_convs = torch.nn.ModuleList() 74 | for i in range(len(ks)): 75 | self.up_convs.append(GCNConv(in_channels, self.hid_channels)) 76 | self.bn_layers.append(nn.BatchNorm1d(self.hid_channels)) 77 | self.up_convs.append(GCNConv(2 * self.hid_channels, out_channels)) 78 | if dropout: 79 | self.drop = torch.nn.Dropout(p=0.3) 80 | else: 81 | self.drop = torch.nn.Dropout(p=0.) 82 | self.reset_parameters() 83 | 84 | def reset_parameters(self): 85 | for conv in self.down_convs: 86 | conv.reset_parameters() 87 | for pool in self.pools: 88 | pool.reset_parameters() 89 | for conv in self.up_convs: 90 | conv.reset_parameters() 91 | 92 | def forward(self, x, edge_index, batch=None): 93 | #print("Step 1:", x) 94 | #print(x) 95 | """""" 96 | if batch is None: 97 | batch = edge_index.new_zeros(x.size(0)) 98 | edge_weight = x.new_ones(edge_index.size(1)) 99 | x = self.down_convs[0](x, edge_index, edge_weight) 100 | #x = self.bn_layers[0](x) 101 | x = self.act(x) 102 | x = self.drop(x) 103 | #print("Step 2:", x) 104 | #print("init x shape", x.shape) 105 | org_X = x 106 | 107 | xs = [x] 108 | edge_indices = [edge_index] 109 | edge_weights = [edge_weight] 110 | perms = [] 111 | 112 | for i in range(1, self.depth + 1): 113 | #edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0)) 114 | #print(i, x) 115 | #print("**************") 116 | 117 | x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](x, edge_index, edge_weight, batch) 118 | x = self.down_convs[i](x, edge_index, edge_weight) 119 | #x = self.bn_layers[i](x) 120 | x = self.act(x) 121 | 122 | if i < self.depth: 123 | xs += [x] 124 | edge_indices += [edge_index] 125 | edge_weights += [edge_weight] 126 | perms += [perm] 127 | #print("Step 3:", x) 128 | for i in range(self.depth): 129 | j = self.depth - 1 - i 130 | 131 | res = xs[j] 132 | edge_index = edge_indices[j] 133 | edge_weight = edge_weights[j] 134 | perm = perms[j] 135 | 136 | up = torch.zeros_like(res) 137 | up[perm] = x 138 | x = res + up if self.sum_res else torch.cat((res, up), dim=-1) 139 | 140 | x = self.up_convs[i](x, edge_index, edge_weight) 141 | #x = self.bn_layers[i + self.depth + 1](x) 142 | x = self.act(x) if i < self.depth - 1 else x 143 | x = self.drop(x) if i < self.depth - 1 else x 144 | x = torch.cat([x, org_X], 1) 145 | #print("Step 4:", x) 146 | x = self.up_convs[-1](x, edge_index, edge_weight) 147 | return x 148 | 149 | def augment_adj(self, edge_index, edge_weight, num_nodes): 150 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 151 | num_nodes=num_nodes) 152 | edge_index, edge_weight = sort_edge_index(edge_index, edge_weight, 153 | num_nodes) 154 | edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index, 155 | edge_weight, num_nodes, num_nodes, 156 | num_nodes) 157 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 158 | return edge_index, edge_weight 159 | 160 | 161 | class MLPClassifier(nn.Module): 162 | def __init__(self, input_size, hidden_size, num_class, with_dropout=False): 163 | super(MLPClassifier, self).__init__() 164 | 165 | self.h1_weights = nn.Linear(input_size, hidden_size) 166 | self.h2_weights = nn.Linear(hidden_size, num_class) 167 | self.with_dropout = with_dropout 168 | utils.weights_init(self) 169 | 170 | def forward(self, x): 171 | h1 = self.h1_weights(x) 172 | h1 = F.relu(h1) 173 | if self.with_dropout: 174 | h1 = F.dropout(h1, training=self.training) 175 | 176 | logits = self.h2_weights(h1) 177 | logits = F.log_softmax(logits, dim=1) 178 | return logits 179 | 180 | 181 | class Classifier(nn.Module): 182 | def __init__(self, node_feat, nn_hid, nn_out, k, hidden, num_class): 183 | super(Classifier, self).__init__() 184 | self.feature_extractor = GraphUNet( 185 | in_channels=node_feat, 186 | hid_channels=nn_hid, 187 | out_channels=nn_out, 188 | ks=[0.9, 0.7, 0.6, 0.5], 189 | sum_res=True, 190 | dropout=0.3 191 | ) 192 | self.readout = Conv1dReadout(input_dim=nn_out, k=k) 193 | self.mlp = MLPClassifier( 194 | input_size=self.readout.dense_dim, hidden_size=hidden, 195 | num_class=num_class, with_dropout=True 196 | ) 197 | 198 | def forward(self, x, edge_index, batch): 199 | gfeat = self.feature_extractor(x, edge_index, batch) 200 | vfeat = self.readout(gfeat, batch) 201 | return self.mlp(vfeat) 202 | 203 | 204 | class GraphUNets(nn.Module): 205 | def __init__(self, in_channels, hid_channels, out_channels, ks, 206 | sum_res=True, dropout=0.5 207 | ): 208 | super(GraphUNets, self).__init__() 209 | self.in_channels = in_channels 210 | self.hid_channels = hid_channels 211 | self.out_channels = out_channels 212 | self.pool_ratios = ks 213 | self.depth = len(ks) 214 | self.act = F.relu 215 | self.sum_res=sum_res 216 | self.down_convs = torch.nn.ModuleList() 217 | self.pools = torch.nn.ModuleList() 218 | self.bn_layers = torch.nn.ModuleList() 219 | self.down_convs.append(GraphConv(in_channels, self.hid_channels)) 220 | self.bn_layers.append(nn.BatchNorm1d(self.hid_channels)) 221 | 222 | for i in range(len(ks)): 223 | self.pools.append(TopKPooling(self.hid_channels, self.pool_ratios[i])) 224 | self.down_convs.append(GraphConv(self.hid_channels, self.hid_channels)) 225 | self.bn_layers.append(nn.BatchNorm1d(self.hid_channels)) 226 | 227 | in_channels = self.hid_channels if sum_res else 2 * self.hid_channels 228 | self.up_convs = torch.nn.ModuleList() 229 | for i in range(len(ks)): 230 | self.up_convs.append(GraphConv(in_channels, self.hid_channels)) 231 | self.bn_layers.append(nn.BatchNorm1d(self.hid_channels)) 232 | self.up_convs.append(GraphConv(2 * self.hid_channels, out_channels)) 233 | if dropout: 234 | self.drop = torch.nn.Dropout(p=0.5) 235 | else: 236 | self.drop = torch.nn.Dropout(p=0.) 237 | self.reset_parameters() 238 | 239 | def reset_parameters(self): 240 | for conv in self.down_convs: 241 | conv.reset_parameters() 242 | for pool in self.pools: 243 | pool.reset_parameters() 244 | for conv in self.up_convs: 245 | conv.reset_parameters() 246 | 247 | def forward(self, x, edge_index, batch=None): 248 | #print("Step 1:", x) 249 | #print(x) 250 | """""" 251 | if batch is None: 252 | batch = edge_index.new_zeros(x.size(0)) 253 | edge_weight = x.new_ones(edge_index.size(1)) 254 | #edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0)) 255 | #print("edge_weight:", edge_weight) 256 | #print("edge_index:", edge_index) 257 | x = self.down_convs[0](x, edge_index, edge_weight) 258 | #x = self.bn_layers[0](x) 259 | x = self.act(x) 260 | x = self.drop(x) 261 | #print("Step 2:", x) 262 | #print("init x shape", x.shape) 263 | org_X = x 264 | 265 | xs = [x] 266 | edge_indices = [edge_index] 267 | edge_weights = [edge_weight] 268 | perms = [] 269 | 270 | for i in range(1, self.depth + 1): 271 | #edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0)) 272 | #print(i, x) 273 | #print("**************") 274 | 275 | x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](x, edge_index, edge_weight, batch) 276 | #print("[ITEST] pool:", i, x) 277 | #print(self.down_convs[i]) 278 | #print("edge_index:", edge_index) 279 | #print("edge_weight:", edge_weight) 280 | x = self.down_convs[i](x, edge_index, edge_weight) 281 | #x = self.bn_layers[i](x) 282 | x = self.act(x) 283 | #print("[INFO] i:", i, x) 284 | #print("x shape:", x.shape) 285 | #x = self.drop(x) 286 | 287 | if i < self.depth: 288 | xs += [x] 289 | edge_indices += [edge_index] 290 | edge_weights += [edge_weight] 291 | perms += [perm] 292 | #print("Step 3:", x) 293 | for i in range(self.depth): 294 | j = self.depth - 1 - i 295 | 296 | res = xs[j] 297 | edge_index = edge_indices[j] 298 | edge_weight = edge_weights[j] 299 | perm = perms[j] 300 | 301 | up = torch.zeros_like(res) 302 | up[perm] = x 303 | x = res + up if self.sum_res else torch.cat((res, up), dim=-1) 304 | 305 | x = self.up_convs[i](x, edge_index, edge_weight) 306 | #x = self.bn_layers[i + self.depth + 1](x) 307 | x = self.act(x) if i < self.depth - 1 else x 308 | x = self.drop(x) if i < self.depth - 1 else x 309 | x = torch.cat([x, org_X], 1) 310 | #print("Step 4:", x) 311 | x = self.up_convs[-1](x, edge_index, edge_weight) 312 | #print("[INFO] Test") 313 | #print(x) 314 | #exit(0) 315 | return x 316 | 317 | def augment_adj(self, edge_index, edge_weight, num_nodes): 318 | edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 319 | num_nodes=num_nodes) 320 | edge_index, edge_weight = sort_edge_index(edge_index, edge_weight, 321 | num_nodes) 322 | edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index, 323 | edge_weight, num_nodes, num_nodes, 324 | num_nodes) 325 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 326 | return edge_index, edge_weight 327 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.model_selection import StratifiedKFold 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | from torch_geometric.utils import degree 6 | 7 | 8 | def sep_data(dataset, seed=0, fold_idx=0): 9 | skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed) 10 | labels = [data['y'] for data in dataset] 11 | idx_list = [] 12 | for idx in skf.split(np.zeros(len(labels)), labels): 13 | idx_list.append(idx) 14 | return idx_list[fold_idx] 15 | 16 | 17 | def glorot_uniform(t): 18 | if len(t.size()) == 2: 19 | fan_in, fan_out = t.size() 20 | elif len(t.size()) == 3: 21 | # out_ch, in_ch, kernel for Conv 1 22 | fan_in = t.size()[1] * t.size()[2] 23 | fan_out = t.size()[0] * t.size()[2] 24 | else: 25 | fan_in = np.prod(t.size()) 26 | fan_out = np.prod(t.size()) 27 | 28 | limit = np.sqrt(6.0 / (fan_in + fan_out)) 29 | t.uniform_(-limit, limit) 30 | 31 | 32 | def _param_init(m): 33 | if isinstance(m, Parameter): 34 | glorot_uniform(m.data) 35 | elif isinstance(m, nn.Linear): 36 | m.bias.data.zero_() 37 | glorot_uniform(m.weight.data) 38 | 39 | def weights_init(m): 40 | for p in m.modules(): 41 | if isinstance(p, nn.ParameterList): 42 | for pp in p: 43 | _param_init(pp) 44 | else: 45 | _param_init(p) 46 | 47 | for name, p in m.named_parameters(): 48 | if not '.' in name: # top-level parameters 49 | _param_init(p) 50 | 51 | 52 | class Indegree(object): 53 | r"""Adds the globally normalized node degree to the node features. 54 | 55 | Args: 56 | cat (bool, optional): If set to :obj:`False`, all existing node 57 | features will be replaced. (default: :obj:`True`) 58 | """ 59 | 60 | def __init__(self, norm=False, max_value=None): 61 | self.norm = norm 62 | self.max = max_value 63 | 64 | def __call__(self, data): 65 | col, x = data.edge_index[1], data.x 66 | deg = degree(col, data.num_nodes) 67 | #print(deg) 68 | if self.norm: 69 | deg = deg / (deg.max() if self.max is None else self.max) 70 | deg = deg.view(-1, 1) 71 | if x is None: 72 | data.x = deg 73 | return data 74 | 75 | def __repr__(self): 76 | return '{}(norm={}, max_value={})'.format(self.__class__.__name__, self.norm, self.max) 77 | --------------------------------------------------------------------------------