├── README.md ├── data_split.py ├── few_shot_data.zip ├── loss.py ├── model.py ├── subgcon.py ├── subgraph.py ├── train.py └── utils_mp.py /README.md: -------------------------------------------------------------------------------- 1 | ## COSMIC 2 | Here we provide the code of COSMIC: "_Contrastive Meta-Learning for Few-shot Node Classification_". Our work is accepted by SIGKDD 2023. 3 | 4 | 5 | 6 | ## Instruction 7 | 8 | First extract the datasets: 9 | 10 | ``` 11 | unzip few_shot_data.zip 12 | ``` 13 | 14 | 15 | Then run the python file: 16 | 17 | ``` 18 | python train.py --dataset CoraFull 19 | ``` 20 | 21 | ## Citation 22 | Welcome to cite our work!
23 | 24 | > @inproceedings{wang2023contrastive, 25 | title={Contrastive Meta-Learning for Few-shot Node Classification}, 26 | author={Wang, Song and Tan, Zhen and Liu, Huan and Li, Jundong}, 27 | booktitle={SIGKDD}, 28 | year={2023} 29 | } 30 | 31 | -------------------------------------------------------------------------------- /data_split.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ogb.nodeproppred import PygNodePropPredDataset 3 | from torch_geometric.datasets import CoraFull, Reddit2, Coauthor, Planetoid, Amazon, DBLP 4 | import random 5 | import numpy as np 6 | import scipy.io as sio 7 | from sklearn import preprocessing 8 | # class_split = {"train": 0.6,"test": 0.4} 9 | 10 | class_split = { 11 | "CoraFull": {"train": 40, 'dev': 15, 'test': 15}, # Sufficient number of base classes 12 | "ogbn-arxiv": {"train": 20, 'dev': 10, 'test': 10}, 13 | "Coauthor-CS": {"train": 5, 'dev': 5, 'test': 5}, 14 | "Amazon-Computer": {"train": 4, 'dev': 3, 'test': 3}, 15 | "Cora": {"train": 3, 'dev': 2, 'test': 2}, 16 | "CiteSeer": {"train": 2, 'dev': 2, 'test': 2}, 17 | "Reddit": {"train": 21, 'dev': 10, 'test': 10}, 18 | 'dblp':{"train": 77, 'dev': 30, 'test': 30}, 19 | } 20 | 21 | 22 | class dblp_data(): 23 | def __init__(self): 24 | self.x=None 25 | self.edge_index=None 26 | self.num_nodes=None 27 | self.y=None 28 | self.num_edges=None 29 | self.num_features=7202 30 | class dblp_dataset(): 31 | def __init__(self,data,num_classes): 32 | self.data=data 33 | self.num_classes=num_classes 34 | 35 | def load_DBLP(root=None, dataset_source='dblp'): 36 | dataset=dblp_data() 37 | n1s = [] 38 | n2s = [] 39 | for line in open("./few_shot_data/{}_network".format(dataset_source)): 40 | n1, n2 = line.strip().split('\t') 41 | if int(n1)>int(n2): 42 | n1s.append(int(n1)) 43 | n2s.append(int(n2)) 44 | 45 | 46 | 47 | num_nodes = max(max(n1s), max(n2s)) + 1 48 | print('nodes num', num_nodes) 49 | 50 | data_train = sio.loadmat("./few_shot_data/{}_train.mat".format(dataset_source)) 51 | data_test = sio.loadmat("./few_shot_data/{}_test.mat".format(dataset_source)) 52 | 53 | labels = np.zeros((num_nodes, 1)) 54 | labels[data_train['Index']] = data_train["Label"] 55 | labels[data_test['Index']] = data_test["Label"] 56 | 57 | features = np.zeros((num_nodes, data_train["Attributes"].shape[1])) 58 | features[data_train['Index']] = data_train["Attributes"].toarray() 59 | features[data_test['Index']] = data_test["Attributes"].toarray() 60 | 61 | 62 | lb = preprocessing.LabelBinarizer() 63 | labels = lb.fit_transform(labels) 64 | 65 | features = torch.FloatTensor(features) 66 | labels = torch.LongTensor(np.where(labels)[1]) 67 | 68 | dataset.edge_index=torch.tensor([n1s,n2s]) 69 | dataset.y=labels 70 | dataset.x=features 71 | dataset.num_nodes=num_nodes 72 | dataset.num_edges=dataset.edge_index.shape[1] 73 | 74 | return dblp_dataset(dataset, num_classes=80+27+30) 75 | 76 | def split(dataset_name): 77 | if dataset_name == 'Cora': 78 | dataset = Planetoid(root='~/dataset/' + dataset_name, name="Cora") 79 | num_nodes = dataset.data.num_nodes 80 | elif dataset_name == 'CiteSeer': 81 | dataset = Planetoid(root='~/dataset/' + dataset_name, name="CiteSeer") 82 | num_nodes = dataset.data.num_nodes 83 | elif dataset_name == 'Amazon-Computer': 84 | dataset = Amazon(root='~/dataset/' + dataset_name, name="Computers") 85 | num_nodes = dataset.data.num_nodes 86 | elif dataset_name == 'Coauthor-CS': 87 | dataset = Coauthor(root='~/dataset/' + dataset_name, name="CS") 88 | num_nodes = dataset.data.num_nodes 89 | elif dataset_name == 'CoraFull': 90 | dataset = CoraFull(root='./dataset/' + dataset_name) 91 | num_nodes = dataset.data.num_nodes 92 | elif dataset_name == 'Reddit': 93 | dataset = Reddit2(root='./dataset/' + dataset_name) 94 | num_nodes = dataset.data.num_nodes 95 | elif dataset_name == 'ogbn-arxiv': 96 | dataset = PygNodePropPredDataset(name = dataset_name, root='./dataset/' + dataset_name) 97 | num_nodes = dataset.data.num_nodes 98 | elif dataset_name == 'dblp': 99 | dataset = load_DBLP(root='./few_shot_data/') 100 | num_nodes=dataset.data.num_nodes 101 | else: 102 | print("Dataset not support!") 103 | exit(0) 104 | data = dataset.data 105 | class_list = [i for i in range(dataset.num_classes)] 106 | print("********" * 10) 107 | 108 | 109 | 110 | 111 | 112 | train_num = class_split[dataset_name]["train"] 113 | dev_num = class_split[dataset_name]["dev"] 114 | test_num = class_split[dataset_name]["test"] 115 | 116 | random.shuffle(class_list) 117 | train_class = class_list[: train_num] 118 | dev_class = class_list[train_num : train_num + dev_num] 119 | test_class = class_list[train_num + dev_num :] 120 | print("train_num: {}; dev_num: {}; test_num: {}".format(train_num, dev_num, test_num)) 121 | 122 | id_by_class = {} 123 | for i in class_list: 124 | id_by_class[i] = [] 125 | for id, cla in enumerate(torch.squeeze(data.y).tolist()): 126 | id_by_class[cla].append(id) 127 | 128 | 129 | train_idx = [] 130 | for cla in train_class: 131 | train_idx.extend(id_by_class[cla]) 132 | 133 | degree_inv = num_nodes / (dataset.data.num_edges * 2) 134 | 135 | return data, np.array(train_idx), id_by_class, train_class, dev_class, test_class, degree_inv 136 | 137 | 138 | def test_task_generator(id_by_class, class_list, n_way, k_shot, m_query): 139 | 140 | # sample class indices 141 | class_selected = random.sample(class_list, n_way) 142 | id_support = [] 143 | id_query = [] 144 | for cla in class_selected: 145 | temp = random.sample(id_by_class[cla], k_shot + m_query) 146 | id_support.extend(temp[:k_shot]) 147 | id_query.extend(temp[k_shot:]) 148 | 149 | return np.array(id_support), np.array(id_query), class_selected 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /few_shot_data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongW-SW/COSMIC/de917d82e0cf25603e52060e41836a3038bd739c/few_shot_data.zip -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from math import sqrt 5 | 6 | 7 | class SupConLoss(nn.Module): 8 | def __init__(self, temperature=0.07, contrast_mode='all', 9 | base_temperature=0.07): 10 | super(SupConLoss, self).__init__() 11 | self.temperature = temperature 12 | self.contrast_mode = contrast_mode 13 | self.base_temperature = base_temperature 14 | 15 | def forward(self, features, labels=None, mask=None): 16 | device = (torch.device('cuda') 17 | if features.is_cuda 18 | else torch.device('cpu')) 19 | 20 | if len(features.shape) < 3: 21 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 22 | 'at least 3 dimensions are required') 23 | if len(features.shape) > 3: 24 | features = features.view(features.shape[0], features.shape[1], -1) 25 | 26 | batch_size = features.shape[0] 27 | 28 | 29 | 30 | 31 | if labels is not None and mask is not None: 32 | raise ValueError('Cannot define both `labels` and `mask`') 33 | elif labels is None and mask is None: 34 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 35 | elif labels is not None: 36 | #----------------> 37 | labels = labels.contiguous().view(-1, 1) 38 | 39 | if labels.shape[0] != batch_size: 40 | raise ValueError('Num of labels does not match num of features') 41 | mask = torch.eq(labels, labels.T).float().to(device) 42 | 43 | else: 44 | mask = mask.float().to(device) 45 | 46 | contrast_count = features.shape[1] 47 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 48 | if self.contrast_mode == 'one': 49 | anchor_feature = features[:, 0] 50 | anchor_count = 1 51 | elif self.contrast_mode == 'all': 52 | anchor_feature = contrast_feature 53 | anchor_count = contrast_count 54 | else: 55 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 56 | 57 | # compute logits 58 | anchor_dot_contrast = torch.div( 59 | torch.matmul(anchor_feature, contrast_feature.T), 60 | self.temperature) 61 | # for numerical stability 62 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 63 | logits = anchor_dot_contrast - logits_max.detach() 64 | 65 | 66 | # tile mask 67 | # count=2 68 | mask = mask.repeat(anchor_count, contrast_count) 69 | # mask-out self-contrast cases 70 | 71 | 72 | 73 | #set the daigonal elements as 0 74 | logits_mask = torch.scatter( 75 | torch.ones_like(mask), 76 | 1, 77 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 78 | 0 79 | ) 80 | mask = mask * logits_mask 81 | 82 | 83 | #print(mask[25:,25:]) 84 | 85 | #print(contrast_feature.shape) 86 | 87 | # compute log_prob 88 | exp_logits = torch.exp(logits) * logits_mask 89 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 90 | 91 | # compute mean of log-likelihood over positive 92 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 93 | 94 | # loss 95 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 96 | loss = torch.mean(loss.view(anchor_count, batch_size)) 97 | 98 | return loss -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Linear 5 | from torch_geometric.nn import GCNConv, GATConv, GINConv, SAGEConv, SGConv, global_mean_pool, global_max_pool, global_add_pool, SAGPooling 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, in_channels, hidden_channels, encoder_type='GCN'): 9 | super(Encoder, self).__init__() 10 | self.hidden_channels = hidden_channels 11 | if encoder_type=='GCN': 12 | self.conv1 = GCNConv(in_channels, self.hidden_channels) 13 | elif encoder_type=='GAT': 14 | self.conv1 = GATConv(in_channels, self.hidden_channels) 15 | elif encoder_type=='GraphSAGE': 16 | self.conv1 = SAGEConv(in_channels, self.hidden_channels) 17 | elif encoder_type=='SGC': 18 | self.conv1 = SGConv(in_channels, self.hidden_channels) 19 | elif encoder_type=='GIN': 20 | self.mlp = nn.Linear(in_channels, self.hidden_channels) 21 | self.conv1 = GINConv(self.mlp) 22 | 23 | self.prelu1 = nn.PReLU(self.hidden_channels) 24 | 25 | def forward(self, x, edge_index, edge_attr=None): 26 | 27 | if edge_attr!=None: 28 | x1 = self.conv1(x, edge_index, edge_attr) 29 | else: 30 | x1 = self.conv1(x, edge_index) 31 | x1 = self.prelu1(x1) 32 | x1 = F.normalize(x1) 33 | return x1 34 | 35 | 36 | 37 | class Pool(nn.Module): 38 | def __init__(self, in_channels, ratio=1.0): 39 | super(Pool, self).__init__() 40 | self.sag_pool = SAGPooling(in_channels, ratio) 41 | self.lin1 = torch.nn.Linear(in_channels * 2, in_channels) 42 | 43 | def forward(self, x, edge, batch, type='mean_pool'): 44 | if type == 'mean_pool': 45 | return global_mean_pool(x, batch) 46 | elif type == 'max_pool': 47 | return global_max_pool(x, batch) 48 | elif type == 'sum_pool': 49 | return global_add_pool(x, batch) 50 | elif type == 'sag_pool': 51 | x1, _, _, batch, _, _ = self.sag_pool(x, edge, batch=batch) 52 | return global_mean_pool(x1, batch) 53 | 54 | 55 | 56 | 57 | class Scorer(nn.Module): 58 | def __init__(self, hidden_size): 59 | super(Scorer, self).__init__() 60 | self.weight = nn.Parameter(torch.Tensor(hidden_size, hidden_size)) 61 | 62 | def forward(self, input1, input2): 63 | output = torch.sigmoid(torch.sum(input1 * torch.matmul(input2, self.weight), dim = -1)) 64 | return output 65 | 66 | 67 | -------------------------------------------------------------------------------- /subgcon.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import Parameter 5 | import torch.nn.functional as F 6 | from torch_geometric.nn.inits import reset 7 | from sklearn.linear_model import LogisticRegression 8 | from sklearn.cluster import KMeans 9 | from loss import SupConLoss 10 | from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score 11 | 12 | 13 | EPS = 1e-15 14 | 15 | class SugbCon(torch.nn.Module): 16 | 17 | def __init__(self, hidden_channels, encoder, pool, scorer, beta, degree_inv): 18 | super(SugbCon, self).__init__() 19 | self.SupConLoss = SupConLoss() 20 | self.encoder = encoder 21 | self.hidden_channels = hidden_channels 22 | self.pool = pool 23 | self.scorer = scorer 24 | self.marginloss = nn.MarginRankingLoss(0.5) 25 | self.sigmoid = nn.Sigmoid() 26 | # self.prompt = nn.parameter.Parameter(torch.rand(4)) 27 | self.reset_parameters() 28 | 29 | def reset_parameters(self): 30 | reset(self.scorer) 31 | reset(self.encoder) 32 | reset(self.pool) 33 | 34 | def forward(self, x, edge_index, batch=None, index=None, edge_attr=None): 35 | r""" Return node and subgraph representations of each node before and after being shuffled """ 36 | hidden = self.encoder(x, edge_index, edge_attr) 37 | if index is None: 38 | return hidden 39 | 40 | z = hidden[index] 41 | summary = self.pool(hidden, edge_index, batch) 42 | return z, summary 43 | 44 | 45 | def loss(self, hidden1, summary1, labels=None): 46 | features = torch.cat([hidden1.unsqueeze(1), summary1.unsqueeze(1)], dim=1) 47 | 48 | loss = self.SupConLoss(features, labels) 49 | return loss 50 | 51 | 52 | def test(self, train_z, train_y, test_z, test_y, solver='lbfgs', 53 | multi_class='auto', *args, **kwargs): 54 | r"""Evaluates latent space quality via a logistic regression downstream task.""" 55 | clf = LogisticRegression(solver=solver, max_iter=500, multi_class=multi_class, *args, 56 | **kwargs).fit(train_z.detach().cpu().numpy(), 57 | train_y) 58 | test_acc = clf.score(test_z.detach().cpu().numpy(), test_y) 59 | return test_acc 60 | 61 | def clustering_test(self, test_z, test_y, n_way, rs=0): 62 | pred_y = KMeans(n_clusters=n_way, random_state=rs).fit(test_z.detach().cpu().numpy()).labels_ 63 | nmi = normalized_mutual_info_score(test_y, pred_y) 64 | ari = adjusted_rand_score(test_y, pred_y) 65 | 66 | return nmi, ari 67 | -------------------------------------------------------------------------------- /subgraph.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import Parameter 5 | import torch.nn.functional as F 6 | from torch_geometric.nn.inits import reset 7 | from sklearn.linear_model import LogisticRegression 8 | 9 | 10 | EPS = 1e-15 11 | 12 | class SugbCon(torch.nn.Module): 13 | 14 | def __init__(self, hidden_channels, encoder, pool, scorer): 15 | super(SugbCon, self).__init__() 16 | self.encoder = encoder 17 | self.hidden_channels = hidden_channels 18 | self.pool = pool 19 | self.scorer = scorer 20 | self.marginloss = nn.MarginRankingLoss(0.5) 21 | self.sigmoid = nn.Sigmoid() 22 | self.reset_parameters() 23 | 24 | def reset_parameters(self): 25 | reset(self.scorer) 26 | reset(self.encoder) 27 | reset(self.pool) 28 | 29 | def forward(self, x, edge_index, batch=None, index=None): 30 | r""" Return node and subgraph representations of each node before and after being shuffled """ 31 | hidden = self.encoder(x, edge_index) 32 | if index is None: 33 | return hidden 34 | 35 | z = hidden[index] 36 | summary = self.pool(hidden, edge_index, batch) 37 | return z, summary 38 | 39 | 40 | def loss(self, hidden1, summary1): 41 | r"""Computes the margin objective.""" 42 | 43 | shuf_index = torch.randperm(summary1.size(0)) 44 | 45 | hidden2 = hidden1[shuf_index] 46 | summary2 = summary1[shuf_index] 47 | 48 | logits_aa = torch.sigmoid(torch.sum(hidden1 * summary1, dim = -1)) 49 | logits_bb = torch.sigmoid(torch.sum(hidden2 * summary2, dim = -1)) 50 | logits_ab = torch.sigmoid(torch.sum(hidden1 * summary2, dim = -1)) 51 | logits_ba = torch.sigmoid(torch.sum(hidden2 * summary1, dim = -1)) 52 | 53 | TotalLoss = 0.0 54 | ones = torch.ones(logits_aa.size(0)).cuda(logits_aa.device) 55 | TotalLoss += self.marginloss(logits_aa, logits_ba, ones) 56 | TotalLoss += self.marginloss(logits_bb, logits_ab, ones) 57 | 58 | return TotalLoss 59 | 60 | 61 | def test(self, train_z, train_y, val_z, val_y, test_z, test_y, solver='lbfgs', 62 | multi_class='auto', *args, **kwargs): 63 | r"""Evaluates latent space quality via a logistic regression downstream task.""" 64 | clf = LogisticRegression(solver=solver, multi_class=multi_class, *args, 65 | **kwargs).fit(train_z.detach().cpu().numpy(), 66 | train_y.detach().cpu().numpy()) 67 | val_acc = clf.score(val_z.detach().cpu().numpy(), val_y.detach().cpu().numpy()) 68 | test_acc = clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy()) 69 | return val_acc, test_acc -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import math 3 | import torch 4 | import random 5 | import numpy as np 6 | from torch_geometric.datasets import Planetoid 7 | import torch_geometric.transforms as T 8 | 9 | from utils_mp import Subgraph, preprocess, save_object 10 | from subgcon import SugbCon 11 | from model import Encoder, Scorer, Pool 12 | from sklearn.metrics.cluster import normalized_mutual_info_score 13 | from sklearn.metrics.cluster import adjusted_rand_score 14 | from sklearn.cluster import KMeans 15 | from data_split import * 16 | import time 17 | 18 | def get_parser(): 19 | parser = argparse.ArgumentParser(description='Description: Script to run our model.') 20 | parser.add_argument('--seed', type=int, default=12345, help='Random seed.') 21 | parser.add_argument('--dataset', default='CoraFull') 22 | parser.add_argument('--batch_size', type=int, help='batch size', default=100) 23 | parser.add_argument('--subgraph_size', type=int, help='subgraph size default 20', default=20) 24 | parser.add_argument('--n_order', type=int, help='order of neighbor nodes', default=10) 25 | parser.add_argument('--hidden_size', type=int, help='hidden size', default=1024) 26 | parser.add_argument('--n_way', type=int, help='n way', default=5) 27 | parser.add_argument('--k_shot', type=int, help='k shot', default=1) 28 | parser.add_argument('--m_qry', type=int, help='m query', default=10) 29 | parser.add_argument('--test_num', type=int, help='test number', default=100) 30 | parser.add_argument('--patience', type=int, help='epoch patience number', default=10) 31 | parser.add_argument('--beta', type=float, help='G-supcon temperture number', default=1.) 32 | parser.add_argument('--unsup', action='store_true', help='degrade to unsupervised contrastive training (SimCLR).') 33 | return parser 34 | 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = get_parser() 39 | try: 40 | args = parser.parse_args() 41 | except: 42 | exit() 43 | print(args) 44 | test_num = args.test_num 45 | n_way = args.n_way 46 | k_shot = args.k_shot 47 | m_qry = args.m_qry 48 | patience = args.patience 49 | 50 | random.seed(args.seed) 51 | torch.manual_seed(args.seed) 52 | if torch.cuda.is_available(): 53 | torch.cuda.manual_seed(args.seed) 54 | 55 | # Loading data 56 | data, train_idx, id_by_class, train_class, dev_class, test_class, degree_inv = split(args.dataset) 57 | 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | 60 | # Setting up the subgraph extractor 61 | #ppr_path = './subgraph/' + args.dataset 62 | current_directory = os.getcwd() 63 | ppr_path=os.path.join(current_directory, 'subgraph') 64 | os.makedirs(ppr_path, exist_ok=True) 65 | ppr_path=os.path.join(ppr_path,args.dataset) 66 | os.makedirs(ppr_path, exist_ok=True) 67 | 68 | 69 | subgraph = Subgraph(data.x, data.edge_index, ppr_path, args.subgraph_size, args.n_order, args, id_by_class, train_class) 70 | subgraph.build() 71 | 72 | def train(model, optimizer): 73 | # Model training 74 | model.train() 75 | optimizer.zero_grad() 76 | 77 | args.unsup=False 78 | 79 | if not args.unsup: 80 | # class_batch = 42 if args.dataset == 'CoraFull' else 24 81 | # class_batch = len(train_class) 82 | class_batch = n_way 83 | sampled_class = random.sample(range(len(train_class)), class_batch) 84 | sample_idx = [] 85 | for c in sampled_class: 86 | # sample_idx.extend(random.sample(id_by_class[c], args.batch_size // class_batch)) 87 | sample_idx.extend(random.sample(id_by_class[c], k_shot)) 88 | sample_labels = torch.squeeze(data.y)[sample_idx] 89 | class_selected = list(set(sample_labels.tolist())) 90 | sample_labels =[class_selected.index(i) for i in sample_labels] 91 | 92 | batch, index, sample_labels = subgraph.search(sample_idx,sample_labels, interp=True) 93 | 94 | z, summary = model(batch.x.cuda(), batch.edge_index.cuda(), batch.batch.cuda(), index.cuda()) 95 | loss = model.loss(z, summary, sample_labels) 96 | else: 97 | sample_idx = random.sample(range(len(train_idx)), args.batch_size) 98 | batch, index = subgraph.search(sample_idx, interp=True) 99 | z, summary = model(batch.x.cuda(), batch.edge_index.cuda(), batch.batch.cuda(), index.cuda()) 100 | loss = model.loss(z, summary, None) 101 | loss.backward() 102 | optimizer.step() 103 | return loss.item() 104 | 105 | 106 | def get_all_node_emb(model, node_list): 107 | # Obtain central node embs from subgraphs 108 | list_size = node_list.size 109 | z = torch.Tensor(list_size, args.hidden_size).cuda() 110 | group_nb = math.ceil(list_size/args.batch_size) 111 | 112 | 113 | for i in range(group_nb): 114 | maxx = min(list_size, (i + 1) * args.batch_size) 115 | minn = i * args.batch_size 116 | batch, index = subgraph.search(node_list[minn:maxx], interp=False) 117 | node, _ = model(batch.x.cuda(), batch.edge_index.cuda(), batch.batch.cuda(), index.cuda()) 118 | z[minn:maxx] = node 119 | return z 120 | 121 | 122 | def test(model, eval_class, output_ari=False): 123 | # Model testing 124 | model.eval() 125 | #sample downstream few-shot tasks 126 | test_acc_all = [] 127 | purity_all = 0. 128 | nmi_all = 0. 129 | ari_all = 0. 130 | test_acc = 0. 131 | n_way=args.n_way 132 | k_shot = args.k_shot 133 | 134 | 135 | train_z=model(data.x.cuda(), data.edge_index.cuda()) 136 | 137 | np.save('emb_no_contrast.npy',train_z.detach().cpu().numpy()) 138 | np.save('label.npy',data.y) 139 | 140 | 141 | for i in range(test_num): 142 | test_id_support, test_id_query, test_class_selected = \ 143 | test_task_generator(id_by_class, eval_class, n_way, k_shot, m_qry) 144 | 145 | with torch.no_grad(): 146 | train_z = get_all_node_emb(model, test_id_support) 147 | test_z = get_all_node_emb(model, test_id_query) 148 | 149 | train_y = np.array([test_class_selected.index(i) for i in torch.squeeze(data.y)[test_id_support]]) 150 | test_y = np.array([test_class_selected.index(i) for i in torch.squeeze(data.y)[test_id_query]]) 151 | 152 | 153 | 154 | 155 | # save_object(test_z, "./CoraFull/z") 156 | # save_object(test_y, "./CoraFull/y") 157 | 158 | test_acc = model.test(train_z, train_y, test_z, test_y) 159 | if output_ari: 160 | nmi, ari = model.clustering_test(test_z, test_y, n_way) 161 | nmi_all += nmi/test_num 162 | ari_all += ari/test_num 163 | 164 | test_acc_all.append(test_acc) 165 | 166 | m, s = np.mean(test_acc_all), np.std(test_acc_all) 167 | interval = 1.96 * (s / np.sqrt(len(test_acc_all))) 168 | 169 | #print("="*40) 170 | #print('test_acc = {}'.format(m)) 171 | #print('test_interval = {}'.format(interval)) 172 | if output_ari: 173 | return m, s, interval, nmi_all, ari_all 174 | else: 175 | return m, s, interval 176 | 177 | 178 | def train_eval(): 179 | # Setting up the model and optimizer 180 | model = SugbCon( 181 | hidden_channels=args.hidden_size, encoder=Encoder(data.num_features, args.hidden_size,encoder_type='GCN'), 182 | pool=Pool(in_channels=args.hidden_size), 183 | scorer=Scorer(args.hidden_size), 184 | beta=args.beta, 185 | degree_inv=degree_inv).to(device) 186 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 187 | 188 | print('Start training !!!') 189 | best_test_acc = 0 190 | stop_cnt = 0 191 | best_epoch = 0 192 | time_begin=time.time() 193 | for epoch in range(10000): 194 | loss = train(model, optimizer) 195 | if epoch%100==0: 196 | print('epoch = {}, loss = {}'.format(epoch, loss)) 197 | 198 | # validation 199 | if epoch % 10 == 0 and epoch != 0: 200 | test_acc, _, _ = test(model, dev_class) 201 | if test_acc >= best_test_acc: 202 | best_test_acc = test_acc 203 | best_epoch = epoch 204 | stop_cnt = 0 205 | #torch.save(model.state_dict(), 'model.pkl') 206 | else: 207 | stop_cnt += 1 208 | if stop_cnt >= patience: 209 | print('Time', time.time()-time_begin, 'Epoch: {}'.format(epoch)) 210 | break 211 | 212 | # final test 213 | #model.load_state_dict(torch.load('model.pkl')) 214 | acc, std, interval, nmi, ari = test(model, test_class, output_ari=True) 215 | print("Current acc mean: " + str(acc)) 216 | print("Current acc std: " + str(std)) 217 | print('nmi: {:.4f} ari: {:.4f}'.format(nmi,ari)) 218 | return acc, std, interval 219 | 220 | 221 | acc_mean = [] 222 | acc_std = [] 223 | acc_interval = [] 224 | for __ in range(5): 225 | m, s, interval = train_eval() 226 | acc_mean.append(m) 227 | acc_std.append(s) 228 | acc_interval.append(interval) 229 | print("****"*20) 230 | print("Final acc: " + str(np.mean(acc_mean))) 231 | print("Final acc std: " + str(np.mean(acc_std))) 232 | print("Final acc interval: " + str(np.mean(acc_interval))) 233 | 234 | 235 | 236 | -------------------------------------------------------------------------------- /utils_mp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import AsyncIterable 3 | import torch 4 | import numpy as np 5 | from cytoolz import curry 6 | import multiprocessing as mp 7 | from scipy import sparse as sp 8 | from sklearn.preprocessing import normalize, StandardScaler 9 | from torch_geometric.data import Data, Batch 10 | import pickle 11 | import random 12 | def l2_normalize(x): 13 | norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2) 14 | out = x.div(norm+1e-9) 15 | return out 16 | 17 | def standardize(feat, mask): 18 | scaler = StandardScaler() 19 | scaler.fit(feat[mask]) 20 | new_feat = torch.FloatTensor(scaler.transform(feat)) 21 | return new_feat 22 | 23 | 24 | def preprocess(features): 25 | rowsum = np.array(features.sum(1)) 26 | r_inv = np.power(rowsum, -1).flatten() 27 | r_inv[np.isinf(r_inv)] = 0. 28 | r_mat_inv = sp.diags(r_inv) 29 | features = r_mat_inv.dot(features) 30 | return torch.tensor(features) 31 | 32 | 33 | 34 | def save_object(obj, filename): 35 | with open(filename, 'wb') as fout: # Overwrites any existing file. 36 | pickle.dump(obj, fout, pickle.HIGHEST_PROTOCOL) 37 | 38 | 39 | def load_object(filename): 40 | with open(filename, 'rb') as fin: 41 | obj = pickle.load(fin) 42 | return obj 43 | 44 | 45 | class PPR: 46 | #Node-wise personalized pagerank 47 | def __init__(self, adj_mat, maxsize=200, n_order=2, alpha=0.85): 48 | self.n_order = n_order 49 | self.maxsize = maxsize 50 | self.adj_mat = adj_mat 51 | self.P = normalize(adj_mat, norm='l1', axis=0) 52 | self.d = np.array(adj_mat.sum(1)).squeeze() 53 | 54 | #self.scores=self.cal_scores() 55 | 56 | def cal_scores(self, seed, alpha=0.85): 57 | x = sp.csc_matrix((np.ones(1), ([seed], np.zeros(1, dtype=int))), shape=[self.P.shape[0], 1]) 58 | r = x.copy() 59 | for _ in range(self.n_order): 60 | x = (1 - alpha) * r + alpha * self.P @ x 61 | scores = x.data / (self.d[x.indices] + 1e-9) 62 | 63 | return scores 64 | 65 | 66 | def search(self, seed, alpha=0.85): 67 | x = sp.csc_matrix((np.ones(1), ([seed], np.zeros(1, dtype=int))), shape=[self.P.shape[0], 1]) 68 | r = x.copy() 69 | for _ in range(self.n_order): 70 | x = (1 - alpha) * r + alpha * self.P @ x 71 | scores = x.data / (self.d[x.indices] + 1e-9) 72 | 73 | print(scores) 74 | 75 | 76 | idx = scores.argsort()[::-1][:self.maxsize] 77 | neighbor = np.array(x.indices[idx]) 78 | 79 | seed_idx = np.where(neighbor == seed)[0] 80 | if seed_idx.size == 0: 81 | neighbor = np.append(np.array([seed]), neighbor) 82 | else : 83 | seed_idx = seed_idx[0] 84 | neighbor[seed_idx], neighbor[0] = neighbor[0], neighbor[seed_idx] 85 | 86 | assert np.where(neighbor == seed)[0].size == 1 87 | assert np.where(neighbor == seed)[0][0] == 0 88 | 89 | return neighbor 90 | 91 | @curry 92 | def process(self, path, seed): 93 | ppr_path = os.path.join(path, 'ppr{}'.format(seed)) 94 | if not os.path.isfile(ppr_path) or os.stat(ppr_path).st_size == 0: 95 | print ('Processing node {}.'.format(seed)) 96 | neighbor = self.search(seed) 97 | torch.save(neighbor, ppr_path) 98 | else : 99 | print ('File of node {} exists.'.format(seed)) 100 | 101 | def search_all(self, node_num, path): 102 | neighbor = {} 103 | if os.path.isfile(path+'_neighbor') and os.stat(path+'_neighbor').st_size != 0: 104 | print ("Exists neighbor file") 105 | neighbor = torch.load(path+'_neighbor') 106 | else : 107 | print ("Extracting subgraphs") 108 | os.system('mkdir {}'.format(path)) 109 | with mp.Pool() as pool: 110 | list(pool.imap_unordered(self.process(path), list(range(node_num)), chunksize=1000)) 111 | 112 | print ("Finish Extracting") 113 | for i in range(node_num): 114 | neighbor[i] = torch.load(os.path.join(path, 'ppr{}'.format(i))) 115 | torch.save(neighbor, path+'_neighbor') 116 | os.system('rm -r {}'.format(path)) 117 | print ("Finish Writing") 118 | return neighbor 119 | 120 | 121 | class Subgraph: 122 | #Class for subgraph extraction 123 | 124 | def __init__(self, x, edge_index, path, maxsize=50, n_order=10, args=None,id_by_class=None, train_class=None): 125 | self.x = x 126 | self.path = path 127 | self.edge_index = np.array(edge_index) 128 | self.edge_num = edge_index[0].size(0) 129 | self.node_num = x.size(0) 130 | self.maxsize = maxsize 131 | 132 | self.sp_adj = sp.csc_matrix((np.ones(self.edge_num), (edge_index[0], edge_index[1])), 133 | shape=[self.node_num, self.node_num]) 134 | self.ppr = PPR(self.sp_adj, n_order=n_order) 135 | 136 | self.neighbor = {} 137 | self.adj_list = {} 138 | self.subgraph = {} 139 | 140 | self.id_by_class=id_by_class 141 | self.train_class=train_class 142 | 143 | self.args=args 144 | 145 | def process_adj_list(self): 146 | for i in range(self.node_num): 147 | self.adj_list[i] = set() 148 | for i in range(self.edge_num): 149 | u, v = self.edge_index[0][i], self.edge_index[1][i] 150 | self.adj_list[u].add(v) 151 | self.adj_list[v].add(u) 152 | 153 | def adjust_edge(self, idx): 154 | #Generate edges for subgraphs 155 | dic = {} 156 | for i in range(len(idx)): 157 | dic[idx[i]] = i 158 | 159 | new_index = [[], []] 160 | nodes = set(idx) 161 | for i in idx: 162 | edge = list(self.adj_list[i] & nodes) 163 | edge = [dic[_] for _ in edge] 164 | #edge = [_ for _ in edge if _ > i] 165 | new_index[0] += len(edge) * [dic[i]] 166 | new_index[1] += edge 167 | return torch.LongTensor(new_index) 168 | 169 | def adjust_x(self, idx): 170 | #Generate node features for subgraphs 171 | return self.x[idx] 172 | 173 | def build(self): 174 | #Extract subgraphs for all nodes 175 | if os.path.isfile(self.path+'_subgraph') and os.stat(self.path+'_subgraph').st_size != 0: 176 | print ("Exists subgraph file") 177 | self.subgraph = torch.load(self.path+'_subgraph') 178 | return 179 | 180 | self.neighbor = self.ppr.search_all(self.node_num, self.path) 181 | self.process_adj_list() 182 | for i in range(self.node_num): 183 | nodes = self.neighbor[i][:self.maxsize] 184 | x = self.adjust_x(nodes) 185 | edge = self.adjust_edge(nodes) 186 | self.subgraph[i] = Data(x, edge) 187 | torch.save(self.subgraph, self.path+'_subgraph') 188 | 189 | def search(self, node_list, sample_labels=None, interp=True): 190 | from torch_geometric.utils import to_dense_adj, dense_to_sparse 191 | from torch.distributions.beta import Beta 192 | #Extract subgraphs for nodes in the list 193 | batch = [] 194 | index = [] 195 | size = 0 196 | 197 | x_features=[] 198 | x=[] 199 | 200 | 201 | for node in node_list: 202 | #batch.append(self.subgraph[node]) 203 | batch.append(Data(self.subgraph[node].x,self.subgraph[node].edge_index, torch.ones(self.subgraph[node].edge_index.shape[1]))) 204 | 205 | index.append(size) 206 | size += self.subgraph[node].x.size(0) 207 | if interp: 208 | if self.subgraph[node].x.shape[0]<20: 209 | x.append(torch.cat([self.subgraph[node].x,torch.zeros([20-self.subgraph[node].x.shape[0],self.subgraph[node].x.shape[1]])],0)) 210 | else: 211 | x.append(self.subgraph[node].x) 212 | x_features.append(self.subgraph[node].x.mean(0)) 213 | 214 | 215 | #for k in range(10): 216 | # scores=self.ppr.cal_scores(k) 217 | # print(scores.to_dense()) 218 | #print(1/0) 219 | 220 | x_temp = [] 221 | batch_temp=[] 222 | 223 | #sample_num=15 if self.args.dataset!='Coauthor-CS' else 5 224 | sample_num=5 225 | sampled_class = random.sample(range(len(self.train_class)), sample_num ) 226 | sample_idx = [] 227 | for c in sampled_class: 228 | sample_idx.extend(random.sample(self.id_by_class[c], 5)) 229 | for node in sample_idx: 230 | 231 | #batch.append(self.subgraph[node]) 232 | batch_temp.append(Data(self.subgraph[node].x,self.subgraph[node].edge_index, torch.ones(self.subgraph[node].edge_index.shape[1]))) 233 | if interp: 234 | if self.subgraph[node].x.shape[0]<20: 235 | x_temp.append(torch.cat([self.subgraph[node].x,torch.zeros([20-self.subgraph[node].x.shape[0],self.subgraph[node].x.shape[1]])],0)) 236 | else: 237 | x_temp.append(self.subgraph[node].x) 238 | x_features.append(self.subgraph[node].x.mean(0)) 239 | 240 | 241 | 242 | 243 | #N = self.args.n_way 244 | 245 | if interp: 246 | x_features = torch.stack(x_features, 0) 247 | #simi=torch.ones([N*K,N*K])*0.5 248 | simi=torch.sigmoid(-l2_normalize(x_features).matmul(l2_normalize(x_features).t())) 249 | 250 | #batch_temp = Batch().from_data_list(batch) 251 | batch_temp = Batch().from_data_list(batch_temp) 252 | 253 | dense_adj = to_dense_adj(batch_temp.edge_index, batch_temp.batch,max_num_nodes=20) 254 | 255 | for i in range(len(sample_idx)): 256 | 257 | target_idx = i 258 | sim = simi[i, target_idx] 259 | 260 | m = Beta(sim*10, 1) 261 | lambda_value = m.sample(dense_adj[i].shape) 262 | 263 | #lambda_value=0.1 264 | 265 | interp_adj = dense_adj[i] * lambda_value + dense_adj[target_idx] * (1 - lambda_value) 266 | edge_index, edge_attr = dense_to_sparse(interp_adj) 267 | 268 | lambda_value = m.sample([dense_adj[i].shape[0], 1]) 269 | 270 | #lambda_value=0.1 271 | try: 272 | x_interp = x[i] * lambda_value + x_temp[target_idx] * (1 - lambda_value) 273 | except: 274 | pass 275 | #print(dense_adj.shape) 276 | #print(lambda_value.shape) 277 | 278 | index.append(size) 279 | size += x_interp.size(0) 280 | batch.append(Data(x_interp, edge_index, edge_attr)) 281 | 282 | #repeat the label 283 | sample_labels=sample_labels 284 | for j in range(sample_num): 285 | sample_labels.extend([j+5]*5) 286 | 287 | 288 | batch = Batch().from_data_list(batch) 289 | index = torch.tensor(index) 290 | if sample_labels!=None: 291 | sample_labels= torch.LongTensor(sample_labels) 292 | return batch, index, sample_labels 293 | else: 294 | return batch, index 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | --------------------------------------------------------------------------------