├── .DS_Store ├── GraphExp ├── .ipynb_checkpoints │ ├── params-checkpoint.yaml │ └── search_space-checkpoint.json ├── datasets │ ├── __init__.py │ └── data_util.py ├── evaluator.py ├── main_graph.py ├── models │ ├── DDM.py │ ├── __init__.py │ ├── mlp_gat.py │ └── utils.py ├── run.sh ├── utils │ ├── .ipynb_checkpoints │ │ └── utils-checkpoint.py │ ├── __pycache__ │ │ ├── collect_env.cpython-36.pyc │ │ ├── collect_env.cpython-38.pyc │ │ ├── collect_env.cpython-39.pyc │ │ ├── comm.cpython-36.pyc │ │ ├── comm.cpython-38.pyc │ │ ├── comm.cpython-39.pyc │ │ ├── logger.cpython-36.pyc │ │ ├── logger.cpython-38.pyc │ │ ├── logger.cpython-39.pyc │ │ ├── misc.cpython-36.pyc │ │ ├── misc.cpython-38.pyc │ │ ├── misc.cpython-39.pyc │ │ ├── utils.cpython-38.pyc │ │ └── utils.cpython-39.pyc │ ├── algos.c │ ├── algos.cpython-38-x86_64-linux-gnu.so │ ├── algos.cpython-39-x86_64-linux-gnu.so │ ├── algos.pyx │ ├── build │ │ ├── temp.linux-x86_64-3.8 │ │ │ └── algos.o │ │ └── temp.linux-x86_64-3.9 │ │ │ └── algos.o │ ├── collect_env.py │ ├── comm.py │ ├── logger.py │ ├── metric_logger.py │ ├── misc.py │ ├── setup.py │ └── utils.py └── yamls │ └── MUTAG.yaml ├── NodeExp ├── .ipynb_checkpoints │ ├── Untitled-checkpoint.ipynb │ ├── main_node-checkpoint.py │ ├── params-checkpoint.yaml │ └── search_space-checkpoint.json ├── datasets │ ├── .ipynb_checkpoints │ │ └── data_util-checkpoint.py │ ├── __init__.py │ └── data_util.py ├── evaluator.py ├── main_node.py ├── models │ ├── DDM.py │ ├── __init__.py │ ├── mlp_gat.py │ └── utils.py ├── run.sh ├── utils │ ├── .ipynb_checkpoints │ │ └── utils-checkpoint.py │ ├── __pycache__ │ │ ├── collect_env.cpython-36.pyc │ │ ├── collect_env.cpython-38.pyc │ │ ├── collect_env.cpython-39.pyc │ │ ├── comm.cpython-36.pyc │ │ ├── comm.cpython-38.pyc │ │ ├── comm.cpython-39.pyc │ │ ├── logger.cpython-36.pyc │ │ ├── logger.cpython-38.pyc │ │ ├── logger.cpython-39.pyc │ │ ├── misc.cpython-36.pyc │ │ ├── misc.cpython-38.pyc │ │ ├── misc.cpython-39.pyc │ │ ├── utils.cpython-38.pyc │ │ └── utils.cpython-39.pyc │ ├── algos.c │ ├── algos.cpython-38-x86_64-linux-gnu.so │ ├── algos.cpython-39-x86_64-linux-gnu.so │ ├── algos.pyx │ ├── build │ │ ├── temp.linux-x86_64-3.8 │ │ │ └── algos.o │ │ └── temp.linux-x86_64-3.9 │ │ │ └── algos.o │ ├── collect_env.py │ ├── comm.py │ ├── logger.py │ ├── metric_logger.py │ ├── misc.py │ ├── setup.py │ └── utils.py └── yamls │ └── photo.yaml ├── README.md ├── framework.pdf ├── framework.png ├── json_config ├── amazoncobuycomputer.json ├── amazoncobuypho.json ├── cora-832.json ├── imdbb.json ├── mutag.json ├── pubmed.json └── reddit-b.json ├── nni_search ├── main_graph.py └── run_search.py ├── noise_com.pdf ├── noise_com.png └── requirements.txt /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/.DS_Store -------------------------------------------------------------------------------- /GraphExp/.ipynb_checkpoints/params-checkpoint.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | data_name: MUTAG 3 | deg4feat: True 4 | 5 | DATALOADER: 6 | NUM_WORKERS: 0 7 | BATCH_SIZE: 32 8 | 9 | MODEL: 10 | num_hidden: 512 11 | num_layers: 2 12 | nhead: 4 13 | activation: prelu 14 | attn_drop: 0.1 15 | feat_drop: 0.2 16 | norm: layernorm 17 | pooler: mean 18 | beta_schedule: sigmoid 19 | beta_1: 0.0001 20 | beta_T: 0.02 21 | T: 1000 22 | 23 | SOLVER: 24 | optim_type: adamw 25 | optim_type_f: adamw 26 | alpha: 0.8 27 | decay: 35 28 | LR: 0.00005 29 | # LR_f: 0.005 30 | weight_decay: 0.0005 31 | # weight_decay_f: 0.0005 32 | MAX_EPOCH: 5 33 | # max_epoch_f: 50 34 | 35 | DEVICE: cuda 36 | seeds: 37 | - 11 38 | eval_T: 39 | - 50 40 | - 100 41 | - 200 42 | 43 | -------------------------------------------------------------------------------- /GraphExp/.ipynb_checkpoints/search_space-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "feat_drop": { 3 | "_type": "uniform", 4 | "_value": [ 5 | 0.5, 6 | 0.9 7 | ] 8 | }, 9 | "attn_drop": { 10 | "_type": "uniform", 11 | "_value": [ 12 | 0.5, 13 | 0.9 14 | ] 15 | }, 16 | "num_hidden": { 17 | "_type": " quniform", 18 | "_value": [ 19 | 64, 20 | 512, 21 | 16 22 | ] 23 | }, 24 | "nhead": { 25 | "_type": "choice", 26 | "_value": [ 27 | 2, 28 | 4 29 | ] 30 | }, 31 | "LR": { 32 | "_type": "loguniform", 33 | "_value": [ 34 | 0.00001, 35 | 0.001 36 | ] 37 | }, 38 | "alpha": { 39 | "_type": "choice", 40 | "_value": [ 41 | 0.2, 42 | 0.3, 43 | 0.4, 44 | 0.5, 45 | 0.6, 46 | 0.7, 47 | 0.8, 48 | 0.9, 49 | 1 50 | ] 51 | }, 52 | "decay": { 53 | "_type": "quniform", 54 | "_value": [ 55 | 10, 56 | 50, 57 | 10 58 | ] 59 | }, 60 | "norm": { 61 | "_type": "choice", 62 | "_value": [ 63 | "layernorm", 64 | "batchnorm" 65 | ] 66 | }, 67 | "beta_schedule": { 68 | "_type": "choice", 69 | "_value": [ 70 | "linear", 71 | "quad", 72 | "const", 73 | "sigmoid" 74 | ] 75 | }, 76 | "beta_1": { 77 | "_type": "uniform", 78 | "_value": [ 79 | 0.00005, 80 | 0.001 81 | ] 82 | }, 83 | "beta_T": { 84 | "_type": "uniform", 85 | "_value": [ 86 | 0.005, 87 | 0.1 88 | ] 89 | }, 90 | "T": { 91 | "_type": " quniform", 92 | "_value": [ 93 | 100, 94 | 1000, 95 | 100 96 | ] 97 | } 98 | } -------------------------------------------------------------------------------- /GraphExp/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/datasets/__init__.py -------------------------------------------------------------------------------- /GraphExp/datasets/data_util.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import namedtuple, Counter 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import dgl 9 | from dgl.data import ( 10 | load_data, 11 | TUDataset, 12 | CoraGraphDataset, 13 | CiteseerGraphDataset, 14 | PubmedGraphDataset 15 | ) 16 | from ogb.nodeproppred import DglNodePropPredDataset 17 | from dgl.data.ppi import PPIDataset 18 | from dgl.dataloading import GraphDataLoader 19 | 20 | from sklearn.preprocessing import StandardScaler 21 | 22 | 23 | GRAPH_DICT = { 24 | "cora": CoraGraphDataset, 25 | "citeseer": CiteseerGraphDataset, 26 | "pubmed": PubmedGraphDataset, 27 | "ogbn-arxiv": DglNodePropPredDataset 28 | } 29 | 30 | 31 | def preprocess(graph): 32 | feat = graph.ndata["feat"] 33 | graph = dgl.to_bidirected(graph) 34 | graph.ndata["feat"] = feat 35 | 36 | graph = graph.remove_self_loop().add_self_loop() 37 | graph.create_formats_() 38 | return graph 39 | 40 | 41 | def scale_feats(x): 42 | scaler = StandardScaler() 43 | feats = x.numpy() 44 | scaler.fit(feats) 45 | feats = torch.from_numpy(scaler.transform(feats)).float() 46 | return feats 47 | 48 | 49 | def load_dataset(dataset_name): 50 | assert dataset_name in GRAPH_DICT, f"Unknow dataset: {dataset_name}." 51 | if dataset_name.startswith("ogbn"): 52 | dataset = GRAPH_DICT[dataset_name](dataset_name) 53 | else: 54 | dataset = GRAPH_DICT[dataset_name]() 55 | 56 | if dataset_name == "ogbn-arxiv": 57 | graph, labels = dataset[0] 58 | num_nodes = graph.num_nodes() 59 | 60 | split_idx = dataset.get_idx_split() 61 | train_idx, val_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 62 | graph = preprocess(graph) 63 | 64 | if not torch.is_tensor(train_idx): 65 | train_idx = torch.as_tensor(train_idx) 66 | val_idx = torch.as_tensor(val_idx) 67 | test_idx = torch.as_tensor(test_idx) 68 | 69 | feat = graph.ndata["feat"] 70 | feat = scale_feats(feat) 71 | graph.ndata["feat"] = feat 72 | 73 | train_mask = torch.full((num_nodes,), False).index_fill_(0, train_idx, True) 74 | val_mask = torch.full((num_nodes,), False).index_fill_(0, val_idx, True) 75 | test_mask = torch.full((num_nodes,), False).index_fill_(0, test_idx, True) 76 | graph.ndata["label"] = labels.view(-1) 77 | graph.ndata["train_mask"], graph.ndata["val_mask"], graph.ndata["test_mask"] = train_mask, val_mask, test_mask 78 | else: 79 | graph = dataset[0] 80 | graph = graph.remove_self_loop() 81 | graph = graph.add_self_loop() 82 | num_features = graph.ndata["feat"].shape[1] 83 | num_classes = dataset.num_classes 84 | return graph, (num_features, num_classes) 85 | 86 | 87 | def load_inductive_dataset(dataset_name): 88 | if dataset_name == "ppi": 89 | batch_size = 2 90 | # define loss function 91 | # create the dataset 92 | train_dataset = PPIDataset(mode='train') 93 | valid_dataset = PPIDataset(mode='valid') 94 | test_dataset = PPIDataset(mode='test') 95 | train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size) 96 | valid_dataloader = GraphDataLoader(valid_dataset, batch_size=batch_size, shuffle=False) 97 | test_dataloader = GraphDataLoader(test_dataset, batch_size=batch_size, shuffle=False) 98 | eval_train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size, shuffle=False) 99 | g = train_dataset[0] 100 | num_classes = train_dataset.num_labels 101 | num_features = g.ndata['feat'].shape[1] 102 | else: 103 | _args = namedtuple("dt", "dataset") 104 | dt = _args(dataset_name) 105 | batch_size = 1 106 | dataset = load_data(dt) 107 | num_classes = dataset.num_classes 108 | 109 | g = dataset[0] 110 | num_features = g.ndata["feat"].shape[1] 111 | 112 | train_mask = g.ndata['train_mask'] 113 | feat = g.ndata["feat"] 114 | feat = scale_feats(feat) 115 | g.ndata["feat"] = feat 116 | 117 | g = g.remove_self_loop() 118 | g = g.add_self_loop() 119 | 120 | train_nid = np.nonzero(train_mask.data.numpy())[0].astype(np.int64) 121 | train_g = dgl.node_subgraph(g, train_nid) 122 | train_dataloader = [train_g] 123 | valid_dataloader = [g] 124 | test_dataloader = valid_dataloader 125 | eval_train_dataloader = [train_g] 126 | 127 | return train_dataloader, valid_dataloader, test_dataloader, eval_train_dataloader, num_features, num_classes 128 | 129 | 130 | 131 | def load_graph_classification_dataset(dataset_name, deg4feat=False, PE=True): 132 | dataset_name = dataset_name.upper() 133 | dataset = TUDataset(dataset_name) 134 | graph, _ = dataset[0] 135 | 136 | if "attr" not in graph.ndata: 137 | if "node_labels" in graph.ndata and not deg4feat: 138 | print("Use node label as node features") 139 | feature_dim = 0 140 | for g, _ in dataset: 141 | feature_dim = max(feature_dim, g.ndata["node_labels"].max().item()) 142 | 143 | feature_dim += 1 144 | x_attr = [] 145 | for g, l in dataset: 146 | node_label = g.ndata["node_labels"].view(-1) 147 | feat = F.one_hot(node_label, num_classes=feature_dim).float() 148 | g.ndata["attr"] = feat 149 | x_attr.append(feat) 150 | x_attr = torch.cat(x_attr, dim=0).numpy() 151 | 152 | scaler = StandardScaler() 153 | scaler.fit(x_attr) 154 | for g, l in dataset: 155 | g.ndata['attr'] = torch.from_numpy(scaler.transform(g.ndata['attr'])).float() 156 | 157 | 158 | 159 | else: 160 | print("Using degree as node features") 161 | feature_dim = 0 162 | degrees = [] 163 | for g, _ in dataset: 164 | feature_dim = max(feature_dim, g.in_degrees().max().item()) 165 | degrees.extend(g.in_degrees().tolist()) 166 | MAX_DEGREES = 400 167 | 168 | oversize = 0 169 | for d, n in Counter(degrees).items(): 170 | if d > MAX_DEGREES: 171 | oversize += n 172 | # print(f"N > {MAX_DEGREES}, #NUM: {oversize}, ratio: {oversize/sum(degrees):.8f}") 173 | feature_dim = min(feature_dim, MAX_DEGREES) 174 | 175 | feature_dim += 1 176 | x_attr = [] 177 | for g, l in dataset: 178 | degrees = g.in_degrees() 179 | degrees[degrees > MAX_DEGREES] = MAX_DEGREES 180 | 181 | feat = F.one_hot(degrees, num_classes=feature_dim).float() 182 | g.ndata["attr"] = feat 183 | x_attr.append(feat) 184 | x_attr = torch.cat(x_attr, dim=0).numpy() 185 | scaler = StandardScaler() 186 | scaler.fit(x_attr) 187 | for g, l in dataset: 188 | g.ndata['attr'] = torch.from_numpy(scaler.transform(g.ndata['attr'])).float() 189 | else: 190 | print("******** Use `attr` as node features ********") 191 | feature_dim = graph.ndata["attr"].shape[1] 192 | 193 | labels = torch.tensor([x[1] for x in dataset]) 194 | 195 | num_classes = torch.max(labels).item() + 1 196 | dataset = [(g.remove_self_loop().add_self_loop(), y) for g, y in dataset] 197 | 198 | print(f"******** # Num Graphs: {len(dataset)}, # Num Feat: {feature_dim}, # Num Classes: {num_classes} ********") 199 | 200 | return dataset, (feature_dim, num_classes) 201 | -------------------------------------------------------------------------------- /GraphExp/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import argparse 4 | 5 | import shutil 6 | import time 7 | import os.path as osp 8 | 9 | import dgl 10 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 11 | from dgl.dataloading import GraphDataLoader 12 | from dgl import RandomWalkPE 13 | 14 | import torch 15 | from torch.utils.data.sampler import SubsetRandomSampler 16 | import torch.nn as nn 17 | from dgl.nn.functional import edge_softmax 18 | 19 | from sklearn.model_selection import StratifiedKFold, GridSearchCV 20 | from sklearn.svm import SVC 21 | from sklearn.metrics import f1_score 22 | 23 | from utils.utils import (create_optimizer, create_pooler, set_random_seed, compute_ppr) 24 | 25 | from datasets.data_util import load_graph_classification_dataset 26 | 27 | from models import DDM 28 | 29 | from config import config as cfg 30 | import multiprocessing 31 | from multiprocessing import Pool 32 | 33 | 34 | from utils import comm 35 | from utils.collect_env import collect_env_info 36 | from utils.logger import setup_logger 37 | from utils.misc import mkdir 38 | 39 | 40 | def graph_classification_evaluation(model, T, pooler, dataloader, device, logger): 41 | model.eval() 42 | embed_list = [] 43 | head_list = [] 44 | optim_list = [] 45 | with torch.no_grad(): 46 | for t in T: 47 | x_list = [] 48 | y_list = [] 49 | for i, (batch_g, labels) in enumerate(dataloader): 50 | batch_g = batch_g.to(device) 51 | feat = batch_g.ndata["attr"] 52 | out = model.embed(batch_g, feat, t) 53 | out = pooler(batch_g, out) 54 | y_list.append(labels) 55 | x_list.append(out) 56 | head_list.append(1) 57 | embed_list.append(torch.cat(x_list, dim=0).cpu().numpy()) 58 | y_list = torch.cat(y_list, dim=0) 59 | embed_list = np.array(embed_list) 60 | y_list = y_list.cpu().numpy() 61 | test_f1, test_std = evaluate_graph_embeddings_using_svm(T, embed_list, y_list) 62 | logger.info(f"#Test_f1: {test_f1:.4f}±{test_std:.4f}") 63 | return test_f1 64 | 65 | 66 | def inner_func(args): 67 | T = args[0] 68 | train_index = args[1] 69 | test_index = args[2] 70 | embed_list = args[3] 71 | y_list = args[4] 72 | pred_list = [] 73 | for idx in range(len(T)): 74 | embeddings = embed_list[idx] 75 | labels = y_list 76 | x_train = embeddings[train_index] 77 | x_test = embeddings[test_index] 78 | y_train = labels[train_index] 79 | y_test = labels[test_index] 80 | params = {"C": [1e-3, 1e-2, 1e-1, 1, 10]} 81 | svc = SVC(random_state=42) 82 | clf = GridSearchCV(svc, params) 83 | clf.fit(x_train, y_train) 84 | 85 | out = clf.predict(x_test) 86 | pred_list.append(out) 87 | preds = np.stack(pred_list, axis=0) 88 | preds = torch.from_numpy(preds) 89 | preds = torch.mode(preds, dim=0)[0].long().numpy() 90 | f1 = f1_score(y_test, preds, average="micro") 91 | return f1 92 | 93 | 94 | def evaluate_graph_embeddings_using_svm(T, embed_list, y_list): 95 | kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=0) 96 | process_args = [(T, train_index, test_index, embed_list, y_list) 97 | for train_index, test_index in kf.split(embed_list[0], y_list)] 98 | with Pool(10) as p: 99 | result = p.map(inner_func, process_args) 100 | test_f1 = np.mean(result) 101 | test_std = np.std(result) 102 | 103 | return test_f1, test_std 104 | -------------------------------------------------------------------------------- /GraphExp/main_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: main_graph.py 4 | # Author: Yang Run 5 | # Created Time: 2022-10-28 22:32 6 | # Last Modified: - 7 | import numpy as np 8 | 9 | import argparse 10 | 11 | import shutil 12 | import time 13 | import os.path as osp 14 | 15 | import dgl 16 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 17 | from dgl.dataloading import GraphDataLoader 18 | from dgl import RandomWalkPE 19 | 20 | import torch 21 | from torch.utils.data.sampler import SubsetRandomSampler 22 | import torch.nn as nn 23 | from dgl.nn.functional import edge_softmax 24 | 25 | from sklearn.model_selection import StratifiedKFold, GridSearchCV 26 | from sklearn.svm import SVC 27 | from sklearn.metrics import f1_score 28 | 29 | from utils.utils import (create_optimizer, create_pooler, set_random_seed, compute_ppr) 30 | 31 | from datasets.data_util import load_graph_classification_dataset 32 | 33 | from models import DDM 34 | 35 | import multiprocessing 36 | from multiprocessing import Pool 37 | 38 | 39 | from utils import comm 40 | from utils.collect_env import collect_env_info 41 | from utils.logger import setup_logger 42 | from utils.misc import mkdir 43 | 44 | from evaluator import graph_classification_evaluation 45 | import yaml 46 | from easydict import EasyDict as edict 47 | 48 | 49 | parser = argparse.ArgumentParser(description='Graph DGL Training') 50 | parser.add_argument('--resume', '-r', action='store_true', default=False, 51 | help='resume from checkpoint') 52 | parser.add_argument("--local_rank", type=int, default=0, help="local rank") 53 | parser.add_argument("--seed", type=int, default=1234, help="random seed") 54 | parser.add_argument("--yaml_dir", type=str, default=None) 55 | parser.add_argument("--output_dir", type=str, default=None) 56 | parser.add_argument("--checkpoint_dir", type=str, default=None) 57 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 58 | help='manual epoch number (useful on restarts)') 59 | args = parser.parse_args() 60 | 61 | 62 | def pretrain(model, train_loader, optimizer, device, epoch, logger=None): 63 | model.train() 64 | loss_list = [] 65 | for batch in train_loader: 66 | batch_g, _ = batch 67 | batch_g = batch_g.to(device) 68 | feat = batch_g.ndata["attr"] 69 | loss, loss_dict = model(batch_g, feat) 70 | optimizer.zero_grad() 71 | loss.backward() 72 | optimizer.step() 73 | loss_list.append(loss.item()) 74 | lr = optimizer.param_groups[0]['lr'] 75 | logger.info(f"Epoch {epoch} | train_loss: {np.mean(loss_list):.4f} | lr: {lr:.6f}") 76 | 77 | 78 | def collate_fn(batch): 79 | graphs = [x[0] for x in batch] 80 | labels = [x[1] for x in batch] 81 | batch_g = dgl.batch(graphs) 82 | labels = torch.cat(labels, dim=0) 83 | return batch_g, labels 84 | 85 | 86 | def save_checkpoint(state, is_best, filename): 87 | ckp = osp.join(filename, 'checkpoint.pth.tar') 88 | # ckp = filename + "checkpoint.pth.tar" 89 | torch.save(state, ckp) 90 | if is_best: 91 | shutil.copyfile(ckp, filename+'/model_best.pth.tar') 92 | 93 | 94 | def adjust_learning_rate(optimizer, epoch, alpha, decay, lr): 95 | """Sets the learning rate to the initial LR decayed by 10 every 80 epochs""" 96 | lr = lr * (alpha ** (epoch // decay)) 97 | for param_group in optimizer.param_groups: 98 | param_group['lr'] = lr 99 | 100 | 101 | def main(cfg): 102 | best_f1 = float('-inf') 103 | best_f1_epoch = float('inf') 104 | 105 | if cfg.output_dir: 106 | mkdir(cfg.output_dir) 107 | mkdir(cfg.checkpoint_dir) 108 | 109 | logger = setup_logger("graph", cfg.output_dir, comm.get_rank(), filename='train_log.txt') 110 | logger.info("Rank of current process: {}. World size: {}".format(comm.get_rank(), comm.get_world_size())) 111 | logger.info("Environment info:\n" + collect_env_info()) 112 | logger.info("Command line arguments: " + str(args)) 113 | 114 | shutil.copyfile('./params.yaml', cfg.output_dir + '/params.yaml') 115 | shutil.copyfile('./main_graph.py', cfg.output_dir + '/graph.py') 116 | shutil.copyfile('./models/DDM.py', cfg.output_dir + '/DDM.py') 117 | shutil.copyfile('./models/mlp_gat.py', cfg.output_dir + '/mlp_gat.py') 118 | 119 | graphs, (num_features, num_classes) = load_graph_classification_dataset(cfg.DATA.data_name, 120 | deg4feat=cfg.DATA.deg4feat, 121 | PE=False) 122 | cfg.num_features = num_features 123 | 124 | train_idx = torch.arange(len(graphs)) 125 | train_sampler = SubsetRandomSampler(train_idx) 126 | train_loader = GraphDataLoader(graphs, sampler=train_sampler, collate_fn=collate_fn, 127 | batch_size=cfg.DATALOADER.BATCH_SIZE, pin_memory=True) 128 | eval_loader = GraphDataLoader(graphs, collate_fn=collate_fn, batch_size=len(graphs), shuffle=False) 129 | 130 | pooler = create_pooler(cfg.MODEL.pooler) 131 | 132 | acc_list = [] 133 | for i, seed in enumerate(cfg.seeds): 134 | logger.info(f'Run {i}th for seed {seed}') 135 | set_random_seed(seed) 136 | 137 | ml_cfg = cfg.MODEL 138 | ml_cfg.update({'in_dim': num_features}) 139 | model = DDM(**ml_cfg) 140 | total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 141 | logger.info('Total trainable params num : {}'.format(total_trainable_params)) 142 | model.to(cfg.DEVICE) 143 | 144 | optimizer = create_optimizer(cfg.SOLVER.optim_type, model, cfg.SOLVER.LR, cfg.SOLVER.weight_decay) 145 | 146 | start_epoch = 0 147 | if args.resume: 148 | if osp.isfile(cfg.pretrain_checkpoint_dir): 149 | logger.info("=> loading checkpoint '{}'".format(cfg.checkpoint_dir)) 150 | checkpoint = torch.load(cfg.checkpoint_dir, map_location=torch.device('cpu')) 151 | start_epoch = checkpoint['epoch'] 152 | model.load_state_dict(checkpoint['state_dict']) 153 | optimizer.load_state_dict(checkpoint['optimizer']) 154 | logger.info("=> loaded checkpoint '{}' (epoch {})" 155 | .format(cfg.checkpoint_dir, checkpoint['epoch'])) 156 | 157 | logger.info("----------Start Training----------") 158 | 159 | for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH): 160 | adjust_learning_rate(optimizer, epoch=epoch, alpha=cfg.SOLVER.alpha, decay=cfg.SOLVER.decay, lr=cfg.SOLVER.LR) 161 | pretrain(model, train_loader, optimizer, cfg.DEVICE, epoch, logger) 162 | if ((epoch + 1) % 1 == 0) & (epoch > 1): 163 | model.eval() 164 | test_f1 = graph_classification_evaluation(model, cfg.eval_T, pooler, eval_loader, 165 | cfg.DEVICE, logger) 166 | is_best = test_f1 > best_f1 167 | if is_best: 168 | best_f1_epoch = epoch 169 | best_f1 = max(test_f1, best_f1) 170 | logger.info(f"Epoch {epoch}: get test f1 score: {test_f1: .3f}") 171 | logger.info(f"best_f1 {best_f1:.3f} at epoch {best_f1_epoch}") 172 | save_checkpoint({'epoch': epoch + 1, 173 | 'state_dict': model.state_dict(), 174 | 'best_f1': best_f1, 175 | 'optimizer': optimizer.state_dict()}, 176 | is_best, filename=cfg.checkpoint_dir) 177 | acc_list.append(best_f1) 178 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list) 179 | logger.info((f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}")) 180 | return final_acc 181 | 182 | 183 | if __name__ == "__main__": 184 | root_dir = osp.abspath(osp.dirname(__file__)) 185 | yaml_dir = osp.join(root_dir, 'params.yaml') 186 | output_dir = osp.join(root_dir, 'log') 187 | checkpoint_dir = osp.join(output_dir, "checkpoint") 188 | 189 | yaml_dir = args.yaml_dir if args.yaml_dir else yaml_dir 190 | output_dir = args.output_dir if args.output_dir else output_dir 191 | checkpoint_dir = args.checkpoint_dir if args.checkpoint_dir else checkpoint_dir 192 | 193 | with open(yaml_dir, "r") as f: 194 | config = yaml.load(f, yaml.FullLoader) 195 | cfg = edict(config) 196 | 197 | cfg.output_dir, cfg.checkpoint_dir = output_dir, checkpoint_dir 198 | print(cfg) 199 | f1 = main(cfg) 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | -------------------------------------------------------------------------------- /GraphExp/models/DDM.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: diffusion.py 4 | # Author: Yang Run 5 | # Created Time: 2022-10-29 17:09 6 | # Last Modified: - 7 | 8 | import sys 9 | from typing import Optional 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn import init 15 | 16 | import math 17 | import dgl 18 | import dgl.function as fn 19 | from utils.utils import make_edge_weights 20 | from .mlp_gat import Denoising_Unet 21 | import numpy as np 22 | 23 | 24 | def extract(v, t, x_shape): 25 | """ 26 | Extract some coefficients at specified timesteps, then reshape to 27 | [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. 28 | """ 29 | out = torch.gather(v, index=t, dim=0).float() 30 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) 31 | 32 | 33 | class DDM(nn.Module): 34 | def __init__( 35 | self, 36 | in_dim: int, 37 | num_hidden: int, 38 | num_layers: int, 39 | nhead: int, 40 | activation: str, 41 | feat_drop: float, 42 | attn_drop: float, 43 | norm: Optional[str], 44 | alpha_l: float = 2, 45 | beta_schedule: str = 'linear', 46 | beta_1: float = 0.0001, 47 | beta_T: float = 0.02, 48 | T: int = 1000, 49 | **kwargs 50 | 51 | ): 52 | super(DDM, self).__init__() 53 | self.T = T 54 | beta = get_beta_schedule(beta_schedule, beta_1, beta_T, T) 55 | self.register_buffer( 56 | 'betas', beta 57 | ) 58 | alphas = 1. - self.betas 59 | alphas_bar = torch.cumprod(alphas, dim=0) 60 | 61 | self.register_buffer( 62 | 'sqrt_alphas_bar', torch.sqrt(alphas_bar) 63 | ) 64 | self.register_buffer( 65 | 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar) 66 | ) 67 | 68 | self.alpha_l = alpha_l 69 | assert num_hidden % nhead == 0 70 | self.net = Denoising_Unet(in_dim=in_dim, 71 | num_hidden=num_hidden, 72 | out_dim=in_dim, 73 | num_layers=num_layers, 74 | nhead=nhead, 75 | activation=activation, 76 | feat_drop=feat_drop, 77 | attn_drop=attn_drop, 78 | negative_slope=0.2, 79 | norm=norm) 80 | 81 | self.time_embedding = nn.Embedding(T, num_hidden) 82 | 83 | def forward(self, g, x): 84 | with torch.no_grad(): 85 | x = F.layer_norm(x, (x.shape[-1], )) 86 | 87 | t = torch.randint(self.T, size=(x.shape[0], ), device=x.device) 88 | x_t, time_embed, g, label_embed = self.sample_q(t, x, g) 89 | 90 | loss = self.node_denoising(x, x_t, time_embed, g) 91 | loss_item = {"loss": loss.item()} 92 | return loss, loss_item 93 | 94 | def sample_q(self, t, x, g): 95 | miu, std = x.mean(dim=0), x.std(dim=0) 96 | noise = torch.randn_like(x, device=x.device) 97 | with torch.no_grad(): 98 | noise = F.layer_norm(noise, (noise.shape[-1], )) 99 | noise = noise * std + miu 100 | noise = torch.sign(x) * torch.abs(noise) 101 | x_t = ( 102 | extract(self.sqrt_alphas_bar, t, x.shape) * x + 103 | extract(self.sqrt_one_minus_alphas_bar, t, x.shape) * noise 104 | ) 105 | time_embed = self.time_embedding(t) 106 | return x_t, time_embed, g 107 | 108 | def node_denoising(self, x, x_t, time_embed, g): 109 | out, _ = self.net(g, x_t=x_t, time_embed=time_embed) 110 | loss = sce_loss(out, x, self.alpha_l) 111 | 112 | return loss 113 | 114 | def embed(self, g, x, T): 115 | t = torch.full((1, ), T, device=x.device) 116 | with torch.no_grad(): 117 | x = F.layer_norm(x, (x.shape[-1], )) 118 | x_t, time_embed, g = self.sample_q(t, x, g) 119 | _, hidden = self.net(g, x_t=x_t, time_embed=time_embed) 120 | return hidden 121 | 122 | 123 | def loss_fn(x, y, alpha=2): 124 | x = F.normalize(x, p=2, dim=-1) 125 | y = F.normalize(y, p=2, dim=-1) 126 | 127 | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) 128 | 129 | loss = loss.mean() 130 | return loss 131 | 132 | 133 | def get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps): 134 | def sigmoid(x): 135 | return 1 / (np.exp(-x) + 1) 136 | 137 | if beta_schedule == "quad": 138 | betas = ( 139 | np.linspace( 140 | beta_start ** 0.5, 141 | beta_end ** 0.5, 142 | num_diffusion_timesteps, 143 | dtype=np.float64, 144 | ) 145 | ** 2 146 | ) 147 | elif beta_schedule == "linear": 148 | betas = np.linspace( 149 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 150 | ) 151 | elif beta_schedule == "const": 152 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 153 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 154 | betas = 1.0 / np.linspace( 155 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 156 | ) 157 | elif beta_schedule == "sigmoid": 158 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 159 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 160 | else: 161 | raise NotImplementedError(beta_schedule) 162 | assert betas.shape == (num_diffusion_timesteps,) 163 | return torch.from_numpy(betas) 164 | -------------------------------------------------------------------------------- /GraphExp/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .DDM import DDM 2 | 3 | -------------------------------------------------------------------------------- /GraphExp/models/mlp_gat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: mlp_gat.py 4 | # Author: Yang Run 5 | # Created Time: 2022-12-06 13:48 6 | # Last Modified: - 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import dgl 12 | from dgl.nn import GINConv 13 | from dgl.nn import GATConv 14 | from dgl.nn import EGATConv 15 | import dgl.function as fn 16 | from dgl.nn.functional import edge_softmax 17 | from .utils import create_activation, create_norm 18 | 19 | 20 | def exists(x): 21 | return x is not None 22 | 23 | 24 | class Denoising_Unet(nn.Module): 25 | def __init__(self, 26 | in_dim, 27 | num_hidden, 28 | out_dim, 29 | num_layers, 30 | nhead, 31 | activation, 32 | feat_drop, 33 | attn_drop, 34 | negative_slope, 35 | norm, 36 | ): 37 | super(Denoising_Unet, self).__init__() 38 | self.out_dim = out_dim 39 | self.num_heads = nhead 40 | self.num_layers = num_layers 41 | self.num_hidden = num_hidden 42 | self.down_layers = nn.ModuleList() 43 | self.up_layers = nn.ModuleList() 44 | self.activation = activation 45 | 46 | self.mlp_in_t = MlpBlock(in_dim=in_dim, hidden_dim=num_hidden*2, out_dim=num_hidden, 47 | norm=norm, activation=activation) 48 | 49 | self.mlp_middle = MlpBlock(num_hidden, num_hidden, num_hidden, norm=norm, activation=activation) 50 | 51 | self.mlp_out = MlpBlock(num_hidden, out_dim, out_dim, norm=norm, activation=activation) 52 | 53 | self.down_layers.append(GATConv(num_hidden, num_hidden // nhead, nhead, feat_drop, attn_drop, negative_slope)) 54 | self.up_layers.append(GATConv(num_hidden, num_hidden, 1, feat_drop, attn_drop, negative_slope)) 55 | 56 | 57 | for _ in range(1, num_layers): 58 | self.down_layers.append(GATConv(num_hidden, num_hidden // nhead, nhead, feat_drop, 59 | attn_drop, negative_slope)) 60 | self.up_layers.append(GATConv(num_hidden, num_hidden // nhead, nhead, feat_drop, 61 | attn_drop, negative_slope)) 62 | self.up_layers = self.up_layers[::-1] 63 | 64 | def forward(self, g, x_t, time_embed): 65 | h_t = self.mlp_in_t(x_t) 66 | down_hidden = [] 67 | for l in range(self.num_layers): 68 | if h_t.ndim > 2: 69 | h_t = h_t + time_embed.unsqueeze(1).repeat(1, h_t.shape[1], 1) 70 | else: 71 | pass 72 | h_t = self.down_layers[l](g, h_t) 73 | h_t = h_t.flatten(1) 74 | down_hidden.append(h_t) 75 | h_middle = self.mlp_middle(h_t) 76 | 77 | h_t = h_middle 78 | out_hidden = [] 79 | for l in range(self.num_layers): 80 | h_t = h_t + down_hidden[self.num_layers - l - 1 ] 81 | if h_t.ndim > 2: 82 | h_t = h_t + time_embed.unsqueeze(1).repeat(1, h_t.shape[1], 1) 83 | else: 84 | pass 85 | h_t = self.up_layers[l](g, h_t) 86 | h_t = h_t.flatten(1) 87 | out_hidden.append(h_t) 88 | out = self.mlp_out(h_t) 89 | out_hidden = torch.cat(out_hidden, dim=-1) 90 | 91 | return out, out_hidden 92 | 93 | 94 | class Residual(nn.Module): 95 | def __init__(self, fnc): 96 | super().__init__() 97 | self.fnc = fnc 98 | 99 | def forward(self, x, *args, **kwargs): 100 | return self.fnc(x, *args, **kwargs) + x 101 | 102 | 103 | class MlpBlock(nn.Module): 104 | def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, 105 | norm: str = 'layernorm', activation: str = 'prelu'): 106 | super(MlpBlock, self).__init__() 107 | self.in_proj = nn.Linear(in_dim, hidden_dim) 108 | self.res_mlp = Residual(nn.Sequential(nn.Linear(hidden_dim, hidden_dim), 109 | create_norm(norm)(hidden_dim), 110 | create_activation(activation), 111 | nn.Linear(hidden_dim, hidden_dim))) 112 | self.out_proj = nn.Linear(hidden_dim, out_dim) 113 | self.act = create_activation(activation) 114 | def forward(self, x): 115 | x = self.in_proj(x) 116 | x = self.res_mlp(x) 117 | x = self.out_proj(x) 118 | x = self.act(x) 119 | return x 120 | 121 | -------------------------------------------------------------------------------- /GraphExp/models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: utils.py 4 | # Author: Yang Run 5 | # Created Time: 2022-10-29 23:35 6 | # Last Modified: - 7 | 8 | import torch 9 | import torch.nn as nn 10 | from functools import partial 11 | 12 | def create_activation(name): 13 | if name == "relu": 14 | return nn.ReLU() 15 | elif name == "gelu": 16 | return nn.GELU() 17 | elif name == "prelu": 18 | return nn.PReLU() 19 | elif name is None: 20 | return nn.Identity() 21 | elif name == "elu": 22 | return nn.ELU() 23 | else: 24 | raise NotImplementedError(f"{name} is not implemented.") 25 | 26 | 27 | def create_norm(name): 28 | if name == "layernorm": 29 | return nn.LayerNorm 30 | elif name == "batchnorm": 31 | return nn.BatchNorm1d 32 | elif name == "graphnorm": 33 | return partial(NormLayer, norm_type="groupnorm") 34 | else: 35 | return nn.Identity 36 | 37 | -------------------------------------------------------------------------------- /GraphExp/run.sh: -------------------------------------------------------------------------------- 1 | python main_graph.py 2 | -------------------------------------------------------------------------------- /GraphExp/utils/.ipynb_checkpoints/utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import yaml 5 | import logging 6 | from functools import partial 7 | import numpy as np 8 | 9 | import dgl 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch import optim as optim 14 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 15 | from scipy.linalg import fractional_matrix_power, inv 16 | 17 | from sklearn.metrics import f1_score 18 | 19 | 20 | def compute_ppr(graph, alpha=0.2, self_loop=False): 21 | a = graph.adj().to_dense().numpy() 22 | if self_loop: 23 | a = a + np.eye(a.shape[0]) # A^ = A + I_n 24 | d = np.diag(np.sum(a, 1)) # D^ = Sigma A^_ii 25 | dinv = fractional_matrix_power(d, -0.5) # D^(-1/2) 26 | at = np.matmul(np.matmul(dinv, a), dinv) # A~ = D^(-1/2) x A^ x D^(-1/2) 27 | return alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at)) # a(I_n-(1-a)A~)^-1 28 | 29 | 30 | def accuracy(y_pred, y_true): 31 | y_true = y_true.squeeze().long().cpu().numpy() 32 | preds = y_pred.max(1)[1].long().cpu().numpy() 33 | f1 = f1_score(y_true, preds, average='micro') 34 | # correct = preds.eq(y_true).double() 35 | # correct = correct.sum().item() 36 | # return correct / len(y_true) 37 | return f1 38 | 39 | 40 | def set_random_seed(seed): 41 | random.seed(seed) 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | torch.backends.cudnn.determinstic = True 47 | # dgl.random.seed(seed) 48 | 49 | 50 | def get_current_lr(optimizer): 51 | return optimizer.state_dict()["param_groups"][0]["lr"] 52 | 53 | 54 | def build_args(): 55 | parser = argparse.ArgumentParser(description="GAT") 56 | parser.add_argument("--seeds", type=int, nargs="+", default=[0]) 57 | parser.add_argument("--dataset", type=str, default="cora") 58 | parser.add_argument("--device", type=int, default=-1) 59 | parser.add_argument("--max_epoch", type=int, default=200, 60 | help="number of training epochs") 61 | parser.add_argument("--warmup_steps", type=int, default=-1) 62 | 63 | parser.add_argument("--num_heads", type=int, default=4, 64 | help="number of hidden attention heads") 65 | parser.add_argument("--num_out_heads", type=int, default=1, 66 | help="number of output attention heads") 67 | parser.add_argument("--num_layers", type=int, default=2, 68 | help="number of hidden layers") 69 | parser.add_argument("--num_hidden", type=int, default=256, 70 | help="number of hidden units") 71 | parser.add_argument("--residual", action="store_true", default=False, 72 | help="use residual connection") 73 | parser.add_argument("--in_drop", type=float, default=.2, 74 | help="input feature dropout") 75 | parser.add_argument("--attn_drop", type=float, default=.1, 76 | help="attention dropout") 77 | parser.add_argument("--norm", type=str, default=None) 78 | parser.add_argument("--lr", type=float, default=0.005, 79 | help="learning rate") 80 | parser.add_argument("--weight_decay", type=float, default=5e-4, 81 | help="weight decay") 82 | parser.add_argument("--negative_slope", type=float, default=0.2, 83 | help="the negative slope of leaky relu for GAT") 84 | parser.add_argument("--activation", type=str, default="prelu") 85 | parser.add_argument("--mask_rate", type=float, default=0.5) 86 | parser.add_argument("--drop_edge_rate", type=float, default=0.0) 87 | parser.add_argument("--replace_rate", type=float, default=0.0) 88 | 89 | parser.add_argument("--encoder", type=str, default="gat") 90 | parser.add_argument("--decoder", type=str, default="gat") 91 | parser.add_argument("--loss_fn", type=str, default="sce") 92 | parser.add_argument("--alpha_l", type=float, default=2, help="`pow`coefficient for `sce` loss") 93 | parser.add_argument("--optimizer", type=str, default="adam") 94 | 95 | parser.add_argument("--max_epoch_f", type=int, default=30) 96 | parser.add_argument("--lr_f", type=float, default=0.001, help="learning rate for evaluation") 97 | parser.add_argument("--weight_decay_f", type=float, default=0.0, help="weight decay for evaluation") 98 | parser.add_argument("--linear_prob", action="store_true", default=False) 99 | 100 | parser.add_argument("--load_model", action="store_true") 101 | parser.add_argument("--save_model", action="store_true") 102 | parser.add_argument("--use_cfg", action="store_true") 103 | parser.add_argument("--logging", action="store_true") 104 | parser.add_argument("--scheduler", action="store_true", default=False) 105 | parser.add_argument("--concat_hidden", action="store_true", default=False) 106 | 107 | # for graph classification 108 | parser.add_argument("--pooling", type=str, default="mean") 109 | parser.add_argument("--deg4feat", action="store_true", default=False, help="use node degree as input feature") 110 | parser.add_argument("--batch_size", type=int, default=32) 111 | args = parser.parse_args() 112 | return args 113 | 114 | 115 | def create_activation(name): 116 | if name == "relu": 117 | return nn.ReLU() 118 | elif name == "gelu": 119 | return nn.GELU() 120 | elif name == "prelu": 121 | return nn.PReLU() 122 | elif name is None: 123 | return nn.Identity() 124 | elif name == "elu": 125 | return nn.ELU() 126 | else: 127 | raise NotImplementedError(f"{name} is not implemented.") 128 | 129 | 130 | def create_norm(name): 131 | if name == "layernorm": 132 | return nn.LayerNorm 133 | elif name == "batchnorm": 134 | return nn.BatchNorm1d 135 | elif name == "graphnorm": 136 | return partial(NormLayer, norm_type="groupnorm") 137 | else: 138 | return nn.Identity 139 | 140 | 141 | def create_optimizer(opt, model, lr, weight_decay, get_num_layer=None, get_layer_scale=None): 142 | opt_lower = opt.lower() 143 | 144 | parameters = model.parameters() 145 | opt_args = dict(lr=lr, weight_decay=weight_decay) 146 | 147 | opt_split = opt_lower.split("_") 148 | opt_lower = opt_split[-1] 149 | if opt_lower == "adam": 150 | optimizer = optim.Adam(parameters, **opt_args) 151 | elif opt_lower == "adamw": 152 | optimizer = optim.AdamW(parameters, **opt_args) 153 | elif opt_lower == "adadelta": 154 | optimizer = optim.Adadelta(parameters, **opt_args) 155 | elif opt_lower == "radam": 156 | optimizer = optim.RAdam(parameters, **opt_args) 157 | elif opt_lower == "sgd": 158 | opt_args["momentum"] = 0.9 159 | return optim.SGD(parameters, **opt_args) 160 | else: 161 | assert False and "Invalid optimizer" 162 | 163 | return optimizer 164 | 165 | 166 | def create_pooler(pooling): 167 | if pooling == "mean": 168 | pooler = AvgPooling() 169 | elif pooling == "max": 170 | pooler = MaxPooling() 171 | elif pooling == "sum": 172 | pooler = SumPooling() 173 | else: 174 | raise NotImplementedError 175 | return pooler 176 | 177 | 178 | 179 | # ------------------- 180 | def mask_edge(graph, mask_prob): 181 | E = graph.num_edges() 182 | 183 | mask_rates = torch.FloatTensor(np.ones(E) * mask_prob) 184 | masks = torch.bernoulli(1 - mask_rates) 185 | mask_idx = masks.nonzero().squeeze(1) 186 | return mask_idx 187 | 188 | def make_edge_weights(graph): 189 | E = graph.num_edges() 190 | weights = torch.FloatTensor(np.ones(E)) 191 | return weights 192 | 193 | def make_noisy_edge_weights(graph): 194 | E = graph.num_edges() 195 | weights = torch.FloatTensor(torch.rand(E)) 196 | return weights 197 | 198 | 199 | 200 | def drop_edge(graph, drop_rate, return_edges=False): 201 | if drop_rate <= 0: 202 | return graph 203 | 204 | n_node = graph.num_nodes() 205 | edge_mask = mask_edge(graph, drop_rate) 206 | src = graph.edges()[0] 207 | dst = graph.edges()[1] 208 | 209 | nsrc = src[edge_mask] 210 | ndst = dst[edge_mask] 211 | 212 | ng = dgl.graph((nsrc, ndst), num_nodes=n_node) 213 | ng = ng.add_self_loop() 214 | 215 | dsrc = src[~edge_mask] 216 | ddst = dst[~edge_mask] 217 | 218 | if return_edges: 219 | return ng, (dsrc, ddst) 220 | return ng 221 | 222 | 223 | def load_best_configs(args, path): 224 | with open(path, "r") as f: 225 | configs = yaml.load(f, yaml.FullLoader) 226 | 227 | if args.dataset not in configs: 228 | logging.info("Best args not found") 229 | return args 230 | 231 | logging.info("Using best configs") 232 | configs = configs[args.dataset] 233 | 234 | for k, v in configs.items(): 235 | if "lr" in k or "weight_decay" in k: 236 | v = float(v) 237 | setattr(args, k, v) 238 | print("------ Use best configs ------") 239 | return args 240 | 241 | 242 | # ------ logging ------ 243 | 244 | class TBLogger(object): 245 | def __init__(self, log_path="./logging_data", name="run"): 246 | super(TBLogger, self).__init__() 247 | 248 | if not os.path.exists(log_path): 249 | os.makedirs(log_path, exist_ok=True) 250 | 251 | self.last_step = 0 252 | self.log_path = log_path 253 | raw_name = os.path.join(log_path, name) 254 | name = raw_name 255 | for i in range(1000): 256 | name = raw_name + str(f"_{i}") 257 | if not os.path.exists(name): 258 | break 259 | self.writer = SummaryWriter(logdir=name) 260 | 261 | def note(self, metrics, step=None): 262 | if step is None: 263 | step = self.last_step 264 | for key, value in metrics.items(): 265 | self.writer.add_scalar(key, value, step) 266 | self.last_step = step 267 | 268 | def finish(self): 269 | self.writer.close() 270 | 271 | 272 | class NormLayer(nn.Module): 273 | def __init__(self, hidden_dim, norm_type): 274 | super().__init__() 275 | if norm_type == "batchnorm": 276 | self.norm = nn.BatchNorm1d(hidden_dim) 277 | elif norm_type == "layernorm": 278 | self.norm = nn.LayerNorm(hidden_dim) 279 | elif norm_type == "graphnorm": 280 | self.norm = norm_type 281 | self.weight = nn.Parameter(torch.ones(hidden_dim)) 282 | self.bias = nn.Parameter(torch.zeros(hidden_dim)) 283 | 284 | self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) 285 | else: 286 | raise NotImplementedError 287 | 288 | def forward(self, graph, x): 289 | tensor = x 290 | if self.norm is not None and type(self.norm) != str: 291 | return self.norm(tensor) 292 | elif self.norm is None: 293 | return tensor 294 | 295 | batch_list = graph.batch_num_nodes 296 | batch_size = len(batch_list) 297 | batch_list = torch.Tensor(batch_list).long().to(tensor.device) 298 | batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) 299 | batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor) 300 | mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) 301 | mean = mean.scatter_add_(0, batch_index, tensor) 302 | mean = (mean.T / batch_list).T 303 | mean = mean.repeat_interleave(batch_list, dim=0) 304 | 305 | sub = tensor - mean * self.mean_scale 306 | 307 | std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) 308 | std = std.scatter_add_(0, batch_index, sub.pow(2)) 309 | std = ((std.T / batch_list).T + 1e-6).sqrt() 310 | std = std.repeat_interleave(batch_list, dim=0) 311 | return self.weight * sub / std + self.bias 312 | -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/collect_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/collect_env.cpython-36.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/collect_env.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/collect_env.cpython-38.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/collect_env.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/collect_env.cpython-39.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/comm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/comm.cpython-38.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/comm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/comm.cpython-39.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/misc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/misc.cpython-39.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /GraphExp/utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /GraphExp/utils/algos.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/algos.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /GraphExp/utils/algos.cpython-39-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/algos.cpython-39-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /GraphExp/utils/algos.pyx: -------------------------------------------------------------------------------- 1 | 2 | import cython 3 | from cython.parallel cimport prange, parallel 4 | cimport numpy 5 | import numpy 6 | 7 | 8 | def floyd_warshall(adjacency_matrix): 9 | 10 | (nrows, ncols) = adjacency_matrix.shape 11 | assert nrows == ncols 12 | cdef unsigned int n = nrows 13 | 14 | adj_mat_copy = adjacency_matrix.astype(long, order='C', casting='safe', copy=True) 15 | assert adj_mat_copy.flags['C_CONTIGUOUS'] 16 | cdef numpy.ndarray[long, ndim=2, mode='c'] M = adj_mat_copy 17 | cdef numpy.ndarray[long, ndim=2, mode='c'] path = -1 * numpy.ones([n, n], dtype=numpy.int64) 18 | 19 | cdef unsigned int i, j, k 20 | cdef long M_ij, M_ik, cost_ikkj 21 | cdef long* M_ptr = &M[0,0] 22 | cdef long* M_i_ptr 23 | cdef long* M_k_ptr 24 | 25 | # set unreachable nodes distance to 510 26 | for i in range(n): 27 | for j in range(n): 28 | if i == j: 29 | M[i][j] = 0 30 | elif M[i][j] == 0: 31 | M[i][j] = 510 32 | 33 | # floyed algo 34 | for k in range(n): 35 | M_k_ptr = M_ptr + n*k 36 | for i in range(n): 37 | M_i_ptr = M_ptr + n*i 38 | M_ik = M_i_ptr[k] 39 | for j in range(n): 40 | cost_ikkj = M_ik + M_k_ptr[j] 41 | M_ij = M_i_ptr[j] 42 | if M_ij > cost_ikkj: 43 | M_i_ptr[j] = cost_ikkj 44 | path[i][j] = k 45 | 46 | # set unreachable path to 510 47 | for i in range(n): 48 | for j in range(n): 49 | if M[i][j] >= 510: 50 | path[i][j] = 510 51 | M[i][j] = 510 52 | 53 | return M, path 54 | 55 | 56 | def get_all_edges(path, i, j): 57 | cdef int k = path[i][j] 58 | if k == -1: 59 | return [] 60 | else: 61 | return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j) 62 | 63 | 64 | def gen_edge_input(max_dist, path, edge_feat): 65 | 66 | (nrows, ncols) = path.shape 67 | assert nrows == ncols 68 | cdef unsigned int n = nrows 69 | cdef unsigned int max_dist_copy = max_dist 70 | 71 | path_copy = path.astype(long, order='C', casting='safe', copy=True) 72 | edge_feat_copy = edge_feat.astype(long, order='C', casting='safe', copy=True) 73 | assert path_copy.flags['C_CONTIGUOUS'] 74 | assert edge_feat_copy.flags['C_CONTIGUOUS'] 75 | 76 | cdef numpy.ndarray[long, ndim=4, mode='c'] edge_fea_all = -1 * numpy.ones([n, n, max_dist_copy, edge_feat.shape[-1]], dtype=numpy.int64) 77 | cdef unsigned int i, j, k, num_path, cur 78 | 79 | for i in range(n): 80 | for j in range(n): 81 | if i == j: 82 | continue 83 | if path_copy[i][j] == 510: 84 | continue 85 | path = [i] + get_all_edges(path_copy, i, j) + [j] 86 | num_path = len(path) - 1 87 | for k in range(num_path): 88 | edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :] 89 | 90 | return edge_fea_all 91 | -------------------------------------------------------------------------------- /GraphExp/utils/build/temp.linux-x86_64-3.8/algos.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/build/temp.linux-x86_64-3.8/algos.o -------------------------------------------------------------------------------- /GraphExp/utils/build/temp.linux-x86_64-3.9/algos.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/GraphExp/utils/build/temp.linux-x86_64-3.9/algos.o -------------------------------------------------------------------------------- /GraphExp/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import PIL 3 | 4 | from torch.utils.collect_env import get_pretty_env_info 5 | 6 | 7 | def get_pil_version(): 8 | return "\n Pillow ({})".format(PIL.__version__) 9 | 10 | 11 | def collect_env_info(): 12 | env_str = get_pretty_env_info() 13 | env_str += get_pil_version() 14 | return env_str 15 | -------------------------------------------------------------------------------- /GraphExp/utils/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import functools 9 | import pickle 10 | import logging 11 | 12 | 13 | def get_world_size(): 14 | if not dist.is_available(): 15 | return 1 16 | if not dist.is_initialized(): 17 | return 1 18 | return dist.get_world_size() 19 | 20 | 21 | def get_rank(): 22 | if not dist.is_available(): 23 | return 0 24 | if not dist.is_initialized(): 25 | return 0 26 | return dist.get_rank() 27 | 28 | 29 | def is_main_process(): 30 | return get_rank() == 0 31 | 32 | 33 | def synchronize(): 34 | """ 35 | Helper function to synchronize (barrier) among all processes when 36 | using distributed training 37 | """ 38 | if not dist.is_available(): 39 | return 40 | if not dist.is_initialized(): 41 | return 42 | world_size = dist.get_world_size() 43 | if world_size == 1: 44 | return 45 | dist.barrier() 46 | 47 | 48 | def all_gather(data): 49 | """ 50 | Run all_gather on arbitrary picklable data (not necessarily tensors) 51 | Args: 52 | data: any picklable object 53 | Returns: 54 | list[data]: list of data gathered from each rank 55 | """ 56 | world_size = get_world_size() 57 | if world_size == 1: 58 | return [data] 59 | 60 | # serialized to a Tensor 61 | buffer = pickle.dumps(data) 62 | storage = torch.ByteStorage.from_buffer(buffer) 63 | tensor = torch.ByteTensor(storage).to("cuda") 64 | 65 | # obtain Tensor size of each rank 66 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 67 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 68 | dist.all_gather(size_list, local_size) 69 | size_list = [int(size.item()) for size in size_list] 70 | max_size = max(size_list) 71 | 72 | # receiving Tensor from all ranks 73 | # we pad the tensor because torch all_gather does not support 74 | # gathering tensors of different shapes 75 | tensor_list = [] 76 | for _ in size_list: 77 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 78 | if local_size != max_size: 79 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 80 | tensor = torch.cat((tensor, padding), dim=0) 81 | dist.all_gather(tensor_list, tensor) 82 | 83 | data_list = [] 84 | for size, tensor in zip(size_list, tensor_list): 85 | buffer = tensor.cpu().numpy().tobytes()[:size] 86 | data_list.append(pickle.loads(buffer)) 87 | 88 | return data_list 89 | 90 | 91 | @functools.lru_cache() 92 | def _get_global_gloo_group(): 93 | """ 94 | Return a process group based on gloo backend, containing all the ranks 95 | The result is cached. 96 | """ 97 | if dist.get_backend() == "nccl": 98 | return dist.new_group(backend="gloo") 99 | else: 100 | return dist.group.WORLD 101 | 102 | 103 | def _serialize_to_tensor(data, group): 104 | backend = dist.get_backend(group) 105 | assert backend in ["gloo", "nccl"] 106 | device = torch.device("cpu" if backend == "gloo" else "cuda") 107 | 108 | buffer = pickle.dumps(data) 109 | if len(buffer) > 1024 ** 3: 110 | logger = logging.getLogger(__name__) 111 | logger.warning( 112 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 113 | get_rank(), len(buffer) / (1024 ** 3), device 114 | ) 115 | ) 116 | storage = torch.ByteStorage.from_buffer(buffer) 117 | tensor = torch.ByteTensor(storage).to(device=device) 118 | return tensor 119 | 120 | 121 | def _pad_to_largest_tensor(tensor, group): 122 | """ 123 | Returns: 124 | list[int]: size of the tensor, on each rank 125 | Tensor: padded tensor that has the max size 126 | """ 127 | world_size = dist.get_world_size(group=group) 128 | assert ( 129 | world_size >= 1 130 | ), "comm.gather/all_gather must be called from ranks within the given group!" 131 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 132 | size_list = [ 133 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 134 | ] 135 | dist.all_gather(size_list, local_size, group=group) 136 | size_list = [int(size.item()) for size in size_list] 137 | 138 | max_size = max(size_list) 139 | 140 | # we pad the tensor because torch all_gather does not support 141 | # gathering tensors of different shapes 142 | if local_size != max_size: 143 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 144 | tensor = torch.cat((tensor, padding), dim=0) 145 | return size_list, tensor 146 | 147 | 148 | def reduce_dict(input_dict, average=True): 149 | """ 150 | Args: 151 | input_dict (dict): all the values will be reduced 152 | average (bool): whether to do average or sum 153 | Reduce the values in the dictionary from all processes so that process with rank 154 | 0 has the averaged results. Returns a dict with the same fields as 155 | input_dict, after reduction. 156 | """ 157 | world_size = get_world_size() 158 | if world_size < 2: 159 | return input_dict 160 | with torch.no_grad(): 161 | names = [] 162 | values = [] 163 | # sort the keys so that they are consistent across processes 164 | for k in sorted(input_dict.keys()): 165 | names.append(k) 166 | values.append(input_dict[k]) 167 | values = torch.stack(values, dim=0) 168 | dist.reduce(values, dst=0) 169 | if dist.get_rank() == 0 and average: 170 | # only main process gets accumulated, so only divide by 171 | # world_size in this case 172 | values /= world_size 173 | reduced_dict = {k: v for k, v in zip(names, values)} 174 | return reduced_dict 175 | 176 | 177 | def gather(data, dst=0, group=None): 178 | """ 179 | Run gather on arbitrary picklable data (not necessarily tensors). 180 | 181 | Args: 182 | data: any picklable object 183 | dst (int): destination rank 184 | group: a torch process group. By default, will use a group which 185 | contains all ranks on gloo backend. 186 | 187 | Returns: 188 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 189 | an empty list. 190 | """ 191 | if get_world_size() == 1: 192 | return [data] 193 | if group is None: 194 | group = _get_global_gloo_group() 195 | if dist.get_world_size(group=group) == 1: 196 | return [data] 197 | rank = dist.get_rank(group=group) 198 | tensor = _serialize_to_tensor(data, group) 199 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 200 | 201 | # receiving Tensor from all ranks 202 | if rank == dst: 203 | max_size = max(size_list) 204 | tensor_list = [ 205 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 206 | ] 207 | dist.gather(tensor, tensor_list, dst=dst, group=group) 208 | 209 | data_list = [] 210 | for size, tensor in zip(size_list, tensor_list): 211 | buffer = tensor.cpu().numpy().tobytes()[:size] 212 | data_list.append(pickle.loads(buffer)) 213 | return data_list 214 | else: 215 | dist.gather(tensor, [], dst=dst, group=group) 216 | return [] 217 | -------------------------------------------------------------------------------- /GraphExp/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import functools 4 | import os 5 | import sys 6 | from termcolor import colored 7 | 8 | 9 | class _ColorfulFormatter(logging.Formatter): 10 | def __init__(self, *args, **kwargs): 11 | self._root_name = kwargs.pop("root_name") + "." 12 | self._abbrev_name = kwargs.pop("abbrev_name", "") 13 | if len(self._abbrev_name): 14 | self._abbrev_name = self._abbrev_name + "." 15 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 16 | 17 | def formatMessage(self, record): 18 | record.name = record.name.replace(self._root_name, self._abbrev_name) 19 | log = super(_ColorfulFormatter, self).formatMessage(record) 20 | if record.levelno == logging.WARNING: 21 | prefix = colored("WARNING", "red", attrs=["blink"]) 22 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 23 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 24 | else: 25 | return log 26 | return prefix + " " + log 27 | 28 | 29 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers 30 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", color=True, abbrev_name=None): 31 | logger = logging.getLogger(name) 32 | logger.setLevel(logging.DEBUG) 33 | logger.propagate = False 34 | 35 | if abbrev_name is None: 36 | abbrev_name = "ugait" if name == "ugait" else name 37 | plain_formatter = logging.Formatter( 38 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 39 | ) 40 | 41 | # don't log results for the non-master process 42 | if distributed_rank > 0: 43 | return logger 44 | ch = logging.StreamHandler(stream=sys.stdout) 45 | ch.setLevel(logging.DEBUG) 46 | if color: 47 | formatter = _ColorfulFormatter( 48 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 49 | datefmt="%m/%d %H:%M:%S", 50 | root_name=name, 51 | abbrev_name=str(abbrev_name), 52 | ) 53 | else: 54 | formatter = plain_formatter 55 | ch.setFormatter(formatter) 56 | logger.addHandler(ch) 57 | 58 | if save_dir: 59 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 60 | fh.setLevel(logging.DEBUG) 61 | fh.setFormatter(formatter) 62 | logger.addHandler(fh) 63 | return logger 64 | -------------------------------------------------------------------------------- /GraphExp/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from collections import defaultdict, deque 5 | 6 | 7 | class SmoothedValue(object): 8 | """Track a series of values and provide access to smoothed values over a 9 | window or the global series average. 10 | """ 11 | 12 | def __init__(self, window_size=20): 13 | self.deque = deque(maxlen=window_size) 14 | self.series = [] 15 | self.total = 0.0 16 | self.count = 0 17 | 18 | def update(self, value): 19 | self.deque.append(value) 20 | self.series.append(value) 21 | self.count += 1 22 | self.total += value 23 | 24 | @property 25 | def median(self): 26 | d = torch.tensor(list(self.deque)) 27 | return d.median().item() 28 | 29 | @property 30 | def avg(self): 31 | d = torch.tensor(list(self.deque)) 32 | return d.mean().item() 33 | 34 | @property 35 | def global_avg(self): 36 | return self.total / self.count 37 | 38 | 39 | class MetricLogger(object): 40 | def __init__(self, delimiter="\t"): 41 | self.meters = defaultdict(SmoothedValue) 42 | self.delimiter = delimiter 43 | 44 | def update(self, **kwargs): 45 | for k, v in kwargs.items(): 46 | if isinstance(v, torch.Tensor): 47 | v = v.item() 48 | assert isinstance(v, (float, int)) 49 | self.meters[k].update(v) 50 | 51 | def __getattr__(self, attr): 52 | if attr in self.meters: 53 | return self.meters[attr] 54 | if attr in self.__dict__: 55 | return self.__dict__[attr] 56 | raise AttributeError("'{}' object has no attribute '{}'".format( 57 | type(self).__name__, attr)) 58 | 59 | def __str__(self): 60 | loss_str = [] 61 | for name, meter in self.meters.items(): 62 | loss_str.append( 63 | "{}: {:.4f}".format(name, meter.global_avg) 64 | ) 65 | return self.delimiter.join(loss_str) 66 | -------------------------------------------------------------------------------- /GraphExp/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Tencent, Inc. and its affiliates. All Rights Reserved. 2 | import errno 3 | import logging 4 | import os 5 | 6 | 7 | def mkdir(path): 8 | try: 9 | os.makedirs(path) 10 | except OSError as e: 11 | if e.errno != errno.EEXIST: 12 | raise 13 | 14 | 15 | def link_file(src, target): 16 | """symbol link the source directories to target.""" 17 | if os.path.isdir(target) or os.path.isfile(target): 18 | os.remove(target) 19 | os.system('ln -s {} {}'.format(src, target)) 20 | 21 | 22 | def print_model_parameters(model,logger, only_num=True): 23 | logger.info('*****************Model Parameter*****************') 24 | if not only_num: 25 | for name, param in model.named_parameters(): 26 | logger.info(name, param.shape, param.requires_grad) 27 | total_num = sum([param.nelement() for param in model.parameters()]) 28 | logger.info('Total params num: {}'.format(total_num)) 29 | logger.info('*****************Finish Parameter****************') 30 | -------------------------------------------------------------------------------- /GraphExp/utils/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: setup.py 4 | # Author: Yang Run 5 | # Created Time: 2022-11-08 00:41 6 | # Last Modified: - 7 | 8 | from distutils.core import setup 9 | from Cython.Build import cythonize 10 | import numpy as np 11 | 12 | setup( 13 | ext_modules=cythonize('algos.pyx'), 14 | include_dirs=[np.get_include()] 15 | ) 16 | -------------------------------------------------------------------------------- /GraphExp/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import yaml 5 | import logging 6 | from functools import partial 7 | import numpy as np 8 | 9 | import dgl 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch import optim as optim 14 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 15 | from scipy.linalg import fractional_matrix_power, inv 16 | 17 | from sklearn.metrics import f1_score 18 | 19 | 20 | def compute_ppr(graph, alpha=0.2, self_loop=False): 21 | a = graph.adj().to_dense().numpy() 22 | if self_loop: 23 | a = a + np.eye(a.shape[0]) # A^ = A + I_n 24 | d = np.diag(np.sum(a, 1)) # D^ = Sigma A^_ii 25 | dinv = fractional_matrix_power(d, -0.5) # D^(-1/2) 26 | at = np.matmul(np.matmul(dinv, a), dinv) # A~ = D^(-1/2) x A^ x D^(-1/2) 27 | return alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at)) # a(I_n-(1-a)A~)^-1 28 | 29 | 30 | def accuracy(y_pred, y_true): 31 | y_true = y_true.squeeze().long().cpu().numpy() 32 | preds = y_pred.max(1)[1].long().cpu().numpy() 33 | f1 = f1_score(y_true, preds, average='micro') 34 | # correct = preds.eq(y_true).double() 35 | # correct = correct.sum().item() 36 | # return correct / len(y_true) 37 | return f1 38 | 39 | 40 | def set_random_seed(seed): 41 | random.seed(seed) 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | torch.backends.cudnn.determinstic = True 47 | # dgl.random.seed(seed) 48 | 49 | 50 | def get_current_lr(optimizer): 51 | return optimizer.state_dict()["param_groups"][0]["lr"] 52 | 53 | 54 | def build_args(): 55 | parser = argparse.ArgumentParser(description="GAT") 56 | parser.add_argument("--seeds", type=int, nargs="+", default=[0]) 57 | parser.add_argument("--dataset", type=str, default="cora") 58 | parser.add_argument("--device", type=int, default=-1) 59 | parser.add_argument("--max_epoch", type=int, default=200, 60 | help="number of training epochs") 61 | parser.add_argument("--warmup_steps", type=int, default=-1) 62 | 63 | parser.add_argument("--num_heads", type=int, default=4, 64 | help="number of hidden attention heads") 65 | parser.add_argument("--num_out_heads", type=int, default=1, 66 | help="number of output attention heads") 67 | parser.add_argument("--num_layers", type=int, default=2, 68 | help="number of hidden layers") 69 | parser.add_argument("--num_hidden", type=int, default=256, 70 | help="number of hidden units") 71 | parser.add_argument("--residual", action="store_true", default=False, 72 | help="use residual connection") 73 | parser.add_argument("--in_drop", type=float, default=.2, 74 | help="input feature dropout") 75 | parser.add_argument("--attn_drop", type=float, default=.1, 76 | help="attention dropout") 77 | parser.add_argument("--norm", type=str, default=None) 78 | parser.add_argument("--lr", type=float, default=0.005, 79 | help="learning rate") 80 | parser.add_argument("--weight_decay", type=float, default=5e-4, 81 | help="weight decay") 82 | parser.add_argument("--negative_slope", type=float, default=0.2, 83 | help="the negative slope of leaky relu for GAT") 84 | parser.add_argument("--activation", type=str, default="prelu") 85 | parser.add_argument("--mask_rate", type=float, default=0.5) 86 | parser.add_argument("--drop_edge_rate", type=float, default=0.0) 87 | parser.add_argument("--replace_rate", type=float, default=0.0) 88 | 89 | parser.add_argument("--encoder", type=str, default="gat") 90 | parser.add_argument("--decoder", type=str, default="gat") 91 | parser.add_argument("--loss_fn", type=str, default="sce") 92 | parser.add_argument("--alpha_l", type=float, default=2, help="`pow`coefficient for `sce` loss") 93 | parser.add_argument("--optimizer", type=str, default="adam") 94 | 95 | parser.add_argument("--max_epoch_f", type=int, default=30) 96 | parser.add_argument("--lr_f", type=float, default=0.001, help="learning rate for evaluation") 97 | parser.add_argument("--weight_decay_f", type=float, default=0.0, help="weight decay for evaluation") 98 | parser.add_argument("--linear_prob", action="store_true", default=False) 99 | 100 | parser.add_argument("--load_model", action="store_true") 101 | parser.add_argument("--save_model", action="store_true") 102 | parser.add_argument("--use_cfg", action="store_true") 103 | parser.add_argument("--logging", action="store_true") 104 | parser.add_argument("--scheduler", action="store_true", default=False) 105 | parser.add_argument("--concat_hidden", action="store_true", default=False) 106 | 107 | # for graph classification 108 | parser.add_argument("--pooling", type=str, default="mean") 109 | parser.add_argument("--deg4feat", action="store_true", default=False, help="use node degree as input feature") 110 | parser.add_argument("--batch_size", type=int, default=32) 111 | args = parser.parse_args() 112 | return args 113 | 114 | 115 | def create_activation(name): 116 | if name == "relu": 117 | return nn.ReLU() 118 | elif name == "gelu": 119 | return nn.GELU() 120 | elif name == "prelu": 121 | return nn.PReLU() 122 | elif name is None: 123 | return nn.Identity() 124 | elif name == "elu": 125 | return nn.ELU() 126 | else: 127 | raise NotImplementedError(f"{name} is not implemented.") 128 | 129 | 130 | def create_norm(name): 131 | if name == "layernorm": 132 | return nn.LayerNorm 133 | elif name == "batchnorm": 134 | return nn.BatchNorm1d 135 | elif name == "graphnorm": 136 | return partial(NormLayer, norm_type="groupnorm") 137 | else: 138 | return nn.Identity 139 | 140 | 141 | def create_optimizer(opt, model, lr, weight_decay, get_num_layer=None, get_layer_scale=None): 142 | opt_lower = opt.lower() 143 | 144 | parameters = model.parameters() 145 | opt_args = dict(lr=lr, weight_decay=weight_decay) 146 | 147 | opt_split = opt_lower.split("_") 148 | opt_lower = opt_split[-1] 149 | if opt_lower == "adam": 150 | optimizer = optim.Adam(parameters, **opt_args) 151 | elif opt_lower == "adamw": 152 | optimizer = optim.AdamW(parameters, **opt_args) 153 | elif opt_lower == "adadelta": 154 | optimizer = optim.Adadelta(parameters, **opt_args) 155 | elif opt_lower == "radam": 156 | optimizer = optim.RAdam(parameters, **opt_args) 157 | elif opt_lower == "sgd": 158 | opt_args["momentum"] = 0.9 159 | return optim.SGD(parameters, **opt_args) 160 | else: 161 | assert False and "Invalid optimizer" 162 | 163 | return optimizer 164 | 165 | 166 | def create_pooler(pooling): 167 | if pooling == "mean": 168 | pooler = AvgPooling() 169 | elif pooling == "max": 170 | pooler = MaxPooling() 171 | elif pooling == "sum": 172 | pooler = SumPooling() 173 | else: 174 | raise NotImplementedError 175 | return pooler 176 | 177 | 178 | 179 | # ------------------- 180 | def mask_edge(graph, mask_prob): 181 | E = graph.num_edges() 182 | 183 | mask_rates = torch.FloatTensor(np.ones(E) * mask_prob) 184 | masks = torch.bernoulli(1 - mask_rates) 185 | mask_idx = masks.nonzero().squeeze(1) 186 | return mask_idx 187 | 188 | def make_edge_weights(graph): 189 | E = graph.num_edges() 190 | weights = torch.FloatTensor(np.ones(E)) 191 | return weights 192 | 193 | def make_noisy_edge_weights(graph): 194 | E = graph.num_edges() 195 | weights = torch.FloatTensor(torch.rand(E)) 196 | return weights 197 | 198 | 199 | 200 | def drop_edge(graph, drop_rate, return_edges=False): 201 | if drop_rate <= 0: 202 | return graph 203 | 204 | n_node = graph.num_nodes() 205 | edge_mask = mask_edge(graph, drop_rate) 206 | src = graph.edges()[0] 207 | dst = graph.edges()[1] 208 | 209 | nsrc = src[edge_mask] 210 | ndst = dst[edge_mask] 211 | 212 | ng = dgl.graph((nsrc, ndst), num_nodes=n_node) 213 | ng = ng.add_self_loop() 214 | 215 | dsrc = src[~edge_mask] 216 | ddst = dst[~edge_mask] 217 | 218 | if return_edges: 219 | return ng, (dsrc, ddst) 220 | return ng 221 | 222 | 223 | def load_best_configs(args, path): 224 | with open(path, "r") as f: 225 | configs = yaml.load(f, yaml.FullLoader) 226 | 227 | if args.dataset not in configs: 228 | logging.info("Best args not found") 229 | return args 230 | 231 | logging.info("Using best configs") 232 | configs = configs[args.dataset] 233 | 234 | for k, v in configs.items(): 235 | if "lr" in k or "weight_decay" in k: 236 | v = float(v) 237 | setattr(args, k, v) 238 | print("------ Use best configs ------") 239 | return args 240 | 241 | 242 | # ------ logging ------ 243 | 244 | class TBLogger(object): 245 | def __init__(self, log_path="./logging_data", name="run"): 246 | super(TBLogger, self).__init__() 247 | 248 | if not os.path.exists(log_path): 249 | os.makedirs(log_path, exist_ok=True) 250 | 251 | self.last_step = 0 252 | self.log_path = log_path 253 | raw_name = os.path.join(log_path, name) 254 | name = raw_name 255 | for i in range(1000): 256 | name = raw_name + str(f"_{i}") 257 | if not os.path.exists(name): 258 | break 259 | self.writer = SummaryWriter(logdir=name) 260 | 261 | def note(self, metrics, step=None): 262 | if step is None: 263 | step = self.last_step 264 | for key, value in metrics.items(): 265 | self.writer.add_scalar(key, value, step) 266 | self.last_step = step 267 | 268 | def finish(self): 269 | self.writer.close() 270 | 271 | 272 | class NormLayer(nn.Module): 273 | def __init__(self, hidden_dim, norm_type): 274 | super().__init__() 275 | if norm_type == "batchnorm": 276 | self.norm = nn.BatchNorm1d(hidden_dim) 277 | elif norm_type == "layernorm": 278 | self.norm = nn.LayerNorm(hidden_dim) 279 | elif norm_type == "graphnorm": 280 | self.norm = norm_type 281 | self.weight = nn.Parameter(torch.ones(hidden_dim)) 282 | self.bias = nn.Parameter(torch.zeros(hidden_dim)) 283 | 284 | self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) 285 | else: 286 | raise NotImplementedError 287 | 288 | def forward(self, graph, x): 289 | tensor = x 290 | if self.norm is not None and type(self.norm) != str: 291 | return self.norm(tensor) 292 | elif self.norm is None: 293 | return tensor 294 | 295 | batch_list = graph.batch_num_nodes 296 | batch_size = len(batch_list) 297 | batch_list = torch.Tensor(batch_list).long().to(tensor.device) 298 | batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) 299 | batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor) 300 | mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) 301 | mean = mean.scatter_add_(0, batch_index, tensor) 302 | mean = (mean.T / batch_list).T 303 | mean = mean.repeat_interleave(batch_list, dim=0) 304 | 305 | sub = tensor - mean * self.mean_scale 306 | 307 | std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) 308 | std = std.scatter_add_(0, batch_index, sub.pow(2)) 309 | std = ((std.T / batch_list).T + 1e-6).sqrt() 310 | std = std.repeat_interleave(batch_list, dim=0) 311 | return self.weight * sub / std + self.bias 312 | -------------------------------------------------------------------------------- /GraphExp/yamls/MUTAG.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | data_name: MUTAG 3 | deg4feat: True 4 | 5 | DATALOADER: 6 | NUM_WORKERS: 0 7 | BATCH_SIZE: 32 8 | 9 | MODEL: 10 | num_hidden: 512 11 | num_layers: 2 12 | nhead: 4 13 | activation: prelu 14 | attn_drop: 0.1 15 | feat_drop: 0.2 16 | norm: layernorm 17 | pooler: mean 18 | beta_schedule: sigmoid 19 | beta_1: 0.000335 20 | beta_T: 0.03379 21 | T: 728 22 | 23 | SOLVER: 24 | optim_type: adamw 25 | optim_type_f: adamw 26 | alpha: 1 27 | decay: 30 28 | LR: 0.000292 29 | weight_decay: 0.0005 30 | MAX_EPOCH: 100 31 | 32 | DEVICE: cuda 33 | seeds: 34 | - 11 35 | eval_T: 36 | - 50 37 | - 100 38 | - 200 39 | 40 | -------------------------------------------------------------------------------- /NodeExp/.ipynb_checkpoints/Untitled-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 5 6 | } 7 | -------------------------------------------------------------------------------- /NodeExp/.ipynb_checkpoints/main_node-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: main_graph.py 4 | # Author: Yang Run 5 | # Created Time: 2022-10-28 22:32 6 | # Last Modified: - 7 | import numpy as np 8 | import json 9 | 10 | import argparse 11 | 12 | import shutil 13 | import time 14 | import os.path as osp 15 | 16 | import dgl 17 | import torch 18 | import torch.nn as nn 19 | 20 | from utils.utils import (create_optimizer, create_pooler, set_random_seed, compute_ppr) 21 | from datasets.data_util import load_dataset 22 | from models import DDM 23 | # from config import config as cfg 24 | import multiprocessing 25 | from multiprocessing import Pool 26 | 27 | from utils import comm 28 | from utils.collect_env import collect_env_info 29 | from utils.logger import setup_logger 30 | from utils.misc import mkdir 31 | 32 | from evaluator import node_classification_evaluation 33 | import yaml 34 | import nni 35 | from easydict import EasyDict as edict 36 | 37 | 38 | parser = argparse.ArgumentParser(description='Graph DGL Training') 39 | parser.add_argument('--resume', '-r', action='store_true', default=False, 40 | help='resume from checkpoint') 41 | parser.add_argument("--local_rank", type=int, default=0, help="local rank") 42 | parser.add_argument("--seed", type=int, default=1234, help="random seed") 43 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 44 | help='manual epoch number (useful on restarts)') 45 | args = parser.parse_args() 46 | 47 | 48 | 49 | def pretrain(model, graph, feat, optimizer, epoch, logger, loss_record): 50 | logger.info("start epoch {}.".format(epoch)) 51 | model.train() 52 | loss, loss_dict = model(graph, feat) 53 | loss_record.append({'epoch': epoch, 'loss': loss.item()}) 54 | optimizer.zero_grad() 55 | loss.backward() 56 | optimizer.step() 57 | lr = optimizer.param_groups[0]['lr'] 58 | logger.info(f"# Epoch {epoch}: train_loss: {loss.item():.4f} | lr: {lr:.6f}") 59 | 60 | 61 | def save_checkpoint(state, is_best, filename): 62 | ckp = osp.join(filename, 'checkpoint.pth.tar') 63 | torch.save(state, ckp) 64 | if is_best: 65 | shutil.copyfile(ckp, filename+'/model_best.pth.tar') 66 | 67 | 68 | def adjust_learning_rate(optimizer, epoch, alpha, decay, lr): 69 | """Sets the learning rate to the initial LR decayed by 10 every 80 epochs""" 70 | lr = lr * (alpha ** (epoch // decay)) 71 | for param_group in optimizer.param_groups: 72 | param_group['lr'] = lr 73 | 74 | 75 | def main(cfg): 76 | 77 | if cfg.output_dir: 78 | mkdir(cfg.output_dir) 79 | mkdir(cfg.checkpoint_dir) 80 | 81 | logger = setup_logger("graph", cfg.output_dir, comm.get_rank(), filename='train_log.txt') 82 | logger.info("Rank of current process: {}. World size: {}".format(comm.get_rank(), comm.get_world_size())) 83 | logger.info("Environment info:\n" + collect_env_info()) 84 | logger.info("Command line arguments: " + str(args)) 85 | 86 | shutil.copyfile('./params.yaml', cfg.output_dir + '/params.yaml') 87 | shutil.copyfile('./main_node.py', cfg.output_dir + '/node.py') 88 | shutil.copyfile('./models/DDM.py', cfg.output_dir + '/DDM.py') 89 | shutil.copyfile('./models/mlp_gat.py', cfg.output_dir + '/mlp_gat.py') 90 | 91 | graph, (num_features, num_classes) = load_dataset(cfg.DATA.data_name) 92 | 93 | 94 | acc_list = [] 95 | for i, seed in enumerate(cfg.seeds): 96 | best_acc = float('-inf') 97 | best_estp_acc = float('-inf') 98 | best_acc_epoch = float('inf') 99 | logger.info(f'Run {i}th for seed {seed}') 100 | set_random_seed(seed) 101 | 102 | ml_cfg = cfg.MODEL 103 | ml_cfg.update({'in_dim': num_features}) 104 | model = DDM(**ml_cfg) 105 | total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 106 | logger.info('Total trainable params num : {}'.format(total_trainable_params)) 107 | model.to(cfg.DEVICE) 108 | 109 | optimizer = create_optimizer(cfg.SOLVER.optim_type, model, cfg.SOLVER.LR, cfg.SOLVER.weight_decay) 110 | 111 | start_epoch = 0 112 | if args.resume: 113 | if osp.isfile(cfg.pretrain_checkpoint_dir): 114 | logger.info("=> loading checkpoint '{}'".format(cfg.checkpoint_dir)) 115 | checkpoint = torch.load(cfg.checkpoint_dir, map_location=torch.device('cpu')) 116 | start_epoch = checkpoint['epoch'] 117 | model.load_state_dict(checkpoint['state_dict']) 118 | optimizer.load_state_dict(checkpoint['optimizer']) 119 | logger.info("=> loaded checkpoint '{}' (epoch {})" 120 | .format(cfg.checkpoint_dir, checkpoint['epoch'])) 121 | 122 | logger.info("----------Start Training----------") 123 | 124 | graph = graph.to(cfg.DEVICE) 125 | feat = graph.ndata['feat'] 126 | loss_record = [] 127 | for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH): 128 | adjust_learning_rate(optimizer, epoch=epoch, alpha=cfg.SOLVER.alpha, decay=cfg.SOLVER.decay, lr=cfg.SOLVER.LR) 129 | pretrain(model, graph, feat, optimizer, epoch, logger, loss_record) 130 | if ((epoch + 1) % 1 == 0) & (epoch > 100): 131 | model.eval() 132 | acc = node_classification_evaluation(model, cfg.eval_T, graph, feat, num_classes, 133 | cfg.SOLVER.optim_type, 134 | cfg.SOLVER.LR_f, cfg.SOLVER.weight_decay_f, 135 | cfg.SOLVER.max_epoch_f, cfg.DEVICE, logger) 136 | nni.report_intermediate_result(acc) 137 | is_best = acc > best_acc 138 | if is_best: 139 | best_acc_epoch = epoch 140 | best_acc = max(acc, best_acc) 141 | logger.info(f"Epoch {epoch}: get test acc score: {acc: .3f}") 142 | logger.info(f"best_f1 {best_acc:.3f} at epoch {best_acc_epoch}") 143 | # save_checkpoint({'epoch': epoch + 1, 144 | # 'state_dict': model.state_dict(), 145 | # 'best_f1': best_f1, 146 | # 'optimizer': optimizer.state_dict()}, 147 | # is_best, filename=cfg.checkpoint_dir) 148 | acc_list.append(best_acc) 149 | file = './record_dm.json' 150 | with open(file, "w") as w: 151 | json.dump(loss_record, w) 152 | 153 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list) 154 | logger.info((f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}")) 155 | return final_acc 156 | 157 | 158 | if __name__ == "__main__": 159 | root_dir = osp.abspath(osp.dirname(__file__)) 160 | yaml_dir = osp.join(root_dir, 'params.yaml') 161 | output_dir = osp.join(root_dir, 'log-recordloss') 162 | checkpoint_dir = osp.join(output_dir, "checkpoint") 163 | 164 | with open(yaml_dir, "r") as f: 165 | config = yaml.load(f, yaml.FullLoader) 166 | cfg = edict(config) 167 | 168 | cfg.output_dir, cfg.checkpoint_dir = output_dir, checkpoint_dir 169 | optimized_params = nni.get_next_parameter() 170 | # optimized_params = {} 171 | SOLVER_params = {} 172 | for key, value in cfg.SOLVER.items(): 173 | param_type = type(value) 174 | sp = optimized_params.get(key, value) 175 | if type(sp) == float and param_type == int: 176 | sp = int(sp) 177 | SOLVER_params[key] = sp 178 | cfg.SOLVER.update(SOLVER_params) 179 | MODEL_params = {} 180 | for key, value in cfg.MODEL.items(): 181 | param_type = type(value) 182 | sp = optimized_params.get(key, value) 183 | if type(sp) == float and param_type == int: 184 | sp = int(sp) 185 | MODEL_params[key] = sp 186 | cfg.MODEL.update(MODEL_params) 187 | 188 | print('---new cfg------') 189 | print(cfg) 190 | f1 = main(cfg) 191 | nni.report_final_result(f1) 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /NodeExp/.ipynb_checkpoints/params-checkpoint.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | data_name: MUTAG 3 | deg4feat: True 4 | 5 | DATALOADER: 6 | NUM_WORKERS: 0 7 | BATCH_SIZE: 32 8 | 9 | MODEL: 10 | num_hidden: 512 11 | num_layers: 2 12 | nhead: 4 13 | activation: prelu 14 | attn_drop: 0.1 15 | feat_drop: 0.2 16 | norm: layernorm 17 | pooler: mean 18 | beta_schedule: sigmoid 19 | beta_1: 0.0001 20 | beta_T: 0.02 21 | T: 1000 22 | 23 | SOLVER: 24 | optim_type: adamw 25 | optim_type_f: adamw 26 | alpha: 0.8 27 | decay: 35 28 | LR: 0.00005 29 | # LR_f: 0.005 30 | weight_decay: 0.0005 31 | # weight_decay_f: 0.0005 32 | MAX_EPOCH: 5 33 | # max_epoch_f: 50 34 | 35 | DEVICE: cuda 36 | seeds: 37 | - 11 38 | eval_T: 39 | - 50 40 | - 100 41 | - 200 42 | 43 | -------------------------------------------------------------------------------- /NodeExp/.ipynb_checkpoints/search_space-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "feat_drop": { 3 | "_type": "uniform", 4 | "_value": [ 5 | 0.5, 6 | 0.9 7 | ] 8 | }, 9 | "attn_drop": { 10 | "_type": "uniform", 11 | "_value": [ 12 | 0.5, 13 | 0.9 14 | ] 15 | }, 16 | "num_hidden": { 17 | "_type": " quniform", 18 | "_value": [ 19 | 64, 20 | 512, 21 | 16 22 | ] 23 | }, 24 | "nhead": { 25 | "_type": "choice", 26 | "_value": [ 27 | 2, 28 | 4 29 | ] 30 | }, 31 | "LR": { 32 | "_type": "loguniform", 33 | "_value": [ 34 | 0.00001, 35 | 0.001 36 | ] 37 | }, 38 | "alpha": { 39 | "_type": "choice", 40 | "_value": [ 41 | 0.2, 42 | 0.3, 43 | 0.4, 44 | 0.5, 45 | 0.6, 46 | 0.7, 47 | 0.8, 48 | 0.9, 49 | 1 50 | ] 51 | }, 52 | "decay": { 53 | "_type": "quniform", 54 | "_value": [ 55 | 10, 56 | 50, 57 | 10 58 | ] 59 | }, 60 | "norm": { 61 | "_type": "choice", 62 | "_value": [ 63 | "layernorm", 64 | "batchnorm" 65 | ] 66 | }, 67 | "beta_schedule": { 68 | "_type": "choice", 69 | "_value": [ 70 | "linear", 71 | "quad", 72 | "const", 73 | "sigmoid" 74 | ] 75 | }, 76 | "beta_1": { 77 | "_type": "uniform", 78 | "_value": [ 79 | 0.00005, 80 | 0.001 81 | ] 82 | }, 83 | "beta_T": { 84 | "_type": "uniform", 85 | "_value": [ 86 | 0.005, 87 | 0.1 88 | ] 89 | }, 90 | "T": { 91 | "_type": " quniform", 92 | "_value": [ 93 | 100, 94 | 1000, 95 | 100 96 | ] 97 | } 98 | } -------------------------------------------------------------------------------- /NodeExp/datasets/.ipynb_checkpoints/data_util-checkpoint.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import namedtuple, Counter 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import dgl 9 | from dgl.data import ( 10 | load_data, 11 | TUDataset, 12 | CoraGraphDataset, 13 | CiteseerGraphDataset, 14 | PubmedGraphDataset 15 | ) 16 | from dgl.data import AmazonCoBuyPhotoDataset, AmazonCoBuyComputerDataset 17 | from ogb.nodeproppred import DglNodePropPredDataset 18 | from dgl.data.ppi import PPIDataset 19 | from dgl.dataloading import GraphDataLoader 20 | 21 | from sklearn.preprocessing import StandardScaler 22 | 23 | 24 | GRAPH_DICT = { 25 | "cora": CoraGraphDataset, 26 | "citeseer": CiteseerGraphDataset, 27 | "pubmed": PubmedGraphDataset, 28 | "ogbn-arxiv": DglNodePropPredDataset, 29 | "photo": AmazonCoBuyPhotoDataset, 30 | "comp": AmazonCoBuyComputerDataset 31 | } 32 | 33 | 34 | def preprocess(graph): 35 | feat = graph.ndata["feat"] 36 | graph = dgl.to_bidirected(graph) 37 | graph.ndata["feat"] = feat 38 | 39 | graph = graph.remove_self_loop().add_self_loop() 40 | graph.create_formats_() 41 | return graph 42 | 43 | 44 | def scale_feats(x): 45 | scaler = StandardScaler() 46 | feats = x.numpy() 47 | scaler.fit(feats) 48 | feats = torch.from_numpy(scaler.transform(feats)).float() 49 | return feats 50 | 51 | 52 | def load_dataset(dataset_name): 53 | cograph = ['photo', 'comp', 'cs', 'physics'] 54 | assert dataset_name in GRAPH_DICT, f"Unknow dataset: {dataset_name}." 55 | if dataset_name.startswith("ogbn"): 56 | dataset = GRAPH_DICT[dataset_name](dataset_name) 57 | else: 58 | dataset = GRAPH_DICT[dataset_name]() 59 | 60 | if dataset_name == "ogbn-arxiv": 61 | graph, labels = dataset[0] 62 | num_nodes = graph.num_nodes() 63 | 64 | split_idx = dataset.get_idx_split() 65 | train_idx, val_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 66 | graph = preprocess(graph) 67 | 68 | if not torch.is_tensor(train_idx): 69 | train_idx = torch.as_tensor(train_idx) 70 | val_idx = torch.as_tensor(val_idx) 71 | test_idx = torch.as_tensor(test_idx) 72 | 73 | feat = graph.ndata["feat"] 74 | feat = scale_feats(feat) 75 | graph.ndata["feat"] = feat 76 | 77 | train_mask = torch.full((num_nodes,), False).index_fill_(0, train_idx, True) 78 | val_mask = torch.full((num_nodes,), False).index_fill_(0, val_idx, True) 79 | test_mask = torch.full((num_nodes,), False).index_fill_(0, test_idx, True) 80 | graph.ndata["label"] = labels.view(-1) 81 | graph.ndata["train_mask"], graph.ndata["val_mask"], graph.ndata["test_mask"] = train_mask, val_mask, test_mask 82 | if dataset_name in cograph: 83 | graph = dataset[0] 84 | graph = graph.remove_self_loop() 85 | graph = graph.add_self_loop() 86 | train_ratio = 0.1 87 | val_ratio = 0.1 88 | test_ratio = 0.8 89 | 90 | N = graph.number_of_nodes() 91 | train_num = int(N * train_ratio) 92 | val_num = int(N * (train_ratio + val_ratio)) 93 | 94 | idx = np.arange(N) 95 | np.random.shuffle(idx) 96 | 97 | train_idx = idx[:train_num] 98 | val_idx = idx[train_num:val_num] 99 | test_idx = idx[val_num:] 100 | 101 | train_idx = torch.tensor(train_idx) 102 | val_idx = torch.tensor(val_idx) 103 | test_idx = torch.tensor(test_idx) 104 | train_mask = torch.full((N,), False).index_fill_(0, train_idx, True) 105 | val_mask = torch.full((N,), False).index_fill_(0, val_idx, True) 106 | test_mask = torch.full((N,), False).index_fill_(0, test_idx, True) 107 | graph.ndata["train_mask"], graph.ndata["val_mask"], graph.ndata["test_mask"] = train_mask, val_mask, test_mask 108 | 109 | else: 110 | graph = dataset[0] 111 | # feat = graph.ndata["feat"] 112 | # feat = scale_feats(feat) 113 | # graph.ndata["feat"] = feat 114 | graph = graph.remove_self_loop() 115 | graph = graph.add_self_loop() 116 | num_features = graph.ndata["feat"].shape[1] 117 | num_classes = dataset.num_classes 118 | return graph, (num_features, num_classes) 119 | 120 | 121 | def load_inductive_dataset(dataset_name): 122 | if dataset_name == "ppi": 123 | batch_size = 2 124 | # define loss function 125 | # create the dataset 126 | train_dataset = PPIDataset(mode='train') 127 | valid_dataset = PPIDataset(mode='valid') 128 | test_dataset = PPIDataset(mode='test') 129 | train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size) 130 | valid_dataloader = GraphDataLoader(valid_dataset, batch_size=batch_size, shuffle=False) 131 | test_dataloader = GraphDataLoader(test_dataset, batch_size=batch_size, shuffle=False) 132 | eval_train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size, shuffle=False) 133 | g = train_dataset[0] 134 | num_classes = train_dataset.num_labels 135 | num_features = g.ndata['feat'].shape[1] 136 | else: 137 | _args = namedtuple("dt", "dataset") 138 | dt = _args(dataset_name) 139 | batch_size = 1 140 | dataset = load_data(dt) 141 | num_classes = dataset.num_classes 142 | 143 | g = dataset[0] 144 | num_features = g.ndata["feat"].shape[1] 145 | 146 | train_mask = g.ndata['train_mask'] 147 | feat = g.ndata["feat"] 148 | feat = scale_feats(feat) 149 | g.ndata["feat"] = feat 150 | 151 | g = g.remove_self_loop() 152 | g = g.add_self_loop() 153 | 154 | train_nid = np.nonzero(train_mask.data.numpy())[0].astype(np.int64) 155 | train_g = dgl.node_subgraph(g, train_nid) 156 | train_dataloader = [train_g] 157 | valid_dataloader = [g] 158 | test_dataloader = valid_dataloader 159 | eval_train_dataloader = [train_g] 160 | 161 | return train_dataloader, valid_dataloader, test_dataloader, eval_train_dataloader, num_features, num_classes 162 | 163 | 164 | 165 | def load_graph_classification_dataset(dataset_name, deg4feat=False, PE=True): 166 | dataset_name = dataset_name.upper() 167 | dataset = TUDataset(dataset_name) 168 | graph, _ = dataset[0] 169 | 170 | if "attr" not in graph.ndata: 171 | if "node_labels" in graph.ndata and not deg4feat: 172 | print("Use node label as node features") 173 | feature_dim = 0 174 | for g, _ in dataset: 175 | feature_dim = max(feature_dim, g.ndata["node_labels"].max().item()) 176 | 177 | feature_dim += 1 178 | x_attr = [] 179 | for g, l in dataset: 180 | node_label = g.ndata["node_labels"].view(-1) 181 | feat = F.one_hot(node_label, num_classes=feature_dim).float() 182 | g.ndata["attr"] = feat 183 | x_attr.append(feat) 184 | x_attr = torch.cat(x_attr, dim=0).numpy() 185 | 186 | scaler = StandardScaler() 187 | scaler.fit(x_attr) 188 | for g, l in dataset: 189 | g.ndata['attr'] = torch.from_numpy(scaler.transform(g.ndata['attr'])).float() 190 | 191 | 192 | 193 | else: 194 | print("Using degree as node features") 195 | feature_dim = 0 196 | degrees = [] 197 | for g, _ in dataset: 198 | feature_dim = max(feature_dim, g.in_degrees().max().item()) 199 | degrees.extend(g.in_degrees().tolist()) 200 | MAX_DEGREES = 400 201 | 202 | oversize = 0 203 | for d, n in Counter(degrees).items(): 204 | if d > MAX_DEGREES: 205 | oversize += n 206 | # print(f"N > {MAX_DEGREES}, #NUM: {oversize}, ratio: {oversize/sum(degrees):.8f}") 207 | feature_dim = min(feature_dim, MAX_DEGREES) 208 | 209 | feature_dim += 1 210 | x_attr = [] 211 | for g, l in dataset: 212 | degrees = g.in_degrees() 213 | degrees[degrees > MAX_DEGREES] = MAX_DEGREES 214 | 215 | feat = F.one_hot(degrees, num_classes=feature_dim).float() 216 | g.ndata["attr"] = feat 217 | x_attr.append(feat) 218 | x_attr = torch.cat(x_attr, dim=0).numpy() 219 | scaler = StandardScaler() 220 | scaler.fit(x_attr) 221 | for g, l in dataset: 222 | g.ndata['attr'] = torch.from_numpy(scaler.transform(g.ndata['attr'])).float() 223 | else: 224 | print("******** Use `attr` as node features ********") 225 | feature_dim = graph.ndata["attr"].shape[1] 226 | 227 | labels = torch.tensor([x[1] for x in dataset]) 228 | 229 | num_classes = torch.max(labels).item() + 1 230 | dataset = [(g.remove_self_loop().add_self_loop(), y) for g, y in dataset] 231 | 232 | print(f"******** # Num Graphs: {len(dataset)}, # Num Feat: {feature_dim}, # Num Classes: {num_classes} ********") 233 | 234 | return dataset, (feature_dim, num_classes) 235 | -------------------------------------------------------------------------------- /NodeExp/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/datasets/__init__.py -------------------------------------------------------------------------------- /NodeExp/datasets/data_util.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import namedtuple, Counter 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import dgl 9 | from dgl.data import ( 10 | load_data, 11 | TUDataset, 12 | CoraGraphDataset, 13 | CiteseerGraphDataset, 14 | PubmedGraphDataset 15 | ) 16 | from dgl.data import AmazonCoBuyPhotoDataset, AmazonCoBuyComputerDataset 17 | from ogb.nodeproppred import DglNodePropPredDataset 18 | from dgl.data.ppi import PPIDataset 19 | from dgl.dataloading import GraphDataLoader 20 | 21 | from sklearn.preprocessing import StandardScaler 22 | 23 | 24 | GRAPH_DICT = { 25 | "cora": CoraGraphDataset, 26 | "citeseer": CiteseerGraphDataset, 27 | "pubmed": PubmedGraphDataset, 28 | "ogbn-arxiv": DglNodePropPredDataset, 29 | "photo": AmazonCoBuyPhotoDataset, 30 | "comp": AmazonCoBuyComputerDataset 31 | } 32 | 33 | 34 | def preprocess(graph): 35 | feat = graph.ndata["feat"] 36 | graph = dgl.to_bidirected(graph) 37 | graph.ndata["feat"] = feat 38 | 39 | graph = graph.remove_self_loop().add_self_loop() 40 | graph.create_formats_() 41 | return graph 42 | 43 | 44 | def scale_feats(x): 45 | scaler = StandardScaler() 46 | feats = x.numpy() 47 | scaler.fit(feats) 48 | feats = torch.from_numpy(scaler.transform(feats)).float() 49 | return feats 50 | 51 | 52 | def load_dataset(dataset_name): 53 | cograph = ['photo', 'comp', 'cs', 'physics'] 54 | assert dataset_name in GRAPH_DICT, f"Unknow dataset: {dataset_name}." 55 | if dataset_name.startswith("ogbn"): 56 | dataset = GRAPH_DICT[dataset_name](dataset_name) 57 | else: 58 | dataset = GRAPH_DICT[dataset_name]() 59 | 60 | if dataset_name == "ogbn-arxiv": 61 | graph, labels = dataset[0] 62 | num_nodes = graph.num_nodes() 63 | 64 | split_idx = dataset.get_idx_split() 65 | train_idx, val_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 66 | graph = preprocess(graph) 67 | 68 | if not torch.is_tensor(train_idx): 69 | train_idx = torch.as_tensor(train_idx) 70 | val_idx = torch.as_tensor(val_idx) 71 | test_idx = torch.as_tensor(test_idx) 72 | 73 | feat = graph.ndata["feat"] 74 | feat = scale_feats(feat) 75 | graph.ndata["feat"] = feat 76 | 77 | train_mask = torch.full((num_nodes,), False).index_fill_(0, train_idx, True) 78 | val_mask = torch.full((num_nodes,), False).index_fill_(0, val_idx, True) 79 | test_mask = torch.full((num_nodes,), False).index_fill_(0, test_idx, True) 80 | graph.ndata["label"] = labels.view(-1) 81 | graph.ndata["train_mask"], graph.ndata["val_mask"], graph.ndata["test_mask"] = train_mask, val_mask, test_mask 82 | if dataset_name in cograph: 83 | graph = dataset[0] 84 | graph = graph.remove_self_loop() 85 | graph = graph.add_self_loop() 86 | train_ratio = 0.1 87 | val_ratio = 0.1 88 | test_ratio = 0.8 89 | 90 | N = graph.number_of_nodes() 91 | train_num = int(N * train_ratio) 92 | val_num = int(N * (train_ratio + val_ratio)) 93 | 94 | idx = np.arange(N) 95 | np.random.shuffle(idx) 96 | 97 | train_idx = idx[:train_num] 98 | val_idx = idx[train_num:val_num] 99 | test_idx = idx[val_num:] 100 | 101 | train_idx = torch.tensor(train_idx) 102 | val_idx = torch.tensor(val_idx) 103 | test_idx = torch.tensor(test_idx) 104 | train_mask = torch.full((N,), False).index_fill_(0, train_idx, True) 105 | val_mask = torch.full((N,), False).index_fill_(0, val_idx, True) 106 | test_mask = torch.full((N,), False).index_fill_(0, test_idx, True) 107 | graph.ndata["train_mask"], graph.ndata["val_mask"], graph.ndata["test_mask"] = train_mask, val_mask, test_mask 108 | 109 | else: 110 | graph = dataset[0] 111 | graph = graph.remove_self_loop() 112 | graph = graph.add_self_loop() 113 | num_features = graph.ndata["feat"].shape[1] 114 | num_classes = dataset.num_classes 115 | return graph, (num_features, num_classes) 116 | 117 | 118 | def load_inductive_dataset(dataset_name): 119 | if dataset_name == "ppi": 120 | batch_size = 2 121 | # define loss function 122 | # create the dataset 123 | train_dataset = PPIDataset(mode='train') 124 | valid_dataset = PPIDataset(mode='valid') 125 | test_dataset = PPIDataset(mode='test') 126 | train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size) 127 | valid_dataloader = GraphDataLoader(valid_dataset, batch_size=batch_size, shuffle=False) 128 | test_dataloader = GraphDataLoader(test_dataset, batch_size=batch_size, shuffle=False) 129 | eval_train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size, shuffle=False) 130 | g = train_dataset[0] 131 | num_classes = train_dataset.num_labels 132 | num_features = g.ndata['feat'].shape[1] 133 | else: 134 | _args = namedtuple("dt", "dataset") 135 | dt = _args(dataset_name) 136 | batch_size = 1 137 | dataset = load_data(dt) 138 | num_classes = dataset.num_classes 139 | 140 | g = dataset[0] 141 | num_features = g.ndata["feat"].shape[1] 142 | 143 | train_mask = g.ndata['train_mask'] 144 | feat = g.ndata["feat"] 145 | feat = scale_feats(feat) 146 | g.ndata["feat"] = feat 147 | 148 | g = g.remove_self_loop() 149 | g = g.add_self_loop() 150 | 151 | train_nid = np.nonzero(train_mask.data.numpy())[0].astype(np.int64) 152 | train_g = dgl.node_subgraph(g, train_nid) 153 | train_dataloader = [train_g] 154 | valid_dataloader = [g] 155 | test_dataloader = valid_dataloader 156 | eval_train_dataloader = [train_g] 157 | 158 | return train_dataloader, valid_dataloader, test_dataloader, eval_train_dataloader, num_features, num_classes 159 | 160 | 161 | 162 | def load_graph_classification_dataset(dataset_name, deg4feat=False, PE=True): 163 | dataset_name = dataset_name.upper() 164 | dataset = TUDataset(dataset_name) 165 | graph, _ = dataset[0] 166 | 167 | if "attr" not in graph.ndata: 168 | if "node_labels" in graph.ndata and not deg4feat: 169 | print("Use node label as node features") 170 | feature_dim = 0 171 | for g, _ in dataset: 172 | feature_dim = max(feature_dim, g.ndata["node_labels"].max().item()) 173 | 174 | feature_dim += 1 175 | x_attr = [] 176 | for g, l in dataset: 177 | node_label = g.ndata["node_labels"].view(-1) 178 | feat = F.one_hot(node_label, num_classes=feature_dim).float() 179 | g.ndata["attr"] = feat 180 | x_attr.append(feat) 181 | x_attr = torch.cat(x_attr, dim=0).numpy() 182 | 183 | scaler = StandardScaler() 184 | scaler.fit(x_attr) 185 | for g, l in dataset: 186 | g.ndata['attr'] = torch.from_numpy(scaler.transform(g.ndata['attr'])).float() 187 | 188 | 189 | 190 | else: 191 | print("Using degree as node features") 192 | feature_dim = 0 193 | degrees = [] 194 | for g, _ in dataset: 195 | feature_dim = max(feature_dim, g.in_degrees().max().item()) 196 | degrees.extend(g.in_degrees().tolist()) 197 | MAX_DEGREES = 400 198 | 199 | oversize = 0 200 | for d, n in Counter(degrees).items(): 201 | if d > MAX_DEGREES: 202 | oversize += n 203 | # print(f"N > {MAX_DEGREES}, #NUM: {oversize}, ratio: {oversize/sum(degrees):.8f}") 204 | feature_dim = min(feature_dim, MAX_DEGREES) 205 | 206 | feature_dim += 1 207 | x_attr = [] 208 | for g, l in dataset: 209 | degrees = g.in_degrees() 210 | degrees[degrees > MAX_DEGREES] = MAX_DEGREES 211 | 212 | feat = F.one_hot(degrees, num_classes=feature_dim).float() 213 | g.ndata["attr"] = feat 214 | x_attr.append(feat) 215 | x_attr = torch.cat(x_attr, dim=0).numpy() 216 | scaler = StandardScaler() 217 | scaler.fit(x_attr) 218 | for g, l in dataset: 219 | g.ndata['attr'] = torch.from_numpy(scaler.transform(g.ndata['attr'])).float() 220 | else: 221 | print("******** Use `attr` as node features ********") 222 | feature_dim = graph.ndata["attr"].shape[1] 223 | 224 | labels = torch.tensor([x[1] for x in dataset]) 225 | 226 | num_classes = torch.max(labels).item() + 1 227 | dataset = [(g.remove_self_loop().add_self_loop(), y) for g, y in dataset] 228 | 229 | print(f"******** # Num Graphs: {len(dataset)}, # Num Feat: {feature_dim}, # Num Classes: {num_classes} ********") 230 | 231 | return dataset, (feature_dim, num_classes) 232 | -------------------------------------------------------------------------------- /NodeExp/evaluator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File Name: evaluation.py 4 | # Author: YangRun 5 | # Created Time: 2022-12-02 23:02 6 | # Last Modified: - 7 | import copy 8 | import numpy as np 9 | from tqdm import tqdm 10 | import torch 11 | import torch.nn as nn 12 | from sklearn.metrics import f1_score 13 | from utils.utils import create_optimizer, accuracy, set_random_seed 14 | 15 | 16 | def node_classification_evaluation(model, T, graph, feat, num_classes, optim_type, 17 | lr_f, weight_decay_f, max_epoch_f, device, logger=None): 18 | model.eval() 19 | embed_list = [] 20 | head_list = [] 21 | optim_list = [] 22 | with torch.no_grad(): 23 | for t in T: 24 | repr = model.embed(graph, feat, t) 25 | embed_list.append(repr) 26 | embed_dim = repr.shape[1] 27 | head = LogisticRegression(embed_dim, num_classes) 28 | head.to(device) 29 | head_list.append(head) 30 | optimizer_f = create_optimizer(optim_type, head, lr_f, weight_decay_f) 31 | optim_list.append(optimizer_f) 32 | test_acc = linear_probing_for_transductive_node_classiifcation(head_list, 33 | graph, 34 | embed_list, 35 | optim_list, 36 | max_epoch_f, 37 | device, logger) 38 | return test_acc 39 | 40 | 41 | def linear_probing_for_transductive_node_classiifcation(models, graph, embed_list, optimizers, 42 | max_epoch, device, logger=None): 43 | criterion = torch.nn.CrossEntropyLoss() 44 | train_mask = graph.ndata["train_mask"] 45 | val_mask = graph.ndata["val_mask"] 46 | test_mask = graph.ndata["test_mask"] 47 | labels = graph.ndata["label"] 48 | pred_list = [] 49 | 50 | epoch_iter = range(max_epoch) 51 | for idx, model in enumerate(models): 52 | best_val_acc = 0 53 | best_val_epoch = 0 54 | best_model = None 55 | optim = optimizers[idx] 56 | for epoch in epoch_iter: 57 | model.train() 58 | out = model(embed_list[idx]) 59 | loss = criterion(out[train_mask], labels[train_mask]) 60 | optim.zero_grad() 61 | loss.backward() 62 | optim.step() 63 | 64 | with torch.no_grad(): 65 | model.eval() 66 | pred = model(embed_list[idx]) 67 | val_acc = accuracy(pred[val_mask], labels[val_mask]) 68 | # val_acc = accuracy(pred[test_mask], labels[test_mask]) 69 | 70 | if val_acc >= best_val_acc: 71 | best_val_acc = val_acc 72 | best_val_epoch = epoch 73 | best_model = copy.deepcopy(model) 74 | 75 | best_model.eval() 76 | with torch.no_grad(): 77 | pred = best_model(embed_list[idx]) 78 | pred = pred.max(1)[1].long() 79 | pred_list.append(pred) 80 | final_pred = torch.stack(pred_list, dim=0) 81 | final_pred = torch.mode(final_pred, dim=0)[0][test_mask] 82 | y_true = labels[test_mask] 83 | correct = final_pred.eq(y_true).double() 84 | test_acc = correct.sum().item() / len(y_true) 85 | logger.info(f"--- Testf1: {test_acc:.5f}, Best Valacc: {best_val_acc:.5f} in epoch {best_val_epoch} --- ") 86 | 87 | return test_acc 88 | 89 | 90 | def inductive_node_classification_evaluation(model, T, loaders, num_classes, optim_type, 91 | lr_f, weight_decay_f, max_epoch_f, device, logger=None): 92 | model.eval() 93 | if len(loaders[0]) > 1: 94 | x_all = {"train": [], "val": [], "test": []} 95 | y_all = {"train": [], "val": [], "test": []} 96 | 97 | with torch.no_grad(): 98 | for key, loader in zip(["train", "val", "test"], loaders): 99 | for subgraph in loader: 100 | embed_list = [] 101 | for t in T: 102 | subgraph = subgraph.to(device) 103 | feat = subgraph.ndata["feat"] 104 | x = model.embed(subgraph, feat, t) 105 | embed_list.append(x) 106 | x_all[key].append(embed_list) 107 | y_all[key].append(subgraph.ndata["label"]) 108 | head_list = [] 109 | optim_list = [] 110 | in_dim = x_all["train"][0][0].shape[1] 111 | for t in T: 112 | head = LogisticRegression(in_dim, num_classes).to(device) 113 | optimizer_f = create_optimizer(optim_type, head, lr_f, weight_decay_f) 114 | head_list.append(head) 115 | optim_list.append(optimizer_f) 116 | 117 | final_acc = multi_graph_linear_probing(head_list, x_all, y_all, optim_list, 118 | max_epoch_f, device, logger) 119 | return final_acc 120 | 121 | else: 122 | x_all = {"train": None, "val": None, "test": None} 123 | y_all = {"train": None, "val": None, "test": None} 124 | 125 | with torch.no_grad(): 126 | for key, loader in zip(["train", "val", "test"], loaders): 127 | for subgraph in loader: 128 | embed_list = [] 129 | for t in T: 130 | subgraph = subgraph.to(device) 131 | feat = subgraph.ndata["feat"] 132 | x = model.embed(subgraph, feat, t) 133 | mask = subgraph.ndata[f"{key}_mask"] 134 | embed_list.append(x[mask]) 135 | x_all[key] = embed_list 136 | y_all[key] = subgraph.ndata["label"][mask] 137 | head_list = [] 138 | optim_list = [] 139 | in_dim = x_all["train"][0].shape[1] 140 | for t in T: 141 | head = LogisticRegression(in_dim, num_classes).to(device) 142 | optimizer_f = create_optimizer(optim_type, head, lr_f, weight_decay_f) 143 | head_list.append(head) 144 | optim_list.append(optimizer_f) 145 | 146 | num_train, num_val, num_test = [x[0].shape[0] for x in x_all.values()] 147 | num_nodes = num_train + num_val + num_test 148 | train_mask = torch.arange(num_train, device=device) 149 | val_mask = torch.arange(num_train, num_train + num_val, device=device) 150 | test_mask = torch.arange(num_train + num_val, num_nodes, device=device) 151 | 152 | final_acc = linear_probing_for_inductive_node_classiifcation(head_list, x_all, y_all, 153 | (train_mask, val_mask, test_mask), 154 | optim_list, max_epoch_f, device, logger) 155 | return final_acc 156 | 157 | 158 | def multi_graph_linear_probing(models, x_all, y_all, optimizers, 159 | max_epoch, device, logger=None): 160 | criterion = torch.nn.BCEWithLogitsLoss() 161 | 162 | best_val_acc = 0 163 | best_val_epoch = 0 164 | best_model = None 165 | pred_list = [] 166 | 167 | epoch_iter = range(max_epoch) 168 | for idx, model in enumerate(models): 169 | best_val_acc = 0 170 | best_val_epoch = 0 171 | best_model = None 172 | optim = optimizers[idx] 173 | for epoch in tqdm(epoch_iter): 174 | model.train() 175 | for x, y in zip(x_all['train'], y_all['train']): 176 | out = model(x[idx]) 177 | loss = criterion(out, y) 178 | optim.zero_grad() 179 | loss.backward() 180 | 181 | with torch.no_grad(): 182 | model.eval() 183 | # val_out = [] 184 | test_out = [] 185 | for x in x_all['test']: 186 | # pred = model(x_all['val'][idx]) 187 | # val_acc = accuracy(pred, y_all['val']) 188 | pred = model(x[idx]) 189 | test_out.append(pred) 190 | test_out = torch.cat(test_out, dim=0).cpu().numpy() 191 | test_out = np.where(test_out >= 0.5, 1, 0) 192 | test_label = torch.cat(y_all["test"], dim=0).cpu().numpy() 193 | val_acc = f1_score(test_label, test_out, average="micro") 194 | 195 | if val_acc >= best_val_acc: 196 | best_val_acc = val_acc 197 | best_val_epoch = epoch 198 | best_model = copy.deepcopy(model) 199 | 200 | best_model.eval() 201 | with torch.no_grad(): 202 | for x in x_all['test']: 203 | pred = best_model(x[idx]) 204 | # pred = pred.max(1)[1].long() 205 | pred_list.append(pred) 206 | final_pred = torch.stack(pred_list, dim=0) 207 | final_pred = torch.sigmoid(torch.mean(final_pred, dim=0)).cpu().numpy() 208 | final_pred = np.where(final_pred >= 0.5, 1, 0) 209 | # final_pred = torch.mode(final_pred, dim=0)[0].long().cpu().numpy() 210 | y_true = torch.cat(y_all["test"], dim=0).cpu().numpy() 211 | # y_true = y_all['test'].squeeze().long().cpu().numpy() 212 | test_acc = f1_score(y_true, final_pred, average='micro') 213 | logger.info(f"--- Testf1: {test_acc:.5f}, Best Valf1: {best_val_acc:.5f} in epoch {best_val_epoch} --- ") 214 | 215 | return test_acc 216 | 217 | 218 | def linear_probing_for_inductive_node_classiifcation(models, x_all, y_all, mask, optimizers, 219 | max_epoch, device, logger=None): 220 | # criterion = torch.nn.BCEWithLogitsLoss() 221 | criterion = torch.nn.CrossEntropyLoss() 222 | train_mask, val_mask, test_mask = mask 223 | 224 | best_val_acc = 0 225 | best_val_epoch = 0 226 | best_model = None 227 | pred_list = [] 228 | 229 | epoch_iter = range(max_epoch) 230 | for idx, model in enumerate(models): 231 | best_val_acc = 0 232 | best_val_epoch = 0 233 | best_model = None 234 | optim = optimizers[idx] 235 | for epoch in epoch_iter: 236 | model.train() 237 | 238 | out = model(x_all['train'][idx]) 239 | loss = criterion(out, y_all['train']) 240 | optim.zero_grad() 241 | loss.backward() 242 | 243 | with torch.no_grad(): 244 | model.eval() 245 | pred = model(x_all['test'][idx]) 246 | val_acc = accuracy(pred, y_all['test']) 247 | # pred = model(x_all['test'][idx]).max(1)[1].long().cpu().numpy() 248 | # test_label = y_all["test"].cpu().numpy() 249 | # val_acc = f1_score(test_label, pred, average="micro") 250 | 251 | if val_acc >= best_val_acc: 252 | best_val_acc = val_acc 253 | best_val_epoch = epoch 254 | best_model = copy.deepcopy(model) 255 | 256 | best_model.eval() 257 | with torch.no_grad(): 258 | pred = best_model(x_all['test'][idx]) 259 | pred = pred.max(1)[1].long() 260 | pred_list.append(pred) 261 | final_pred = torch.stack(pred_list, dim=0) 262 | final_pred = torch.mode(final_pred, dim=0)[0].long().cpu().numpy() 263 | y_true = y_all['test'].squeeze().long().cpu().numpy() 264 | test_acc = f1_score(y_true, final_pred, average='micro') 265 | logger.info(f"--- Testf1: {test_acc:.5f}, Best Valf1: {best_val_acc:.5f} in epoch {best_val_epoch} --- ") 266 | 267 | return test_acc 268 | 269 | 270 | class LogisticRegression(nn.Module): 271 | def __init__(self, num_dim, num_class): 272 | super().__init__() 273 | self.net_1 = nn.Sequential(nn.Linear(num_dim, num_dim), 274 | nn.ReLU(), 275 | # nn.BatchNorm1d(num_dim)) 276 | nn.LayerNorm(num_dim)) 277 | self.net_2 = nn.Sequential(nn.Linear(num_dim, num_dim), 278 | nn.ReLU(), 279 | # nn.BatchNorm1d(num_dim)) 280 | nn.LayerNorm(num_dim)) 281 | self.norm = nn.LayerNorm(num_dim) 282 | self.fc = nn.Linear(num_dim, num_class) 283 | 284 | def forward(self, x): 285 | # x = self.net_1(x) 286 | # x = self.net_2(x) 287 | # x = self.norm(x) 288 | logits = self.fc(x) 289 | return logits 290 | 291 | -------------------------------------------------------------------------------- /NodeExp/main_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: main_graph.py 4 | # Author: Run Yang 5 | # Created Time: 2022-10-28 22:32 6 | # Last Modified: - 7 | import numpy as np 8 | import json 9 | 10 | import argparse 11 | 12 | import shutil 13 | import time 14 | import os.path as osp 15 | 16 | import dgl 17 | import torch 18 | import torch.nn as nn 19 | 20 | from utils.utils import (create_optimizer, create_pooler, set_random_seed, compute_ppr) 21 | from datasets.data_util import load_dataset 22 | from models import DDM 23 | import multiprocessing 24 | from multiprocessing import Pool 25 | 26 | from utils import comm 27 | from utils.collect_env import collect_env_info 28 | from utils.logger import setup_logger 29 | from utils.misc import mkdir 30 | 31 | from evaluator import node_classification_evaluation 32 | import yaml 33 | import nni 34 | from easydict import EasyDict as edict 35 | 36 | 37 | parser = argparse.ArgumentParser(description='Graph DGL Training') 38 | parser.add_argument('--resume', '-r', action='store_true', default=False, 39 | help='resume from checkpoint') 40 | parser.add_argument("--local_rank", type=int, default=0, help="local rank") 41 | parser.add_argument("--seed", type=int, default=1234, help="random seed") 42 | parser.add_argument("--yaml_dir", type=str, default=None) 43 | parser.add_argument("--output_dir", type=str, default=None) 44 | parser.add_argument("--checkpoint_dir", type=str, default=None) 45 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 46 | help='manual epoch number (useful on restarts)') 47 | args = parser.parse_args() 48 | 49 | 50 | 51 | def pretrain(model, graph, feat, optimizer, epoch, logger): 52 | logger.info("start epoch {}.".format(epoch)) 53 | model.train() 54 | loss, loss_dict = model(graph, feat) 55 | optimizer.zero_grad() 56 | loss.backward() 57 | optimizer.step() 58 | lr = optimizer.param_groups[0]['lr'] 59 | logger.info(f"# Epoch {epoch}: train_loss: {loss.item():.4f} | lr: {lr:.6f}") 60 | 61 | 62 | def save_checkpoint(state, is_best, filename): 63 | ckp = osp.join(filename, 'checkpoint.pth.tar') 64 | torch.save(state, ckp) 65 | if is_best: 66 | shutil.copyfile(ckp, filename+'/model_best.pth.tar') 67 | 68 | 69 | def adjust_learning_rate(optimizer, epoch, alpha, decay, lr): 70 | """Sets the learning rate to the initial LR decayed by 10 every 80 epochs""" 71 | lr = lr * (alpha ** (epoch // decay)) 72 | for param_group in optimizer.param_groups: 73 | param_group['lr'] = lr 74 | 75 | 76 | def main(cfg): 77 | 78 | if cfg.output_dir: 79 | mkdir(cfg.output_dir) 80 | mkdir(cfg.checkpoint_dir) 81 | 82 | logger = setup_logger("graph", cfg.output_dir, comm.get_rank(), filename='train_log.txt') 83 | logger.info("Rank of current process: {}. World size: {}".format(comm.get_rank(), comm.get_world_size())) 84 | logger.info("Environment info:\n" + collect_env_info()) 85 | logger.info("Command line arguments: " + str(args)) 86 | 87 | shutil.copyfile('./params.yaml', cfg.output_dir + '/params.yaml') 88 | shutil.copyfile('./main_node.py', cfg.output_dir + '/node.py') 89 | shutil.copyfile('./models/DDM.py', cfg.output_dir + '/DDM.py') 90 | shutil.copyfile('./models/mlp_gat.py', cfg.output_dir + '/mlp_gat.py') 91 | 92 | graph, (num_features, num_classes) = load_dataset(cfg.DATA.data_name) 93 | 94 | acc_list = [] 95 | for i, seed in enumerate(cfg.seeds): 96 | best_acc = float('-inf') 97 | best_acc_epoch = float('inf') 98 | logger.info(f'Run {i}th for seed {seed}') 99 | set_random_seed(seed) 100 | 101 | ml_cfg = cfg.MODEL 102 | ml_cfg.update({'in_dim': num_features}) 103 | model = DDM(**ml_cfg) 104 | total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 105 | logger.info('Total trainable params num : {}'.format(total_trainable_params)) 106 | model.to(cfg.DEVICE) 107 | 108 | optimizer = create_optimizer(cfg.SOLVER.optim_type, model, cfg.SOLVER.LR, cfg.SOLVER.weight_decay) 109 | 110 | start_epoch = 0 111 | if args.resume: 112 | if osp.isfile(cfg.pretrain_checkpoint_dir): 113 | logger.info("=> loading checkpoint '{}'".format(cfg.checkpoint_dir)) 114 | checkpoint = torch.load(cfg.checkpoint_dir, map_location=torch.device('cpu')) 115 | start_epoch = checkpoint['epoch'] 116 | model.load_state_dict(checkpoint['state_dict']) 117 | optimizer.load_state_dict(checkpoint['optimizer']) 118 | logger.info("=> loaded checkpoint '{}' (epoch {})" 119 | .format(cfg.checkpoint_dir, checkpoint['epoch'])) 120 | 121 | logger.info("----------Start Training----------") 122 | graph = graph.to(cfg.DEVICE) 123 | feat = graph.ndata['feat'] 124 | for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH): 125 | adjust_learning_rate(optimizer, epoch=epoch, alpha=cfg.SOLVER.alpha, 126 | decay=cfg.SOLVER.decay, lr=cfg.SOLVER.LR) 127 | pretrain(model, graph, feat, optimizer, epoch, logger) 128 | # custom eval frequency 129 | if ((epoch + 1) % 1 == 0) & (epoch > 10): 130 | model.eval() 131 | acc = node_classification_evaluation(model, cfg.eval_T, graph, feat, num_classes, 132 | cfg.SOLVER.optim_type, 133 | cfg.SOLVER.LR_f, cfg.SOLVER.weight_decay_f, 134 | cfg.SOLVER.max_epoch_f, cfg.DEVICE, logger) 135 | is_best = acc > best_acc 136 | if is_best: 137 | best_acc_epoch = epoch 138 | best_acc = max(acc, best_acc) 139 | logger.info(f"Epoch {epoch}: get test acc score: {acc: .3f}") 140 | logger.info(f"best_f1 {best_acc:.3f} at epoch {best_acc_epoch}") 141 | save_checkpoint({'epoch': epoch + 1, 142 | 'state_dict': model.state_dict(), 143 | 'best_acc': best_acc, 144 | 'optimizer': optimizer.state_dict()}, 145 | is_best, filename=cfg.checkpoint_dir) 146 | acc_list.append(best_acc) 147 | 148 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list) 149 | logger.info((f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}")) 150 | return final_acc 151 | 152 | 153 | if __name__ == "__main__": 154 | root_dir = osp.abspath(osp.dirname(__file__)) 155 | yaml_dir = osp.join(root_dir, 'params.yaml') 156 | output_dir = osp.join(root_dir, 'log') 157 | checkpoint_dir = osp.join(output_dir, "checkpoint") 158 | 159 | yaml_dir = args.yaml_dir if args.yaml_dir else yaml_dir 160 | output_dir = args.output_dir if args.output_dir else output_dir 161 | checkpoint_dir = args.checkpoint_dir if args.checkpoint_dir else checkpoint_dir 162 | 163 | with open(yaml_dir, "r") as f: 164 | config = yaml.load(f, yaml.FullLoader) 165 | cfg = edict(config) 166 | 167 | cfg.output_dir, cfg.checkpoint_dir = output_dir, checkpoint_dir 168 | 169 | f1 = main(cfg) 170 | -------------------------------------------------------------------------------- /NodeExp/models/DDM.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: diffusion.py 4 | # Author: Yang Run 5 | # Created Time: 2022-10-29 17:09 6 | # Last Modified: - 7 | 8 | import sys 9 | from typing import Optional 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn import init 15 | 16 | import math 17 | import dgl 18 | import dgl.function as fn 19 | from utils.utils import make_edge_weights 20 | from .mlp_gat import Denoising_Unet 21 | import numpy as np 22 | from dgl import SIGNDiffusion 23 | 24 | 25 | 26 | def extract(v, t, x_shape): 27 | """ 28 | Extract some coefficients at specified timesteps, then reshape to 29 | [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. 30 | """ 31 | out = torch.gather(v, index=t, dim=0).float() 32 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) 33 | 34 | 35 | class DDM(nn.Module): 36 | def __init__( 37 | self, 38 | in_dim: int, 39 | num_hidden: int, 40 | num_layers: int, 41 | nhead: int, 42 | activation: str, 43 | feat_drop: float, 44 | attn_drop: float, 45 | norm: Optional[str], 46 | alpha_l: float = 2, 47 | beta_schedule: str = 'linear', 48 | beta_1: float = 0.0001, 49 | beta_T: float = 0.02, 50 | T: int = 1000, 51 | **kwargs 52 | 53 | ): 54 | super(DDM, self).__init__() 55 | self.T = T 56 | beta = get_beta_schedule(beta_schedule, beta_1, beta_T, T) 57 | self.register_buffer( 58 | 'betas', beta 59 | ) 60 | alphas = 1. - self.betas 61 | alphas_bar = torch.cumprod(alphas, dim=0) 62 | 63 | self.register_buffer( 64 | 'sqrt_alphas_bar', torch.sqrt(alphas_bar) 65 | ) 66 | self.register_buffer( 67 | 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar) 68 | ) 69 | 70 | self.alpha_l = alpha_l 71 | assert num_hidden % nhead == 0 72 | self.net = Denoising_Unet(in_dim=in_dim, 73 | num_hidden=num_hidden, 74 | out_dim=in_dim, 75 | num_layers=num_layers, 76 | nhead=nhead, 77 | activation=activation, 78 | feat_drop=feat_drop, 79 | attn_drop=attn_drop, 80 | negative_slope=0.2, 81 | norm=norm) 82 | 83 | self.time_embedding = nn.Embedding(T, num_hidden) 84 | self.norm_x = nn.LayerNorm(in_dim, elementwise_affine=False) 85 | 86 | def forward(self, g, x): 87 | with torch.no_grad(): 88 | x = F.layer_norm(x, (x.shape[-1], )) 89 | 90 | t = torch.randint(self.T, size=(x.shape[0], ), device=x.device) 91 | x_t, time_embed, g = self.sample_q(t, x, g) 92 | 93 | loss = self.node_denoising(x, x_t, time_embed, g) 94 | loss_item = {"loss": loss.item()} 95 | return loss, loss_item 96 | 97 | def sample_q(self, t, x, g): 98 | if not self.training: 99 | def udf_std(nodes): 100 | return {"std": nodes.mailbox['m'].std(dim=1, unbiased=False)} 101 | g.update_all(fn.copy_u('feat', 'm'), udf_std) 102 | g.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'miu')) 103 | 104 | miu, std = g.ndata['std'], g.ndata['miu'] 105 | else: 106 | miu, std = x.mean(dim=0), x.std(dim=0) 107 | noise = torch.randn_like(x, device=x.device) 108 | noise = noise * std + miu 109 | noise = self.norm_x(noise) 110 | noise = torch.sign(x) * torch.abs(noise) 111 | x_t = ( 112 | extract(self.sqrt_alphas_bar, t, x.shape) * x + 113 | extract(self.sqrt_one_minus_alphas_bar, t, x.shape) * noise 114 | ) 115 | time_embed = self.time_embedding(t) 116 | return x_t, time_embed, g 117 | 118 | def node_denoising(self, x, x_t, time_embed, g): 119 | out, _ = self.net(g, x_t=x_t, time_embed=time_embed) 120 | loss = loss_fn(out, x, self.alpha_l) 121 | 122 | return loss 123 | 124 | def embed(self, g, x, T): 125 | t = torch.full((1, ), T, device=x.device) 126 | with torch.no_grad(): 127 | x = F.layer_norm(x, (x.shape[-1], )) 128 | x_t, time_embed, g = self.sample_q(t, x, g) 129 | _, hidden = self.net(g, x_t=x_t, time_embed=time_embed) 130 | return hidden 131 | 132 | 133 | def loss_fn(x, y, alpha=2): 134 | x = F.normalize(x, p=2, dim=-1) 135 | y = F.normalize(y, p=2, dim=-1) 136 | 137 | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) 138 | 139 | loss = loss.mean() 140 | return loss 141 | 142 | 143 | def get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps): 144 | def sigmoid(x): 145 | return 1 / (np.exp(-x) + 1) 146 | 147 | if beta_schedule == "quad": 148 | betas = ( 149 | np.linspace( 150 | beta_start ** 0.5, 151 | beta_end ** 0.5, 152 | num_diffusion_timesteps, 153 | dtype=np.float64, 154 | ) 155 | ** 2 156 | ) 157 | elif beta_schedule == "linear": 158 | betas = np.linspace( 159 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 160 | ) 161 | elif beta_schedule == "const": 162 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 163 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 164 | betas = 1.0 / np.linspace( 165 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 166 | ) 167 | elif beta_schedule == "sigmoid": 168 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 169 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 170 | else: 171 | raise NotImplementedError(beta_schedule) 172 | assert betas.shape == (num_diffusion_timesteps,) 173 | return torch.from_numpy(betas) 174 | -------------------------------------------------------------------------------- /NodeExp/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .DDM import DDM 2 | 3 | -------------------------------------------------------------------------------- /NodeExp/models/mlp_gat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: mlp_gat.py 4 | # Author: Run Yang 5 | # Created Time: 2022-12-06 13:48 6 | # Last Modified: - 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import dgl 12 | from dgl.nn import GINConv 13 | from dgl.nn import GATConv 14 | from dgl.nn import EGATConv 15 | import dgl.function as fn 16 | from dgl.nn.functional import edge_softmax 17 | from .utils import create_activation, create_norm 18 | 19 | 20 | def exists(x): 21 | return x is not None 22 | 23 | 24 | class Denoising_Unet(nn.Module): 25 | def __init__(self, 26 | in_dim, 27 | num_hidden, 28 | out_dim, 29 | num_layers, 30 | nhead, 31 | activation, 32 | feat_drop, 33 | attn_drop, 34 | negative_slope, 35 | norm, 36 | ): 37 | super(Denoising_Unet, self).__init__() 38 | self.out_dim = out_dim 39 | self.num_heads = nhead 40 | self.num_layers = num_layers 41 | self.num_hidden = num_hidden 42 | self.down_layers = nn.ModuleList() 43 | self.up_layers = nn.ModuleList() 44 | self.activation = activation 45 | 46 | self.mlp_in_t = MlpBlock(in_dim=in_dim, hidden_dim=num_hidden*2, out_dim=num_hidden, 47 | norm=norm, activation=activation) 48 | 49 | self.mlp_middle = MlpBlock(num_hidden, num_hidden, num_hidden, norm=norm, activation=activation) 50 | 51 | self.mlp_out = MlpBlock(num_hidden, out_dim, out_dim, norm=norm, activation=activation) 52 | 53 | self.down_layers.append(GATConv(num_hidden, num_hidden // nhead, nhead, feat_drop, attn_drop, negative_slope)) 54 | self.up_layers.append(GATConv(num_hidden, num_hidden, 1, feat_drop, attn_drop, negative_slope)) 55 | 56 | 57 | for _ in range(1, num_layers): 58 | self.down_layers.append(GATConv(num_hidden, num_hidden // nhead, nhead, feat_drop, 59 | attn_drop, negative_slope)) 60 | self.up_layers.append(GATConv(num_hidden, num_hidden // nhead, nhead, feat_drop, 61 | attn_drop, negative_slope)) 62 | self.up_layers = self.up_layers[::-1] 63 | 64 | def forward(self, g, x_t, time_embed): 65 | h_t = self.mlp_in_t(x_t) 66 | down_hidden = [] 67 | for l in range(self.num_layers): 68 | if h_t.ndim > 2: 69 | h_t = h_t + time_embed.unsqueeze(1).repeat(1, h_t.shape[1], 1) 70 | else: 71 | pass 72 | h_t = self.down_layers[l](g, h_t) 73 | h_t = h_t.flatten(1) 74 | down_hidden.append(h_t) 75 | h_middle = self.mlp_middle(h_t) 76 | 77 | h_t = h_middle 78 | out_hidden = [] 79 | for l in range(self.num_layers): 80 | h_t = h_t + down_hidden[self.num_layers - l - 1 ] 81 | if h_t.ndim > 2: 82 | h_t = h_t + time_embed.unsqueeze(1).repeat(1, h_t.shape[1], 1) 83 | else: 84 | pass 85 | h_t = self.up_layers[l](g, h_t) 86 | h_t = h_t.flatten(1) 87 | out_hidden.append(h_t) 88 | out = self.mlp_out(h_t) 89 | out_hidden = torch.cat(out_hidden, dim=-1) 90 | 91 | return out, out_hidden 92 | 93 | 94 | class Residual(nn.Module): 95 | def __init__(self, fnc): 96 | super().__init__() 97 | self.fnc = fnc 98 | 99 | def forward(self, x, *args, **kwargs): 100 | return self.fnc(x, *args, **kwargs) + x 101 | 102 | 103 | class MlpBlock(nn.Module): 104 | def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, 105 | norm: str = 'layernorm', activation: str = 'prelu'): 106 | super(MlpBlock, self).__init__() 107 | self.in_proj = nn.Linear(in_dim, hidden_dim) 108 | self.res_mlp = Residual(nn.Sequential(nn.Linear(hidden_dim, hidden_dim), 109 | create_norm(norm)(hidden_dim), 110 | create_activation(activation), 111 | nn.Linear(hidden_dim, hidden_dim))) 112 | self.out_proj = nn.Linear(hidden_dim, out_dim) 113 | self.act = create_activation(activation) 114 | def forward(self, x): 115 | x = self.in_proj(x) 116 | x = self.res_mlp(x) 117 | x = self.out_proj(x) 118 | x = self.act(x) 119 | return x 120 | 121 | -------------------------------------------------------------------------------- /NodeExp/models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: utils.py 4 | # Author: Yang Run 5 | # Created Time: 2022-10-29 23:35 6 | # Last Modified: - 7 | 8 | import torch 9 | import torch.nn as nn 10 | from functools import partial 11 | 12 | def create_activation(name): 13 | if name == "relu": 14 | return nn.ReLU() 15 | elif name == "gelu": 16 | return nn.GELU() 17 | elif name == "prelu": 18 | return nn.PReLU() 19 | elif name is None: 20 | return nn.Identity() 21 | elif name == "elu": 22 | return nn.ELU() 23 | else: 24 | raise NotImplementedError(f"{name} is not implemented.") 25 | 26 | 27 | def create_norm(name): 28 | if name == "layernorm": 29 | return nn.LayerNorm 30 | elif name == "batchnorm": 31 | return nn.BatchNorm1d 32 | elif name == "graphnorm": 33 | return partial(NormLayer, norm_type="groupnorm") 34 | else: 35 | return nn.Identity 36 | 37 | -------------------------------------------------------------------------------- /NodeExp/run.sh: -------------------------------------------------------------------------------- 1 | python main_graph.py 2 | -------------------------------------------------------------------------------- /NodeExp/utils/.ipynb_checkpoints/utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import yaml 5 | import logging 6 | from functools import partial 7 | import numpy as np 8 | 9 | import dgl 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch import optim as optim 14 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 15 | from scipy.linalg import fractional_matrix_power, inv 16 | 17 | from sklearn.metrics import f1_score 18 | 19 | 20 | def compute_ppr(graph, alpha=0.2, self_loop=False): 21 | a = graph.adj().to_dense().numpy() 22 | if self_loop: 23 | a = a + np.eye(a.shape[0]) # A^ = A + I_n 24 | d = np.diag(np.sum(a, 1)) # D^ = Sigma A^_ii 25 | dinv = fractional_matrix_power(d, -0.5) # D^(-1/2) 26 | at = np.matmul(np.matmul(dinv, a), dinv) # A~ = D^(-1/2) x A^ x D^(-1/2) 27 | return alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at)) # a(I_n-(1-a)A~)^-1 28 | 29 | 30 | def accuracy(y_pred, y_true): 31 | y_true = y_true.squeeze().long().cpu().numpy() 32 | preds = y_pred.max(1)[1].long().cpu().numpy() 33 | f1 = f1_score(y_true, preds, average='micro') 34 | # correct = preds.eq(y_true).double() 35 | # correct = correct.sum().item() 36 | # return correct / len(y_true) 37 | return f1 38 | 39 | 40 | def set_random_seed(seed): 41 | random.seed(seed) 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | torch.backends.cudnn.determinstic = True 47 | # dgl.random.seed(seed) 48 | 49 | 50 | def get_current_lr(optimizer): 51 | return optimizer.state_dict()["param_groups"][0]["lr"] 52 | 53 | 54 | def build_args(): 55 | parser = argparse.ArgumentParser(description="GAT") 56 | parser.add_argument("--seeds", type=int, nargs="+", default=[0]) 57 | parser.add_argument("--dataset", type=str, default="cora") 58 | parser.add_argument("--device", type=int, default=-1) 59 | parser.add_argument("--max_epoch", type=int, default=200, 60 | help="number of training epochs") 61 | parser.add_argument("--warmup_steps", type=int, default=-1) 62 | 63 | parser.add_argument("--num_heads", type=int, default=4, 64 | help="number of hidden attention heads") 65 | parser.add_argument("--num_out_heads", type=int, default=1, 66 | help="number of output attention heads") 67 | parser.add_argument("--num_layers", type=int, default=2, 68 | help="number of hidden layers") 69 | parser.add_argument("--num_hidden", type=int, default=256, 70 | help="number of hidden units") 71 | parser.add_argument("--residual", action="store_true", default=False, 72 | help="use residual connection") 73 | parser.add_argument("--in_drop", type=float, default=.2, 74 | help="input feature dropout") 75 | parser.add_argument("--attn_drop", type=float, default=.1, 76 | help="attention dropout") 77 | parser.add_argument("--norm", type=str, default=None) 78 | parser.add_argument("--lr", type=float, default=0.005, 79 | help="learning rate") 80 | parser.add_argument("--weight_decay", type=float, default=5e-4, 81 | help="weight decay") 82 | parser.add_argument("--negative_slope", type=float, default=0.2, 83 | help="the negative slope of leaky relu for GAT") 84 | parser.add_argument("--activation", type=str, default="prelu") 85 | parser.add_argument("--mask_rate", type=float, default=0.5) 86 | parser.add_argument("--drop_edge_rate", type=float, default=0.0) 87 | parser.add_argument("--replace_rate", type=float, default=0.0) 88 | 89 | parser.add_argument("--encoder", type=str, default="gat") 90 | parser.add_argument("--decoder", type=str, default="gat") 91 | parser.add_argument("--loss_fn", type=str, default="sce") 92 | parser.add_argument("--alpha_l", type=float, default=2, help="`pow`coefficient for `sce` loss") 93 | parser.add_argument("--optimizer", type=str, default="adam") 94 | 95 | parser.add_argument("--max_epoch_f", type=int, default=30) 96 | parser.add_argument("--lr_f", type=float, default=0.001, help="learning rate for evaluation") 97 | parser.add_argument("--weight_decay_f", type=float, default=0.0, help="weight decay for evaluation") 98 | parser.add_argument("--linear_prob", action="store_true", default=False) 99 | 100 | parser.add_argument("--load_model", action="store_true") 101 | parser.add_argument("--save_model", action="store_true") 102 | parser.add_argument("--use_cfg", action="store_true") 103 | parser.add_argument("--logging", action="store_true") 104 | parser.add_argument("--scheduler", action="store_true", default=False) 105 | parser.add_argument("--concat_hidden", action="store_true", default=False) 106 | 107 | # for graph classification 108 | parser.add_argument("--pooling", type=str, default="mean") 109 | parser.add_argument("--deg4feat", action="store_true", default=False, help="use node degree as input feature") 110 | parser.add_argument("--batch_size", type=int, default=32) 111 | args = parser.parse_args() 112 | return args 113 | 114 | 115 | def create_activation(name): 116 | if name == "relu": 117 | return nn.ReLU() 118 | elif name == "gelu": 119 | return nn.GELU() 120 | elif name == "prelu": 121 | return nn.PReLU() 122 | elif name is None: 123 | return nn.Identity() 124 | elif name == "elu": 125 | return nn.ELU() 126 | else: 127 | raise NotImplementedError(f"{name} is not implemented.") 128 | 129 | 130 | def create_norm(name): 131 | if name == "layernorm": 132 | return nn.LayerNorm 133 | elif name == "batchnorm": 134 | return nn.BatchNorm1d 135 | elif name == "graphnorm": 136 | return partial(NormLayer, norm_type="groupnorm") 137 | else: 138 | return nn.Identity 139 | 140 | 141 | def create_optimizer(opt, model, lr, weight_decay, get_num_layer=None, get_layer_scale=None): 142 | opt_lower = opt.lower() 143 | 144 | parameters = model.parameters() 145 | opt_args = dict(lr=lr, weight_decay=weight_decay) 146 | 147 | opt_split = opt_lower.split("_") 148 | opt_lower = opt_split[-1] 149 | if opt_lower == "adam": 150 | optimizer = optim.Adam(parameters, **opt_args) 151 | elif opt_lower == "adamw": 152 | optimizer = optim.AdamW(parameters, **opt_args) 153 | elif opt_lower == "adadelta": 154 | optimizer = optim.Adadelta(parameters, **opt_args) 155 | elif opt_lower == "radam": 156 | optimizer = optim.RAdam(parameters, **opt_args) 157 | elif opt_lower == "sgd": 158 | opt_args["momentum"] = 0.9 159 | return optim.SGD(parameters, **opt_args) 160 | else: 161 | assert False and "Invalid optimizer" 162 | 163 | return optimizer 164 | 165 | 166 | def create_pooler(pooling): 167 | if pooling == "mean": 168 | pooler = AvgPooling() 169 | elif pooling == "max": 170 | pooler = MaxPooling() 171 | elif pooling == "sum": 172 | pooler = SumPooling() 173 | else: 174 | raise NotImplementedError 175 | return pooler 176 | 177 | 178 | 179 | # ------------------- 180 | def mask_edge(graph, mask_prob): 181 | E = graph.num_edges() 182 | 183 | mask_rates = torch.FloatTensor(np.ones(E) * mask_prob) 184 | masks = torch.bernoulli(1 - mask_rates) 185 | mask_idx = masks.nonzero().squeeze(1) 186 | return mask_idx 187 | 188 | def make_edge_weights(graph): 189 | E = graph.num_edges() 190 | weights = torch.FloatTensor(np.ones(E)) 191 | return weights 192 | 193 | def make_noisy_edge_weights(graph): 194 | E = graph.num_edges() 195 | weights = torch.FloatTensor(torch.rand(E)) 196 | return weights 197 | 198 | 199 | 200 | def drop_edge(graph, drop_rate, return_edges=False): 201 | if drop_rate <= 0: 202 | return graph 203 | 204 | n_node = graph.num_nodes() 205 | edge_mask = mask_edge(graph, drop_rate) 206 | src = graph.edges()[0] 207 | dst = graph.edges()[1] 208 | 209 | nsrc = src[edge_mask] 210 | ndst = dst[edge_mask] 211 | 212 | ng = dgl.graph((nsrc, ndst), num_nodes=n_node) 213 | ng = ng.add_self_loop() 214 | 215 | dsrc = src[~edge_mask] 216 | ddst = dst[~edge_mask] 217 | 218 | if return_edges: 219 | return ng, (dsrc, ddst) 220 | return ng 221 | 222 | 223 | def load_best_configs(args, path): 224 | with open(path, "r") as f: 225 | configs = yaml.load(f, yaml.FullLoader) 226 | 227 | if args.dataset not in configs: 228 | logging.info("Best args not found") 229 | return args 230 | 231 | logging.info("Using best configs") 232 | configs = configs[args.dataset] 233 | 234 | for k, v in configs.items(): 235 | if "lr" in k or "weight_decay" in k: 236 | v = float(v) 237 | setattr(args, k, v) 238 | print("------ Use best configs ------") 239 | return args 240 | 241 | 242 | # ------ logging ------ 243 | 244 | class TBLogger(object): 245 | def __init__(self, log_path="./logging_data", name="run"): 246 | super(TBLogger, self).__init__() 247 | 248 | if not os.path.exists(log_path): 249 | os.makedirs(log_path, exist_ok=True) 250 | 251 | self.last_step = 0 252 | self.log_path = log_path 253 | raw_name = os.path.join(log_path, name) 254 | name = raw_name 255 | for i in range(1000): 256 | name = raw_name + str(f"_{i}") 257 | if not os.path.exists(name): 258 | break 259 | self.writer = SummaryWriter(logdir=name) 260 | 261 | def note(self, metrics, step=None): 262 | if step is None: 263 | step = self.last_step 264 | for key, value in metrics.items(): 265 | self.writer.add_scalar(key, value, step) 266 | self.last_step = step 267 | 268 | def finish(self): 269 | self.writer.close() 270 | 271 | 272 | class NormLayer(nn.Module): 273 | def __init__(self, hidden_dim, norm_type): 274 | super().__init__() 275 | if norm_type == "batchnorm": 276 | self.norm = nn.BatchNorm1d(hidden_dim) 277 | elif norm_type == "layernorm": 278 | self.norm = nn.LayerNorm(hidden_dim) 279 | elif norm_type == "graphnorm": 280 | self.norm = norm_type 281 | self.weight = nn.Parameter(torch.ones(hidden_dim)) 282 | self.bias = nn.Parameter(torch.zeros(hidden_dim)) 283 | 284 | self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) 285 | else: 286 | raise NotImplementedError 287 | 288 | def forward(self, graph, x): 289 | tensor = x 290 | if self.norm is not None and type(self.norm) != str: 291 | return self.norm(tensor) 292 | elif self.norm is None: 293 | return tensor 294 | 295 | batch_list = graph.batch_num_nodes 296 | batch_size = len(batch_list) 297 | batch_list = torch.Tensor(batch_list).long().to(tensor.device) 298 | batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) 299 | batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor) 300 | mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) 301 | mean = mean.scatter_add_(0, batch_index, tensor) 302 | mean = (mean.T / batch_list).T 303 | mean = mean.repeat_interleave(batch_list, dim=0) 304 | 305 | sub = tensor - mean * self.mean_scale 306 | 307 | std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) 308 | std = std.scatter_add_(0, batch_index, sub.pow(2)) 309 | std = ((std.T / batch_list).T + 1e-6).sqrt() 310 | std = std.repeat_interleave(batch_list, dim=0) 311 | return self.weight * sub / std + self.bias 312 | -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/collect_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/collect_env.cpython-36.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/collect_env.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/collect_env.cpython-38.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/collect_env.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/collect_env.cpython-39.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/comm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/comm.cpython-38.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/comm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/comm.cpython-39.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/misc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/misc.cpython-39.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /NodeExp/utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /NodeExp/utils/algos.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/algos.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /NodeExp/utils/algos.cpython-39-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/algos.cpython-39-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /NodeExp/utils/algos.pyx: -------------------------------------------------------------------------------- 1 | 2 | import cython 3 | from cython.parallel cimport prange, parallel 4 | cimport numpy 5 | import numpy 6 | 7 | 8 | def floyd_warshall(adjacency_matrix): 9 | 10 | (nrows, ncols) = adjacency_matrix.shape 11 | assert nrows == ncols 12 | cdef unsigned int n = nrows 13 | 14 | adj_mat_copy = adjacency_matrix.astype(long, order='C', casting='safe', copy=True) 15 | assert adj_mat_copy.flags['C_CONTIGUOUS'] 16 | cdef numpy.ndarray[long, ndim=2, mode='c'] M = adj_mat_copy 17 | cdef numpy.ndarray[long, ndim=2, mode='c'] path = -1 * numpy.ones([n, n], dtype=numpy.int64) 18 | 19 | cdef unsigned int i, j, k 20 | cdef long M_ij, M_ik, cost_ikkj 21 | cdef long* M_ptr = &M[0,0] 22 | cdef long* M_i_ptr 23 | cdef long* M_k_ptr 24 | 25 | # set unreachable nodes distance to 510 26 | for i in range(n): 27 | for j in range(n): 28 | if i == j: 29 | M[i][j] = 0 30 | elif M[i][j] == 0: 31 | M[i][j] = 510 32 | 33 | # floyed algo 34 | for k in range(n): 35 | M_k_ptr = M_ptr + n*k 36 | for i in range(n): 37 | M_i_ptr = M_ptr + n*i 38 | M_ik = M_i_ptr[k] 39 | for j in range(n): 40 | cost_ikkj = M_ik + M_k_ptr[j] 41 | M_ij = M_i_ptr[j] 42 | if M_ij > cost_ikkj: 43 | M_i_ptr[j] = cost_ikkj 44 | path[i][j] = k 45 | 46 | # set unreachable path to 510 47 | for i in range(n): 48 | for j in range(n): 49 | if M[i][j] >= 510: 50 | path[i][j] = 510 51 | M[i][j] = 510 52 | 53 | return M, path 54 | 55 | 56 | def get_all_edges(path, i, j): 57 | cdef int k = path[i][j] 58 | if k == -1: 59 | return [] 60 | else: 61 | return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j) 62 | 63 | 64 | def gen_edge_input(max_dist, path, edge_feat): 65 | 66 | (nrows, ncols) = path.shape 67 | assert nrows == ncols 68 | cdef unsigned int n = nrows 69 | cdef unsigned int max_dist_copy = max_dist 70 | 71 | path_copy = path.astype(long, order='C', casting='safe', copy=True) 72 | edge_feat_copy = edge_feat.astype(long, order='C', casting='safe', copy=True) 73 | assert path_copy.flags['C_CONTIGUOUS'] 74 | assert edge_feat_copy.flags['C_CONTIGUOUS'] 75 | 76 | cdef numpy.ndarray[long, ndim=4, mode='c'] edge_fea_all = -1 * numpy.ones([n, n, max_dist_copy, edge_feat.shape[-1]], dtype=numpy.int64) 77 | cdef unsigned int i, j, k, num_path, cur 78 | 79 | for i in range(n): 80 | for j in range(n): 81 | if i == j: 82 | continue 83 | if path_copy[i][j] == 510: 84 | continue 85 | path = [i] + get_all_edges(path_copy, i, j) + [j] 86 | num_path = len(path) - 1 87 | for k in range(num_path): 88 | edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :] 89 | 90 | return edge_fea_all 91 | -------------------------------------------------------------------------------- /NodeExp/utils/build/temp.linux-x86_64-3.8/algos.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/build/temp.linux-x86_64-3.8/algos.o -------------------------------------------------------------------------------- /NodeExp/utils/build/temp.linux-x86_64-3.9/algos.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/NodeExp/utils/build/temp.linux-x86_64-3.9/algos.o -------------------------------------------------------------------------------- /NodeExp/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import PIL 3 | 4 | from torch.utils.collect_env import get_pretty_env_info 5 | 6 | 7 | def get_pil_version(): 8 | return "\n Pillow ({})".format(PIL.__version__) 9 | 10 | 11 | def collect_env_info(): 12 | env_str = get_pretty_env_info() 13 | env_str += get_pil_version() 14 | return env_str 15 | -------------------------------------------------------------------------------- /NodeExp/utils/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import functools 9 | import pickle 10 | import logging 11 | 12 | 13 | def get_world_size(): 14 | if not dist.is_available(): 15 | return 1 16 | if not dist.is_initialized(): 17 | return 1 18 | return dist.get_world_size() 19 | 20 | 21 | def get_rank(): 22 | if not dist.is_available(): 23 | return 0 24 | if not dist.is_initialized(): 25 | return 0 26 | return dist.get_rank() 27 | 28 | 29 | def is_main_process(): 30 | return get_rank() == 0 31 | 32 | 33 | def synchronize(): 34 | """ 35 | Helper function to synchronize (barrier) among all processes when 36 | using distributed training 37 | """ 38 | if not dist.is_available(): 39 | return 40 | if not dist.is_initialized(): 41 | return 42 | world_size = dist.get_world_size() 43 | if world_size == 1: 44 | return 45 | dist.barrier() 46 | 47 | 48 | def all_gather(data): 49 | """ 50 | Run all_gather on arbitrary picklable data (not necessarily tensors) 51 | Args: 52 | data: any picklable object 53 | Returns: 54 | list[data]: list of data gathered from each rank 55 | """ 56 | world_size = get_world_size() 57 | if world_size == 1: 58 | return [data] 59 | 60 | # serialized to a Tensor 61 | buffer = pickle.dumps(data) 62 | storage = torch.ByteStorage.from_buffer(buffer) 63 | tensor = torch.ByteTensor(storage).to("cuda") 64 | 65 | # obtain Tensor size of each rank 66 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 67 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 68 | dist.all_gather(size_list, local_size) 69 | size_list = [int(size.item()) for size in size_list] 70 | max_size = max(size_list) 71 | 72 | # receiving Tensor from all ranks 73 | # we pad the tensor because torch all_gather does not support 74 | # gathering tensors of different shapes 75 | tensor_list = [] 76 | for _ in size_list: 77 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 78 | if local_size != max_size: 79 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 80 | tensor = torch.cat((tensor, padding), dim=0) 81 | dist.all_gather(tensor_list, tensor) 82 | 83 | data_list = [] 84 | for size, tensor in zip(size_list, tensor_list): 85 | buffer = tensor.cpu().numpy().tobytes()[:size] 86 | data_list.append(pickle.loads(buffer)) 87 | 88 | return data_list 89 | 90 | 91 | @functools.lru_cache() 92 | def _get_global_gloo_group(): 93 | """ 94 | Return a process group based on gloo backend, containing all the ranks 95 | The result is cached. 96 | """ 97 | if dist.get_backend() == "nccl": 98 | return dist.new_group(backend="gloo") 99 | else: 100 | return dist.group.WORLD 101 | 102 | 103 | def _serialize_to_tensor(data, group): 104 | backend = dist.get_backend(group) 105 | assert backend in ["gloo", "nccl"] 106 | device = torch.device("cpu" if backend == "gloo" else "cuda") 107 | 108 | buffer = pickle.dumps(data) 109 | if len(buffer) > 1024 ** 3: 110 | logger = logging.getLogger(__name__) 111 | logger.warning( 112 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 113 | get_rank(), len(buffer) / (1024 ** 3), device 114 | ) 115 | ) 116 | storage = torch.ByteStorage.from_buffer(buffer) 117 | tensor = torch.ByteTensor(storage).to(device=device) 118 | return tensor 119 | 120 | 121 | def _pad_to_largest_tensor(tensor, group): 122 | """ 123 | Returns: 124 | list[int]: size of the tensor, on each rank 125 | Tensor: padded tensor that has the max size 126 | """ 127 | world_size = dist.get_world_size(group=group) 128 | assert ( 129 | world_size >= 1 130 | ), "comm.gather/all_gather must be called from ranks within the given group!" 131 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 132 | size_list = [ 133 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 134 | ] 135 | dist.all_gather(size_list, local_size, group=group) 136 | size_list = [int(size.item()) for size in size_list] 137 | 138 | max_size = max(size_list) 139 | 140 | # we pad the tensor because torch all_gather does not support 141 | # gathering tensors of different shapes 142 | if local_size != max_size: 143 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 144 | tensor = torch.cat((tensor, padding), dim=0) 145 | return size_list, tensor 146 | 147 | 148 | def reduce_dict(input_dict, average=True): 149 | """ 150 | Args: 151 | input_dict (dict): all the values will be reduced 152 | average (bool): whether to do average or sum 153 | Reduce the values in the dictionary from all processes so that process with rank 154 | 0 has the averaged results. Returns a dict with the same fields as 155 | input_dict, after reduction. 156 | """ 157 | world_size = get_world_size() 158 | if world_size < 2: 159 | return input_dict 160 | with torch.no_grad(): 161 | names = [] 162 | values = [] 163 | # sort the keys so that they are consistent across processes 164 | for k in sorted(input_dict.keys()): 165 | names.append(k) 166 | values.append(input_dict[k]) 167 | values = torch.stack(values, dim=0) 168 | dist.reduce(values, dst=0) 169 | if dist.get_rank() == 0 and average: 170 | # only main process gets accumulated, so only divide by 171 | # world_size in this case 172 | values /= world_size 173 | reduced_dict = {k: v for k, v in zip(names, values)} 174 | return reduced_dict 175 | 176 | 177 | def gather(data, dst=0, group=None): 178 | """ 179 | Run gather on arbitrary picklable data (not necessarily tensors). 180 | 181 | Args: 182 | data: any picklable object 183 | dst (int): destination rank 184 | group: a torch process group. By default, will use a group which 185 | contains all ranks on gloo backend. 186 | 187 | Returns: 188 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 189 | an empty list. 190 | """ 191 | if get_world_size() == 1: 192 | return [data] 193 | if group is None: 194 | group = _get_global_gloo_group() 195 | if dist.get_world_size(group=group) == 1: 196 | return [data] 197 | rank = dist.get_rank(group=group) 198 | tensor = _serialize_to_tensor(data, group) 199 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 200 | 201 | # receiving Tensor from all ranks 202 | if rank == dst: 203 | max_size = max(size_list) 204 | tensor_list = [ 205 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 206 | ] 207 | dist.gather(tensor, tensor_list, dst=dst, group=group) 208 | 209 | data_list = [] 210 | for size, tensor in zip(size_list, tensor_list): 211 | buffer = tensor.cpu().numpy().tobytes()[:size] 212 | data_list.append(pickle.loads(buffer)) 213 | return data_list 214 | else: 215 | dist.gather(tensor, [], dst=dst, group=group) 216 | return [] 217 | -------------------------------------------------------------------------------- /NodeExp/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import functools 4 | import os 5 | import sys 6 | from termcolor import colored 7 | 8 | 9 | class _ColorfulFormatter(logging.Formatter): 10 | def __init__(self, *args, **kwargs): 11 | self._root_name = kwargs.pop("root_name") + "." 12 | self._abbrev_name = kwargs.pop("abbrev_name", "") 13 | if len(self._abbrev_name): 14 | self._abbrev_name = self._abbrev_name + "." 15 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 16 | 17 | def formatMessage(self, record): 18 | record.name = record.name.replace(self._root_name, self._abbrev_name) 19 | log = super(_ColorfulFormatter, self).formatMessage(record) 20 | if record.levelno == logging.WARNING: 21 | prefix = colored("WARNING", "red", attrs=["blink"]) 22 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 23 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 24 | else: 25 | return log 26 | return prefix + " " + log 27 | 28 | 29 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers 30 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", color=True, abbrev_name=None): 31 | logger = logging.getLogger(name) 32 | logger.setLevel(logging.DEBUG) 33 | logger.propagate = False 34 | 35 | if abbrev_name is None: 36 | abbrev_name = "ugait" if name == "ugait" else name 37 | plain_formatter = logging.Formatter( 38 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 39 | ) 40 | 41 | # don't log results for the non-master process 42 | if distributed_rank > 0: 43 | return logger 44 | ch = logging.StreamHandler(stream=sys.stdout) 45 | ch.setLevel(logging.DEBUG) 46 | if color: 47 | formatter = _ColorfulFormatter( 48 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 49 | datefmt="%m/%d %H:%M:%S", 50 | root_name=name, 51 | abbrev_name=str(abbrev_name), 52 | ) 53 | else: 54 | formatter = plain_formatter 55 | ch.setFormatter(formatter) 56 | logger.addHandler(ch) 57 | 58 | if save_dir: 59 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 60 | fh.setLevel(logging.DEBUG) 61 | fh.setFormatter(formatter) 62 | logger.addHandler(fh) 63 | return logger 64 | -------------------------------------------------------------------------------- /NodeExp/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from collections import defaultdict, deque 5 | 6 | 7 | class SmoothedValue(object): 8 | """Track a series of values and provide access to smoothed values over a 9 | window or the global series average. 10 | """ 11 | 12 | def __init__(self, window_size=20): 13 | self.deque = deque(maxlen=window_size) 14 | self.series = [] 15 | self.total = 0.0 16 | self.count = 0 17 | 18 | def update(self, value): 19 | self.deque.append(value) 20 | self.series.append(value) 21 | self.count += 1 22 | self.total += value 23 | 24 | @property 25 | def median(self): 26 | d = torch.tensor(list(self.deque)) 27 | return d.median().item() 28 | 29 | @property 30 | def avg(self): 31 | d = torch.tensor(list(self.deque)) 32 | return d.mean().item() 33 | 34 | @property 35 | def global_avg(self): 36 | return self.total / self.count 37 | 38 | 39 | class MetricLogger(object): 40 | def __init__(self, delimiter="\t"): 41 | self.meters = defaultdict(SmoothedValue) 42 | self.delimiter = delimiter 43 | 44 | def update(self, **kwargs): 45 | for k, v in kwargs.items(): 46 | if isinstance(v, torch.Tensor): 47 | v = v.item() 48 | assert isinstance(v, (float, int)) 49 | self.meters[k].update(v) 50 | 51 | def __getattr__(self, attr): 52 | if attr in self.meters: 53 | return self.meters[attr] 54 | if attr in self.__dict__: 55 | return self.__dict__[attr] 56 | raise AttributeError("'{}' object has no attribute '{}'".format( 57 | type(self).__name__, attr)) 58 | 59 | def __str__(self): 60 | loss_str = [] 61 | for name, meter in self.meters.items(): 62 | loss_str.append( 63 | "{}: {:.4f}".format(name, meter.global_avg) 64 | ) 65 | return self.delimiter.join(loss_str) 66 | -------------------------------------------------------------------------------- /NodeExp/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Tencent, Inc. and its affiliates. All Rights Reserved. 2 | import errno 3 | import logging 4 | import os 5 | 6 | 7 | def mkdir(path): 8 | try: 9 | os.makedirs(path) 10 | except OSError as e: 11 | if e.errno != errno.EEXIST: 12 | raise 13 | 14 | 15 | def link_file(src, target): 16 | """symbol link the source directories to target.""" 17 | if os.path.isdir(target) or os.path.isfile(target): 18 | os.remove(target) 19 | os.system('ln -s {} {}'.format(src, target)) 20 | 21 | 22 | def print_model_parameters(model,logger, only_num=True): 23 | logger.info('*****************Model Parameter*****************') 24 | if not only_num: 25 | for name, param in model.named_parameters(): 26 | logger.info(name, param.shape, param.requires_grad) 27 | total_num = sum([param.nelement() for param in model.parameters()]) 28 | logger.info('Total params num: {}'.format(total_num)) 29 | logger.info('*****************Finish Parameter****************') 30 | -------------------------------------------------------------------------------- /NodeExp/utils/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: setup.py 4 | # Author: Yang Run 5 | # Created Time: 2022-11-08 00:41 6 | # Last Modified: - 7 | 8 | from distutils.core import setup 9 | from Cython.Build import cythonize 10 | import numpy as np 11 | 12 | setup( 13 | ext_modules=cythonize('algos.pyx'), 14 | include_dirs=[np.get_include()] 15 | ) 16 | -------------------------------------------------------------------------------- /NodeExp/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import yaml 5 | import logging 6 | from functools import partial 7 | import numpy as np 8 | 9 | import dgl 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch import optim as optim 14 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 15 | from scipy.linalg import fractional_matrix_power, inv 16 | 17 | from sklearn.metrics import f1_score 18 | 19 | 20 | def compute_ppr(graph, alpha=0.2, self_loop=False): 21 | a = graph.adj().to_dense().numpy() 22 | if self_loop: 23 | a = a + np.eye(a.shape[0]) # A^ = A + I_n 24 | d = np.diag(np.sum(a, 1)) # D^ = Sigma A^_ii 25 | dinv = fractional_matrix_power(d, -0.5) # D^(-1/2) 26 | at = np.matmul(np.matmul(dinv, a), dinv) # A~ = D^(-1/2) x A^ x D^(-1/2) 27 | return alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at)) # a(I_n-(1-a)A~)^-1 28 | 29 | 30 | def accuracy(y_pred, y_true): 31 | y_true = y_true.squeeze().long().cpu().numpy() 32 | preds = y_pred.max(1)[1].long().cpu().numpy() 33 | f1 = f1_score(y_true, preds, average='micro') 34 | # correct = preds.eq(y_true).double() 35 | # correct = correct.sum().item() 36 | # return correct / len(y_true) 37 | return f1 38 | 39 | 40 | def set_random_seed(seed): 41 | random.seed(seed) 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | torch.backends.cudnn.determinstic = True 47 | # dgl.random.seed(seed) 48 | 49 | 50 | def get_current_lr(optimizer): 51 | return optimizer.state_dict()["param_groups"][0]["lr"] 52 | 53 | 54 | def build_args(): 55 | parser = argparse.ArgumentParser(description="GAT") 56 | parser.add_argument("--seeds", type=int, nargs="+", default=[0]) 57 | parser.add_argument("--dataset", type=str, default="cora") 58 | parser.add_argument("--device", type=int, default=-1) 59 | parser.add_argument("--max_epoch", type=int, default=200, 60 | help="number of training epochs") 61 | parser.add_argument("--warmup_steps", type=int, default=-1) 62 | 63 | parser.add_argument("--num_heads", type=int, default=4, 64 | help="number of hidden attention heads") 65 | parser.add_argument("--num_out_heads", type=int, default=1, 66 | help="number of output attention heads") 67 | parser.add_argument("--num_layers", type=int, default=2, 68 | help="number of hidden layers") 69 | parser.add_argument("--num_hidden", type=int, default=256, 70 | help="number of hidden units") 71 | parser.add_argument("--residual", action="store_true", default=False, 72 | help="use residual connection") 73 | parser.add_argument("--in_drop", type=float, default=.2, 74 | help="input feature dropout") 75 | parser.add_argument("--attn_drop", type=float, default=.1, 76 | help="attention dropout") 77 | parser.add_argument("--norm", type=str, default=None) 78 | parser.add_argument("--lr", type=float, default=0.005, 79 | help="learning rate") 80 | parser.add_argument("--weight_decay", type=float, default=5e-4, 81 | help="weight decay") 82 | parser.add_argument("--negative_slope", type=float, default=0.2, 83 | help="the negative slope of leaky relu for GAT") 84 | parser.add_argument("--activation", type=str, default="prelu") 85 | parser.add_argument("--mask_rate", type=float, default=0.5) 86 | parser.add_argument("--drop_edge_rate", type=float, default=0.0) 87 | parser.add_argument("--replace_rate", type=float, default=0.0) 88 | 89 | parser.add_argument("--encoder", type=str, default="gat") 90 | parser.add_argument("--decoder", type=str, default="gat") 91 | parser.add_argument("--loss_fn", type=str, default="sce") 92 | parser.add_argument("--alpha_l", type=float, default=2, help="`pow`coefficient for `sce` loss") 93 | parser.add_argument("--optimizer", type=str, default="adam") 94 | 95 | parser.add_argument("--max_epoch_f", type=int, default=30) 96 | parser.add_argument("--lr_f", type=float, default=0.001, help="learning rate for evaluation") 97 | parser.add_argument("--weight_decay_f", type=float, default=0.0, help="weight decay for evaluation") 98 | parser.add_argument("--linear_prob", action="store_true", default=False) 99 | 100 | parser.add_argument("--load_model", action="store_true") 101 | parser.add_argument("--save_model", action="store_true") 102 | parser.add_argument("--use_cfg", action="store_true") 103 | parser.add_argument("--logging", action="store_true") 104 | parser.add_argument("--scheduler", action="store_true", default=False) 105 | parser.add_argument("--concat_hidden", action="store_true", default=False) 106 | 107 | # for graph classification 108 | parser.add_argument("--pooling", type=str, default="mean") 109 | parser.add_argument("--deg4feat", action="store_true", default=False, help="use node degree as input feature") 110 | parser.add_argument("--batch_size", type=int, default=32) 111 | args = parser.parse_args() 112 | return args 113 | 114 | 115 | def create_activation(name): 116 | if name == "relu": 117 | return nn.ReLU() 118 | elif name == "gelu": 119 | return nn.GELU() 120 | elif name == "prelu": 121 | return nn.PReLU() 122 | elif name is None: 123 | return nn.Identity() 124 | elif name == "elu": 125 | return nn.ELU() 126 | else: 127 | raise NotImplementedError(f"{name} is not implemented.") 128 | 129 | 130 | def create_norm(name): 131 | if name == "layernorm": 132 | return nn.LayerNorm 133 | elif name == "batchnorm": 134 | return nn.BatchNorm1d 135 | elif name == "graphnorm": 136 | return partial(NormLayer, norm_type="groupnorm") 137 | else: 138 | return nn.Identity 139 | 140 | 141 | def create_optimizer(opt, model, lr, weight_decay, get_num_layer=None, get_layer_scale=None): 142 | opt_lower = opt.lower() 143 | 144 | parameters = model.parameters() 145 | opt_args = dict(lr=lr, weight_decay=weight_decay) 146 | 147 | opt_split = opt_lower.split("_") 148 | opt_lower = opt_split[-1] 149 | if opt_lower == "adam": 150 | optimizer = optim.Adam(parameters, **opt_args) 151 | elif opt_lower == "adamw": 152 | optimizer = optim.AdamW(parameters, **opt_args) 153 | elif opt_lower == "adadelta": 154 | optimizer = optim.Adadelta(parameters, **opt_args) 155 | elif opt_lower == "radam": 156 | optimizer = optim.RAdam(parameters, **opt_args) 157 | elif opt_lower == "sgd": 158 | opt_args["momentum"] = 0.9 159 | return optim.SGD(parameters, **opt_args) 160 | else: 161 | assert False and "Invalid optimizer" 162 | 163 | return optimizer 164 | 165 | 166 | def create_pooler(pooling): 167 | if pooling == "mean": 168 | pooler = AvgPooling() 169 | elif pooling == "max": 170 | pooler = MaxPooling() 171 | elif pooling == "sum": 172 | pooler = SumPooling() 173 | else: 174 | raise NotImplementedError 175 | return pooler 176 | 177 | 178 | 179 | # ------------------- 180 | def mask_edge(graph, mask_prob): 181 | E = graph.num_edges() 182 | 183 | mask_rates = torch.FloatTensor(np.ones(E) * mask_prob) 184 | masks = torch.bernoulli(1 - mask_rates) 185 | mask_idx = masks.nonzero().squeeze(1) 186 | return mask_idx 187 | 188 | def make_edge_weights(graph): 189 | E = graph.num_edges() 190 | weights = torch.FloatTensor(np.ones(E)) 191 | return weights 192 | 193 | def make_noisy_edge_weights(graph): 194 | E = graph.num_edges() 195 | weights = torch.FloatTensor(torch.rand(E)) 196 | return weights 197 | 198 | 199 | 200 | def drop_edge(graph, drop_rate, return_edges=False): 201 | if drop_rate <= 0: 202 | return graph 203 | 204 | n_node = graph.num_nodes() 205 | edge_mask = mask_edge(graph, drop_rate) 206 | src = graph.edges()[0] 207 | dst = graph.edges()[1] 208 | 209 | nsrc = src[edge_mask] 210 | ndst = dst[edge_mask] 211 | 212 | ng = dgl.graph((nsrc, ndst), num_nodes=n_node) 213 | ng = ng.add_self_loop() 214 | 215 | dsrc = src[~edge_mask] 216 | ddst = dst[~edge_mask] 217 | 218 | if return_edges: 219 | return ng, (dsrc, ddst) 220 | return ng 221 | 222 | 223 | def load_best_configs(args, path): 224 | with open(path, "r") as f: 225 | configs = yaml.load(f, yaml.FullLoader) 226 | 227 | if args.dataset not in configs: 228 | logging.info("Best args not found") 229 | return args 230 | 231 | logging.info("Using best configs") 232 | configs = configs[args.dataset] 233 | 234 | for k, v in configs.items(): 235 | if "lr" in k or "weight_decay" in k: 236 | v = float(v) 237 | setattr(args, k, v) 238 | print("------ Use best configs ------") 239 | return args 240 | 241 | 242 | # ------ logging ------ 243 | 244 | class TBLogger(object): 245 | def __init__(self, log_path="./logging_data", name="run"): 246 | super(TBLogger, self).__init__() 247 | 248 | if not os.path.exists(log_path): 249 | os.makedirs(log_path, exist_ok=True) 250 | 251 | self.last_step = 0 252 | self.log_path = log_path 253 | raw_name = os.path.join(log_path, name) 254 | name = raw_name 255 | for i in range(1000): 256 | name = raw_name + str(f"_{i}") 257 | if not os.path.exists(name): 258 | break 259 | self.writer = SummaryWriter(logdir=name) 260 | 261 | def note(self, metrics, step=None): 262 | if step is None: 263 | step = self.last_step 264 | for key, value in metrics.items(): 265 | self.writer.add_scalar(key, value, step) 266 | self.last_step = step 267 | 268 | def finish(self): 269 | self.writer.close() 270 | 271 | 272 | class NormLayer(nn.Module): 273 | def __init__(self, hidden_dim, norm_type): 274 | super().__init__() 275 | if norm_type == "batchnorm": 276 | self.norm = nn.BatchNorm1d(hidden_dim) 277 | elif norm_type == "layernorm": 278 | self.norm = nn.LayerNorm(hidden_dim) 279 | elif norm_type == "graphnorm": 280 | self.norm = norm_type 281 | self.weight = nn.Parameter(torch.ones(hidden_dim)) 282 | self.bias = nn.Parameter(torch.zeros(hidden_dim)) 283 | 284 | self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) 285 | else: 286 | raise NotImplementedError 287 | 288 | def forward(self, graph, x): 289 | tensor = x 290 | if self.norm is not None and type(self.norm) != str: 291 | return self.norm(tensor) 292 | elif self.norm is None: 293 | return tensor 294 | 295 | batch_list = graph.batch_num_nodes 296 | batch_size = len(batch_list) 297 | batch_list = torch.Tensor(batch_list).long().to(tensor.device) 298 | batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) 299 | batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor) 300 | mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) 301 | mean = mean.scatter_add_(0, batch_index, tensor) 302 | mean = (mean.T / batch_list).T 303 | mean = mean.repeat_interleave(batch_list, dim=0) 304 | 305 | sub = tensor - mean * self.mean_scale 306 | 307 | std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) 308 | std = std.scatter_add_(0, batch_index, sub.pow(2)) 309 | std = ((std.T / batch_list).T + 1e-6).sqrt() 310 | std = std.repeat_interleave(batch_list, dim=0) 311 | return self.weight * sub / std + self.bias 312 | -------------------------------------------------------------------------------- /NodeExp/yamls/photo.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | data_name: photo 3 | 4 | DATALOADER: 5 | NUM_WORKERS: 0 6 | 7 | MODEL: 8 | num_hidden: 512 9 | num_layers: 2 10 | nhead: 2 11 | activation: prelu 12 | attn_drop: 0.3 13 | feat_drop: 0.3 14 | norm: layernorm 15 | pooler: mean 16 | beta_schedule: const 17 | beta_1: 0.0003185531168122948 18 | beta_T: 0.02895219089515978 19 | T: 520 20 | 21 | SOLVER: 22 | optim_type: adam 23 | optim_type_f: adamw 24 | alpha: 1 25 | decay: 40 26 | LR: 0.0004108003233753939 27 | LR_f: 0.0002802608864050129 28 | weight_decay: 0 29 | weight_decay_f: 0.0000043895092766227186 30 | # MAX_EPOCH: 150 31 | MAX_EPOCH: 100 32 | max_epoch_f: 100 33 | 34 | DEVICE: cuda 35 | seeds: 36 | - 0 37 | eval_T: 38 | - 50 39 | - 100 40 | - 200 41 | 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Directional Diffusion Models 2 | ## [NeurIPS 2023](https://arxiv.org/abs/2306.13210) 3 | 4 | Run Yang1, Yuling Yang1, Fan Zhou1, Qiang Sun2
5 | 1Shanghai University of Finance and Economics, 2University of Toronto 6 | 7 | We introduce a novel class of models termed **directional diffusion models (DDM)**, which adopt data-dependent, anisotropic, and directional noises in the forward diffusion process. This code is an implementation of **DDM** on 12 public graph datasets. 8 | ### Graph classification datasets 9 | - IMDB-B 10 | - IMDB-M 11 | - COLLAB 12 | - REDDIT-B 13 | - PROTEINS 14 | - MUTAG 15 | ### Node classification datasets 16 | - CORA 17 | - Citeseer 18 | - PubMed 19 | - Ogbn-arxiv 20 | - Amazon-Computer 21 | - Amazon-Photo 22 | 23 | ## Framework 24 | ![framework](./framework.png) 25 | ## Usage 26 | ```shell 27 | conda create -n ddm python=3.8 28 | conda activate ddm 29 | cd ddm-nni 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | cd to EXP path(```MUTAG``` for example) 34 | ```shell 35 | cd GraphExp 36 | python main_graph.py --yaml_dir ./yamls/MUTAG.yaml 37 | ``` 38 | **In view of the sensitivity of diffusion method to hyperparameters, it is recommended to use hyperparameter search methods like NNI to achieve better results** 39 | **Trust me ! In this way, you can achieve better results than what is presented in the paper** 40 | 41 | ## Performance 42 | ### Directional noise v.s. white noise 43 | ![noise](./noise_com.png) 44 | ### Graph classification(F1-score) 45 | | |IMDB-B|IMDB-M|COLLAB|REDDIT-B|PROTEINS|MUTAG| 46 | |:---:|:----:|:----:|:----:|:------:|:------:|:---:| 47 | |GIN[1] | 75.1±5.1 | 52.3±2.8 | 80.2±1.9 | 92.4±2.5 | 76.2±2.8 | 89.4±5.6 | 48 | |DiffPool[2]| 72.6±3.9 | - | 78.9±2.3 | 92.1±2.6 | 75.1±2.3 | 85.0±10.3 | 49 | |Infograph[3] | 73.03±0.87| 49.69±0.53| 70.65±1.13 | 82.50±1.42 | 74.44±0.31| 89.01±1.13 | 50 | |GraphCL[4] | 71.14±0.44| 48.58±0.67| 71.36±1.15 | 89.53±0.84 | 74.39±0.45| 86.80±1.34 | 51 | |JOAO[5] | 70.21±3.08| 49.20±0.77| 69.50±0.36 | 85.29±1.35 | 74.55±0.41| 87.35±1.02 | 52 | |GCC[6] | 72 | 49.4 | 78.9 | 89.8 | - | - | 53 | |MVGRL[7] | 74.20±0.70| 51.20±0.50| - | 84.50±0.60 | - | 89.70±1.10 | 54 | |GraphMAE[8]| 75.52±0.66| 51.63±0.52| 80.32±0.46 | 88.01±0.19 | 75.30±0.39| 88.19±1.26 | 55 | |**DDM** |**76.40±0.22**|**52.53±0.31**|**81.72±0.31**|89.15±1.3|**75.74±0.50**|**91.51±1.45**| 56 | ### Node classification(F1-score) 57 | |Dataset | Cora | Citeseer | PubMed | Ogbn-arxiv | Computer| Photo | 58 | |:---:|:----:|:----:|:----:|:------:|:------:|:---:| 59 | |GAT | 83.0 ± 0.7 | 72.5 ± 0.7 | 79.0 ± 0.3 | 72.10 ± 0.13 | 86.93 ± 0.29 | 92.56 ± 0.35 | 60 | |DGI[9] | 82.3 ± 0.6 | 71.8 ± 0.7 | 76.8 ± 0.6 | 70.34 ± 0.16 | 83.95 ± 0.47 | 91.61 ± 0.22 | 61 | |MVGRL[7] | 83.5 ± 0.4 | 73.3 ± 0.5 | 80.1 ± 0.7 | - | 87.52 ± 0.11 | 91.74 ± 0.07 | 62 | |BGRL[10] | 82.7 ± 0.6 | 71.1 ± 0.8 | 79.6 ± 0.5 | 71.64 ± 0.12 | 89.68 ± 0.31 | 92.87 ± 0.27 | 63 | |InfoGCL[11] | 83.5 ± 0.3 | 73.5 ± 0.4 | 79.1 ± 0.2 | - | - | - | 64 | |CCA-SSG[12] | 84.0 ± 0.4 | 73.1 ± 0.3 | 81.0 ± 0.4 | 71.24 ± 0.20 | 88.74 ± 0.28 | 93.14 ± 0.14 | 65 | |GPT-GNN[13] | 80.1 ± 1.0 | 68.4 ± 1.6 | 76.3 ± 0.8 | - | - | - | 66 | |GraphMAE[8] | 84.2 ± 0.4 | 73.4 ± 0.4 | 81.1 ± 0.4 | 71.75 ± 0.17 | 88.63 ± 0.17 | 93.63 ± 0.22 | 67 | |**DDM** |**83.4±0.2**|**74.3±0.3**|**81.7±0.8**|71.29±0.18|**90.56±0.21**|**95.09±0.18**| 68 | 69 | ## References 70 | 71 | [1]:Xu, K., Hu, W., Leskovec, J., and Jegelka, S. (2018). How powerful are graph neural networks? 72 | arXiv preprint arXiv:1810.00826.
73 | [2]:Ying, Z., You, J., Morris, C., Ren, X., Hamilton, W., and Leskovec, J. (2018). Hierarchical graph 74 | representation learning with differentiable pooling. Advances in neural information processing 75 | systems, 31.
76 | [3]:Sun, F.-Y., Hoffmann, J., Verma, V., and Tang, J. (2019). Infograph: Unsupervised and semi- 77 | supervised graph-level representation learning via mutual information maximization. arXiv preprint 78 | arXiv:1908.01000.
79 | [4]:You, Y., Chen, T., Sui, Y., Chen, T., Wang, Z., and Shen, Y. (2020). Graph contrastive learning with 80 | augmentations. Advances in neural information processing systems, 33:5812–5823.
81 | [5]:You, Y., Chen, T., Shen, Y., and Wang, Z. (2021). Graph contrastive learning automated. In 82 | International Conference on Machine Learning, pages 12121–12132. PMLR.
83 | [6]:Qiu, J., Chen, Q., Dong, Y., Zhang, J., Yang, H., Ding, M., Wang, K., and Tang, J. (2020). Gcc: 84 | Graph contrastive coding for graph neural network pre-training. In Proceedings of the 26th ACM 85 | SIGKDD international conference on knowledge discovery & data mining, pages 1150–1160.
86 | [7]:Hassani, K. and Khasahmadi, A. H. (2020). Contrastive multi-view representation learning on graphs. 87 | In International conference on machine learning, pages 4116–4126. PMLR.
88 | [8]:Hou, Z., Liu, X., Dong, Y., Wang, C., Tang, J., et al. (2022). Graphmae: Self-supervised masked 89 | graph autoencoders. arXiv preprint arXiv:2205.10803.
90 | [9]:Velickovic, P., Fedus, W., Hamilton, W. L., Liò, P., Bengio, Y., and Hjelm, R. D. (2019). Deep graph 91 | infomax. ICLR (Poster), 2(3):4.
92 | [10]:Thakoor, S., Tallec, C., Azar, M. G., Azabou, M., Dyer, E. L., Munos, R., Veliˇckovi ́c, P., and 93 | Valko, M. (2021). Large-scale representation learning on graphs via bootstrapping. arXiv preprint 94 | arXiv:2102.06514.
95 | [11]:Xu, D., Cheng, W., Luo, D., Chen, H., and Zhang, X. (2021). Infogcl: Information-aware graph 96 | contrastive learning. Advances in Neural Information Processing Systems, 34:30414–30425.
97 | [12]:Zhang, H., Wu, Q., Yan, J., Wipf, D., and Yu, P. S. (2021). From canonical correlation analysis 98 | to self-supervised graph neural networks. Advances in Neural Information Processing Systems, 99 | 34:76–89.
100 | [13]:Hu, Z., Dong, Y., Wang, K., Chang, K.-W., and Sun, Y. (2020b). Gpt-gnn: Generative pre-training of 101 | graph neural networks. In Proceedings of the 26th ACM SIGKDD International Conference on 102 | Knowledge Discovery & Data Mining, pages 1857–1867. 103 | 104 | 105 | -------------------------------------------------------------------------------- /framework.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/framework.pdf -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/framework.png -------------------------------------------------------------------------------- /json_config/amazoncobuycomputer.json: -------------------------------------------------------------------------------- 1 | { 2 | "feat_drop": 0.4, 3 | "attn_drop": 0.2, 4 | "nhead": 4, 5 | "num_hidden": 512, 6 | "LR": 0.00018765505300375018, 7 | "LR_f": 0.03615207171929937, 8 | "weight_decay_f": 0.0007562446446127072, 9 | "alpha": 1, 10 | "alpha_l": 1, 11 | "decay": 20, 12 | "norm": "batchnorm", 13 | "beta_schedule": "quad", 14 | "beta_1": 0.0003068153897945834, 15 | "beta_T": 0.007175081620921913, 16 | "T": 688.8956772672298 17 | } -------------------------------------------------------------------------------- /json_config/amazoncobuypho.json: -------------------------------------------------------------------------------- 1 | { 2 | "feat_drop": 0.2, 3 | "attn_drop": 0.30000000000000004, 4 | "nhead": 4, 5 | "num_hidden": 1024, 6 | "LR": 0.00014911063734877325, 7 | "LR_f": 0.001868720327448783, 8 | "weight_decay_f": 0.000016353299469262238, 9 | "alpha": 1, 10 | "alpha_l": 1, 11 | "decay": 20, 12 | "norm": "batchnorm", 13 | "beta_schedule": "sigmoid", 14 | "beta_1": 0.00009064518879876235, 15 | "beta_T": 0.06145430769759989, 16 | "T": 423.57693268035916 17 | } -------------------------------------------------------------------------------- /json_config/cora-832.json: -------------------------------------------------------------------------------- 1 | { 2 | "feat_drop": 0.1, 3 | "attn_drop": 0.30000000000000004, 4 | "nhead": 4, 5 | "num_hidden": 1024, 6 | "LR": 0.000059052598224579665, 7 | "LR_f": 0.0008306387422633891, 8 | "weight_decay_f": 0.000004357724640129446, 9 | "alpha": 1, 10 | "alpha_l": 3, 11 | "decay": 30, 12 | "norm": "layernorm", 13 | "beta_schedule": "sigmoid", 14 | "beta_1": 0.000498919890653928, 15 | "beta_T": 0.052885733911001744, 16 | "T": 870.1535571539575 17 | } -------------------------------------------------------------------------------- /json_config/imdbb.json: -------------------------------------------------------------------------------- 1 | { 2 | "feat_drop": 0.4, 3 | "attn_drop": 0.4, 4 | "nhead": 2, 5 | "num_hidden": 128, 6 | "LR": 0.00009091421925639108, 7 | "alpha": 0.8, 8 | "decay": 40, 9 | "norm": "layernorm", 10 | "beta_schedule": "sigmoid", 11 | "beta_1": 0.0003539425420564767, 12 | "beta_T": 0.04332511993329242, 13 | "T": 278.6269855370546 14 | } -------------------------------------------------------------------------------- /json_config/mutag.json: -------------------------------------------------------------------------------- 1 | { 2 | "feat_drop": 0.2, 3 | "attn_drop": 0.1, 4 | "nhead": 4, 5 | "LR": 0.00029244475954904524, 6 | "alpha": 1, 7 | "decay": 30, 8 | "norm": "layernorm", 9 | "beta_schedule": "sigmoid", 10 | "beta_1": 0.0003357064467122699, 11 | "beta_T": 0.03379792460462926, 12 | "T": 728.219972794016 13 | } -------------------------------------------------------------------------------- /json_config/pubmed.json: -------------------------------------------------------------------------------- 1 | { 2 | "feat_drop": 0.2, 3 | "attn_drop": 0.4, 4 | "nhead": 4, 5 | "num_hidden": 1024, 6 | "LR": 0.00017533667447018291, 7 | "LR_f": 0.0011968000247534107, 8 | "weight_decay_f": 0.0008085461520382234, 9 | "alpha": 1, 10 | "alpha_l": 2, 11 | "decay": 10, 12 | "norm": "layernorm", 13 | "beta_schedule": "const", 14 | "beta_1": 0.00045192826273998145, 15 | "beta_T": 0.09275724691731281, 16 | "T": 408.8727660489 17 | } -------------------------------------------------------------------------------- /json_config/reddit-b.json: -------------------------------------------------------------------------------- 1 | { 2 | "feat_drop": 0.4, 3 | "attn_drop": 0.2, 4 | "nhead": 4, 5 | "num_hidden": 512, 6 | "LR": 0.00018765505300375018, 7 | "LR_f": 0.03615207171929937, 8 | "weight_decay_f": 0.0007562446446127072, 9 | "alpha": 1, 10 | "alpha_l": 1, 11 | "decay": 20, 12 | "norm": "batchnorm", 13 | "beta_schedule": "quad", 14 | "beta_1": 0.0003068153897945834, 15 | "beta_T": 0.007175081620921913, 16 | "T": 688.8956772672298 17 | } -------------------------------------------------------------------------------- /nni_search/main_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File Name: main_graph.py 4 | # Author: Yang Run 5 | # Created Time: 2022-10-28 22:32 6 | # Last Modified: - 7 | import numpy as np 8 | 9 | import argparse 10 | 11 | import shutil 12 | import time 13 | import os.path as osp 14 | 15 | import dgl 16 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 17 | from dgl.dataloading import GraphDataLoader 18 | from dgl import RandomWalkPE 19 | 20 | import torch 21 | from torch.utils.data.sampler import SubsetRandomSampler 22 | import torch.nn as nn 23 | from dgl.nn.functional import edge_softmax 24 | 25 | from sklearn.model_selection import StratifiedKFold, GridSearchCV 26 | from sklearn.svm import SVC 27 | from sklearn.metrics import f1_score 28 | 29 | from utils.utils import (create_optimizer, create_pooler, set_random_seed, compute_ppr) 30 | 31 | from datasets.data_util import load_graph_classification_dataset 32 | 33 | from models import DDM 34 | 35 | # from config import config as cfg 36 | import multiprocessing 37 | from multiprocessing import Pool 38 | 39 | 40 | from utils import comm 41 | from utils.collect_env import collect_env_info 42 | from utils.logger import setup_logger 43 | from utils.misc import mkdir 44 | 45 | from evaluator import graph_classification_evaluation 46 | import yaml 47 | import nni 48 | from easydict import EasyDict as edict 49 | 50 | 51 | parser = argparse.ArgumentParser(description='Graph DGL Training') 52 | parser.add_argument('--resume', '-r', action='store_true', default=False, 53 | help='resume from checkpoint') 54 | parser.add_argument("--local_rank", type=int, default=0, help="local rank") 55 | parser.add_argument("--seed", type=int, default=1234, help="random seed") 56 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 57 | help='manual epoch number (useful on restarts)') 58 | args = parser.parse_args() 59 | 60 | 61 | def pretrain(model, train_loader, optimizer, device, epoch, logger=None): 62 | model.train() 63 | loss_list = [] 64 | for batch in train_loader: 65 | batch_g, _ = batch 66 | batch_g = batch_g.to(device) 67 | feat = batch_g.ndata["attr"] 68 | loss, loss_dict = model(batch_g, feat) 69 | optimizer.zero_grad() 70 | loss.backward() 71 | optimizer.step() 72 | loss_list.append(loss.item()) 73 | lr = optimizer.param_groups[0]['lr'] 74 | logger.info(f"Epoch {epoch} | train_loss: {np.mean(loss_list):.4f} | lr: {lr:.6f}") 75 | 76 | 77 | def collate_fn(batch): 78 | graphs = [x[0] for x in batch] 79 | labels = [x[1] for x in batch] 80 | batch_g = dgl.batch(graphs) 81 | labels = torch.cat(labels, dim=0) 82 | return batch_g, labels 83 | 84 | 85 | def save_checkpoint(state, is_best, filename): 86 | ckp = osp.join(filename, 'checkpoint.pth.tar') 87 | # ckp = filename + "checkpoint.pth.tar" 88 | torch.save(state, ckp) 89 | if is_best: 90 | shutil.copyfile(ckp, filename+'/model_best.pth.tar') 91 | 92 | 93 | def adjust_learning_rate(optimizer, epoch, alpha, decay, lr): 94 | """Sets the learning rate to the initial LR decayed by 10 every 80 epochs""" 95 | lr = lr * (alpha ** (epoch // decay)) 96 | for param_group in optimizer.param_groups: 97 | param_group['lr'] = lr 98 | 99 | 100 | def main(cfg): 101 | best_f1 = float('-inf') 102 | best_f1_epoch = float('inf') 103 | 104 | if cfg.output_dir: 105 | mkdir(cfg.output_dir) 106 | mkdir(cfg.checkpoint_dir) 107 | 108 | logger = setup_logger("graph", cfg.output_dir, comm.get_rank(), filename='train_log.txt') 109 | logger.info("Rank of current process: {}. World size: {}".format(comm.get_rank(), comm.get_world_size())) 110 | logger.info("Environment info:\n" + collect_env_info()) 111 | logger.info("Command line arguments: " + str(args)) 112 | 113 | shutil.copyfile('./params.yaml', cfg.output_dir + '/params.yaml') 114 | shutil.copyfile('./main_graph.py', cfg.output_dir + '/graph.py') 115 | shutil.copyfile('./models/DDM.py', cfg.output_dir + '/DDM.py') 116 | shutil.copyfile('./models/mlp_gat.py', cfg.output_dir + '/mlp_gat.py') 117 | 118 | graphs, (num_features, num_classes) = load_graph_classification_dataset(cfg.DATA.data_name, 119 | deg4feat=cfg.DATA.deg4feat, 120 | PE=False) 121 | cfg.num_features = num_features 122 | 123 | train_idx = torch.arange(len(graphs)) 124 | train_sampler = SubsetRandomSampler(train_idx) 125 | train_loader = GraphDataLoader(graphs, sampler=train_sampler, collate_fn=collate_fn, 126 | batch_size=cfg.DATALOADER.BATCH_SIZE, pin_memory=True) 127 | # eval_loader = GraphDataLoader(graphs, collate_fn=collate_fn, batch_size=len(graphs), shuffle=False) 128 | b0, b1 = [], [] 129 | for g, label in graphs: 130 | if label == 0: 131 | b0.append((g, label)) 132 | else: 133 | b1.append((g, label)) 134 | 135 | eval_loader = [] 136 | eval_loader.append(collate_fn(b0)) 137 | eval_loader.append(collate_fn(b1)) 138 | pooler = create_pooler(cfg.MODEL.pooler) 139 | 140 | acc_list = [] 141 | for i, seed in enumerate(cfg.seeds): 142 | logger.info(f'Run {i}th for seed {seed}') 143 | set_random_seed(seed) 144 | 145 | ml_cfg = cfg.MODEL 146 | ml_cfg.update({'in_dim': num_features}) 147 | model = DDM(**ml_cfg) 148 | total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 149 | logger.info('Total trainable params num : {}'.format(total_trainable_params)) 150 | model.to(cfg.DEVICE) 151 | 152 | optimizer = create_optimizer(cfg.SOLVER.optim_type, model, cfg.SOLVER.LR, cfg.SOLVER.weight_decay) 153 | 154 | start_epoch = 0 155 | if args.resume: 156 | if osp.isfile(cfg.pretrain_checkpoint_dir): 157 | logger.info("=> loading checkpoint '{}'".format(cfg.checkpoint_dir)) 158 | checkpoint = torch.load(cfg.checkpoint_dir, map_location=torch.device('cpu')) 159 | start_epoch = checkpoint['epoch'] 160 | model.load_state_dict(checkpoint['state_dict']) 161 | optimizer.load_state_dict(checkpoint['optimizer']) 162 | logger.info("=> loaded checkpoint '{}' (epoch {})" 163 | .format(cfg.checkpoint_dir, checkpoint['epoch'])) 164 | 165 | logger.info("----------Start Training----------") 166 | 167 | for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH): 168 | adjust_learning_rate(optimizer, epoch=epoch, alpha=cfg.SOLVER.alpha, decay=cfg.SOLVER.decay, lr=cfg.SOLVER.LR) 169 | pretrain(model, train_loader, optimizer, cfg.DEVICE, epoch, logger) 170 | if ((epoch + 1) % 1 == 0) & (epoch > 1): 171 | model.eval() 172 | test_f1 = graph_classification_evaluation(model, cfg.eval_T, pooler, eval_loader, 173 | cfg.DEVICE, logger) 174 | nni.report_intermediate_result(test_f1) 175 | is_best = test_f1 > best_f1 176 | if is_best: 177 | best_f1_epoch = epoch 178 | best_f1 = max(test_f1, best_f1) 179 | logger.info(f"Epoch {epoch}: get test f1 score: {test_f1: .3f}") 180 | logger.info(f"best_f1 {best_f1:.3f} at epoch {best_f1_epoch}") 181 | save_checkpoint({'epoch': epoch + 1, 182 | 'state_dict': model.state_dict(), 183 | 'best_f1': best_f1, 184 | 'optimizer': optimizer.state_dict()}, 185 | is_best, filename=cfg.checkpoint_dir) 186 | acc_list.append(best_f1) 187 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list) 188 | logger.info((f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}")) 189 | return final_acc 190 | 191 | 192 | if __name__ == "__main__": 193 | root_dir = osp.abspath(osp.dirname(__file__)) 194 | yaml_dir = osp.join(root_dir, 'params.yaml') 195 | output_dir = osp.join(root_dir, 'log') 196 | checkpoint_dir = osp.join(output_dir, "checkpoint") 197 | 198 | with open(yaml_dir, "r") as f: 199 | config = yaml.load(f, yaml.FullLoader) 200 | cfg = edict(config) 201 | 202 | cfg.output_dir, cfg.checkpoint_dir = output_dir, checkpoint_dir 203 | optimized_params = nni.get_next_parameter() 204 | # optimized_params = {} 205 | SOLVER_params = {} 206 | for key, value in cfg.SOLVER.items(): 207 | param_type = type(value) 208 | sp = optimized_params.get(key, value) 209 | if type(sp) == float and param_type == int: 210 | sp = int(sp) 211 | SOLVER_params[key] = sp 212 | cfg.SOLVER.update(SOLVER_params) 213 | MODEL_params = {} 214 | for key, value in cfg.MODEL.items(): 215 | param_type = type(value) 216 | sp = optimized_params.get(key, value) 217 | if type(sp) == float and param_type == int: 218 | sp = int(sp) 219 | MODEL_params[key] = sp 220 | cfg.MODEL.update(MODEL_params) 221 | 222 | print('---new cfg------') 223 | print(cfg) 224 | f1 = main(cfg) 225 | nni.report_final_result(f1) 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /nni_search/run_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File Name: run_search.py 4 | # Author: Yizhuo Quan 5 | # Created Time: 2023-03-05 09:38 6 | # Last Modified: - 7 | 8 | import json 9 | import signal 10 | 11 | file = open('search_space.json', 'r') 12 | search_space = json.load(file) 13 | 14 | from nni.experiment import Experiment 15 | experiment = Experiment('local') 16 | 17 | experiment.config.trial_command = 'python main_graph.py' 18 | experiment.config.trial_code_directory = '.' 19 | experiment.config.experiment_working_directory = '.' 20 | 21 | 22 | experiment.config.search_space = search_space 23 | experiment.config.tuner.name = 'TPE' 24 | 25 | experiment.config.tuner.class_args['optimize_mode'] = 'maximize' 26 | experiment.config.max_trial_number = 300 27 | experiment.config.trial_concurrency = 1 28 | experiment.run(6006) 29 | signal.pause() 30 | # experiment.stop() 31 | -------------------------------------------------------------------------------- /noise_com.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/noise_com.pdf -------------------------------------------------------------------------------- /noise_com.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeXAIS/DDM/c8dbfee8830fa61a354cc7c2e0aef1e8fb3d015a/noise_com.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | anyio==3.4.0 3 | argon2-cffi==21.1.0 4 | attrs==21.2.0 5 | Babel==2.9.1 6 | backcall==0.2.0 7 | bleach==4.1.0 8 | Bottleneck @ file:///opt/conda/conda-bld/bottleneck_1657175564434/work 9 | brotlipy==0.7.0 10 | cachetools==4.2.4 11 | certifi @ file:///croot/certifi_1671487769961/work/certifi 12 | cffi @ file:///tmp/build/80754af9/cffi_1625807838443/work 13 | chardet @ file:///tmp/build/80754af9/chardet_1607706746162/work 14 | conda==23.1.0 15 | conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1618262148928/work 16 | cryptography @ file:///tmp/build/80754af9/cryptography_1616769286105/work 17 | cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work 18 | Cython==0.29.35 19 | debugpy==1.5.1 20 | decorator==5.1.0 21 | defusedxml==0.7.1 22 | dgl==1.0.0+cu113 23 | dgl-cu113 @ file:///ddm/reddit-binary/dgl_cu113-100.1.5-py3-none-any.whl 24 | easydict==1.9 25 | entrypoints==0.3 26 | fonttools==4.28.2 27 | google-auth==2.3.3 28 | google-auth-oauthlib==0.4.6 29 | grpcio==1.42.0 30 | idna @ file:///home/linux1/recipes/ci/idna_1610986105248/work 31 | importlib-metadata==4.8.2 32 | importlib-resources==5.4.0 33 | ipykernel==6.5.1 34 | ipython==7.29.0 35 | ipython-genutils==0.2.0 36 | ipywidgets==7.6.5 37 | jedi==0.18.1 38 | Jinja2==3.0.3 39 | joblib==1.1.0 40 | json5==0.9.6 41 | jsonschema==4.2.1 42 | jupyter-client==7.1.0 43 | jupyter-core==4.9.1 44 | jupyter-server==1.12.0 45 | jupyterlab==3.2.4 46 | jupyterlab-language-pack-zh-CN==3.2.post2 47 | jupyterlab-pygments==0.1.2 48 | jupyterlab-server==2.8.2 49 | jupyterlab-widgets==1.0.2 50 | kiwisolver @ file:///croot/kiwisolver_1672387140495/work 51 | littleutils==0.2.2 52 | Markdown==3.3.6 53 | MarkupSafe==2.0.1 54 | matplotlib==3.5.0 55 | matplotlib-inline==0.1.3 56 | mistune==0.8.4 57 | mkl-fft==1.3.1 58 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work 59 | mkl-service==2.4.0 60 | nbclassic==0.3.4 61 | nbclient==0.5.9 62 | nbconvert==6.3.0 63 | nbformat==5.1.3 64 | nest-asyncio==1.5.1 65 | networkx==3.0rc1 66 | notebook==6.4.6 67 | numexpr @ file:///croot/numexpr_1668713893690/work 68 | numpy @ file:///croot/numpy_and_numpy_base_1672336185480/work 69 | oauthlib==3.1.1 70 | ogb==1.3.5 71 | outdated==0.2.2 72 | packaging @ file:///croot/packaging_1671697413597/work 73 | pandas==1.5.2 74 | pandocfilters==1.5.0 75 | parso==0.8.2 76 | patsy==0.5.2 77 | pexpect==4.8.0 78 | pickleshare==0.7.5 79 | Pillow==8.4.0 80 | pluggy @ file:///tmp/build/80754af9/pluggy_1648042571233/work 81 | pmdarima==2.0.2 82 | prometheus-client==0.12.0 83 | prompt-toolkit==3.0.22 84 | protobuf==3.19.1 85 | psutil==5.9.4 86 | ptyprocess==0.7.0 87 | pyasn1==0.4.8 88 | pyasn1-modules==0.2.8 89 | pycosat==0.6.3 90 | pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work 91 | Pygments==2.10.0 92 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1608057966937/work 93 | pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work 94 | pyrsistent==0.18.0 95 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work 96 | python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work 97 | pytz @ file:///croot/pytz_1671697431263/work 98 | PyYAML==6.0 99 | pyzmq==22.3.0 100 | requests @ file:///tmp/build/80754af9/requests_1608241421344/work 101 | requests-oauthlib==1.3.0 102 | rsa==4.8 103 | ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016699510/work 104 | ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work 105 | ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work 106 | scikit-learn==1.1.3 107 | scipy==1.9.3 108 | seaborn @ file:///croot/seaborn_1673479180098/work 109 | Send2Trash==1.8.0 110 | setuptools-scm==6.3.2 111 | six @ file:///tmp/build/80754af9/six_1623709665295/work 112 | sniffio==1.2.0 113 | statsmodels==0.13.5 114 | supervisor==4.2.2 115 | tensorboard==2.7.0 116 | tensorboard-data-server==0.6.1 117 | tensorboard-plugin-wit==1.8.0 118 | tensorboardX==2.5.1 119 | termcolor==2.1.1 120 | terminado==0.12.1 121 | testpath==0.5.0 122 | threadpoolctl==3.1.0 123 | tomli==1.2.2 124 | toolz @ file:///croot/toolz_1667464077321/work 125 | torch @ http://download.pytorch.org/whl/cu113/torch-1.10.0%2Bcu113-cp38-cp38-linux_x86_64.whl 126 | torchvision @ http://download.pytorch.org/whl/cu113/torchvision-0.11.1%2Bcu113-cp38-cp38-linux_x86_64.whl 127 | tornado @ file:///opt/conda/conda-bld/tornado_1662061693373/work 128 | tqdm @ file:///tmp/build/80754af9/tqdm_1625563689033/work 129 | traitlets==5.1.1 130 | typing-extensions==4.0.0 131 | urllib3 @ file:///tmp/build/80754af9/urllib3_1625084269274/work 132 | wcwidth==0.2.5 133 | webencodings==0.5.1 134 | websocket-client==1.2.1 135 | Werkzeug==2.0.2 136 | widgetsnbextension==3.5.2 137 | zipp==3.6.0 138 | --------------------------------------------------------------------------------