├── Fig ├── fig3.png └── fig5.png ├── .gitattributes ├── LICENSE ├── graph_conv.py ├── README.md └── mixup.py /Fig/fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanoracai/MixupForGraph/HEAD/Fig/fig3.png -------------------------------------------------------------------------------- /Fig/fig5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanoracai/MixupForGraph/HEAD/Fig/fig5.png -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 vanoracai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /graph_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_geometric.nn.conv import MessagePassing 4 | 5 | from torch_geometric.nn.inits import uniform 6 | import pdb 7 | 8 | 9 | class GraphConv(MessagePassing): 10 | def __init__(self, in_channels, out_channels, aggr='mean', bias=True, 11 | **kwargs): 12 | super(GraphConv, self).__init__(aggr=aggr, **kwargs) 13 | 14 | self.in_channels = in_channels 15 | self.out_channels = out_channels 16 | 17 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 18 | self.lin = torch.nn.Linear(in_channels, out_channels, bias=bias) 19 | 20 | self.reset_parameters() 21 | 22 | def reset_parameters(self): 23 | uniform(self.in_channels, self.weight) 24 | self.lin.reset_parameters() 25 | 26 | def forward(self, x, edge_index, x_cen): 27 | h = torch.matmul(x, self.weight) 28 | aggr_out = self.propagate(edge_index, size=None, h=h, edge_weight=None) 29 | return aggr_out + self.lin(x_cen) 30 | 31 | def message(self, h_j): 32 | return h_j 33 | 34 | def __repr__(self): 35 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 36 | self.out_channels) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mixup for Node and Graph Classification [[draft](https://wangywust.github.io/Paper/2021mix.pdf)] 2 | 3 | Mixup for Node and Graph Classification| 4 | :-------------------------:| 5 | ![](Fig/fig5.png) | 6 | 7 | 8 | 9 | 10 | ## Introduction 11 | 12 | This is an implementation for the paper titled 'Mixup for Node and Graph Classification' based on torch_geometric. Mixup is an advanced data augmentation method for training neural network based image classifiers, which interpolates both features and labels of a pair of images to produce synthetic samples. However, devising the Mixup methods for graph learning is challenging due to the irregularity and connectivity of graph data. In this work, we propose the Mixup methods for two fundamental tasks in graph learning: node and graph classification. To interpolate the irregular graph topology, we propose the two-branch graph convolution to mix the receptive field subgraphs for the paired nodes. Mixup on different node pairs can interfere with the mixed features for each other due to the connectivity between nodes. To block this interference, we propose the two-stage Mixup framework, which uses each node's neighbors' representations before Mixup for graph convolutions. For graph classification, we interpolate complex and diverse graphs in the semantic space. Qualitatively, our Mixup methods enable GNNs to learn more discriminative features and reduce over-fitting. Quantitative results show that our method yields consistent gains in terms of test accuracy and F1-micro scores on standard datasets, for both node and graph classification. Overall, our method effectively regularizes popular graph neural networks for better generalization without increasing their time complexity. 13 | 14 | #### Paper link: [Mixup for Node and Graph Classification](https://wangywust.github.io/Paper/2021mix.pdf) 15 | 16 | Two-stage Mixup| 17 | :-------------------------:| 18 | ![](Fig/fig3.png) | 19 | 20 | ## Running the experiments 21 | 22 | ### Requirements 23 | 24 | Dependencies (with python == 3.8.2): 25 | 26 | ```{bash} 27 | torch==1.8.1+cu102 28 | torch_geometric==1.7.0 29 | numpy==1.18.1 30 | ``` 31 | 32 | ### Data Preparation 33 | Just create a folder named 'data' on the same level with this repo's folder. Then the left things will be done by the code automatically. 34 | 35 | ### Model Training 36 | ```{bash} 37 | # node classification with mixup 38 | python mixup.py --mixup 39 | 40 | # node classification without mixup 41 | python mixup.py 42 | ``` 43 | 44 | ## License 45 | MIT -------------------------------------------------------------------------------- /mixup.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | 5 | import argparse 6 | import torch 7 | import torch.nn.functional as F 8 | from torch_geometric.datasets import Planetoid, Coauthor 9 | from torch_geometric.data import Data 10 | from graph_conv import GraphConv 11 | from torch_geometric.utils import degree 12 | from torch_sparse import SparseTensor 13 | import torch_geometric.transforms as T 14 | 15 | import pdb 16 | import numpy as np 17 | import random 18 | import copy 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser('Mixup') 22 | parser.add_argument('--mixup', action='store_true', help='Whether to have Mixup') 23 | args = parser.parse_args() 24 | 25 | def idNode(data, id_new_value_old): 26 | data = copy.deepcopy(data) 27 | data.x = None 28 | data.y[data.val_id] = -1 29 | data.y[data.test_id] = -1 30 | data.y = data.y[id_new_value_old] 31 | 32 | data.train_id = None 33 | data.test_id = None 34 | data.val_id = None 35 | 36 | id_old_value_new = torch.zeros(id_new_value_old.shape[0], dtype = torch.long) 37 | id_old_value_new[id_new_value_old] = torch.arange(0, id_new_value_old.shape[0], dtype = torch.long) 38 | row = data.edge_index[0] 39 | col = data.edge_index[1] 40 | row = id_old_value_new[row] 41 | col = id_old_value_new[col] 42 | data.edge_index = torch.stack([row, col], dim=0) 43 | 44 | return data 45 | 46 | def shuffleData(data): 47 | data = copy.deepcopy(data) 48 | id_new_value_old = np.arange(data.num_nodes) 49 | train_id_shuffle = copy.deepcopy(data.train_id) 50 | np.random.shuffle(train_id_shuffle) 51 | id_new_value_old[data.train_id] = train_id_shuffle 52 | data = idNode(data, id_new_value_old) 53 | 54 | return data, id_new_value_old 55 | 56 | 57 | class Net(torch.nn.Module): 58 | def __init__(self, hidden_channels, in_channel, out_channel): 59 | super(Net, self).__init__() 60 | self.conv1 = GraphConv(in_channel, hidden_channels) 61 | self.conv2 = GraphConv(hidden_channels, hidden_channels) 62 | self.conv3 = GraphConv(hidden_channels, hidden_channels) 63 | self.lin = torch.nn.Linear(1 * hidden_channels, out_channel) 64 | 65 | def forward(self, x0, edge_index, edge_index_b, lam, id_new_value_old): 66 | 67 | x1 = self.conv1(x0, edge_index, x0) 68 | x1 = F.relu(x1) 69 | x1 = F.dropout(x1, p=0.4, training=self.training) 70 | 71 | x2 = self.conv2(x1, edge_index, x1) 72 | x2 = F.relu(x2) 73 | x2 = F.dropout(x2, p=0.4, training=self.training) 74 | 75 | x0_b = x0[id_new_value_old] 76 | x1_b = x1[id_new_value_old] 77 | x2_b = x2[id_new_value_old] 78 | 79 | x0_mix = x0 * lam + x0_b * (1 - lam) 80 | 81 | new_x1 = self.conv1(x0, edge_index, x0_mix) 82 | new_x1_b = self.conv1(x0_b, edge_index_b, x0_mix) 83 | new_x1 = F.relu(new_x1) 84 | new_x1_b = F.relu(new_x1_b) 85 | 86 | x1_mix = new_x1 * lam + new_x1_b * (1 - lam) 87 | x1_mix = F.dropout(x1_mix, p=0.4, training=self.training) 88 | 89 | new_x2 = self.conv2(x1, edge_index, x1_mix) 90 | new_x2_b = self.conv2(x1_b, edge_index_b, x1_mix) 91 | new_x2 = F.relu(new_x2) 92 | new_x2_b = F.relu(new_x2_b) 93 | 94 | x2_mix = new_x2 * lam + new_x2_b * (1 - lam) 95 | x2_mix = F.dropout(x2_mix, p=0.4, training=self.training) 96 | 97 | new_x3 = self.conv3(x2, edge_index, x2_mix) 98 | new_x3_b = self.conv3(x2_b, edge_index_b, x2_mix) 99 | new_x3 = F.relu(new_x3) 100 | new_x3_b = F.relu(new_x3_b) 101 | 102 | x3_mix = new_x3 * lam + new_x3_b * (1 - lam) 103 | x3_mix = F.dropout(x3_mix, p=0.4, training=self.training) 104 | 105 | x = x3_mix 106 | x = self.lin(x) 107 | return x.log_softmax(dim=-1) 108 | 109 | 110 | # set random seed 111 | SEED = 0 112 | torch.manual_seed(SEED) 113 | if torch.cuda.is_available(): 114 | torch.cuda.manual_seed(SEED) 115 | np.random.seed(SEED) # Numpy module. 116 | random.seed(SEED) # Python random module. 117 | 118 | 119 | # load data 120 | dataset = 'Pubmed' 121 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) 122 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 123 | data = dataset[0] 124 | 125 | 126 | # split data 127 | node_id = np.arange(data.num_nodes) 128 | np.random.shuffle(node_id) 129 | data.train_id = node_id[:int(data.num_nodes * 0.6)] 130 | data.val_id = node_id[int(data.num_nodes * 0.6):int(data.num_nodes * 0.8)] 131 | data.test_id = node_id[int(data.num_nodes * 0.8):] 132 | 133 | 134 | # define model 135 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 136 | model = Net(hidden_channels=256, in_channel = dataset.num_node_features, out_channel = dataset.num_classes).to(device) 137 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 138 | 139 | 140 | # func train one epoch 141 | def train(data): 142 | model.train() 143 | 144 | if args.mixup: 145 | lam = np.random.beta(4.0, 4.0) 146 | else: 147 | lam = 1.0 148 | 149 | data_b, id_new_value_old = shuffleData(data) 150 | data = data.to(device) 151 | data_b = data_b.to(device) 152 | 153 | optimizer.zero_grad() 154 | 155 | out = model(data.x, data.edge_index, data_b.edge_index, lam, id_new_value_old) 156 | loss = F.nll_loss(out[data.train_id], data.y[data.train_id]) * lam + \ 157 | F.nll_loss(out[data.train_id], data_b.y[data.train_id]) * (1 - lam) 158 | 159 | loss.backward() 160 | optimizer.step() 161 | 162 | return loss.item() 163 | 164 | 165 | # test 166 | @torch.no_grad() 167 | def test(data): 168 | model.eval() 169 | 170 | out = model(data.x.to(device), data.edge_index.to(device), data.edge_index.to(device), 1, np.arange(data.num_nodes)) 171 | pred = out.argmax(dim=-1) 172 | correct = pred.eq(data.y.to(device)) 173 | 174 | accs = [] 175 | for _, id_ in data('train_id', 'val_id', 'test_id'): 176 | accs.append(correct[id_].sum().item() / id_.shape[0]) 177 | return accs 178 | 179 | 180 | best_acc = 0 181 | accord_epoch = 0 182 | accord_train_acc = 0 183 | accord_train_loss = 0 184 | for epoch in range(1, 300): 185 | loss = train(data) 186 | accs = test(data) 187 | print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train Acc: {accs[0]:.4f}, Test Acc: {accs[2]:.4f}') --------------------------------------------------------------------------------