├── .gitignore ├── ImgFilter.py ├── README.md ├── RealWorld.py ├── bestHyperparams.py ├── data ├── 2Dgrid │ └── raw │ │ └── 2Dgrid.mat ├── chameleon.pt ├── film.pt └── squirrel.pt ├── dataset_image.py ├── dataset_utils.py ├── datasets.py └── impl ├── GDataset.py ├── PolyConv.py ├── __init__.py ├── metrics.py ├── models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | RealWorldWork.py 2 | ImgFilterWork.py 3 | RealWorld 4 | ImgFilter 5 | __pycache__ 6 | nohup.out 7 | impl/__pycache__ 8 | _data 9 | eigenvalues.npy 10 | eigenvectors.npy 11 | data/2Dgrid/processed 12 | data/*.pt 13 | data/ 14 | out 15 | work.py 16 | test.ipynb 17 | tune 18 | *.db 19 | *.npy -------------------------------------------------------------------------------- /ImgFilter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from impl import models, PolyConv, GDataset, utils 4 | import datasets 5 | from torch.optim import Adam 6 | import optuna 7 | import torch.nn as nn 8 | 9 | 10 | def split(): 11 | ''' 12 | Following `"BernNet: Learning Arbitrary Graph Spectral Filters via Bernstein Approximation", 13 | we remove edge pixels. 14 | ''' 15 | global masked_dataset 16 | masked_dataset = GDataset.GDataset(*baseG.get_split("valid")) 17 | 18 | 19 | def buildModel(conv_layer, aggr, alpha, image_idx=0, **kwargs): 20 | emb = models.Seq([models.TensorMod(baseG.x[:, image_idx].reshape(-1, 1))]) 21 | if args.power: 22 | conv_fn = PolyConv.PowerConv 23 | elif args.legendre: 24 | conv_fn = PolyConv.LegendreConv 25 | elif args.cheby: 26 | conv_fn = PolyConv.ChebyshevConv 27 | else: 28 | from functools import partial 29 | conv_fn = partial(PolyConv.JacobiConv, **kwargs) 30 | if args.fixalpha: 31 | from bestHyperparams import image_filter_alpha 32 | alpha = image_filter_alpha["power" if args.power else 33 | ("cheby" if args.cheby else "jacobi")][args.dataset] 34 | conv = PolyConv.PolyConvFrame(conv_fn, 35 | depth=conv_layer, 36 | aggr=aggr, 37 | alpha=alpha, 38 | fixed=args.fixalpha) 39 | comb = models.Combination(1, conv_layer + 1, args.sole) 40 | if args.bern: 41 | conv = PolyConv.Bern_prop(conv_layer) 42 | gnn = models.Gmodel(emb, conv, comb).to(device) 43 | return gnn 44 | 45 | 46 | def search_hyper_params(trial): 47 | conv_layer = 10 48 | aggr = "gcn" 49 | lr1 = trial.suggest_categorical("lr1", [0.001, 0.005, 0.01, 0.05]) 50 | lr2 = trial.suggest_categorical("lr2", [0.001, 0.005, 0.01, 0.05]) 51 | lr3 = trial.suggest_categorical("lr3", [0.001, 0.005, 0.01, 0.05]) 52 | wd1 = trial.suggest_categorical("wd1", [0.0, 1e-4, 5e-4, 1e-3]) 53 | wd2 = trial.suggest_categorical("wd2", [0.0, 1e-4, 5e-4, 1e-3]) 54 | wd3 = trial.suggest_categorical("wd3", [0.0, 1e-4, 5e-4, 1e-3]) 55 | alpha = trial.suggest_float('alpha', 0.5, 2.0, step=0.5) 56 | a = trial.suggest_float('a', -1.1, -0.0, step=0.05) 57 | b = trial.suggest_float('b', -0.2, 3.0, step=0.05) 58 | return work(conv_layer, 59 | aggr, 60 | alpha, 61 | lr1, 62 | lr2, 63 | lr3, 64 | wd1, 65 | wd2, 66 | wd3, 67 | a=a, 68 | b=b) 69 | 70 | 71 | def work(conv_layer: int = 10, 72 | aggr: str = "gcn", 73 | alpha: float = 1.0, 74 | lr1: float = 1e-2, 75 | lr2: float = 1e-2, 76 | lr3: float = 1e-2, 77 | wd1: float = 0, 78 | wd2: float = 0, 79 | wd3: float = 0, 80 | **kwargs): 81 | 82 | out_loss = [] 83 | for rep in range(args.repeat): 84 | out_loss.append([]) 85 | utils.set_seed(rep) 86 | for idx in range(50): 87 | y = masked_dataset.y[:, idx].reshape(-1, 1) 88 | gnn = buildModel(conv_layer, aggr, alpha, idx, **kwargs) 89 | optimizer = Adam([{ 90 | 'params': gnn.emb.parameters(), 91 | 'weight_decay': wd1, 92 | 'lr': lr1 93 | }, { 94 | 'params': gnn.conv.parameters(), 95 | 'weight_decay': wd2, 96 | 'lr': lr2 97 | }, { 98 | 'params': gnn.comb.parameters(), 99 | 'weight_decay': wd3, 100 | 'lr': lr3 101 | }]) 102 | best_loss = np.inf 103 | early_stop = 0 104 | gnn.train() 105 | for i in range(1000): 106 | optimizer.zero_grad() 107 | pred = gnn(masked_dataset.edge_index, masked_dataset.edge_attr, 108 | masked_dataset.mask) 109 | loss = torch.square(pred - y).sum() 110 | loss.backward() 111 | optimizer.step() 112 | loss = loss.item() 113 | if loss < best_loss: 114 | best_loss = loss 115 | early_stop = 0 116 | early_stop += 1 117 | if early_stop > 200: 118 | break 119 | out_loss[-1].append(best_loss) 120 | print( 121 | f"end loss {np.average(out_loss):.6e}" 122 | ) 123 | return np.average(out_loss) 124 | 125 | 126 | if __name__ == '__main__': 127 | args = utils.parse_args() 128 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 129 | baseG = datasets.load_dataset(args.dataset, args.split) 130 | baseG.to(device) 131 | masked_dataset = None 132 | output_channels = 1 133 | split() 134 | 135 | if args.test: 136 | from bestHyperparams import img_params 137 | print(work(**(img_params[args.dataset]))) 138 | else: 139 | study = optuna.create_study(direction="minimize", 140 | storage="sqlite:///" + args.path + 141 | args.name + ".db", 142 | study_name=args.name, 143 | load_if_exists=True) 144 | study.optimize(search_hyper_params, n_trials=args.optruns) 145 | print("best params ", study.best_params) 146 | print("best valf1 ", study.best_value) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # How Powerful are Spectral Graph Neural Networks 2 | 3 | This repository is the official implementation of the model in the [following paper](https://arxiv.org/abs/2205.11172v1): 4 | 5 | Xiyuan Wang, Muhan Zhang: How Powerful are Spectral Graph Neural Networks. ICML 2022 6 | 7 | ```{bibtex} 8 | @article{JacobiConv, 9 | author = {Xiyuan Wang and 10 | Muhan Zhang}, 11 | title = {How Powerful are Spectral Graph Neural Networks}, 12 | journal = {ICML}, 13 | year = {2022} 14 | } 15 | ``` 16 | 17 | #### Requirements 18 | Tested combination: Python 3.9.6 + [PyTorch 1.9.0](https://pytorch.org/get-started/previous-versions/) + [PyTorch_Geometric 2.0.3](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) + [PyTorch Sparse 0.6.12](https://github.com/rusty1s/pytorch_sparse) 19 | 20 | Other required python libraries include: numpy, scikit-learn, optuna, seaborn etc. 21 | 22 | ### Reproduce Our Results 23 | 24 | #### Image Filter Tasks 25 | 26 | To reproduce results of JacobiConv on image datasets: 27 | ``` 28 | python ImgFilter.py --test --repeat 1 --dataset $dataset --fixalpha 29 | ``` 30 | where $dataset is selected from low, high, rejection, band, and comb. 31 | 32 | To reproduce results of linear GNN with other bases: 33 | ``` 34 | python ImgFilter.py --test --$basis --repeat 1 --dataset $dataset --fixalpha 35 | ``` 36 | where $basis is selected from cheby, power, and bern. 37 | 38 | 39 | We use optuna to select hyperparameters. 40 | ``` 41 | python ImgFilter.py --optruns 100 --dataset $dataset --path $dir --name $dataset 42 | ``` 43 | The record file of optuna will be put in directory $dir. 44 | 45 | #### Real-World Tasks 46 | 47 | To reproduce results of JacobiConv on real-world datasets: 48 | ``` 49 | python RealWorld.py --test --repeat 10 --dataset $dataset --split dense 50 | ``` 51 | where $dataset is selected from pubmed, computers, squirrel, photo, chameleon, film, cora, citeseer, texas, cornell. 52 | 53 | To reproduce results of linear GNN with other bases: 54 | ``` 55 | python RealWorld.py --test --$basis --fixalpha --repeat 10 --dataset $dataset --split dense 56 | ``` 57 | where $basis is selected from cheby, power, and bern. 58 | 59 | To reproduce other ablation studies: 60 | 61 | Unifilter 62 | ``` 63 | python RealWorld.py --test --repeat 10 --dataset $dataset --split dense --sole 64 | ``` 65 | No-PCD 66 | ``` 67 | python RealWorld.py --test --repeat 10 --dataset $dataset --split dense --fixalpha 68 | ``` 69 | NL-RES 70 | ``` 71 | python RealWorld.py --test --repeat 10 --dataset $dataset --split dense --resmultilayer 72 | ``` 73 | NL 74 | ``` 75 | python RealWorld.py --test --repeat 10 --dataset $dataset --split dense --multilayer 76 | ``` 77 | 78 | To select hyperparameters: 79 | ``` 80 | python RealWorld.py --repeat 3 --optruns 400 --split dense --dataset $dataset --path $dir --name $dataset 81 | ``` 82 | The record file of optuna will be put in directory $dir. 83 | 84 | -------------------------------------------------------------------------------- /RealWorld.py: -------------------------------------------------------------------------------- 1 | from impl import metrics, PolyConv, models, GDataset, utils 2 | import datasets 3 | import torch 4 | from torch.optim import Adam 5 | import optuna 6 | import torch.nn as nn 7 | import numpy as np 8 | import seaborn as sns 9 | 10 | 11 | def split(): 12 | global baseG, trn_dataset, val_dataset, tst_dataset 13 | baseG.mask = datasets.split(baseG, split=args.split) 14 | trn_dataset = GDataset.GDataset(*baseG.get_split("train")) 15 | val_dataset = GDataset.GDataset(*baseG.get_split("valid")) 16 | tst_dataset = GDataset.GDataset(*baseG.get_split("test")) 17 | 18 | 19 | def buildModel(conv_layer: int = 10, 20 | aggr: str = "gcn", 21 | alpha: float = 0.2, 22 | dpb: float = 0.0, 23 | dpt: float = 0.0, 24 | **kwargs): 25 | if args.multilayer: 26 | emb = models.Seq([ 27 | models.TensorMod(baseG.x), 28 | nn.Dropout(p=dpb), 29 | nn.Sequential(nn.Linear(baseG.x.shape[1], output_channels), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(output_channels, output_channels)), 32 | nn.Dropout(dpt, inplace=True) 33 | ]) 34 | elif args.resmultilayer: 35 | emb = models.Seq([ 36 | models.TensorMod(baseG.x), 37 | nn.Dropout(p=dpb), 38 | nn.Linear(baseG.x.shape[1], output_channels), 39 | models.ResBlock( 40 | nn.Sequential(nn.ReLU(inplace=True), 41 | nn.Linear(output_channels, output_channels))), 42 | nn.Dropout(dpt, inplace=True) 43 | ]) 44 | else: 45 | emb = models.Seq([ 46 | models.TensorMod(baseG.x), 47 | nn.Dropout(p=dpb), 48 | nn.Linear(baseG.x.shape[1], output_channels), 49 | nn.Dropout(dpt, inplace=True) 50 | ]) 51 | 52 | from functools import partial 53 | 54 | frame_fn = PolyConv.PolyConvFrame 55 | conv_fn = partial(PolyConv.JacobiConv, **kwargs) 56 | if args.power: 57 | conv_fn = PolyConv.PowerConv 58 | if args.legendre: 59 | conv_fn = PolyConv.LegendreConv 60 | if args.cheby: 61 | conv_fn = PolyConv.ChebyshevConv 62 | 63 | if args.bern: 64 | conv = PolyConv.Bern_prop(conv_layer) 65 | else: 66 | if args.fixalpha: 67 | from bestHyperparams import fixalpha_alpha 68 | alpha = fixalpha_alpha[args.dataset]["power" if args.power else ( 69 | "cheby" if args.cheby else "jacobi")] 70 | conv = frame_fn(conv_fn, 71 | depth=conv_layer, 72 | aggr=aggr, 73 | alpha=alpha, 74 | fixed=args.fixalpha) 75 | comb = models.Combination(output_channels, conv_layer + 1, sole=args.sole) 76 | gnn = models.Gmodel(emb, conv, comb).to(device) 77 | return gnn 78 | 79 | 80 | def work(conv_layer: int = 10, 81 | aggr: str = "gcn", 82 | alpha: float = 0.2, 83 | lr1: float = 1e-3, 84 | lr2: float = 1e-3, 85 | lr3: float = 1e-3, 86 | wd1: float = 0, 87 | wd2: float = 0, 88 | wd3: float = 0, 89 | dpb=0.0, 90 | dpt=0.0, 91 | **kwargs): 92 | outs = [] 93 | for rep in range(args.repeat): 94 | utils.set_seed(rep) 95 | split() 96 | gnn = buildModel(conv_layer, aggr, alpha, dpb, dpt, **kwargs) 97 | optimizer = Adam([{ 98 | 'params': gnn.emb.parameters(), 99 | 'weight_decay': wd1, 100 | 'lr': lr1 101 | }, { 102 | 'params': gnn.conv.parameters(), 103 | 'weight_decay': wd2, 104 | 'lr': lr2 105 | }, { 106 | 'params': gnn.comb.parameters(), 107 | 'weight_decay': wd3, 108 | 'lr': lr3 109 | }]) 110 | val_score = 0 111 | early_stop = 0 112 | for i in range(1000): 113 | utils.train(optimizer, gnn, trn_dataset, loss_fn) 114 | score, _ = utils.test(gnn, val_dataset, score_fn, loss_fn=loss_fn) 115 | if score >= val_score: 116 | early_stop = 0 117 | val_score = score 118 | else: 119 | early_stop += 1 120 | if early_stop > 200: 121 | break 122 | outs.append(val_score) 123 | return np.average(outs) 124 | 125 | 126 | def search_hyper_params(trial: optuna.Trial): 127 | conv_layer = 10 128 | aggr = "gcn" 129 | lr1 = trial.suggest_categorical("lr1", [0.0005, 0.001, 0.005, 0.01, 0.05]) 130 | lr2 = trial.suggest_categorical("lr2", [0.0005, 0.001, 0.005, 0.01, 0.05]) 131 | lr3 = trial.suggest_categorical("lr3", [0.0005, 0.001, 0.005, 0.01, 0.05]) 132 | wd1 = trial.suggest_categorical("wd1", [0.0, 5e-5, 1e-4, 5e-4, 1e-3]) 133 | wd2 = trial.suggest_categorical("wd2", [0.0, 5e-5, 1e-4, 5e-4, 1e-3]) 134 | wd3 = trial.suggest_categorical("wd3", [0.0, 5e-5, 1e-4, 5e-4, 1e-3]) 135 | alpha = trial.suggest_float('alpha', 0.5, 2.0, step=0.5) 136 | a = trial.suggest_float('a', -1.0, 2.0, step=0.25) 137 | b = trial.suggest_float('b', -0.5, 2.0, step=0.25) 138 | dpb = trial.suggest_float("dpb", 0.0, 0.9, step=0.1) 139 | dpt = trial.suggest_float("dpt", 0.0, 0.9, step=0.1) 140 | return work(conv_layer, 141 | aggr, 142 | alpha, 143 | lr1, 144 | lr2, 145 | lr3, 146 | wd1, 147 | wd2, 148 | wd3, 149 | dpb, 150 | dpt, 151 | a=a, 152 | b=b) 153 | 154 | 155 | def test(conv_layer=10, 156 | aggr="gcn", 157 | alpha=1.0, 158 | lr1=1e-2, 159 | lr2=1e-2, 160 | lr3=1e-2, 161 | wd1=0.0, 162 | wd2=0.0, 163 | wd3=0.0, 164 | dpb=0.0, 165 | dpt=0.0, 166 | **kwargs): 167 | outs = [] 168 | vals = [] 169 | for rep in range(args.repeat): 170 | print("repeat ", rep) 171 | utils.set_seed(rep) 172 | split() 173 | gnn = buildModel(conv_layer, aggr, alpha, dpb, dpt, **kwargs) 174 | optimizer = Adam([{ 175 | 'params': gnn.emb.parameters(), 176 | 'weight_decay': wd1, 177 | 'lr': lr1 178 | }, { 179 | 'params': gnn.conv.parameters(), 180 | 'weight_decay': wd2, 181 | 'lr': lr2 182 | }, { 183 | 'params': gnn.comb.parameters(), 184 | 'weight_decay': wd3, 185 | 'lr': lr3 186 | }]) 187 | val_score = 0 188 | tst_score = 0 189 | early_stop = 0 190 | for i in range(1000): 191 | utils.train(optimizer, gnn, trn_dataset, loss_fn) 192 | score, _ = utils.test(gnn, val_dataset, score_fn, loss_fn=loss_fn) 193 | if score >= val_score: 194 | early_stop = 0 195 | val_score = score 196 | if args.savemodel: 197 | torch.save(gnn.state_dict(), f"{args.dataset}_{rep}.pt") 198 | tst_score, _ = utils.test(gnn, 199 | tst_dataset, 200 | score_fn, 201 | loss_fn=loss_fn) 202 | else: 203 | early_stop += 1 204 | if early_stop > 200: 205 | break 206 | vals.append(val_score) 207 | outs.append(tst_score) 208 | outs = np.array(outs) 209 | print( 210 | f"avg {np.average(outs):.4f} error {np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(outs,func=np.mean,n_boot=1000),95)-outs.mean())):.4f}" 211 | ) 212 | return np.average(outs) 213 | 214 | 215 | if __name__ == '__main__': 216 | args = utils.parse_args() 217 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 218 | 219 | baseG = datasets.load_dataset(args.dataset, args.split) 220 | baseG.to(device) 221 | trn_dataset, val_dataset, tst_dataset = None, None, None 222 | output_channels = baseG.y.unique().shape[0] 223 | 224 | loss_fn = nn.CrossEntropyLoss() 225 | score_fn = metrics.multiclass_accuracy 226 | split() 227 | 228 | if args.test: 229 | from bestHyperparams import realworld_params 230 | best_hyperparams = realworld_params 231 | print(test(**(best_hyperparams[args.dataset]))) 232 | else: 233 | study = optuna.create_study(direction="maximize", 234 | storage="sqlite:///" + args.path + 235 | args.name + ".db", 236 | study_name=args.name, 237 | load_if_exists=True) 238 | study.optimize(search_hyper_params, n_trials=args.optruns) 239 | print("best params ", study.best_params) 240 | print("best valf1 ", study.best_value) 241 | -------------------------------------------------------------------------------- /bestHyperparams.py: -------------------------------------------------------------------------------- 1 | img_params = { 2 | "band": { 3 | 'a': -0.75, 4 | 'alpha': 1.5, 5 | 'b': 0.75, 6 | 'lr1': 0.001, 7 | 'lr2': 0.001, 8 | 'lr3': 0.05, 9 | 'wd1': 0.0, 10 | 'wd2': 0.0, 11 | 'wd3': 0.0001 12 | }, 13 | "comb": { 14 | 'a': -0.75, 15 | 'alpha': 2.0, 16 | 'b': 0.75, 17 | 'lr1': 0.05, 18 | 'lr2': 0.001, 19 | 'lr3': 0.05, 20 | 'wd1': 0.001, 21 | 'wd2': 0.0001, 22 | 'wd3': 0.0 23 | }, 24 | "high": { 25 | 'lr1': 0.001, 26 | 'lr2': 0.001, 27 | 'lr3': 0.05, 28 | 'wd1': 0.0005, 29 | 'wd2': 0.0, 30 | 'wd3': 0.0001, 31 | 'alpha': 2.0, 32 | 'a': -0.85, 33 | 'b': 0.95 34 | }, 35 | "low": { 36 | 'a': -0.75, 37 | 'alpha': 1.5, 38 | 'b': 0.75, 39 | 'lr1': 0.05, 40 | 'lr2': 0.001, 41 | 'lr3': 0.05, 42 | 'wd1': 0.0005, 43 | 'wd2': 0.0005, 44 | 'wd3': 0.0001 45 | }, 46 | "rejection": { 47 | 'a': -0.75, 48 | 'alpha': 2.0, 49 | 'b': 0.5, 50 | 'lr1': 0.05, 51 | 'lr2': 0.001, 52 | 'lr3': 0.05, 53 | 'wd1': 0.001, 54 | 'wd2': 0.0001, 55 | 'wd3': 0.001 56 | } 57 | } 58 | 59 | image_filter_alpha = { 60 | "power": { 61 | 'low': 0.5, 62 | 'high': 0.5, 63 | 'band': 1.0, 64 | 'rejection': 1.0, 65 | 'comb': 1.5 66 | }, 67 | "cheby": { 68 | 'low': 1.0, 69 | 'high': 0.5, 70 | 'band': 0.5, 71 | 'rejection': 0.5, 72 | 'comb': 0.5 73 | }, 74 | "jacobi": { 75 | 'low': 1.0, 76 | 'high': 2.0, 77 | 'band': 2.0, 78 | 'rejection': 2.0, 79 | 'comb': 2.0 80 | } 81 | } 82 | 83 | realworld_params = { 84 | 'cora': { 85 | 'a': 2.0, 86 | 'alpha': 0.5, 87 | 'b': -0.25, 88 | 'dpb': 0.5, 89 | 'dpt': 0.7, 90 | 'lr1': 0.05, 91 | 'lr2': 0.01, 92 | 'lr3': 0.01, 93 | 'wd1': 0.001, 94 | 'wd2': 0.0001, 95 | 'wd3': 5e-05 96 | }, 97 | 'citeseer': { 98 | 'a': -0.5, 99 | 'alpha': 0.5, 100 | 'b': -0.5, 101 | 'dpb': 0.9, 102 | 'dpt': 0.8, 103 | 'lr1': 0.05, 104 | 'lr2': 0.001, 105 | 'lr3': 0.01, 106 | 'wd1': 5e-05, 107 | 'wd2': 0.0, 108 | 'wd3': 0.001 109 | }, 110 | 'pubmed': { 111 | 'a': 1.5, 112 | 'alpha': 0.5, 113 | 'b': 0.25, 114 | 'dpb': 0.0, 115 | 'dpt': 0.5, 116 | 'lr1': 0.05, 117 | 'lr2': 0.05, 118 | 'lr3': 0.05, 119 | 'wd1': 0.0005, 120 | 'wd2': 0.0005, 121 | 'wd3': 0.0 122 | }, 123 | 'computers': { 124 | 'a': 1.75, 125 | 'alpha': 1.5, 126 | 'b': -0.5, 127 | 'dpb': 0.8, 128 | 'dpt': 0.2, 129 | 'lr1': 0.05, 130 | 'lr2': 0.05, 131 | 'lr3': 0.05, 132 | 'wd1': 0.0001, 133 | 'wd2': 0.0, 134 | 'wd3': 0.0 135 | }, 136 | 'photo': { 137 | 'a': 1.0, 138 | 'alpha': 1.5, 139 | 'b': 0.25, 140 | 'dpb': 0.3, 141 | 'dpt': 0.3, 142 | 'lr1': 0.05, 143 | 'lr2': 0.0005, 144 | 'lr3': 0.05, 145 | 'wd1': 5e-05, 146 | 'wd2': 0.0, 147 | 'wd3': 0.0 148 | }, 149 | 'chameleon': { 150 | 'a': 0.0, 151 | 'alpha': 2.0, 152 | 'b': 0.0, 153 | 'dpb': 0.6, 154 | 'dpt': 0.5, 155 | 'lr1': 0.05, 156 | 'lr2': 0.01, 157 | 'lr3': 0.05, 158 | 'wd1': 0.0, 159 | 'wd2': 0.0001, 160 | 'wd3': 0.0005 161 | }, 162 | 'film': { 163 | 'a': -1.0, 164 | 'alpha': 1.0, 165 | 'b': 0.5, 166 | 'dpb': 0.9, 167 | 'dpt': 0.7, 168 | 'lr1': 0.05, 169 | 'lr2': 0.05, 170 | 'lr3': 0.01, 171 | 'wd1': 0.001, 172 | 'wd2': 0.0005, 173 | 'wd3': 0.001 174 | }, 175 | 'squirrel': { 176 | 'a': 0.5, 177 | 'alpha': 2.0, 178 | 'b': 0.25, 179 | 'dpb': 0.4, 180 | 'dpt': 0.1, 181 | 'lr1': 0.01, 182 | 'lr2': 0.01, 183 | 'lr3': 0.05, 184 | 'wd1': 5e-05, 185 | 'wd2': 0.0, 186 | 'wd3': 0.0 187 | }, 188 | "texas": { 189 | 'a': -0.5, 190 | 'alpha': 0.5, 191 | 'b': 0.0, 192 | 'dpb': 0.8, 193 | 'dpt': 0.7, 194 | 'lr1': 0.05, 195 | 'lr2': 0.005, 196 | 'lr3': 0.01, 197 | 'wd1': 0.001, 198 | 'wd2': 0.0005, 199 | 'wd3': 0.0005 200 | }, 201 | "cornell": { 202 | 'a': -0.75, 203 | 'alpha': 0.5, 204 | 'b': 0.25, 205 | 'dpb': 0.4, 206 | 'dpt': 0.7, 207 | 'lr1': 0.05, 208 | 'lr2': 0.005, 209 | 'lr3': 0.001, 210 | 'wd1': 0.0005, 211 | 'wd2': 0.0005, 212 | 'wd3': 0.0001 213 | } 214 | } 215 | 216 | fixalpha_alpha = { 217 | "cora": { 218 | "power": 1.0, 219 | "cheby": 0.5, 220 | "jacobi": 1.0 221 | }, 222 | "citeseer": { 223 | "power": 0.5, 224 | "cheby": 0.5, 225 | "jacobi": 0.5 226 | }, 227 | "pubmed": { 228 | "power": 1.0, 229 | "cheby": 1.0, 230 | "jacobi": 1.0 231 | }, 232 | "computers": { 233 | "power": 2.0, 234 | "cheby": 1.5, 235 | "jacobi": 1.5 236 | }, 237 | "photo": { 238 | "power": 2.0, 239 | "cheby": 1.0, 240 | "jacobi": 1.5 241 | }, 242 | "chameleon": { 243 | "power": 2.0, 244 | "cheby": 2.0, 245 | "jacobi": 2.0 246 | }, 247 | "film": { 248 | "power": 0.5, 249 | "cheby": 1.0, 250 | "jacobi": 0.5 251 | }, 252 | "squirrel": { 253 | "power": 2.0, 254 | "cheby": 2.0, 255 | "jacobi": 2.0 256 | }, 257 | "texas": { 258 | "power": 0.5, 259 | "cheby": 0.5, 260 | "jacobi": 1.0 261 | }, 262 | "cornell": { 263 | "power": 0.5, 264 | "cheby": 0.5, 265 | "jacobi": 0.5 266 | }, 267 | } 268 | -------------------------------------------------------------------------------- /data/2Dgrid/raw/2Dgrid.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/JacobiConv/5e9e671fac63e680e0681fa9b4d8074960d2d65e/data/2Dgrid/raw/2Dgrid.mat -------------------------------------------------------------------------------- /data/chameleon.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/JacobiConv/5e9e671fac63e680e0681fa9b4d8074960d2d65e/data/chameleon.pt -------------------------------------------------------------------------------- /data/film.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/JacobiConv/5e9e671fac63e680e0681fa9b4d8074960d2d65e/data/film.pt -------------------------------------------------------------------------------- /data/squirrel.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/JacobiConv/5e9e671fac63e680e0681fa9b4d8074960d2d65e/data/squirrel.pt -------------------------------------------------------------------------------- /dataset_image.py: -------------------------------------------------------------------------------- 1 | # copied from https://github.com/ivam-he/BernNet 2 | # load the image dataset from the `"BernNet: Learning Arbitrary Graph Spectral Filters via Bernstein Approximation" paper 3 | from torch_geometric.data import InMemoryDataset 4 | import torch 5 | from torch_geometric.data.data import Data 6 | import scipy.io as sio 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from torch_geometric.utils import to_scipy_sparse_matrix 10 | import os 11 | from numpy.linalg import eigh 12 | import math 13 | 14 | filter_type = ['low', 'high', 'band', 'rejection', 'comb', 'low_band'] 15 | 16 | 17 | class TwoDGrid(InMemoryDataset): 18 | def __init__(self, root="./data/2Dgrid", transform=None, pre_transform=None): 19 | super(TwoDGrid, self).__init__(root, transform, pre_transform) 20 | self.data, self.slices = torch.load(self.processed_paths[0]) 21 | 22 | @property 23 | def raw_file_names(self): 24 | return ["2Dgrid.mat"] 25 | 26 | @property 27 | def processed_file_names(self): 28 | return 'data.pt' 29 | 30 | def download(self): 31 | pass 32 | 33 | def process(self): 34 | # Read data into huge `Data` list. 35 | b = self.processed_paths[0] 36 | a = sio.loadmat(self.raw_paths[0]) # 'subgraphcount/randomgraph.mat') 37 | # list of adjacency matrix 38 | A = a['A'] 39 | # list of output 40 | F = a['F'] 41 | F = F.astype(np.float32) 42 | # Y=a['Y'] 43 | # Y=Y.astype(np.float32) 44 | M = a['mask'] 45 | M = M.astype(np.float32) 46 | 47 | data_list = [] 48 | E = np.where(A > 0) 49 | edge_index = torch.Tensor(np.vstack((E[0], E[1]))).type(torch.int64) 50 | x = torch.tensor(F) 51 | # y=torch.tensor(Y) 52 | m = torch.tensor(M) 53 | 54 | x_tmp = x[:, 0:1] 55 | data_list.append(Data(edge_index=edge_index, x=x, x_tmp=x_tmp, m=m)) 56 | 57 | if self.pre_filter is not None: 58 | data_list = [data for data in data_list if self.pre_filter(data)] 59 | 60 | if self.pre_transform is not None: 61 | data_list = [self.pre_transform(data) for data in data_list] 62 | 63 | data, slices = self.collate(data_list) 64 | torch.save((data, slices), self.processed_paths[0]) 65 | 66 | 67 | def visualize(y): 68 | # y=tensor.detach().cpu().numpy() 69 | y = np.reshape(y, (100, 100)) 70 | plt.imshow(y.T) 71 | plt.colorbar() 72 | plt.show() 73 | 74 | 75 | def myeign(L): 76 | if os.path.exists('./data/eigenvalues.npy') and os.path.exists('./data/eigenvectors.npy'): 77 | eigenvalues = np.load('./data/eigenvalues.npy') 78 | eigenvectors = np.load('./data/eigenvectors.npy') 79 | else: 80 | eigenvalues, eigenvectors = eigh(L) 81 | np.save('./data/eigenvalues.npy', eigenvalues) 82 | np.save('./data/eigenvectors.npy', eigenvectors) 83 | return eigenvalues, eigenvectors 84 | 85 | 86 | def filtering(filter_type, dataset): 87 | data = dataset[0] 88 | x = data.x.numpy() 89 | 90 | # print(data.edge_index) 91 | adj = to_scipy_sparse_matrix(data.edge_index).todense() 92 | nnodes = adj.shape[0] 93 | D_vec = np.sum(adj, axis=1).A1 94 | # print(D_vec.tolist()) 95 | D_vec_invsqrt_corr = 1 / np.sqrt(D_vec) 96 | D_invsqrt_corr = np.diag(D_vec_invsqrt_corr) 97 | # print(D_invsqrt_corr) 98 | L = np.eye(nnodes)-D_invsqrt_corr @ adj @ D_invsqrt_corr 99 | # print(L) 100 | eigenvalues, eigenvectors = myeign(L) 101 | # print(eigenvalues[3]) 102 | 103 | # low-pass 104 | if filter_type == 'low': 105 | value_tmp = [math.exp(-10*(xxx-0)**2) for xxx in eigenvalues] 106 | 107 | # high-pass 108 | elif filter_type == 'high': 109 | value_tmp = [1-math.exp(-10*(xxx-0)**2) for xxx in eigenvalues] 110 | 111 | # band-pass 112 | elif filter_type == 'band': 113 | value_tmp = [math.exp(-10*(xxx-1)**2) for xxx in eigenvalues] 114 | 115 | # band_rejection 116 | elif filter_type == 'rejection': 117 | value_tmp = [1-math.exp(-10*(xxx-1)**2) for xxx in eigenvalues] 118 | 119 | # comb 120 | elif filter_type == 'comb': 121 | value_tmp = [abs(np.sin(xxx*math.pi)) for xxx in eigenvalues] 122 | 123 | # low_band 124 | elif filter_type == 'low_band': 125 | y = [] 126 | for i in eigenvalues: 127 | if i < 0.5: 128 | y.append(1) 129 | elif i < 1 and i >= 0.5: 130 | y.append(math.exp(-100*(i-0.5)**2)) 131 | else: 132 | y.append(math.exp(-50*(i-1.5)**2)) 133 | value_tmp = y 134 | 135 | value_tmp = np.array(value_tmp) 136 | value_tmp = np.diag(value_tmp) 137 | # print(value_tmp[5000][5000]) 138 | 139 | y = eigenvectors@value_tmp@eigenvectors.T@x 140 | np.save('y_'+filter_type+'.npy', y) 141 | return y 142 | 143 | 144 | def load_img(name): 145 | ds = TwoDGrid(root='data/2Dgrid', pre_transform=None) 146 | y = filtering(name, ds) 147 | y = torch.Tensor(y) 148 | data = ds[0] 149 | x = data.x 150 | ei = data.edge_index 151 | ea = torch.ones((ei[1].shape[0])) 152 | mask = data.m.to(torch.long) 153 | return x, y, ei, ea, mask 154 | -------------------------------------------------------------------------------- /dataset_utils.py: -------------------------------------------------------------------------------- 1 | # copied from https://github.com/ivam-he/BernNet 2 | # load the real-world dataset from the `"BernNet: Learning Arbitrary Graph Spectral Filters via Bernstein Approximation" paper 3 | import torch 4 | import pickle 5 | import os.path as osp 6 | import os 7 | import torch_geometric.transforms as T 8 | from torch_geometric.data import InMemoryDataset, download_url, Data 9 | from torch_geometric.datasets import Planetoid, Amazon, WikipediaNetwork, Actor 10 | from torch_sparse import coalesce 11 | from torch_geometric.utils.undirected import to_undirected 12 | 13 | 14 | def index_to_mask(index, size): 15 | mask = torch.zeros(size, dtype=torch.bool, device=index.device) 16 | mask[index] = 1 17 | return mask 18 | 19 | 20 | # GPRGNN 21 | def random_planetoid_splits(data, 22 | num_classes, 23 | percls_trn=20, 24 | val_lb=500, 25 | Flag=0): 26 | indices = [] 27 | for i in range(num_classes): 28 | index = (data.y == i).nonzero().view(-1) 29 | index = index[torch.randperm(index.size(0), device=index.device)] 30 | indices.append(index) 31 | 32 | train_index = torch.cat([i[:percls_trn] for i in indices], dim=0) 33 | 34 | if Flag == 0: 35 | rest_index = torch.cat([i[percls_trn:] for i in indices], dim=0) 36 | rest_index = rest_index[torch.randperm(rest_index.size(0))] 37 | 38 | train_mask = index_to_mask(train_index, size=data.num_nodes) 39 | val_mask = index_to_mask(rest_index[:val_lb], size=data.num_nodes) 40 | test_mask = index_to_mask(rest_index[val_lb:], size=data.num_nodes) 41 | else: 42 | val_index = torch.cat( 43 | [i[percls_trn:percls_trn + val_lb] for i in indices], dim=0) 44 | rest_index = torch.cat([i[percls_trn + val_lb:] for i in indices], 45 | dim=0) 46 | rest_index = rest_index[torch.randperm(rest_index.size(0))] 47 | 48 | train_mask = index_to_mask(train_index, size=data.num_nodes) 49 | val_mask = index_to_mask(val_index, size=data.num_nodes) 50 | test_mask = index_to_mask(rest_index, size=data.num_nodes) 51 | return train_mask, val_mask, test_mask 52 | 53 | 54 | class dataset_heterophily(InMemoryDataset): 55 | def __init__(self, 56 | root='data/', 57 | name=None, 58 | p2raw=None, 59 | train_percent=0.01, 60 | transform=None, 61 | pre_transform=None): 62 | 63 | existing_dataset = ['chameleon', 'film', 'squirrel'] 64 | if name not in existing_dataset: 65 | raise ValueError( 66 | f'name of hypergraph dataset must be one of: {existing_dataset}' 67 | ) 68 | else: 69 | self.name = name 70 | 71 | self._train_percent = train_percent 72 | 73 | if (p2raw is not None) and osp.isdir(p2raw): 74 | self.p2raw = p2raw 75 | elif p2raw is None: 76 | self.p2raw = None 77 | elif not osp.isdir(p2raw): 78 | raise ValueError( 79 | f'path to raw hypergraph dataset "{p2raw}" does not exist!') 80 | 81 | if not osp.isdir(root): 82 | os.makedirs(root) 83 | 84 | self.root = root 85 | 86 | super(dataset_heterophily, self).__init__(root, transform, 87 | pre_transform) 88 | 89 | self.data, self.slices = torch.load(self.processed_paths[0]) 90 | self.train_percent = self.data.train_percent 91 | 92 | @property 93 | def raw_dir(self): 94 | return osp.join(self.root, self.name, 'raw') 95 | 96 | @property 97 | def processed_dir(self): 98 | return osp.join(self.root, self.name, 'processed') 99 | 100 | @property 101 | def raw_file_names(self): 102 | file_names = [self.name] 103 | return file_names 104 | 105 | @property 106 | def processed_file_names(self): 107 | return ['data.pt'] 108 | 109 | def download(self): 110 | pass 111 | 112 | def process(self): 113 | p2f = osp.join(self.raw_dir, self.name) 114 | with open(p2f, 'rb') as f: 115 | data = pickle.load(f) 116 | data = data if self.pre_transform is None else self.pre_transform(data) 117 | torch.save(self.collate([data]), self.processed_paths[0]) 118 | 119 | def __repr__(self): 120 | return '{}()'.format(self.name) 121 | 122 | 123 | class WebKB(InMemoryDataset): 124 | r"""The WebKB datasets used in the 125 | `"Geom-GCN: Geometric Graph Convolutional Networks" 126 | `_ paper. 127 | Nodes represent web pages and edges represent hyperlinks between them. 128 | Node features are the bag-of-words representation of web pages. 129 | The task is to classify the nodes into one of the five categories, student, 130 | project, course, staff, and faculty. 131 | 132 | Args: 133 | root (string): Root directory where the dataset should be saved. 134 | name (string): The name of the dataset (:obj:`"Cornell"`, 135 | :obj:`"Texas"` :obj:`"Washington"`, :obj:`"Wisconsin"`). 136 | transform (callable, optional): A function/transform that takes in an 137 | :obj:`torch_geometric.data.Data` object and returns a transformed 138 | version. The data object will be transformed before every access. 139 | (default: :obj:`None`) 140 | pre_transform (callable, optional): A function/transform that takes in 141 | an :obj:`torch_geometric.data.Data` object and returns a 142 | transformed version. The data object will be transformed before 143 | being saved to disk. (default: :obj:`None`) 144 | """ 145 | 146 | url = ( 147 | 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master/new_data' 148 | ) 149 | 150 | def __init__(self, root, name, transform=None, pre_transform=None): 151 | self.name = name.lower() 152 | assert self.name in ['cornell', 'texas', 'washington', 'wisconsin'] 153 | 154 | super(WebKB, self).__init__(root, transform, pre_transform) 155 | self.data, self.slices = torch.load(self.processed_paths[0]) 156 | 157 | @property 158 | def raw_dir(self): 159 | return osp.join(self.root, self.name, 'raw') 160 | 161 | @property 162 | def processed_dir(self): 163 | return osp.join(self.root, self.name, 'processed') 164 | 165 | @property 166 | def raw_file_names(self): 167 | return ['out1_node_feature_label.txt', 'out1_graph_edges.txt'] 168 | 169 | @property 170 | def processed_file_names(self): 171 | return 'data.pt' 172 | 173 | def download(self): 174 | for name in self.raw_file_names: 175 | download_url(f'{self.url}/{self.name}/{name}', self.raw_dir) 176 | 177 | def process(self): 178 | with open(self.raw_paths[0], 'r') as f: 179 | data = f.read().split('\n')[1:-1] 180 | x = [[float(v) for v in r.split('\t')[1].split(',')] for r in data] 181 | x = torch.tensor(x, dtype=torch.float) 182 | 183 | y = [int(r.split('\t')[2]) for r in data] 184 | y = torch.tensor(y, dtype=torch.long) 185 | 186 | with open(self.raw_paths[1], 'r') as f: 187 | data = f.read().split('\n')[1:-1] 188 | data = [[int(v) for v in r.split('\t')] for r in data] 189 | edge_index = torch.tensor(data, dtype=torch.long).t().contiguous() 190 | edge_index = to_undirected(edge_index) 191 | edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0)) 192 | 193 | data = Data(x=x, edge_index=edge_index, y=y) 194 | data = data if self.pre_transform is None else self.pre_transform(data) 195 | torch.save(self.collate([data]), self.processed_paths[0]) 196 | 197 | def __repr__(self): 198 | return '{}()'.format(self.name) 199 | 200 | 201 | def DataLoader(name): 202 | if name in ['cora', 'citeseer', 'pubmed']: 203 | root_path = './' 204 | path = osp.join(root_path, 'data', name) 205 | dataset = Planetoid(path, name, transform=T.NormalizeFeatures()) 206 | elif name in ['computers', 'photo']: 207 | root_path = './' 208 | path = osp.join(root_path, 'data', name) 209 | dataset = Amazon(path, name, T.NormalizeFeatures()) 210 | elif name in ['chameleon', 'film', 'squirrel']: 211 | dataset = dataset_heterophily(root='./data/', 212 | name=name, 213 | transform=T.NormalizeFeatures()) 214 | elif name in ['texas', 'cornell']: 215 | dataset = WebKB(root='./data/', 216 | name=name, 217 | transform=T.NormalizeFeatures()) 218 | else: 219 | raise ValueError(f'dataset {name} not supported in dataloader') 220 | 221 | return dataset 222 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import is_undirected, to_undirected 3 | import dataset_utils as du 4 | import os 5 | import dataset_image 6 | from torch import Tensor, LongTensor 7 | 8 | 9 | class BaseGraph: 10 | ''' 11 | A general format for datasets. 12 | Args: 13 | x (Tensor): node feature, of shape (number of node, F). 14 | edge_index (LongTensor): of shape (2, number of edge) 15 | edge_weight (Tensor): of shape (number of edge) 16 | mask: a node mask to show a training/valid/test dataset split, of shape (number of node). mask[i]=0, 1, 2 means the i-th node in train, valid, test dataset respectively. 17 | ''' 18 | def __init__(self, x: Tensor, edge_index: LongTensor, edge_weight: Tensor, 19 | y: Tensor, mask: LongTensor): 20 | self.x = x 21 | self.edge_index = edge_index 22 | self.edge_attr = edge_weight 23 | self.y = y 24 | self.num_classes = torch.unique(y).shape[0] 25 | self.num_nodes = x.shape[0] 26 | self.mask = mask 27 | self.to_undirected() 28 | 29 | def get_split(self, split: str): 30 | tar_mask = {"train": 0, "valid": 1, "test": 2}[split] 31 | tmask = self.mask == tar_mask 32 | return self.edge_index, self.edge_attr, tmask, self.y[tmask] 33 | 34 | def to_undirected(self): 35 | if not is_undirected(self.edge_index): 36 | self.edge_index, self.edge_attr = to_undirected( 37 | self.edge_index, self.edge_attr) 38 | 39 | def to(self, device): 40 | self.x = self.x.to(device) 41 | self.edge_index = self.edge_index.to(device) 42 | self.edge_attr = self.edge_attr.to(device) 43 | self.y = self.y.to(device) 44 | self.mask = self.mask.to(device) 45 | return self 46 | 47 | 48 | def split(data: BaseGraph, split: str="dense"): 49 | ''' 50 | split data in to train/valid/test set. 51 | Args: 52 | data (BaseGraph): the dataset to split. 53 | split (str): the split mode, choice: ["sparse", "dense"] 54 | ''' 55 | dense_split = [0.6, 0.2] 56 | sparse_split = [0.025, 0.025] 57 | if split == "dense": 58 | u_split = dense_split 59 | elif split == "sparse": 60 | u_split = sparse_split 61 | else: 62 | raise NotImplementedError("split is dense or sparse") 63 | percls_trn = int(round(u_split[0] * len(data.y) / data.num_classes)) 64 | val_lb = int(round(u_split[1] * len(data.y))) 65 | train_mask, val_mask, test_mask = du.random_planetoid_splits( 66 | data, data.num_classes, percls_trn, val_lb) 67 | dev = data.x.device 68 | mask = torch.empty((data.x.shape[0]), dtype=torch.int8, device=dev) 69 | mask[train_mask] = 0 70 | mask[val_mask] = 1 71 | mask[test_mask] = 2 72 | return mask 73 | 74 | 75 | def load_dataset(name: str, split_t="dense"): 76 | ''' 77 | load dataset into a base graph format. 78 | ''' 79 | savepath = f"./data/{name}.pt" 80 | if name in [ 81 | 'cora', 'citeseer', 'pubmed', 'computers', 'photo', 'texas', 82 | 'cornell', 'chameleon', 'film', 'squirrel' 83 | ]: 84 | if os.path.exists(savepath): 85 | bg = torch.load(savepath, map_location="cpu") 86 | bg.mask = split(bg, split=split_t) 87 | return bg 88 | ds = du.DataLoader(name) 89 | data = ds[0] 90 | data.num_classes = ds.num_classes 91 | x = data.x # torch.empty((data.x.shape[0], 0)) 92 | ei = data.edge_index 93 | ea = torch.ones(ei.shape[1]) 94 | y = data.y 95 | mask = split(data, split=split_t) 96 | bg = BaseGraph(x, ei, ea, y, mask) 97 | bg.num_classes = data.num_classes 98 | bg.y = bg.y.to(torch.int64) 99 | torch.save(bg, savepath) 100 | return bg 101 | elif name in ['low', 'high', 'band', 'rejection', 'comb', 'low_band']: 102 | if os.path.exists(savepath): 103 | bg = torch.load(savepath, map_location="cpu") 104 | return bg 105 | x, y, ei, ea, mask = dataset_image.load_img(name) 106 | mask = mask.flatten() 107 | bg = BaseGraph(x, ei, ea, y, mask) 108 | torch.save(bg, savepath) 109 | return bg 110 | else: 111 | raise NotImplementedError() 112 | -------------------------------------------------------------------------------- /impl/GDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class GDataset: 4 | ''' 5 | A class to put a splitted dataset. 6 | Args: 7 | x : node feature, of shape (number of nodes, F) 8 | mask : the mask to show whether a node is in the dataset, of shape (number of nodes) 9 | y : the target 10 | ''' 11 | def __init__(self, edge_index, edge_attr, mask, y): 12 | self.x = None 13 | self.edge_index = edge_index 14 | self.edge_attr = edge_attr 15 | self.y = y 16 | self.mask = mask 17 | 18 | def __len__(self): 19 | return torch.sum(self.mask) 20 | 21 | def __getitem__(self, idx): 22 | return self.mask[idx], self.y[idx] 23 | 24 | def to(self, device): 25 | self.edge_index = self.edge_index.to(device) 26 | self.edge_attr = self.edge_attr.to(device) 27 | self.mask = self.mask.to(device) 28 | self.y = self.y.to(device) 29 | return self -------------------------------------------------------------------------------- /impl/PolyConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from scipy.special import comb 5 | from torch_geometric.utils import add_self_loops 6 | from torch_geometric.utils import get_laplacian, degree 7 | from torch_sparse import SparseTensor 8 | from torch_geometric.nn import MessagePassing 9 | import torch.nn.functional as F 10 | 11 | 12 | def buildAdj(edge_index: Tensor, edge_weight: Tensor, n_node: int, aggr: str): 13 | ''' 14 | convert edge_index and edge_weight to the sparse adjacency matrix. 15 | Args: 16 | edge_index (Tensor): shape (2, number of edges). 17 | edge_attr (Tensor): shape (number of edges). 18 | n_node (int): number of nodes in the graph. 19 | aggr (str): how adjacency matrix is normalized. choice: ["mean", "sum", "gcn"] 20 | ''' 21 | deg = degree(edge_index[0], n_node) 22 | deg[deg < 0.5] += 1.0 23 | ret = None 24 | if aggr == "mean": 25 | val = (1.0 / deg)[edge_index[0]] * edge_weight 26 | elif aggr == "sum": 27 | val = edge_weight 28 | elif aggr == "gcn": 29 | deg = torch.pow(deg, -0.5) 30 | val = deg[edge_index[0]] * edge_weight * deg[edge_index[1]] 31 | else: 32 | raise NotImplementedError 33 | ret = SparseTensor(row=edge_index[0], 34 | col=edge_index[1], 35 | value=val, 36 | sparse_sizes=(n_node, n_node)).coalesce() 37 | ret = ret.cuda() if edge_index.is_cuda else ret 38 | return ret 39 | 40 | 41 | class PolyConvFrame(nn.Module): 42 | ''' 43 | A framework for polynomial graph signal filter. 44 | Args: 45 | conv_fn: the filter function, like PowerConv, LegendreConv,... 46 | depth (int): the order of polynomial. 47 | cached (bool): whether or not to cache the adjacency matrix. 48 | alpha (float): the parameter to initialize polynomial coefficients. 49 | fixed (bool): whether or not to fix to polynomial coefficients. 50 | ''' 51 | def __init__(self, 52 | conv_fn, 53 | depth: int = 3, 54 | aggr: int = "gcn", 55 | cached: bool = True, 56 | alpha: float = 1.0, 57 | fixed: float = False): 58 | super().__init__() 59 | self.depth = depth 60 | self.basealpha = alpha 61 | self.alphas = nn.ParameterList([ 62 | nn.Parameter(torch.tensor(float(min(1 / alpha, 1))), 63 | requires_grad=not fixed) for i in range(depth + 1) 64 | ]) 65 | self.cached = cached 66 | self.aggr = aggr 67 | self.adj = None 68 | self.conv_fn = conv_fn 69 | 70 | def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor): 71 | ''' 72 | Args: 73 | x: node embeddings. of shape (number of nodes, node feature dimension) 74 | edge_index and edge_attr: If the adjacency is cached, they will be ignored. 75 | ''' 76 | if self.adj is None or not self.cached: 77 | n_node = x.shape[0] 78 | self.adj = buildAdj(edge_index, edge_attr, n_node, self.aggr) 79 | alphas = [self.basealpha * torch.tanh(_) for _ in self.alphas] 80 | xs = [self.conv_fn(0, [x], self.adj, alphas)] 81 | for L in range(1, self.depth + 1): 82 | tx = self.conv_fn(L, xs, self.adj, alphas) 83 | xs.append(tx) 84 | xs = [x.unsqueeze(1) for x in xs] 85 | x = torch.cat(xs, dim=1) 86 | return x 87 | 88 | ''' 89 | conv_fns to build the polynomial filter. 90 | Args: 91 | L (int): the order of polynomial basis. 92 | xs (List[Tensor]): the node embeddings filtered by the previous bases. 93 | adj (SparseTensor): adjacency matrix 94 | alphas (List[Float]): List of polynomial coeffcient. 95 | ''' 96 | 97 | def PowerConv(L, xs, adj, alphas): 98 | ''' 99 | Monomial bases. 100 | ''' 101 | if L == 0: return xs[0] 102 | return alphas[L] * (adj @ xs[-1]) 103 | 104 | 105 | def LegendreConv(L, xs, adj, alphas): 106 | ''' 107 | Legendre bases. Please refer to our paper for the form of the bases. 108 | ''' 109 | if L == 0: return xs[0] 110 | nx = (alphas[L - 1] * (2 - 1 / L)) * (adj @ xs[-1]) 111 | if L > 1: 112 | nx -= (alphas[L - 1] * alphas[L - 2] * (1 - 1 / L)) * xs[-2] 113 | return nx 114 | 115 | 116 | def ChebyshevConv(L, xs, adj, alphas): 117 | ''' 118 | Chebyshev Bases. Please refer to our paper for the form of the bases. 119 | ''' 120 | if L == 0: return xs[0] 121 | nx = (2 * alphas[L - 1]) * (adj @ xs[-1]) 122 | if L > 1: 123 | nx -= (alphas[L - 1] * alphas[L - 2]) * xs[-2] 124 | return nx 125 | 126 | 127 | 128 | def JacobiConv(L, xs, adj, alphas, a=1.0, b=1.0, l=-1.0, r=1.0): 129 | ''' 130 | Jacobi Bases. Please refer to our paper for the form of the bases. 131 | ''' 132 | if L == 0: return xs[0] 133 | if L == 1: 134 | coef1 = (a - b) / 2 - (a + b + 2) / 2 * (l + r) / (r - l) 135 | coef1 *= alphas[0] 136 | coef2 = (a + b + 2) / (r - l) 137 | coef2 *= alphas[0] 138 | return coef1 * xs[-1] + coef2 * (adj @ xs[-1]) 139 | coef_l = 2 * L * (L + a + b) * (2 * L - 2 + a + b) 140 | coef_lm1_1 = (2 * L + a + b - 1) * (2 * L + a + b) * (2 * L + a + b - 2) 141 | coef_lm1_2 = (2 * L + a + b - 1) * (a**2 - b**2) 142 | coef_lm2 = 2 * (L - 1 + a) * (L - 1 + b) * (2 * L + a + b) 143 | tmp1 = alphas[L - 1] * (coef_lm1_1 / coef_l) 144 | tmp2 = alphas[L - 1] * (coef_lm1_2 / coef_l) 145 | tmp3 = alphas[L - 1] * alphas[L - 2] * (coef_lm2 / coef_l) 146 | tmp1_2 = tmp1 * (2 / (r - l)) 147 | tmp2_2 = tmp1 * ((r + l) / (r - l)) + tmp2 148 | nx = tmp1_2 * (adj @ xs[-1]) - tmp2_2 * xs[-1] 149 | nx -= tmp3 * xs[-2] 150 | return nx 151 | 152 | 153 | class Bern_prop(MessagePassing): 154 | # Bernstein polynomial filter from the `"BernNet: Learning Arbitrary Graph Spectral Filters via Bernstein Approximation" paper. 155 | # Copied from the official implementation. 156 | def __init__(self, K, bias=True, **kwargs): 157 | super(Bern_prop, self).__init__(aggr='add', **kwargs) 158 | self.K = K 159 | 160 | def forward(self, x, edge_index, edge_weight=None): 161 | #L=I-D^(-0.5)AD^(-0.5) 162 | edge_index1, norm1 = get_laplacian(edge_index, 163 | edge_weight, 164 | normalization='sym', 165 | dtype=x.dtype, 166 | num_nodes=x.size(0)) 167 | #2I-L 168 | edge_index2, norm2 = add_self_loops(edge_index1, 169 | -norm1, 170 | fill_value=2., 171 | num_nodes=x.size(0)) 172 | 173 | tmp = [] 174 | tmp.append(x) 175 | for i in range(self.K): 176 | x = self.propagate(edge_index2, x=x, norm=norm2, size=None) 177 | tmp.append(x) 178 | 179 | out = [(comb(self.K, 0) / (2**self.K)) * tmp[self.K]] 180 | 181 | for i in range(self.K): 182 | x = tmp[self.K - i - 1] 183 | x = self.propagate(edge_index1, x=x, norm=norm1, size=None) 184 | for j in range(i): 185 | x = self.propagate(edge_index1, x=x, norm=norm1, size=None) 186 | 187 | out.append((comb(self.K, i + 1) / (2**self.K)) * x) 188 | return torch.stack(out, dim=1) 189 | 190 | def message(self, x_j, norm): 191 | return norm.view(-1, 1) * x_j 192 | 193 | def __repr__(self): 194 | return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K, 195 | self.temp) 196 | -------------------------------------------------------------------------------- /impl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/JacobiConv/5e9e671fac63e680e0681fa9b4d8074960d2d65e/impl/__init__.py -------------------------------------------------------------------------------- /impl/metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | metric functions 3 | Args: 4 | pred, label: numpy array of prediction, label 5 | ''' 6 | import sklearn.metrics 7 | import numpy as np 8 | 9 | 10 | def r2_score(pred, label): 11 | return sklearn.metrics.r2_score(label, pred) 12 | 13 | 14 | def multiclass_accuracy(pred, label): 15 | pred_i = np.argmax(pred, axis=1) 16 | return np.sum(pred_i == label)/label.shape[0] 17 | -------------------------------------------------------------------------------- /impl/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from typing import Iterable 5 | import torch.nn.functional as F 6 | 7 | 8 | class Seq(nn.Module): 9 | ''' 10 | An extension of nn.Sequential. 11 | Args: 12 | modlist an iterable of modules to add. 13 | ''' 14 | def __init__(self, modlist: Iterable[nn.Module]): 15 | super().__init__() 16 | self.modlist = nn.ModuleList(modlist) 17 | 18 | def forward(self, *args, **kwargs): 19 | out = self.modlist[0](*args, **kwargs) 20 | for i in range(1, len(self.modlist)): 21 | out = self.modlist[i](out) 22 | return out 23 | 24 | 25 | class TensorMod(nn.Module): 26 | ''' 27 | An mod which forwards a Tensor 28 | Args: 29 | x: Tensor 30 | ''' 31 | def __init__(self, x: Tensor): 32 | super().__init__() 33 | self.x = nn.parameter.Parameter(x, requires_grad=False) 34 | 35 | def forward(self, *args, **kwargs): 36 | return self.x 37 | 38 | 39 | class ResBlock(nn.Module): 40 | ''' 41 | A block building residual connection. 42 | ''' 43 | def __init__(self, mod: nn.Module): 44 | super().__init__() 45 | self.mod = mod 46 | 47 | def forward(self, x): 48 | return x + self.mod(x) 49 | 50 | 51 | class Combination(nn.Module): 52 | ''' 53 | A mod combination the bases of polynomial filters. 54 | Args: 55 | channels (int): number of feature channels. 56 | depth (int): number of bases to combine. 57 | sole (bool): whether or not use the same filter for all output channels. 58 | ''' 59 | def __init__(self, channels: int, depth: int, sole=False): 60 | super().__init__() 61 | if sole: 62 | self.comb_weight = nn.Parameter(torch.ones((1, depth, 1))) 63 | else: 64 | self.comb_weight = nn.Parameter(torch.ones((1, depth, channels))) 65 | 66 | def forward(self, x: Tensor): 67 | ''' 68 | x: node features filtered by bases, of shape (number of nodes, depth, channels). 69 | ''' 70 | x = x * self.comb_weight 71 | x = torch.sum(x, dim=1) 72 | return x 73 | 74 | 75 | 76 | class Gmodel(nn.Module): 77 | ''' 78 | A framework for GNN models. 79 | Args: 80 | embs (nn.Module): produce node features. 81 | conv (nn.Module): do message passing. 82 | comb (nn.Module): combine bases to produce the filter function. 83 | ''' 84 | def __init__(self, emb: nn.Module, conv: nn.Module, comb: nn.Module): 85 | super().__init__() 86 | self.emb = emb 87 | self.conv = conv 88 | self.comb = comb 89 | 90 | def forward(self, edge_index: Tensor, edge_weight: Tensor, pos: Tensor): 91 | ''' 92 | pos: mask of node whose embeddings is needed. 93 | ''' 94 | nemb = self.comb(self.conv(self.emb(), edge_index, edge_weight)) 95 | return nemb[pos.flatten()] 96 | -------------------------------------------------------------------------------- /impl/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import argparse 5 | 6 | 7 | def set_seed(seed: int): 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='') 17 | # Data settings 18 | parser.add_argument('--dataset', type=str) 19 | parser.add_argument('--split', type=str, default="dense") 20 | 21 | # Train settings 22 | parser.add_argument('--repeat', type=int, default=1) 23 | parser.add_argument('--test', action='store_true') 24 | 25 | # Optuna Settings 26 | parser.add_argument('--optruns', type=int, default=50) 27 | parser.add_argument('--path', type=str, default="") 28 | parser.add_argument('--name', type=str, default="opt") 29 | 30 | # Model Settings 31 | parser.add_argument('--detach', action='store_true') 32 | parser.add_argument('--savemodel', action='store_true') 33 | parser.add_argument('--power', action="store_true") 34 | parser.add_argument('--cheby', action="store_true") 35 | parser.add_argument('--legendre', action="store_true") 36 | parser.add_argument('--bern', action="store_true") 37 | parser.add_argument('--sole', action="store_true") 38 | parser.add_argument('--fixalpha', action="store_true") 39 | parser.add_argument('--multilayer', action="store_true") 40 | parser.add_argument('--resmultilayer', action="store_true") 41 | 42 | args = parser.parse_args() 43 | print("args = ", args) 44 | return args 45 | 46 | 47 | 48 | def train(optimizer, model, ds, loss_fn): 49 | optimizer.zero_grad() 50 | model.train() 51 | pred = model(ds.edge_index, ds.edge_attr, ds.mask) 52 | loss = loss_fn(pred, ds.y) 53 | loss.backward() 54 | optimizer.step() 55 | return loss.item() 56 | 57 | 58 | @torch.no_grad() 59 | def test(model, ds, metrics, loss_fn=None): 60 | model.eval() 61 | pred = model(ds.edge_index, ds.edge_attr, ds.mask) 62 | y = ds.y 63 | loss = loss_fn(pred, y) 64 | return metrics(pred.cpu().numpy(), y.cpu().numpy()), loss 65 | 66 | --------------------------------------------------------------------------------