├── graphdata └── save │ └── null.txt ├── README.md ├── model.py ├── configs.yaml ├── main.py ├── utils.py ├── data.py └── model_interface.py /graphdata/save/null.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Contrastive Learning Meets Graph Meta Learning: A Unified Method for Few-shot Node Tasks 2 | This repository is the implementation of the model COLA from paper: [**Graph Contrastive Learning Meets Graph Meta Learning: A Unified Method for Few-shot Node Tasks**](https://dl.acm.org/doi/abs/10.1145/3589334.3645367). 3 | 4 | ## Requirements 5 | ``` 6 | python=3.8 7 | torch=1.13.0 8 | pyg=2.3.0 9 | PyTorch Lightning=2.0.1.post0 10 | ogb=1.3.6 11 | pygcl=0.1.2 12 | wandb=0.14.2 13 | ruamel.yaml=0.17.21 14 | ``` 15 | 16 | ## Usages 17 | Please use the following command to run the code. 18 | Here is an example of running a 2-way 5-shot task on the CiteSeer dataset. 19 | ``` 20 | python main.py --dataset=CiteSeer --n_way=2 --k_shot=5 21 | ``` 22 | 23 | ## Citation 24 | ``` 25 | @inproceedings{liu2024graph, 26 | title={Graph Contrastive Learning Meets Graph Meta Learning: A Unified Method for Few-shot Node Tasks}, 27 | author={Liu, Hao and Feng, Jiarui and Kong, Lecheng and Tao, Dacheng and Chen, Yixin and Zhang, Muhan}, 28 | booktitle={Proceedings of the ACM on Web Conference 2024}, 29 | pages={365--376}, 30 | year={2024} 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch_geometric.transforms as T 6 | import torch_geometric 7 | from tqdm import tqdm 8 | from torch.optim import Adam 9 | 10 | from torch_geometric.nn import GCNConv, GATConv, GraphConv, SGConv 11 | from torch_geometric.datasets import Planetoid 12 | 13 | class MLP(torch.nn.Sequential): 14 | """Simple multi-layer perceptron with ReLu activation and optional dropout layer""" 15 | 16 | def __init__(self, input_dim, hidden_dim, n_layers=1, dropout=0.0): 17 | layers = [] 18 | in_dim = input_dim 19 | for _ in range(n_layers - 1): 20 | layers.append(torch.nn.Linear(in_dim, hidden_dim)) 21 | layers.append(torch.nn.PReLU()) 22 | layers.append(torch.nn.Dropout(dropout)) 23 | in_dim = hidden_dim 24 | 25 | layers.append(torch.nn.Linear(in_dim, input_dim)) 26 | 27 | super().__init__(*layers) 28 | 29 | class GNNModel(nn.Module): 30 | def __init__( 31 | self, 32 | input_dim, 33 | hidden_dim, 34 | out_dim, 35 | num_layers=2, 36 | layer_name="GCN", 37 | activation_name="relu", 38 | dp_rate=0.1, 39 | **kwargs, 40 | ): 41 | """ 42 | Args: 43 | c_in: Dimension of input features 44 | c_hidden: Dimension of hidden features 45 | c_out: Dimension of the output features. Usually number of classes in classification 46 | num_layers: Number of "hidden" graph layers 47 | layer_name: String of the graph layer to use 48 | dp_rate: Dropout rate to apply throughout the network 49 | kwargs: Additional arguments for the graph layer (e.g. number of heads for GAT) 50 | """ 51 | super().__init__() 52 | gnn_layer_by_name = {"GCN": GCNConv, "GAT": GATConv, "GraphConv": GraphConv, "SGC": SGConv} 53 | gnn_layer = gnn_layer_by_name[layer_name] 54 | activation_by_name = {'relu': nn.ReLU(), 'prelu': nn.PReLU()} 55 | activation = activation_by_name[activation_name] 56 | 57 | layers = [] 58 | in_channels, out_channels = input_dim, hidden_dim 59 | for _ in range(num_layers): 60 | layers += [ 61 | gnn_layer(in_channels=in_channels, out_channels=out_channels, **kwargs), 62 | nn.BatchNorm1d(out_channels, momentum=0.01), 63 | activation, 64 | nn.Dropout(dp_rate), 65 | ] 66 | in_channels = hidden_dim 67 | layers += [gnn_layer(in_channels=out_channels, out_channels=out_dim, **kwargs)] 68 | self.layers = nn.ModuleList(layers) 69 | 70 | def forward(self, 71 | x, 72 | edge_index, 73 | edge_weight=None): 74 | """ 75 | Args: 76 | x: Input features per node 77 | edge_index: List of vertex index pairs representing the edges in the graph (PyTorch geometric notation) 78 | """ 79 | for layer in self.layers: 80 | # For graph layers, we need to add the "edge_index" tensor as additional input 81 | # All PyTorch Geometric graph layer inherit the class "MessagePassing", hence 82 | # we can simply check the class type. 83 | if isinstance(layer, torch_geometric.nn.MessagePassing): 84 | x = layer(x, edge_index, edge_weight) 85 | else: 86 | x = layer(x) 87 | return x -------------------------------------------------------------------------------- /configs.yaml: -------------------------------------------------------------------------------- 1 | Cora: &Cora 2 | input_dim: 1433 3 | num_classes: 7 4 | num_samples: 2708 5 | dataset: Cora 6 | class_split_ratio: [3, 2, 2] 7 | 8 | CiteSeer: &CiteSeer 9 | input_dim: 3703 10 | num_classes: 6 11 | num_samples: 3327 12 | dataset: CiteSeer 13 | class_split_ratio: [2, 2, 2] 14 | 15 | CoraFull: &CoraFull 16 | input_dim: 8710 17 | num_classes: 70 18 | num_samples: 19793 19 | dataset: CoraFull 20 | class_split_ratio: [38, 15, 15] 21 | 22 | Computers: &Computers 23 | input_dim: 767 24 | num_classes: 10 25 | num_samples: 13752 26 | dataset: Computers 27 | class_split_ratio: [4, 3, 3] 28 | 29 | Coauthor-CS: &Coauthor-CS 30 | input_dim: 6805 31 | num_classes: 15 32 | num_samples: 18333 33 | dataset: Coauthor-CS 34 | class_split_ratio: [5, 5, 5] 35 | 36 | ogbn-arxiv: &ogbn-arxiv 37 | input_dim: 128 38 | num_classes: 40 39 | num_samples: 169343 40 | dataset: ogbn-arxiv 41 | class_split_ratio: [20, 10, 10] 42 | 43 | 44 | 45 | Cora-GFS: &Cora-GFS 46 | <<: *Cora 47 | activation: 'relu' 48 | base_model: 'GCN' 49 | num_layers: 2 50 | hidden_dim: 256 51 | out_dim: 128 52 | tau: 0.4 53 | lr: 5e-4 54 | weight_decay: 1e-5 55 | head_hidden_dim: 128 56 | head_lr: 1e-3 57 | head_weight_decay: 1e-3 58 | task_num: 50 59 | train_task_num: 20 60 | fs_rate: 1 61 | num_negatives: 27080 62 | f1: 0.3 63 | f2: 0.4 64 | f3: 0.4 65 | e1: 0.2 66 | e2: 0.4 67 | e3: 0.4 68 | 69 | CiteSeer-GFS: &CiteSeer-GFS 70 | <<: *CiteSeer 71 | activation: 'relu' 72 | base_model: 'GCN' 73 | num_layers: 2 74 | hidden_dim: 256 75 | out_dim: 128 76 | tau: 0.4 77 | lr: 5e-4 78 | weight_decay: 1e-5 79 | dropout: 0 80 | head_hidden_dim: 512 81 | head_lr: 1e-3 82 | head_weight_decay: 1e-3 83 | task_num: 50 84 | train_task_num: 20 85 | fs_rate: 1 86 | num_negatives: 3327 87 | f1: 0.3 88 | f2: 0.4 89 | f3: 0.4 90 | e1: 0.2 91 | e2: 0.4 92 | e3: 0.4 93 | 94 | Computers-GFS: &Computers-GFS 95 | <<: *Computers 96 | activation: 'relu' 97 | base_model: 'GCN' 98 | num_layers: 1 99 | hidden_dim: 256 100 | out_dim: 128 101 | tau: 0.7 102 | lr: 5e-4 103 | weight_decay: 1e-5 104 | dropout: 0.1 105 | task_num: 100 106 | train_task_num: 100 107 | fs_rate: 1 108 | num_negatives: 27504 109 | f1: 0.2 110 | f2: 0.1 111 | f3: 0.1 112 | e1: 0.5 113 | e2: 0.4 114 | e3: 0.4 115 | 116 | CoraFull-GFS: &CoraFull-GFS 117 | <<: *CoraFull 118 | activation: 'relu' 119 | base_model: 'GCN' 120 | num_layers: 1 121 | hidden_dim: 256 122 | out_dim: 128 123 | tau: 0.7 124 | temperature2: 0.7 125 | lr: 5e-4 126 | weight_decay: 1e-5 127 | dropout: 0.1 128 | task_num: 50 129 | train_task_num: 50 130 | fs_rate: 1 131 | num_negatives: 39586 132 | f1: 0.2 133 | f2: 0.1 134 | f3: 0.1 135 | e1: 0.3 136 | e2: 0.2 137 | e3: 0.2 138 | 139 | Coauthor-CS-GFS: &Coauthor-CS-GFS 140 | <<: *Coauthor-CS 141 | activation: 'relu' 142 | base_model: 'GCN' 143 | num_layers: 1 144 | hidden_dim: 256 145 | out_dim: 128 146 | tau: 0.7 147 | temperature2: 0.7 148 | lr: 5e-4 149 | weight_decay: 1e-5 150 | dropout: 0.1 151 | task_num: 50 152 | train_task_num: 50 153 | fs_rate: 1 154 | num_negatives: 39586 155 | f1: 0.2 156 | f2: 0.1 157 | f3: 0.1 158 | e1: 0.3 159 | e2: 0.2 160 | e3: 0.2 161 | 162 | ogbn-arxiv-GFS: &ogbn-arxiv-GFS 163 | <<: *ogbn-arxiv 164 | activation: 'relu' 165 | base_model: 'GCN' 166 | num_layers: 2 167 | hidden_dim: 256 168 | out_dim: 128 169 | tau: 0.5 170 | lr: 1e-3 171 | weight_decay: 0 172 | dropout: 0.1 173 | head_hidden_dim: 128 174 | head_lr: 1e-3 175 | head_weight_decay: 1e-3 176 | task_num: 100 177 | train_task_num: 100 178 | q_query: 20 179 | fs_rate: 1 180 | num_negatives: 169343 181 | 182 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | import os 3 | 4 | # For downloading pre-trained models 5 | import urllib.request 6 | from urllib.error import HTTPError 7 | 8 | # PyTorch Lightning 9 | import pytorch_lightning as pl 10 | 11 | # PyTorch 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import numpy as np 17 | # PyTorch geometric 18 | import torch_geometric 19 | import torch_geometric.data as geom_data 20 | import torch_geometric.nn as geom_nn 21 | from torch.utils.data import DataLoader, TensorDataset 22 | import copy 23 | 24 | # PL callbacks 25 | import wandb 26 | import lightning.pytorch as pl 27 | 28 | from torch import Tensor 29 | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, Timer 30 | from lightning.pytorch.callbacks.progress import TQDMProgressBar 31 | from lightning.pytorch.loggers import WandbLogger 32 | from lightning.pytorch.callbacks.early_stopping import EarlyStopping 33 | from lightning.pytorch.utilities import model_summary 34 | import GCL.augmentors as A 35 | from GCL.eval import LREvaluator 36 | 37 | from model_interface import GFS3 38 | 39 | 40 | from data import load_dataset 41 | from utils import get_args 42 | 43 | from data import FewShotDataManager, IndexDataset, IndexDataset3 44 | 45 | 46 | def train_gfs3(args, 47 | max_epochs, 48 | dataset): 49 | project_name = '-'.join([args.exp_name, args.dataset, str(args.n_way), str(args.k_shot)]) 50 | logger = WandbLogger(name=args.name, project=project_name, save_dir=args.save_dir) 51 | logger.log_hyperparams(args) 52 | trainer = pl.Trainer( 53 | accelerator="auto", 54 | devices="auto", 55 | max_epochs=max_epochs, 56 | enable_checkpointing=True, 57 | enable_progress_bar=True, 58 | enable_model_summary=True, 59 | logger=logger, 60 | min_epochs=50, 61 | log_every_n_steps=1, 62 | val_check_interval=0.2, 63 | callbacks=[ 64 | TQDMProgressBar(refresh_rate=20), 65 | ModelCheckpoint(save_weights_only=True, dirpath=args.save_dir + '/' + args.dataset, monitor="val_acc", 66 | mode="max", save_top_k=3), 67 | LearningRateMonitor(logging_interval="epoch"), 68 | EarlyStopping(monitor="val_acc", min_delta=0.00, patience=200, verbose=False, mode="max") 69 | ] 70 | ) 71 | trainer.logger._default_hp_metric = None # Optional logging argument that we don't need 72 | 73 | pretrained_filename = args.best_pretrain+'aaaaa' 74 | print(pretrained_filename) 75 | if os.path.isfile(pretrained_filename): 76 | print(f"Found pretrained model at {pretrained_filename}, loading... ") 77 | model = GFS.load_from_checkpoint(pretrained_filename) 78 | 79 | else: 80 | pl.seed_everything(args.random_seed) 81 | 82 | datamanager = FewShotDataManager(dataset[0], args) 83 | val_loader = datamanager.get_data_loader(1) 84 | test_loader = datamanager.get_data_loader(2) 85 | test_idx = datamanager.split['test'] 86 | 87 | val_idx = datamanager.split['valid'] 88 | # test_val_idx = test_idx.extend(val_idx) 89 | 90 | if args.label_mask == 3: 91 | train_loader = DataLoader(IndexDataset3(dataset[0], test_idx), 92 | batch_size=args.n_way * args.train_task_num, 93 | shuffle=True, 94 | drop_last=True, 95 | num_workers=args.num_workers) 96 | else: 97 | train_loader = DataLoader(IndexDataset(dataset[0]), 98 | batch_size=args.n_way * args.train_task_num, 99 | shuffle=True, 100 | drop_last=True, 101 | num_workers=args.num_workers) 102 | 103 | aug1 = A.Compose([A.EdgeRemoving(pe=args.e1), A.FeatureMasking(pf=args.f1)]) 104 | aug2 = A.Compose([A.EdgeRemoving(pe=args.e2), A.FeatureMasking(pf=args.f2)]) 105 | aug3 = A.Compose([A.EdgeRemoving(pe=args.e3), A.FeatureMasking(pf=args.f3)]) 106 | 107 | model = GFS3(augmentor=(aug1, aug2, aug3), 108 | args=args, 109 | data = dataset[0], 110 | test_idx = test_idx, 111 | encoder_momentum=args.mmt 112 | ) 113 | 114 | 115 | #trainer.fit(model, train_dataloaders=train_loader) 116 | trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=[val_loader,test_loader]) 117 | # Load best checkpoint after training 118 | model = GFS3.load_from_checkpoint(trainer.checkpoint_callback.best_model_path, 119 | data=dataset[0]) 120 | 121 | trainer.test(model, dataloaders=test_loader) 122 | trainer.test(model, dataloaders=val_loader) 123 | 124 | 125 | wandb.finish() 126 | return model 127 | 128 | def main(): 129 | args = get_args() 130 | dataset = load_dataset(args) 131 | 132 | GFS_model = train_gfs3(args=args, max_epochs=args.max_epochs, dataset=dataset) 133 | 134 | 135 | if __name__ == "__main__": 136 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from ruamel.yaml import YAML 4 | 5 | def get_args(): 6 | yaml_path = 'configs.yaml' 7 | parser = argparse.ArgumentParser(description='Here is Meta Learning for Graph data.') 8 | 9 | parser.add_argument('--use_cuda', action='store_true', default=True) 10 | parser.add_argument('--data_path', type=str, default='./graphdata') 11 | parser.add_argument('--save_dir', type=str, default='./graphdata/save') 12 | 13 | parser.add_argument('--name', type=str, default='GFS') 14 | parser.add_argument('--exp_name', type=str, default='F') 15 | parser.add_argument('--model_name', type=str, default='GFS') 16 | parser.add_argument('--random_seed', type=int, default=231) 17 | parser.add_argument('--num_workers', type=int, default=0) 18 | 19 | parser.add_argument('--dataset', type=str, default='Cora') 20 | parser.add_argument('--num_samples', type=int, default=2708) 21 | parser.add_argument('--input_dim', type=str, default=1433) 22 | parser.add_argument('--num_classes', type=int, default=7) 23 | parser.add_argument('--num_negatives', type=int, default=27080, help='for queue') 24 | 25 | parser.add_argument('--max_epochs', type=int, default=200) 26 | parser.add_argument('--batch_size', type=int, default=1) 27 | parser.add_argument('--base_model', type=str, default='GCN') 28 | parser.add_argument('--activation', type=str, default="relu") 29 | parser.add_argument('--num_layers', type=int, default=2) 30 | parser.add_argument('--hidden_dim', type=int, default=256) 31 | parser.add_argument('--out_dim', type=int, default=128) 32 | parser.add_argument('--tau', type=float, default=0.5) 33 | parser.add_argument('--lr', type=float, default=0.0005) 34 | parser.add_argument('--dropout', type=float, default=0) 35 | parser.add_argument('--weight_decay', type=float, default=1e-5) 36 | parser.add_argument('--best_pretrain', type=str, default='no.ckpt') 37 | 38 | parser.add_argument('--head_max_epochs', type=int, default=1) 39 | parser.add_argument('--head_hidden_dim', type=int, default=128) 40 | parser.add_argument('--head_lr', type=float, default=1e-2) 41 | parser.add_argument('--head_weight_decay', type=float, default=0) 42 | parser.add_argument('--best_head', type=str, default='no_head.ckpt') 43 | parser.add_argument('--classifier', type=str, default='LR') 44 | 45 | parser.add_argument('--class_split_ratio', type=list, default=[3, 2, 2]) 46 | parser.add_argument('--n_way', type=int, default=2) 47 | parser.add_argument('--k_shot', type=int, default=5) 48 | parser.add_argument('--q_query', type=int, default=20) 49 | parser.add_argument('--task_num', type=int, default=50, help='Number of tasks used for test/validation.') 50 | parser.add_argument('--train_task_num', type=int, default=20, help='Number of tasks for calculating fs_loss') 51 | parser.add_argument('--temperature2', type=float, default=1) 52 | parser.add_argument('--fs_rate', type=float, default=1, help='Ratio of fs_loss in the final loss function.') 53 | 54 | parser.add_argument('--label_mask', type=int, default=0) 55 | parser.add_argument('--khop_mask', type=int, default=0) 56 | parser.add_argument('--self_mask', type=bool, default=False, help='Mask the node itself search process.') 57 | parser.add_argument('--k_rate', type=int, default=1) 58 | parser.add_argument('--mmt', type=float, default=0.9) 59 | parser.add_argument('--em_scd', type=int ,default=0) 60 | 61 | parser.add_argument('--f1', type=float, default=0.3, help='Augmentation ratio for the first feature.') 62 | parser.add_argument('--f2', type=float, default=0.4, help='Augmentation ratio for the second feature.') 63 | parser.add_argument('--f3', type=float, default=0.4, help='Augmentation ratio for the third feature.') 64 | parser.add_argument('--e1', type=float, default=0.2, help='Augmentation ratio for the first edge.') 65 | parser.add_argument('--e2', type=float, default=0.4, help='Augmentation ratio for the second edge.') 66 | parser.add_argument('--e3', type=float, default=0.4, help='Augmentation ratio for the third edge.') 67 | 68 | parser.add_argument('--compare_mode', type=str, default='m1') 69 | parser.add_argument('--model_mode', type=str, default='fs3') 70 | 71 | args = parser.parse_args() 72 | with open(yaml_path) as args_file: 73 | 74 | args_key = "-".join([args.dataset, args.model_name]) 75 | print(args_key) 76 | try: 77 | parser.set_defaults(**dict(YAML().load(args_file)[args_key].items())) 78 | except KeyError: 79 | raise AssertionError('KeyError: there is no {} in yamls'.format(args_key), "red") 80 | 81 | args = parser.parse_args() 82 | 83 | return args 84 | 85 | 86 | def precision_at_k(output, target, top_k=(1,)): 87 | """Computes the accuracy over the k top predictions for the specified values of k.""" 88 | with torch.no_grad(): 89 | maxk = max(top_k) 90 | batch_size = target.size(0) 91 | 92 | _, pred = output.topk(maxk, 1, True, True) 93 | pred = pred.t() 94 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 95 | 96 | res = [] 97 | for k in top_k: 98 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 99 | res.append(correct_k.mul_(100.0 / batch_size)) 100 | return res 101 | 102 | 103 | def map_class(classes: torch.Tensor, q_query: int) -> torch.Tensor: 104 | # map the true label to task label 105 | # classes: n_way 106 | # output: (q_query x n_way) x 1 107 | exp_classes = classes.unsqueeze(1).expand(classes.size(0), q_query).reshape(-1, 1).squeeze() 108 | map = {x.item(): i for i, x in enumerate(classes)} 109 | remap_classes = torch.LongTensor([map[x.item()] for x in exp_classes]) 110 | return remap_classes 111 | 112 | def normalize_0to1(tensor: torch.Tensor): 113 | min_values, _ = torch.min(tensor, dim=1, keepdim=True) 114 | max_values, _ = torch.max(tensor, dim=1, keepdim=True) 115 | normalized_tensor = (tensor - min_values) / (max_values - min_values) 116 | return normalized_tensor -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import json 5 | import os.path 6 | from abc import abstractmethod 7 | 8 | # utils 9 | import numpy as np 10 | import torch 11 | from ogb.nodeproppred import PygNodePropPredDataset 12 | from torch.utils.data import BatchSampler 13 | from torch.utils.data import DataLoader, Dataset 14 | from torch_geometric.datasets import Planetoid, Amazon, CoraFull, Coauthor, HeterophilousGraphDataset 15 | import torch_geometric.transforms as T 16 | from torch_geometric.utils import to_undirected 17 | 18 | def load_dataset(args): 19 | """ 20 | Create a dictionary to keep the mapping relationship between dataset and dataset_class 21 | key: "dataset" -> value: dataset class # 22 | for example, "Cora" -> 0 23 | class number to dataset_class (in PyG): 0: "Planetoid"; 1: "MyAmazon"; 2: "PygNodePropPredDataset" 24 | """ 25 | relation_dic = {} 26 | available_datasets = [["Cora", "CiteSeer", "PubMed"], 27 | ["Computers", "Photo"], 28 | ["ogbn-arxiv", "ogbn-mag", "ogbn-products]"], 29 | ["CoraFull"], 30 | ["Coauthor-CS"], 31 | ["Roman-empire"]] 32 | for cls, dataset_lst in enumerate(available_datasets): 33 | for dataset in dataset_lst: 34 | relation_dic[dataset] = cls 35 | 36 | # Load dataset from args.data_path folder 37 | # If no existing dataset, it will be automatically downloaded 38 | if relation_dic[args.dataset] == 0: 39 | dataset = Planetoid(args.data_path, args.dataset, transform=T.NormalizeFeatures()) 40 | elif relation_dic[args.dataset] == 1: 41 | dataset = Amazon(args.data_path, args.dataset, transform=T.NormalizeFeatures()) 42 | elif relation_dic[args.dataset] == 2: 43 | dataset = PygNodePropPredDataset(root=args.data_path, name=args.dataset, transform=T.NormalizeFeatures()) 44 | # change arxiv to undirected graph 45 | edge_index = to_undirected(dataset._data.edge_index) 46 | set_dataset_attr(dataset, 'edge_index', edge_index, 47 | edge_index.shape[1]) 48 | # dataset.y = dataset.y.squeeze() 49 | # print(dataset[0].y) 50 | elif relation_dic[args.dataset] == 3: 51 | dataset = CoraFull(root=args.data_path+'/'+args.dataset, transform=T.NormalizeFeatures()) 52 | elif relation_dic[args.dataset] == 4: 53 | dataset = Coauthor(root=args.data_path, name="CS", transform=T.NormalizeFeatures()) 54 | elif relation_dic[args.dataset] == 5: 55 | dataset = HeterophilousGraphDataset(root=args.data_path, name="Roman-empire", transform=T.NormalizeFeatures()) 56 | edge_index = to_undirected(dataset._data.edge_index) 57 | set_dataset_attr(dataset, 'edge_index', edge_index, 58 | edge_index.shape[1]) 59 | else: 60 | raise ValueError(f"Unknown dataset: {args.dataset}. Please choose from {available_datasets}") 61 | 62 | print(f"Successfully load dataset: {args.dataset} from {args.data_path}.") 63 | return dataset 64 | 65 | 66 | # dataset = load_dataset(args) 67 | # data = dataset[0] 68 | # datamanager = FewShotDataManager(data, args) 69 | # train_loader = datamanager.get_data_loader(0) 70 | # val_loader = datamanager.get_data_loader(1) 71 | # test_loader = datamanager.get_data_loader(2) 72 | 73 | 74 | class DataManager: 75 | @abstractmethod 76 | def get_data_loader(self, mode): 77 | pass 78 | 79 | 80 | class FewShotDataManager(DataManager): 81 | def __init__(self, data, args): 82 | super(FewShotDataManager, self).__init__() 83 | self.args = args 84 | data.y = data.y.squeeze() 85 | self.dataset = FewShotDataset(data, args, args.k_shot + args.q_query) 86 | self.split = self.dataset.split 87 | 88 | def get_data_loader(self, mode): 89 | # mode: 0->train, 1->val, 2->test 90 | class_list = self.dataset.__getclass__(mode) 91 | sampler = EpisodeBatchSampler(self.args, class_list, mode) 92 | # sampler = BatchSampler(EpisodeBatchSampler(self.args, class_list), batch_size=10, drop_last=False) 93 | # sampler = np.concatenate(list(sampler), axis=1) 94 | # print(sampler) 95 | data_loader_params = dict(batch_sampler=sampler, 96 | num_workers=self.args.num_workers, 97 | pin_memory=False) 98 | data_loader = DataLoader(self.dataset, **data_loader_params) 99 | return data_loader 100 | 101 | def get_dataset(self): 102 | return self.dataset 103 | 104 | 105 | class FewShotDataset(Dataset): 106 | def __init__(self, data, args, batch_size): 107 | self.data = data 108 | self.args = args 109 | self.batch_size = batch_size 110 | 111 | self.cls_split_lst = self.class_split() 112 | self.cls_dataloader = self.create_subdataloader() 113 | self.split = self.get_split_index() 114 | 115 | def class_split(self): 116 | """ 117 | Split class for train/val/test in meta learning setting. 118 | Save as list: [[train_class_index], [val_class_index], [test_class_index]] 119 | """ 120 | cls_split_file = self.args.data_path + '/' + self.args.dataset + '_class_split.json' 121 | 122 | if os.path.isfile(cls_split_file) and False: 123 | # load list if exists 124 | with open(cls_split_file, 'rb') as f: 125 | cls_split_lst = json.load(f) 126 | print('Complete: Load class split info from %s .' % cls_split_file) 127 | 128 | else: 129 | # create list according to class_split_ratio and save 130 | label = torch.unique(self.data.y).cpu().detach() 131 | # if CoraFull dataset, ignore 68,69 label since they only have 15/29 samples 132 | if label.size(0) == 70: 133 | label = label[:-2] 134 | # randomly shuffle 135 | label = label.index_select(0, torch.randperm(label.shape[0])) 136 | train_class, val_class, test_class = torch.split(label, self.args.class_split_ratio) 137 | cls_split_lst = [train_class.tolist(), val_class.tolist(), test_class.tolist()] 138 | 139 | # with open(cls_split_file, 'w') as f: 140 | # json.dump(cls_split_lst, f) 141 | # print('Complete: Save class split info to %s .' % cls_split_file) 142 | print(cls_split_lst) 143 | return cls_split_lst 144 | 145 | def label_to_index(self) -> (dict, torch.tensor): 146 | """ 147 | Generate a dictionary mapping labels to index list 148 | :return: dictionary: {label: [list of index]} 149 | """ 150 | label = torch.unique(self.data.y) 151 | label2index = {} 152 | for i in label: 153 | label2index[int(i)] = torch.nonzero(self.data.y == i).squeeze() 154 | 155 | return label2index, label 156 | 157 | def create_subdataloader(self): 158 | """ 159 | :return: list of subdataloaders for each class i 160 | """ 161 | label2index, label = self.label_to_index() 162 | cls_dataloader = [] 163 | cls_dataloader_params = dict(batch_size=self.batch_size, 164 | shuffle=True, 165 | num_workers=self.args.num_workers, 166 | pin_memory=False) 167 | for c in label: 168 | cls_dataset = ClassDataset(label2index[int(c)]) 169 | cls_dataloader.append(DataLoader(cls_dataset, **cls_dataloader_params)) 170 | 171 | return cls_dataloader 172 | 173 | def get_split_index(self): 174 | """ 175 | :return: dictionary that contains the node index for each split 176 | """ 177 | label2index, label = self.label_to_index() 178 | cls_split_lst = self.cls_split_lst 179 | split = { 180 | 'train': [], 181 | 'valid': [], 182 | 'test': [] 183 | } 184 | 185 | for c in label: 186 | if c in cls_split_lst[0]: 187 | split['train'].extend([int(idx) for idx in label2index[int(c)]]) 188 | elif c in cls_split_lst[1]: 189 | split['valid'].extend([int(idx) for idx in label2index[int(c)]]) 190 | elif c in cls_split_lst[2]: 191 | split['test'].extend([int(idx) for idx in label2index[int(c)]]) 192 | else: 193 | print("label %s does not belong to any class." % c) 194 | 195 | return split 196 | 197 | def __getitem__(self, class_index): 198 | return next(iter(self.cls_dataloader[class_index])), class_index 199 | 200 | def __len__(self): 201 | # mode = 0 -> train; 1 -> validation; 2 -> test 202 | return len(torch.unique(self.data.y)) 203 | 204 | def __getclass__(self, mode): 205 | # return available classes under current mode (train/val/test) 206 | # print(self.cls_split_lst) 207 | return self.cls_split_lst[mode] 208 | 209 | 210 | class EpisodeBatchSampler(object): 211 | def __init__(self, args, class_list, mode): 212 | # TODO: change value of episode to some variables 213 | self.episode = 1 214 | self.n_way = args.n_way 215 | self.class_list = class_list 216 | self.mode = mode 217 | self.test_num = args.task_num 218 | 219 | def __len__(self): 220 | return self.episode 221 | 222 | def __iter__(self): 223 | for i in range(self.episode): 224 | batch_class = [] 225 | task_num = self.test_num if self.mode!=0 else 1 226 | for j in range(task_num): 227 | batch_class.append(np.random.choice(self.class_list, self.n_way, replace=False)) 228 | yield np.concatenate(batch_class) 229 | 230 | 231 | class ClassDataset(Dataset): 232 | def __init__(self, label_index): 233 | self.label_index = label_index 234 | 235 | def __getitem__(self, i): 236 | return self.label_index[i] 237 | 238 | def __len__(self): 239 | return self.label_index.shape[0] 240 | 241 | 242 | class IndexDataset(Dataset): 243 | def __init__(self, data): 244 | self.data = data 245 | self.len = data.x.size(0) 246 | 247 | def __getitem__(self, i): 248 | return i 249 | 250 | def __len__(self): 251 | return self.len 252 | 253 | class IndexDataset3(Dataset): 254 | def __init__(self, data, test_idx): 255 | self.data = data 256 | self.data_idx = torch.arange(data.x.size(0)) 257 | 258 | test_idx_tensor = torch.tensor(test_idx) 259 | mask = torch.isin(self.data_idx, test_idx_tensor) 260 | mask_not_in_list = torch.logical_not(mask) 261 | self.train_val_idx = self.data_idx[mask_not_in_list] 262 | self.len = self.train_val_idx.size(0) 263 | 264 | 265 | 266 | def __getitem__(self, i): 267 | return self.train_val_idx[i] 268 | 269 | def __len__(self): 270 | return self.len 271 | 272 | def set_dataset_attr(dataset, name, value, size): 273 | dataset._data_list = None 274 | dataset.data[name] = value 275 | if dataset.slices is not None: 276 | dataset.slices[name] = torch.tensor([0, size], dtype=torch.long) 277 | -------------------------------------------------------------------------------- /model_interface.py: -------------------------------------------------------------------------------- 1 | 2 | import lightning.pytorch as pl 3 | from argparse import ArgumentParser 4 | 5 | from typing import Dict, List 6 | import torch.optim as optim 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torchmetrics import MetricCollection, Accuracy, F1Score, Recall, Precision, AUROC, ConfusionMatrix 12 | import torch.nn.functional as F 13 | import torch_geometric 14 | 15 | from torch_geometric.utils import k_hop_subgraph 16 | 17 | from model import MLP, GNNModel 18 | from utils import precision_at_k, map_class 19 | 20 | from sklearn.linear_model import LogisticRegression as SKLR 21 | from sklearn.model_selection import GridSearchCV 22 | from sklearn.svm import LinearSVC 23 | 24 | import ignite.distributed as idist 25 | from utils import normalize_0to1 26 | 27 | 28 | # GFS3 29 | class GFS3(pl.LightningModule): 30 | def __init__(self, 31 | args, 32 | data, 33 | augmentor, 34 | test_idx, 35 | encoder_momentum: float = 0.999, 36 | encoder_depth=4, 37 | head_depth=2, 38 | softmax_temperature: float = 0.5, 39 | learning_rate: float = 1e-3, 40 | momentum: float = 0.9, 41 | weight_decay: float = 1e-4 42 | ): 43 | super().__init__() 44 | self.save_hyperparameters(ignore=["data"]) 45 | self.args = args 46 | self.data = data 47 | self.test_idx = test_idx 48 | 49 | self.aug1, self.aug2, self.aug3 = augmentor 50 | self.training_step_outputs = [] 51 | 52 | # create encoders and projection heads 53 | self.encoder_q, self.encoder_k, self.pretraining_head_q, self.pretraining_head_k = self._init_encoders(args) 54 | # initialize weights 55 | self.encoder_q.apply(self._init_weights) 56 | self.pretraining_head_q.apply(self._init_weights) 57 | 58 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 59 | param_k.data.copy_(param_q.data) # initialize 60 | param_k.requires_grad = False # not update by gradient 61 | for paramh_q, paramh_k in zip(self.pretraining_head_q.parameters(), self.pretraining_head_k.parameters()): 62 | paramh_k.data.copy_(paramh_q.data) # initialize 63 | paramh_k.requires_grad = False # not update by gradient 64 | 65 | 66 | def _init_encoders(self, args): 67 | if args.base_model == 'MLP': 68 | encoder_q = MLP(args.input_dim, args.out_dim, args.num_layers) 69 | encoder_k = MLP(args.input_dim, args.out_dim, args.num_layers) 70 | else: 71 | encoder_q = GNNModel( 72 | input_dim=args.input_dim, 73 | hidden_dim=args.hidden_dim, 74 | out_dim=args.out_dim, 75 | num_layers=args.num_layers, 76 | layer_name=args.base_model, 77 | activation_name=args.activation, 78 | dp_rate=args.dropout 79 | ) 80 | encoder_k = GNNModel( 81 | input_dim=args.input_dim, 82 | hidden_dim=args.hidden_dim, 83 | out_dim=args.out_dim, 84 | num_layers=args.num_layers, 85 | layer_name=args.base_model, 86 | activation_name=args.activation, 87 | dp_rate=args.dropout 88 | ) 89 | 90 | # Initialize pretraining_head with MLP 91 | pretraining_head_q = MLP(args.out_dim, args.out_dim) 92 | pretraining_head_k = MLP(args.out_dim, args.out_dim) 93 | 94 | return encoder_q, encoder_k, pretraining_head_q, pretraining_head_k 95 | 96 | def _init_weights(self, module): 97 | if isinstance(module, nn.Linear): 98 | torch.nn.init.xavier_uniform_(module.weight) 99 | module.bias.data.fill_(0.01) 100 | 101 | @torch.no_grad() 102 | def _momentum_update_key_encoder(self): 103 | """Momentum update of the key encoder.""" 104 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 105 | em = self.hparams.encoder_momentum 106 | if self.current_epoch > 10: 107 | if self.args.em_scd == 1: 108 | # Schedule em 109 | max_em = 0.999 110 | em += (self.current_epoch-10) / 100 111 | em = min(em, max_em) 112 | self.log('momentum',em) 113 | param_k.data = param_k.data * em + param_q.data * (1.0 - em) 114 | for paramh_q, paramh_k in zip(self.pretraining_head_q.parameters(), self.pretraining_head_k.parameters()): 115 | em = self.hparams.encoder_momentum 116 | paramh_k.data = paramh_k.data * em + paramh_q.data * (1.0 - em) 117 | 118 | 119 | def forward(self, 120 | data: torch_geometric.data.data): 121 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 122 | x1, edge_index1, edge_weight1 = self.aug1(x, edge_index, edge_weight) 123 | x2, edge_index2, edge_weight2 = self.aug2(x, edge_index, edge_weight) 124 | x3, edge_index3, edge_weight3 = self.aug3(x, edge_index, edge_weight) 125 | 126 | 127 | # GNN model 128 | z1 = self.encoder_q(x1, edge_index1, edge_weight1) 129 | z1 = self.pretraining_head_q(z1) 130 | #z1 = self.prediction_head(z1) 131 | z1 = nn.functional.normalize(z1, dim=1) 132 | 133 | with torch.no_grad(): 134 | z2 = self.encoder_k(x2, edge_index2, edge_weight2) 135 | #z2 = self.pretraining_head_q(z2) 136 | z2 = nn.functional.normalize(z2, dim=1) 137 | 138 | z3 = self.encoder_k(x3, edge_index3, edge_weight3) 139 | #z3 = self.pretraining_head_q(z3) 140 | z3 = nn.functional.normalize(z3, dim=1) 141 | 142 | 143 | return z1, z2, z3 144 | 145 | def _calculate_fs_loss(self, z1, z2, z3, query_idx, queue_mask=None): 146 | z1_query = z1[query_idx] 147 | z2_query = z2[query_idx] 148 | #z3_query = z3[query_idx] 149 | 150 | sim = torch.einsum("nc,bc->nb", [z2_query, z3]) 151 | #print(sim) 152 | if self.args.label_mask or self.args.khop_mask: 153 | sim *= queue_mask 154 | 155 | # select according to topk similarity 156 | k = self.args.k_shot * self.args.k_rate 157 | topnk_idx = torch.topk(sim, k=k, dim=1, largest=True).indices 158 | topk_idx = topnk_idx[:, torch.randperm(topnk_idx.size(1))[:self.args.k_shot]] 159 | #topk_idx = torch.cat((topk_idx, query_idx.view(-1, 1)), 1) 160 | 161 | true_label = 0 162 | total_match = self.args.k_shot * self.args.n_way 163 | 164 | for i in range(self.args.n_way): 165 | true_label += (self.data.y[query_idx[i]] == self.data.y[topk_idx[i]]).sum() 166 | true_ratio = true_label / total_match 167 | 168 | # n_way x k_shot x out_dim 169 | support_embeddings = z3[topk_idx] 170 | 171 | prototypes = support_embeddings.mean(dim=1) 172 | 173 | support_embeddings = support_embeddings.reshape(-1, self.args.out_dim).transpose(0, 1) 174 | 175 | # Supervised contrastive loss function 176 | loss_fs = torch.mm(z1_query, support_embeddings).div(self.args.temperature2).logsumexp(dim=1) - z1_query.mul(prototypes).div(self.args.temperature2).sum(dim=1) 177 | loss_fs = loss_fs.mean() 178 | 179 | return loss_fs, true_ratio 180 | 181 | 182 | def _calculate_mask(self, query_idx): 183 | data = self.data 184 | queue_mask = None 185 | mask = torch.ones([self.args.n_way, self.args.num_samples]) 186 | 187 | # label mask: [n_way, len(queue)] 188 | if self.args.label_mask: 189 | query_label = data.y[query_idx].view(-1,1) 190 | mask = data.y.T == query_label 191 | neg_mask = data.y.T != query_label 192 | if self.args.label_mask == 2: 193 | for i, idx in enumerate(query_idx): 194 | if int(idx) in self.test_idx: 195 | mask[i, :] = 1 196 | 197 | elif self.args.khop_mask != 0: 198 | mask = torch.zeros([self.args.n_way, self.args.num_samples]) 199 | for row, idx in enumerate(query_idx): 200 | subset, _, _, _ = k_hop_subgraph(int(idx), self.args.khop_mask, data.edge_index) 201 | mask[row, subset] = 1 202 | 203 | elif self.args.self_mask: 204 | mask = torch.ones([self.args.n_way, self.args.num_samples]) 205 | for row, idx in enumerate(query_idx): 206 | mask[row, idx] = 0 207 | 208 | return mask 209 | 210 | 211 | def training_step(self, batch, batch_idx): 212 | #print('GFS3 begins!') 213 | self._momentum_update_key_encoder() 214 | z1, z2, z3 = self(data=self.data.to(batch.device)) 215 | assert z1.requires_grad == True 216 | assert z2.requires_grad == False 217 | assert z3.requires_grad == False 218 | 219 | # calculate few-shot loss 220 | loss_fs1 = 0 221 | loss_fs2 = 0 222 | true_ratio = 0 223 | task_num = batch.size()[0] / self.args.n_way 224 | assert int(task_num) == self.args.train_task_num 225 | for i in range(int(task_num)): 226 | query_idx = batch[i*self.args.n_way:i*self.args.n_way+self.args.n_way] 227 | queue_mask = self._calculate_mask(query_idx).to(batch.device) 228 | if self.args.compare_mode == 'm1': 229 | loss_fs11, true_ratio1 = self._calculate_fs_loss(z1=z1, z2=z2, z3=z3, query_idx=query_idx, queue_mask=queue_mask) 230 | loss_fs22, true_ratio2 = self._calculate_fs_loss(z1=z1, z2=z3, z3=z2, query_idx=query_idx, queue_mask=queue_mask) 231 | loss_fs1 += loss_fs11 232 | loss_fs2 += loss_fs22 233 | true_ratio += (true_ratio1 + true_ratio2) / 2 234 | 235 | loss_fs1 /= task_num 236 | loss_fs2 /= task_num 237 | loss_fs = (loss_fs1 + loss_fs2) / 2 238 | true_ratio /= task_num 239 | self.log("true_ratio",true_ratio) 240 | 241 | # loss_penalty = self._calculate_label_penalty(z1=z1, z2=z2, query_idx=query_idx) 242 | # loss_penalty += self._calculate_label_penalty(z1=z2, z2=z3, query_idx=query_idx) 243 | # loss_fs += 0.5 * loss_penalty 244 | 245 | log = {"train_loss_fs": loss_fs, "train_loss_fs1": loss_fs1, "train_loss_fs2": loss_fs2} 246 | self.log_dict(log) 247 | #print(type(loss_fs)) 248 | self.training_step_outputs.append(true_ratio) 249 | 250 | return loss_fs 251 | 252 | def on_train_epoch_end(self) -> None: 253 | 254 | epoch_average_true_ratio = torch.stack(self.training_step_outputs).mean() 255 | # epoch_average_loss = torch.stack(x['loss'] for x in self.training_step_outputs).mean() 256 | logs = {'true_ratio_epoch': epoch_average_true_ratio, 'step': self.current_epoch} 257 | self.log_dict(logs) 258 | self.training_step_outputs.clear() 259 | 260 | 261 | def fs_test(self, batch, data, args, mode="val"): 262 | 263 | task, target = batch 264 | 265 | encoder_model = self.encoder_q 266 | encoder_model.eval() 267 | embeddings = encoder_model(data.x, 268 | data.edge_index, 269 | data.edge_attr).detach().cpu().numpy() 270 | 271 | test_acc_all = [] 272 | for i in range(args.task_num): 273 | task_idx = i * args.n_way 274 | random_support = torch.randperm(args.n_way * args.k_shot) 275 | random_query = torch.randperm(args.n_way * args.q_query) 276 | 277 | support_idx = task[task_idx:task_idx + args.n_way, :args.k_shot].reshape(1, -1).squeeze()[random_support].detach().cpu().numpy() 278 | query_idx = task[task_idx:task_idx + args.n_way, args.k_shot:].reshape(1, -1).squeeze()[random_query].detach().cpu().numpy() 279 | 280 | task_target = target[task_idx:task_idx + args.n_way] 281 | support_target = map_class(task_target, args.k_shot)[random_support] 282 | query_target = map_class(task_target, args.q_query)[random_query] 283 | 284 | emb_train = embeddings[support_idx] 285 | emb_test = embeddings[query_idx] 286 | 287 | if args.classifier == 'LR': 288 | clf = SKLR(solver='lbfgs', max_iter=1000, multi_class='auto').fit(emb_train, 289 | support_target.detach().numpy()) 290 | elif args.classifier == 'SVC': 291 | params = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]} 292 | clf = GridSearchCV(LinearSVC(), params, cv=5, scoring='accuracy', verbose=0).fit(emb_train, 293 | support_target.detach().numpy()) 294 | 295 | test_acc = clf.score(emb_test, query_target.detach().numpy()) 296 | test_acc_all.append(test_acc) 297 | 298 | final_mean = np.mean(test_acc_all) 299 | final_std = np.std(test_acc_all) 300 | final_interval = 1.96 * (final_std / np.sqrt(len(test_acc_all))) 301 | 302 | log = {mode+"_acc": final_mean, mode+"_std": final_std, mode+"_interval": final_interval} 303 | self.log_dict(log, 304 | prog_bar=True, 305 | batch_size=args.task_num, 306 | add_dataloader_idx=False 307 | ) 308 | 309 | def validation_step(self, batch, batch_idx, dataloader_idx): 310 | if dataloader_idx == 0: 311 | self.fs_test(batch, data=self.data.to(batch[0].device), args=self.args, mode="val") 312 | elif dataloader_idx == 1: 313 | self.fs_test(batch, data=self.data.to(batch[0].device), args=self.args, mode="t_val") 314 | 315 | 316 | def test_step(self, batch, batch_idx): 317 | self.fs_test(batch, data=self.data.to(batch[0].device), args=self.args, mode="test") 318 | 319 | def configure_optimizers(self): 320 | optimizer = optim.Adam(self.parameters(), 321 | lr=self.args.lr, weight_decay=self.args.weight_decay) 322 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR( 323 | optimizer, T_max=self.args.max_epochs 324 | ) 325 | # optimizer = torch.optim.SGD( 326 | # self.parameters(), 327 | # self.args.lr, 328 | # momentum=self.hparams.momentum, 329 | # weight_decay=self.args.weight_decay, 330 | # ) 331 | # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 332 | # optimizer, 333 | # self.trainer.max_epochs, 334 | # ) 335 | return [optimizer], [lr_scheduler] 336 | 337 | 338 | --------------------------------------------------------------------------------