├── .gitignore ├── LICENSE ├── README.md ├── asserts └── overview.png ├── configs ├── citeseer.yaml ├── cora.yaml ├── mag-scholar-f.yaml ├── ogbn-arxiv.yaml ├── ogbn-papers100M.yaml ├── ogbn-products.yaml └── pubmed.yaml ├── datasets ├── __init__.py ├── data_proc.py ├── lc_sampler.py ├── localclustering.py └── saint_sampler.py ├── main_full_batch.py ├── main_large.py ├── models ├── __init__.py ├── edcoder.py ├── finetune.py ├── gat.py ├── gcn.py └── loss_func.py ├── requirements.txt ├── run_fullbatch.sh ├── run_minibatch.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 THUDM 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

GraphMAE2: A Decoding-Enhanced Masked Self-Supervised 2 | Graph Learner

3 | 4 | Implementation for WWW'23 paper: [GraphMAE2: A Decoding-Enhanced Masked Self-Supervised 5 | Graph Learner](https://arxiv.org/abs/2304.04779). 6 | 7 | 8 | [GraphMAE] The predecessor of this work: [GraphMAE: Self-Supervised Masked Graph Autoencoders](https://arxiv.org/abs/2205.10803) can be found [here](https://github.com/THUDM/GraphMAE). 9 | 10 |

❗ Update

11 | 12 | [2023-04-19] We have made **checkpoints** of pre-trained models on different datasets available - feel free to download them from [Google Drive](https://drive.google.com/drive/folders/1GiuP0PtIZaYlJWIrjvu73ZQCJGr6kGkh). 13 | 14 |

Dependencies

15 | 16 | * Python >= 3.7 17 | * [Pytorch](https://pytorch.org/) >= 1.9.0 18 | * pyyaml == 5.4.1 19 | 20 | 21 |

Quick Start

22 | 23 | For quick start, you could run the scripts: 24 | 25 | **Node classification** 26 | 27 | ```bash 28 | sh run_minibatch.sh # for mini batch node classification 29 | # example: sh run_minibatch.sh ogbn-arxiv 0 30 | sh run_fullbatch.sh # for full batch node classification 31 | # example: sh run_fullbatch.sh cora 0 32 | 33 | # Or you could run the code manually: 34 | # for mini batch node classification 35 | python main_large.py --dataset ogbn-arxiv --encoder gat --decoder gat --seed 0 --device 0 36 | # for full batch node classification 37 | python main_full_batch.py --dataset cora --encoder gat --decoder gat --seed 0 --device 0 38 | ``` 39 | 40 | Supported datasets: 41 | 42 | * mini batch node classification: `ogbn-arxiv`, `ogbn-products`, `mag-scholar-f`, `ogbn-papers100M` 43 | * full batch node classification: `cora`, `citeseer`, `pubmed` 44 | 45 | Run the scripts provided or add `--use_cfg` in command to reproduce the reported results. 46 | 47 | **For Large scale graphs** 48 | Before starting mini-batch training, you'll need to generate local clusters if you want to use local-clustering for training. By default, the program will load dataset from `./data` and save the generated local clusters to `./lc_ego_graphs`. To generate a local cluster, you should first install [localclustering](https://github.com/kfoynt/LocalGraphClustering) and then run the following command: 49 | 50 | ``` 51 | python ./datasets/localclustering.py --dataset --data_dir 52 | ``` 53 | And we also provide the pre-generated local clusters which can be downloaded [here](https://cloud.tsinghua.edu.cn/d/64f859f389ca43eda472/) and then put into `lc_ego_graphs` for usage. 54 | 55 | 56 | 57 |

Datasets

58 | 59 | During the code's execution, the OGB and small-scale datasets (Cora, Citeseer, and PubMed) will be downloaded automatically. For the MAG-SCHOLAR dataset, you can download the raw data from [here](https://figshare.com/articles/dataset/mag_scholar/12696653) or use our processed version, which can be found [here](https://cloud.tsinghua.edu.cn/d/776e73d84d47454c958d/) (the four feature files have to be merged in to a `feature_f.npy`). Once you have the dataset, place it into the `./data/mag_scholar_f` folder for later usage. The folder should contain the following files: 60 | ``` 61 | - mag_scholar_f 62 | |--- edge_index_f.npy 63 | |--- split_idx_f.pt 64 | |--- feature_f.npy 65 | |--- label_f.npy 66 | ``` 67 | 68 | Soon, we will provide [SAINTSampler](https://arxiv.org/abs/1907.04931) as the baseline. 69 | 70 | 71 |

Experimental Results

72 | 73 | Experimental results of node classification on large-scale datasets (Accuracy, %): 74 | 75 | | | Ogbn-arxiv | Ogbn-products | Mag-Scholar-F | Ogbn-papers100M | 76 | | ------------------ | ------------ | ------------ | ------------ | -------------- | 77 | | MLP | 55.50±0.23 | 61.06±0.08 | 39.11±0.21 | 47.24±0.31 | 78 | | SGC | 66.92±0.08 | 74.87±0.25 | 54.68±0.23 | 63.29±0.19 | 79 | | Random-Init | 68.14±0.02 | 74.04±0.06 | 56.57±0.03 | 61.55±0.12 | 80 | | CCA-SSG | 68.57±0.02 | 75.27±0.05 | 51.55±0.03 | 55.67±0.15 | 81 | | GRACE | 69.34±0.01 | 79.47±0.59 | 57.39±0.02 | 61.21±0.12 | 82 | | BGRL | 70.51±0.03 | 78.59±0.02 | 57.57±0.01 | 62.18±0.15 | 83 | | GGD | - | 75.70±0.40 | - | 63.50±0.50 | 84 | | GraphMAE | 71.03±0.02 | 78.89±0.01 | 58.75±0.03 | 62.54±0.09 | 85 | | **GraphMAE2** | **71.89±0.03** | **81.59±0.02** | **59.24±0.01** | **64.89±0.04** | 86 | 87 | 88 | 89 |

Citing

90 | 91 | If you find this work is helpful to your research, please consider citing our paper: 92 | 93 | ``` 94 | @inproceedings{hou2023graphmae2, 95 | title={GraphMAE2: A Decoding-Enhanced Masked Self-Supervised Graph Learner}, 96 | author={Zhenyu Hou, Yufei He, Yukuo Cen, Xiao Liu, Yuxiao Dong, Evgeny Kharlamov, Jie Tang}, 97 | booktitle={Proceedings of the ACM Web Conference 2023 (WWW’23)}, 98 | year={2023} 99 | } 100 | ``` 101 | -------------------------------------------------------------------------------- /asserts/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/GraphMAE2/bc473e8b750d6e90b7cba1c6bb20d858f617d90c/asserts/overview.png -------------------------------------------------------------------------------- /configs/citeseer.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.0005 # 0.0005 2 | lr_f: 0.025 3 | num_hidden: 1024 4 | num_heads: 4 5 | num_out_heads: 1 6 | num_layers: 2 7 | weight_decay: 1e-4 8 | weight_decay_f: 1e-2 9 | max_epoch: 500 10 | max_epoch_f: 500 11 | mask_rate: 0.5 12 | num_layers: 2 13 | encoder: gat 14 | decoder: gat 15 | activation: prelu 16 | attn_drop: 0.1 17 | linear_prob: True 18 | in_drop: 0.2 19 | loss_fn: sce 20 | drop_edge_rate: 0.0 21 | optimizer: adam 22 | replace_rate: 0.0 23 | alpha_l: 1 24 | scheduler: True 25 | remask_method: fixed 26 | momentum: 1 27 | lam: 0.1 -------------------------------------------------------------------------------- /configs/cora.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.001 2 | lr_f: 0.025 3 | num_hidden: 1024 4 | num_heads: 8 5 | num_out_heads: 1 6 | num_layers: 2 7 | weight_decay: 2e-4 8 | weight_decay_f: 1e-4 9 | max_epoch: 2000 10 | max_epoch_f: 300 11 | mask_rate: 0.5 12 | num_layers: 2 13 | encoder: gat 14 | decoder: gat 15 | activation: prelu 16 | attn_drop: 0.1 17 | linear_prob: True 18 | in_drop: 0.2 19 | loss_fn: sce 20 | drop_edge_rate: 0.0 21 | optimizer: adam 22 | replace_rate: 0.15 23 | alpha_l: 3 24 | scheduler: True 25 | remask_method: fixed 26 | momentum: 0 27 | lam: 0.15 28 | -------------------------------------------------------------------------------- /configs/mag-scholar-f.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.001 2 | lr_f: 0.001 3 | num_hidden: 1024 4 | num_heads: 8 5 | num_out_heads: 1 6 | num_layers: 4 7 | weight_decay: 0.04 8 | weight_decay_f: 0 9 | max_epoch: 10 10 | max_epoch_f: 1000 11 | batch_size: 512 12 | batch_size_f: 256 13 | mask_rate: 0.5 14 | num_layers: 4 15 | encoder: gat 16 | decoder: gat 17 | activation: prelu 18 | attn_drop: 0.2 19 | linear_prob: True 20 | in_drop: 0.2 21 | loss_fn: sce 22 | drop_edge_rate: 0.5 23 | optimizer: adamw 24 | alpha_l: 2 25 | scheduler: True 26 | remask_method: random 27 | momentum: 0.996 28 | lam: 0.1 29 | delayed_ema_epoch: 0 30 | num_remasking: 3 -------------------------------------------------------------------------------- /configs/ogbn-arxiv.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.0025 2 | lr_f: 0.005 3 | num_hidden: 1024 4 | num_heads: 8 5 | num_out_heads: 1 6 | num_layers: 4 7 | weight_decay: 0.06 8 | weight_decay_f: 1e-4 9 | max_epoch: 60 10 | max_epoch_f: 1000 11 | batch_size: 512 12 | batch_size_f: 256 13 | mask_rate: 0.5 14 | num_layers: 4 15 | encoder: gat 16 | decoder: gat 17 | activation: prelu 18 | attn_drop: 0.1 19 | linear_prob: True 20 | in_drop: 0.2 21 | loss_fn: sce 22 | drop_edge_rate: 0.5 23 | optimizer: adamw 24 | alpha_l: 6 25 | scheduler: True 26 | remask_method: random 27 | momentum: 0.996 28 | lam: 10.0 29 | delayed_ema_epoch: 40 30 | num_remasking: 3 -------------------------------------------------------------------------------- /configs/ogbn-papers100M.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.001 2 | lr_f: 0.001 3 | num_hidden: 1024 4 | num_heads: 4 5 | num_out_heads: 1 6 | num_layers: 4 7 | weight_decay: 0.05 8 | weight_decay_f: 0 9 | max_epoch: 10 10 | max_epoch_f: 1000 11 | batch_size: 512 12 | batch_size_f: 256 13 | mask_rate: 0.5 14 | num_layers: 4 15 | encoder: gat 16 | decoder: gat 17 | activation: prelu 18 | attn_drop: 0.2 19 | linear_prob: True 20 | in_drop: 0.2 21 | loss_fn: sce 22 | drop_edge_rate: 0.5 23 | optimizer: adamw 24 | alpha_l: 2 25 | scheduler: True 26 | remask_method: random 27 | momentum: 0.996 28 | lam: 10.0 29 | delayed_ema_epoch: 0 30 | num_remasking: 3 -------------------------------------------------------------------------------- /configs/ogbn-products.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.002 2 | lr_f: 0.001 3 | num_hidden: 1024 4 | num_heads: 4 5 | num_out_heads: 1 6 | num_layers: 4 7 | weight_decay: 0.04 8 | weight_decay_f: 0 9 | max_epoch: 20 10 | max_epoch_f: 1000 11 | batch_size: 512 12 | batch_size_f: 256 13 | mask_rate: 0.5 14 | num_layers: 4 15 | encoder: gat 16 | decoder: gat 17 | activation: prelu 18 | attn_drop: 0.2 19 | linear_prob: True 20 | in_drop: 0.2 21 | loss_fn: sce 22 | drop_edge_rate: 0.5 23 | optimizer: adamw 24 | alpha_l: 3 25 | scheduler: True 26 | remask_method: random 27 | momentum: 0.996 28 | lam: 5.0 29 | delayed_ema_epoch: 0 30 | num_remasking: 3 -------------------------------------------------------------------------------- /configs/pubmed.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.005 2 | lr_f: 0.025 3 | num_hidden: 512 4 | num_heads: 2 5 | num_out_heads: 1 6 | num_layers: 2 7 | weight_decay: 1e-5 8 | weight_decay_f: 5e-4 9 | max_epoch: 2000 10 | max_epoch_f: 500 11 | mask_rate: 0.9 12 | num_layers: 2 13 | encoder: gat 14 | decoder: gat 15 | activation: prelu 16 | attn_drop: 0.1 17 | linear_prob: True 18 | in_drop: 0.2 19 | loss_fn: sce 20 | drop_edge_rate: 0.0 21 | optimizer: adam 22 | replace_rate: 0.0 23 | alpha_l: 4 24 | scheduler: True 25 | remask_method: fixed 26 | momentum: 0.995 27 | lam: 1 28 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/GraphMAE2/bc473e8b750d6e90b7cba1c6bb20d858f617d90c/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/data_proc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | import dgl 6 | from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset 7 | from ogb.nodeproppred import DglNodePropPredDataset 8 | 9 | from sklearn.preprocessing import StandardScaler 10 | 11 | 12 | GRAPH_DICT = { 13 | "cora": CoraGraphDataset, 14 | "citeseer": CiteseerGraphDataset, 15 | "pubmed": PubmedGraphDataset, 16 | "ogbn-arxiv": DglNodePropPredDataset, 17 | } 18 | 19 | def load_small_dataset(dataset_name): 20 | assert dataset_name in GRAPH_DICT, f"Unknow dataset: {dataset_name}." 21 | if dataset_name.startswith("ogbn"): 22 | dataset = GRAPH_DICT[dataset_name](dataset_name) 23 | else: 24 | dataset = GRAPH_DICT[dataset_name]() 25 | 26 | if dataset_name == "ogbn-arxiv": 27 | graph, labels = dataset[0] 28 | num_nodes = graph.num_nodes() 29 | 30 | split_idx = dataset.get_idx_split() 31 | train_idx, val_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 32 | graph = preprocess(graph) 33 | 34 | if not torch.is_tensor(train_idx): 35 | train_idx = torch.as_tensor(train_idx) 36 | val_idx = torch.as_tensor(val_idx) 37 | test_idx = torch.as_tensor(test_idx) 38 | 39 | feat = graph.ndata["feat"] 40 | feat = scale_feats(feat) 41 | graph.ndata["feat"] = feat 42 | 43 | train_mask = torch.full((num_nodes,), False).index_fill_(0, train_idx, True) 44 | val_mask = torch.full((num_nodes,), False).index_fill_(0, val_idx, True) 45 | test_mask = torch.full((num_nodes,), False).index_fill_(0, test_idx, True) 46 | graph.ndata["label"] = labels.view(-1) 47 | graph.ndata["train_mask"], graph.ndata["val_mask"], graph.ndata["test_mask"] = train_mask, val_mask, test_mask 48 | else: 49 | graph = dataset[0] 50 | graph = graph.remove_self_loop() 51 | graph = graph.add_self_loop() 52 | num_features = graph.ndata["feat"].shape[1] 53 | num_classes = dataset.num_classes 54 | return graph, (num_features, num_classes) 55 | 56 | def preprocess(graph): 57 | # make bidirected 58 | if "feat" in graph.ndata: 59 | feat = graph.ndata["feat"] 60 | else: 61 | feat = None 62 | src, dst = graph.all_edges() 63 | # graph.add_edges(dst, src) 64 | graph = dgl.to_bidirected(graph) 65 | if feat is not None: 66 | graph.ndata["feat"] = feat 67 | 68 | # add self-loop 69 | graph = graph.remove_self_loop().add_self_loop() 70 | # graph.create_formats_() 71 | return graph 72 | 73 | 74 | def scale_feats(x): 75 | logging.info("### scaling features ###") 76 | scaler = StandardScaler() 77 | feats = x.numpy() 78 | scaler.fit(feats) 79 | feats = torch.from_numpy(scaler.transform(feats)).float() 80 | return feats 81 | -------------------------------------------------------------------------------- /datasets/lc_sampler.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import dgl 6 | from ogb.nodeproppred import DglNodePropPredDataset 7 | 8 | from .data_proc import preprocess, scale_feats 9 | from utils import mask_edge 10 | 11 | import logging 12 | import torch.multiprocessing 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | torch.multiprocessing.set_sharing_strategy('file_system') 16 | 17 | 18 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) 19 | 20 | # def collect_topk_ppr(graph, nodes, topk, alpha, epsilon): 21 | # if torch.is_tensor(nodes): 22 | # nodes = nodes.numpy() 23 | # row, col = graph.edges() 24 | # row = row.numpy() 25 | # col = col.numpy() 26 | # num_nodes = graph.num_nodes() 27 | 28 | # neighbors = build_topk_ppr((row, col), alpha, epsilon, nodes, topk, num_nodes=num_nodes) 29 | # return neighbors 30 | 31 | # --------------------------------------------------------------------------------------------------------------------- 32 | 33 | 34 | def load_dataset(data_dir, dataset_name): 35 | if dataset_name.startswith("ogbn"): 36 | dataset = DglNodePropPredDataset(dataset_name, root=os.path.join(data_dir, "dataset")) 37 | graph, label = dataset[0] 38 | 39 | if "year" in graph.ndata: 40 | del graph.ndata["year"] 41 | if not graph.is_multigraph: 42 | logging.info("--- to undirected graph ---") 43 | graph = preprocess(graph) 44 | graph = graph.remove_self_loop().add_self_loop() 45 | 46 | split_idx = dataset.get_idx_split() 47 | label = label.view(-1) 48 | 49 | feats = graph.ndata.pop("feat") 50 | if dataset_name in ("ogbn-arxiv","ogbn-papers100M"): 51 | feats = scale_feats(feats) 52 | elif dataset_name == "mag-scholar-f": 53 | edge_index = np.load(os.path.join(data_dir, dataset_name, "edge_index_f.npy")) 54 | feats = torch.from_numpy(np.load(os.path.join(data_dir, "feature_f.npy"))).float() 55 | 56 | graph = dgl.DGLGraph((edge_index[0], edge_index[1])) 57 | 58 | graph = dgl.remove_self_loop(graph) 59 | graph = dgl.add_self_loop(graph) 60 | 61 | label = torch.from_numpy(np.load(os.path.join(data_dir, "label_f.npy"))).to(torch.long) 62 | split_idx = torch.load(os.path.join(data_dir, "split_idx_f.pt")) 63 | 64 | # graph.ndata["feat"] = feats 65 | # graph.ndata["label"] = label 66 | 67 | return feats, graph, label, split_idx 68 | 69 | class LinearProbingDataLoader(DataLoader): 70 | def __init__(self, idx, feats, labels=None, **kwargs): 71 | self.labels = labels 72 | self.feats = feats 73 | 74 | kwargs["collate_fn"] = self.__collate_fn__ 75 | super().__init__(dataset=idx, **kwargs) 76 | 77 | def __collate_fn__(self, batch_idx): 78 | feats = self.feats[batch_idx] 79 | label = self.labels[batch_idx] 80 | 81 | return feats, label 82 | 83 | class OnlineLCLoader(DataLoader): 84 | def __init__(self, root_nodes, graph, feats, labels=None, drop_edge_rate=0, **kwargs): 85 | self.graph = graph 86 | self.labels = labels 87 | self._drop_edge_rate = drop_edge_rate 88 | self.ego_graph_nodes = root_nodes 89 | self.feats = feats 90 | 91 | dataset = np.arange(len(root_nodes)) 92 | kwargs["collate_fn"] = self.__collate_fn__ 93 | super().__init__(dataset, **kwargs) 94 | 95 | def drop_edge(self, g): 96 | if self._drop_edge_rate <= 0: 97 | return g, g 98 | 99 | g = g.remove_self_loop() 100 | mask_index1 = mask_edge(g, self._drop_edge_rate) 101 | mask_index2 = mask_edge(g, self._drop_edge_rate) 102 | g1 = dgl.remove_edges(g, mask_index1).add_self_loop() 103 | g2 = dgl.remove_edges(g, mask_index2).add_self_loop() 104 | return g1, g2 105 | 106 | def __collate_fn__(self, batch_idx): 107 | ego_nodes = [self.ego_graph_nodes[i] for i in batch_idx] 108 | subgs = [self.graph.subgraph(ego_nodes[i]) for i in range(len(ego_nodes))] 109 | sg = dgl.batch(subgs) 110 | 111 | nodes = torch.from_numpy(np.concatenate(ego_nodes)).long() 112 | num_nodes = [x.shape[0] for x in ego_nodes] 113 | cum_num_nodes = np.cumsum([0] + num_nodes)[:-1] 114 | 115 | if self._drop_edge_rate > 0: 116 | drop_g1, drop_g2 = self.drop_edge(sg) 117 | 118 | sg = sg.remove_self_loop().add_self_loop() 119 | sg.ndata["feat"] = self.feats[nodes] 120 | targets = torch.from_numpy(cum_num_nodes) 121 | 122 | if self.labels != None: 123 | label = self.labels[batch_idx] 124 | else: 125 | label = None 126 | 127 | if self._drop_edge_rate > 0: 128 | return sg, targets, label, nodes, drop_g1, drop_g2 129 | else: 130 | return sg, targets, label, nodes 131 | 132 | 133 | def setup_training_data(dataset_name, data_dir, ego_graphs_file_path): 134 | feats, graph, labels, split_idx = load_dataset(data_dir, dataset_name) 135 | 136 | train_lbls = labels[split_idx["train"]] 137 | val_lbls = labels[split_idx["valid"]] 138 | test_lbls = labels[split_idx["test"]] 139 | 140 | labels = torch.cat([train_lbls, val_lbls, test_lbls]) 141 | 142 | os.makedirs(os.path.dirname(ego_graphs_file_path), exist_ok=True) 143 | 144 | if not os.path.exists(ego_graphs_file_path): 145 | raise FileNotFoundError(f"{ego_graphs_file_path} doesn't exist") 146 | else: 147 | nodes = torch.load(ego_graphs_file_path) 148 | 149 | return feats, graph, labels, split_idx, nodes 150 | 151 | 152 | def setup_training_dataloder(loader_type, training_nodes, graph, feats, batch_size, drop_edge_rate=0, pretrain_clustergcn=False, cluster_iter_data=None): 153 | num_workers = 8 154 | 155 | if loader_type == "lc": 156 | assert training_nodes is not None 157 | else: 158 | raise NotImplementedError(f"{loader_type} is not implemented yet") 159 | 160 | # print(" -------- drop edge rate: {} --------".format(drop_edge_rate)) 161 | dataloader = OnlineLCLoader(training_nodes, graph, feats=feats, drop_edge_rate=drop_edge_rate, batch_size=batch_size, shuffle=True, drop_last=False, persistent_workers=True, num_workers=num_workers) 162 | return dataloader 163 | 164 | 165 | def setup_eval_dataloder(loader_type, graph, feats, ego_graph_nodes=None, batch_size=128, shuffle=False): 166 | num_workers = 8 167 | if loader_type == "lc": 168 | assert ego_graph_nodes is not None 169 | else: 170 | raise NotImplementedError(f"{loader_type} is not implemented yet") 171 | 172 | dataloader = OnlineLCLoader(ego_graph_nodes, graph, feats, batch_size=batch_size, shuffle=shuffle, drop_last=False, persistent_workers=True, num_workers=num_workers) 173 | return dataloader 174 | 175 | 176 | def setup_finetune_dataloder(loader_type, graph, feats, ego_graph_nodes, labels, batch_size, shuffle=False): 177 | num_workers = 8 178 | 179 | if loader_type == "lc": 180 | assert ego_graph_nodes is not None 181 | else: 182 | raise NotImplementedError(f"{loader_type} is not implemented yet") 183 | 184 | dataloader = OnlineLCLoader(ego_graph_nodes, graph, feats, labels=labels, feats=feats, batch_size=batch_size, shuffle=shuffle, drop_last=False, num_workers=num_workers, persistent_workers=True) 185 | 186 | return dataloader 187 | -------------------------------------------------------------------------------- /datasets/localclustering.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import namedtuple 3 | import multiprocessing 4 | import os 5 | 6 | import numpy as np 7 | from localgraphclustering import * 8 | from scipy.sparse import csr_matrix 9 | from ogb.nodeproppred import DglNodePropPredDataset 10 | import torch 11 | import logging 12 | 13 | import dgl 14 | from dgl.data import load_data 15 | 16 | 17 | def my_sweep_cut(g, node): 18 | vol_sum = 0.0 19 | in_edge = 0.0 20 | conds = np.zeros_like(node, dtype=np.float32) 21 | for i in range(len(node)): 22 | idx = node[i] 23 | vol_sum += g.d[idx] 24 | denominator = min(vol_sum, g.vol_G - vol_sum) 25 | if denominator == 0.0: 26 | denominator = 1.0 27 | in_edge += 2*sum([g.adjacency_matrix[idx,prev] for prev in node[:i+1]]) 28 | cut = vol_sum - in_edge 29 | conds[i] = cut/denominator 30 | return conds 31 | 32 | 33 | def calc_local_clustering(args): 34 | i, log_steps, num_iter, ego_size, method = args 35 | if i % log_steps == 0: 36 | print(i) 37 | node, ppr = approximate_PageRank(graphlocal, [i], iterations=num_iter, method=method, normalize=False) 38 | d_inv = graphlocal.dn[node] 39 | d_inv[d_inv > 1.0] = 1.0 40 | ppr_d_inv = ppr * d_inv 41 | output = list(zip(node, ppr_d_inv))[:ego_size] 42 | node, ppr_d_inv = zip(*sorted(output, key=lambda x: x[1], reverse=True)) 43 | assert node[0] == i 44 | node = np.array(node, dtype=np.int32) 45 | conds = my_sweep_cut(graphlocal, node) 46 | return node, conds 47 | 48 | 49 | def step1_local_clustering(data, name, idx_split, ego_size=128, num_iter=1000, log_steps=10000, num_workers=16, method='acl', save_dir=None): 50 | if save_dir is None: 51 | save_path = f"{name}-lc-ego-graphs-{ego_size}.pt" 52 | else: 53 | if not os.path.exists(save_dir): 54 | os.makedirs(save_dir, exist_ok=True) 55 | 56 | save_path = os.path.join(save_dir, f"{name}-lc-ego-graphs-{ego_size}.pt") 57 | 58 | N = data.num_nodes() 59 | edge_index = data.edges() 60 | edge_index = (edge_index[0].numpy(), edge_index[1].numpy()) 61 | adj = csr_matrix((np.ones(edge_index[0].shape[0]), edge_index), shape=(N, N)) 62 | 63 | global graphlocal 64 | graphlocal = GraphLocal.from_sparse_adjacency(adj) 65 | print('graphlocal generated') 66 | 67 | train_idx = idx_split["train"].cpu().numpy() 68 | valid_idx = idx_split["valid"].cpu().numpy() 69 | test_idx = idx_split["test"].cpu().numpy() 70 | 71 | with multiprocessing.Pool(num_workers) as pool: 72 | ego_graphs_train, conds_train = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in train_idx], chunksize=512)) 73 | 74 | with multiprocessing.Pool(num_workers) as pool: 75 | ego_graphs_valid, conds_valid = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in valid_idx], chunksize=512)) 76 | 77 | with multiprocessing.Pool(num_workers) as pool: 78 | ego_graphs_test, conds_test = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in test_idx], chunksize=512)) 79 | 80 | ego_graphs = [] 81 | conds = [] 82 | ego_graphs.extend(ego_graphs_train) 83 | ego_graphs.extend(ego_graphs_valid) 84 | ego_graphs.extend(ego_graphs_test) 85 | conds.extend(conds_train) 86 | conds.extend(conds_valid) 87 | conds.extend(conds_test) 88 | 89 | ego_graphs = [ego_graphs_train, ego_graphs_valid, ego_graphs_test] 90 | 91 | torch.save(ego_graphs, save_path) 92 | 93 | 94 | def preprocess(graph): 95 | # make bidirected 96 | if "feat" in graph.ndata: 97 | feat = graph.ndata["feat"] 98 | else: 99 | feat = None 100 | # src, dst = graph.all_edges() 101 | # graph.add_edges(dst, src) 102 | graph = dgl.to_bidirected(graph) 103 | if feat is not None: 104 | graph.ndata["feat"] = feat 105 | 106 | # add self-loop 107 | graph = graph.remove_self_loop().add_self_loop() 108 | # graph.create_formats_() 109 | return graph 110 | 111 | 112 | def load_dataset(data_dir, dataset_name): 113 | if dataset_name.startswith("ogbn"): 114 | dataset = DglNodePropPredDataset(dataset_name, root=os.path.join(data_dir, "dataset")) 115 | graph, label = dataset[0] 116 | 117 | if "year" in graph.ndata: 118 | del graph.ndata["year"] 119 | if not graph.is_multigraph: 120 | graph = preprocess(graph) 121 | # graph = graph.remove_self_loop().add_self_loop() 122 | 123 | split_idx = dataset.get_idx_split() 124 | label = label.view(-1) 125 | 126 | elif dataset_name == "mag-scholar-f": 127 | edge_index = np.load(os.path.join(data_dir, dataset_name, "edge_index_f.npy")) 128 | print(len(edge_index[0])) 129 | graph = dgl.DGLGraph((edge_index[0], edge_index[1])) 130 | print(graph) 131 | num_nodes = graph.num_nodes() 132 | assert num_nodes == 12403930 133 | split_idx = torch.load(os.path.join(data_dir, dataset_name, "split_idx_f.pt")) 134 | else: 135 | raise NotImplementedError 136 | 137 | return graph, split_idx 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser(description='LCGNN (Preprocessing)') 142 | parser.add_argument('--dataset', type=str, default='flickr') 143 | parser.add_argument("--data_dir", type=str, default="data") 144 | parser.add_argument("--save_dir", type=str, default="lc_ego_graphs") 145 | parser.add_argument('--ego_size', type=int, default=256) 146 | parser.add_argument('--num_iter', type=int, default=1000) 147 | parser.add_argument('--log_steps', type=int, default=10000) 148 | parser.add_argument('--seed', type=int, default=0) 149 | parser.add_argument('--method', type=str, default='acl') 150 | parser.add_argument('--num_workers', type=int, default=16) 151 | args = parser.parse_args() 152 | print(args) 153 | 154 | np.random.seed(args.seed) 155 | 156 | graph, split_idx = load_dataset(args.data_dir, args.dataset) 157 | step1_local_clustering(graph, args.dataset, split_idx, args.ego_size, args.num_iter, args.log_steps, args.num_workers, args.method, args.save_dir) 158 | -------------------------------------------------------------------------------- /datasets/saint_sampler.py: -------------------------------------------------------------------------------- 1 | """This file is modified from https://github.com/dmlc/dgl/blob/master/python/dgl/dataloading/graphsaint.py""" 2 | 3 | import os 4 | import time 5 | import math 6 | import torch as th 7 | from torch.utils.data import DataLoader 8 | import random 9 | import numpy as np 10 | import dgl.function as fn 11 | import dgl 12 | from dgl.sampling import random_walk, pack_traces 13 | from tqdm import tqdm 14 | 15 | 16 | class SAINTSampler: 17 | """ 18 | Description 19 | ----------- 20 | SAINTSampler implements the sampler described in GraphSAINT. This sampler implements offline sampling in 21 | pre-sampling phase as well as fully offline sampling, fully online sampling in training phase. 22 | Users can conveniently set param 'online' of the sampler to choose different modes. 23 | 24 | Parameters 25 | ---------- 26 | node_budget : int 27 | the expected number of nodes in each subgraph, which is specifically explained in the paper. Actually this 28 | param specifies the times of sampling nodes from the original graph with replacement. The meaning of edge_budget 29 | is similar to the node_budget. 30 | dn : str 31 | name of dataset. 32 | g : DGLGraph 33 | the full graph. 34 | train_nid : list 35 | ids of training nodes. 36 | num_workers_sampler : int 37 | number of processes to sample subgraphs in pre-sampling procedure using torch.dataloader. 38 | num_subg_sampler : int, optional 39 | the max number of subgraphs sampled in pre-sampling phase for computing normalization coefficients in the beginning. 40 | Actually this param is used as ``__len__`` of sampler in pre-sampling phase. 41 | Please make sure that num_subg_sampler is greater than batch_size_sampler so that we can sample enough subgraphs. 42 | Defaults: 10000 43 | batch_size_sampler : int, optional 44 | the number of subgraphs sampled by each process concurrently in pre-sampling phase. 45 | Defaults: 200 46 | online : bool, optional 47 | If `True`, we employ online sampling in training phase. Otherwise employing offline sampling. 48 | Defaults: True 49 | num_subg : int, optional 50 | the expected number of sampled subgraphs in pre-sampling phase. 51 | It is actually the 'N' in the original paper. Note that this param is different from the num_subg_sampler. 52 | This param is just used to control the number of pre-sampled subgraphs. 53 | Defaults: 50 54 | """ 55 | 56 | def __init__(self, node_budget, dn, g, num_workers_sampler, num_subg_sampler=10000, 57 | batch_size_sampler=200, online=True, num_subg=50, full=True): 58 | self.g = g.cpu() 59 | self.node_budget = node_budget 60 | # self.train_g: dgl.graph = g.subgraph(train_nid) 61 | self.train_g = g 62 | self.dn, self.num_subg = dn, num_subg 63 | self.node_counter = th.zeros((self.train_g.num_nodes(),)) 64 | self.edge_counter = th.zeros((self.train_g.num_edges(),)) 65 | self.prob = None 66 | self.num_subg_sampler = num_subg_sampler 67 | self.batch_size_sampler = batch_size_sampler 68 | self.num_workers_sampler = num_workers_sampler 69 | self.train = False 70 | self.online = online 71 | self.full = full 72 | 73 | assert self.num_subg_sampler >= self.batch_size_sampler, "num_subg_sampler should be greater than batch_size_sampler" 74 | graph_fn, norm_fn = self.__generate_fn__() 75 | 76 | if os.path.exists(graph_fn): 77 | self.subgraphs = np.load(graph_fn, allow_pickle=True) 78 | # aggr_norm, loss_norm = np.load(norm_fn, allow_pickle=True) 79 | else: 80 | os.makedirs('./subgraphs/', exist_ok=True) 81 | 82 | self.subgraphs = [] 83 | self.N, sampled_nodes = 0, 0 84 | # N: the number of pre-sampled subgraphs 85 | 86 | # Employ parallelism to speed up the sampling procedure 87 | loader = DataLoader(self, batch_size=self.batch_size_sampler, shuffle=True, 88 | num_workers=self.num_workers_sampler, collate_fn=self.__collate_fn__, drop_last=False) 89 | 90 | t = time.perf_counter() 91 | for num_nodes, subgraphs_nids, subgraphs_eids in tqdm(loader, desc="preprocessing saint subgraphs"): 92 | 93 | self.subgraphs.extend(subgraphs_nids) 94 | sampled_nodes += num_nodes 95 | 96 | _subgraphs, _node_counts = np.unique(np.concatenate(subgraphs_nids), return_counts=True) 97 | sampled_nodes_idx = th.from_numpy(_subgraphs) 98 | _node_counts = th.from_numpy(_node_counts) 99 | self.node_counter[sampled_nodes_idx] += _node_counts 100 | 101 | _subgraphs_eids, _edge_counts = np.unique(np.concatenate(subgraphs_eids), return_counts=True) 102 | sampled_edges_idx = th.from_numpy(_subgraphs_eids) 103 | _edge_counts = th.from_numpy(_edge_counts) 104 | self.edge_counter[sampled_edges_idx] += _edge_counts 105 | 106 | self.N += len(subgraphs_nids) # number of subgraphs 107 | 108 | # print("sampled_nodes: ", sampled_nodes, " --> ", self.train_g.num_nodes(), num_subg) 109 | if sampled_nodes > self.train_g.num_nodes() * num_subg: 110 | break 111 | 112 | print(f'Sampling time: [{time.perf_counter() - t:.2f}s]') 113 | np.save(graph_fn, self.subgraphs) 114 | 115 | # t = time.perf_counter() 116 | # aggr_norm, loss_norm = self.__compute_norm__() 117 | # print(f'Normalization time: [{time.perf_counter() - t:.2f}s]') 118 | # np.save(norm_fn, (aggr_norm, loss_norm)) 119 | 120 | # self.train_g.ndata['l_n'] = th.Tensor(loss_norm) 121 | # self.train_g.edata['w'] = th.Tensor(aggr_norm) 122 | # self.__compute_degree_norm() # basically normalizing adjacent matrix 123 | 124 | random.shuffle(self.subgraphs) 125 | self.__clear__() 126 | print("The number of subgraphs is: ", len(self.subgraphs)) 127 | 128 | self.train = True 129 | 130 | def __len__(self): 131 | if self.train is False: 132 | return self.num_subg_sampler 133 | else: 134 | if self.full: 135 | return len(self.subgraphs) 136 | else: 137 | return math.ceil(self.train_g.num_nodes() / self.node_budget) 138 | 139 | def __getitem__(self, idx): 140 | # Only when sampling subgraphs in training procedure and need to utilize sampled subgraphs and we still 141 | # have sampled subgraphs we can fetch a subgraph from sampled subgraphs 142 | if self.train: 143 | if self.online: 144 | subgraph = self.__sample__() 145 | return dgl.node_subgraph(self.train_g, subgraph) 146 | else: 147 | return dgl.node_subgraph(self.train_g, self.subgraphs[idx]) 148 | else: 149 | subgraph_nids = self.__sample__() 150 | num_nodes = len(subgraph_nids) 151 | subgraph_eids = dgl.node_subgraph(self.train_g, subgraph_nids).edata[dgl.EID] 152 | return num_nodes, subgraph_nids, subgraph_eids 153 | 154 | def __collate_fn__(self, batch): 155 | if self.train: # sample only one graph each epoch, batch_size in training phase in 1 156 | return batch[0] 157 | else: 158 | sum_num_nodes = 0 159 | subgraphs_nids_list = [] 160 | subgraphs_eids_list = [] 161 | for num_nodes, subgraph_nids, subgraph_eids in batch: 162 | sum_num_nodes += num_nodes 163 | subgraphs_nids_list.append(subgraph_nids) 164 | subgraphs_eids_list.append(subgraph_eids) 165 | return sum_num_nodes, subgraphs_nids_list, subgraphs_eids_list 166 | 167 | def __clear__(self): 168 | self.prob = None 169 | self.node_counter = None 170 | self.edge_counter = None 171 | self.g = None 172 | 173 | def __generate_fn__(self): 174 | raise NotImplementedError 175 | 176 | def __compute_norm__(self): 177 | 178 | self.node_counter[self.node_counter == 0] = 1 179 | self.edge_counter[self.edge_counter == 0] = 1 180 | 181 | loss_norm = self.N / self.node_counter / self.train_g.num_nodes() 182 | 183 | self.train_g.ndata['n_c'] = self.node_counter 184 | self.train_g.edata['e_c'] = self.edge_counter 185 | self.train_g.apply_edges(fn.v_div_e('n_c', 'e_c', 'a_n')) 186 | aggr_norm = self.train_g.edata.pop('a_n') 187 | 188 | self.train_g.ndata.pop('n_c') 189 | self.train_g.edata.pop('e_c') 190 | 191 | return aggr_norm.numpy(), loss_norm.numpy() 192 | 193 | def __compute_degree_norm(self): 194 | 195 | self.train_g.ndata['train_D_norm'] = 1. / self.train_g.in_degrees().float().clamp(min=1).unsqueeze(1) 196 | self.g.ndata['full_D_norm'] = 1. / self.g.in_degrees().float().clamp(min=1).unsqueeze(1) 197 | 198 | def __sample__(self): 199 | raise NotImplementedError 200 | 201 | def get_generated_subgraph_nodes(self): 202 | return self.subgraphs 203 | 204 | 205 | 206 | class SAINTRandomWalkLoader(SAINTSampler): 207 | """ 208 | Description 209 | ----------- 210 | GraphSAINT with random walk sampler 211 | 212 | Parameters 213 | ---------- 214 | num_roots : int 215 | the number of roots to generate random walks. 216 | length : int 217 | the length of each random walk. 218 | 219 | """ 220 | def __init__(self, feats, num_roots, length, **kwargs): 221 | self.num_roots, self.length = num_roots, length 222 | self.feats = feats 223 | super(SAINTRandomWalkLoader, self).__init__(node_budget=num_roots * length, **kwargs) 224 | 225 | def __generate_fn__(self): 226 | graph_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}.npy'.format(self.dn, self.num_roots, 227 | self.length, self.num_subg)) 228 | norm_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}_norm.npy'.format(self.dn, self.num_roots, 229 | self.length, self.num_subg)) 230 | return graph_fn, norm_fn 231 | 232 | def __sample__(self): 233 | sampled_roots = th.randint(0, self.train_g.num_nodes(), (self.num_roots,)) 234 | traces, types = random_walk(self.train_g, nodes=sampled_roots, length=self.length) 235 | sampled_nodes, _, _, _ = pack_traces(traces, types) 236 | sampled_nodes = sampled_nodes.unique() 237 | return sampled_nodes.numpy() 238 | 239 | def __collate_fn__(self, batch): 240 | if self.train: # sample only one graph each epoch, batch_size in training phase in 1 241 | subg = batch[0] 242 | node_ids = subg.ndata["_ID"] 243 | feats = self.feats[node_ids] 244 | return feats, subg 245 | else: 246 | sum_num_nodes = 0 247 | subgraphs_nids_list = [] 248 | subgraphs_eids_list = [] 249 | for num_nodes, subgraph_nids, subgraph_eids in batch: 250 | sum_num_nodes += num_nodes 251 | subgraphs_nids_list.append(subgraph_nids) 252 | subgraphs_eids_list.append(subgraph_eids) 253 | return sum_num_nodes, subgraphs_nids_list, subgraphs_eids_list 254 | 255 | 256 | def build_saint_dataloader(feats, graph, dataset_name, online=False, **kwargs): 257 | num_nodes = graph.num_nodes() 258 | num_subg_sampler = 20000 # the max num of subgraphs 259 | num_subg = 50 # the max times a node be sampled 260 | full = True 261 | batch_size_sampler = 200 # batch_size of nodes 262 | 263 | num_roots = max(num_nodes // 100, 2000) 264 | # num_roots = 2000 # starting node for each multiple random walk 265 | length = 4 266 | num_workers = 4 267 | 268 | params = { 269 | 'dn': dataset_name, 'g': graph, 'num_workers_sampler': 4, 270 | 'num_subg_sampler': num_subg_sampler, 271 | 'batch_size_sampler': batch_size_sampler, 272 | 'online': online, 'num_subg': num_subg, 273 | 'full': full 274 | } 275 | 276 | saint_sampler = SAINTRandomWalkLoader(feats, num_roots, length, **params) 277 | loader = DataLoader(saint_sampler, collate_fn=saint_sampler.__collate_fn__, batch_size=1, **kwargs) 278 | return loader 279 | 280 | 281 | def get_saint_subgraphs(graph, dataset_name, online=False): 282 | num_nodes = graph.num_nodes() 283 | num_subg_sampler = 20000 # the max num of subgraphs 284 | num_subg = 50 # the max times a node be sampled 285 | full = True 286 | batch_size_sampler = 200 # batch_size of nodes 287 | 288 | num_roots = max(num_nodes // 100, 2000) 289 | # num_roots = 2000 # starting node for each multiple random walk 290 | length = 4 291 | num_workers = 4 292 | 293 | params = { 294 | 'dn': dataset_name, 'g': graph, 'num_workers_sampler': 4, 295 | 'num_subg_sampler': num_subg_sampler, 296 | 'batch_size_sampler': batch_size_sampler, 297 | 'online': online, 'num_subg': num_subg, 298 | 'full': full 299 | } 300 | saint_sampler = SAINTRandomWalkLoader(None, num_roots, length, **params) 301 | subgraph_nodes = saint_sampler.get_generated_subgraph_nodes() 302 | return subgraph_nodes 303 | -------------------------------------------------------------------------------- /main_full_batch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | 6 | from utils import ( 7 | build_args, 8 | create_optimizer, 9 | set_random_seed, 10 | TBLogger, 11 | get_current_lr, 12 | load_best_configs, 13 | ) 14 | from datasets.data_proc import load_small_dataset 15 | from models.finetune import linear_probing_full_batch 16 | from models import build_model 17 | 18 | 19 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) 20 | 21 | 22 | def pretrain(model, graph, feat, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None): 23 | logging.info("start training..") 24 | graph = graph.to(device) 25 | x = feat.to(device) 26 | 27 | target_nodes = torch.arange(x.shape[0], device=x.device, dtype=torch.long) 28 | epoch_iter = tqdm(range(max_epoch)) 29 | 30 | for epoch in epoch_iter: 31 | model.train() 32 | 33 | loss = model(graph, x, targets=target_nodes) 34 | 35 | loss_dict = {"loss": loss.item()} 36 | optimizer.zero_grad() 37 | loss.backward() 38 | optimizer.step() 39 | if scheduler is not None: 40 | scheduler.step() 41 | 42 | epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}") 43 | if logger is not None: 44 | loss_dict["lr"] = get_current_lr(optimizer) 45 | logger.note(loss_dict, step=epoch) 46 | 47 | if (epoch + 1) % 200 == 0: 48 | linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob, mute=True) 49 | 50 | return model 51 | 52 | 53 | def main(args): 54 | device = args.device if args.device >= 0 else "cpu" 55 | seeds = args.seeds 56 | dataset_name = args.dataset 57 | max_epoch = args.max_epoch 58 | max_epoch_f = args.max_epoch_f 59 | num_hidden = args.num_hidden 60 | num_layers = args.num_layers 61 | encoder_type = args.encoder 62 | decoder_type = args.decoder 63 | replace_rate = args.replace_rate 64 | 65 | optim_type = args.optimizer 66 | loss_fn = args.loss_fn 67 | 68 | lr = args.lr 69 | weight_decay = args.weight_decay 70 | lr_f = args.lr_f 71 | weight_decay_f = args.weight_decay_f 72 | linear_prob = args.linear_prob 73 | load_model = args.load_model 74 | logs = args.logging 75 | use_scheduler = args.scheduler 76 | 77 | graph, (num_features, num_classes) = load_small_dataset(dataset_name) 78 | args.num_features = num_features 79 | 80 | acc_list = [] 81 | estp_acc_list = [] 82 | for i, seed in enumerate(seeds): 83 | print(f"####### Run {i} for seed {seed}") 84 | set_random_seed(seed) 85 | 86 | if logs: 87 | logger = TBLogger(name=f"{dataset_name}_loss_{loss_fn}_rpr_{replace_rate}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}") 88 | else: 89 | logger = None 90 | 91 | model = build_model(args) 92 | model.to(device) 93 | optimizer = create_optimizer(optim_type, model, lr, weight_decay) 94 | 95 | if use_scheduler: 96 | logging.info("Use schedular") 97 | scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5 98 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler) 99 | else: 100 | scheduler = None 101 | 102 | x = graph.ndata["feat"] 103 | if not load_model: 104 | model = pretrain(model, graph, x, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger) 105 | model = model.cpu() 106 | 107 | if load_model: 108 | logging.info("Loading Model ... ") 109 | model.load_state_dict(torch.load("checkpoint.pt")) 110 | 111 | model = model.to(device) 112 | model.eval() 113 | 114 | final_acc, estp_acc = linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob) 115 | acc_list.append(final_acc) 116 | estp_acc_list.append(estp_acc) 117 | 118 | if logger is not None: 119 | logger.finish() 120 | 121 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list) 122 | estp_acc, estp_acc_std = np.mean(estp_acc_list), np.std(estp_acc_list) 123 | print(f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}") 124 | print(f"# early-stopping_acc: {estp_acc:.4f}±{estp_acc_std:.4f}") 125 | 126 | 127 | # Press the green button in the gutter to run the script. 128 | if __name__ == "__main__": 129 | args = build_args() 130 | if args.use_cfg: 131 | args = load_best_configs(args) 132 | print(args) 133 | main(args) 134 | -------------------------------------------------------------------------------- /main_large.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | import torch 7 | 8 | from utils import ( 9 | WandbLogger, 10 | build_args, 11 | create_optimizer, 12 | set_random_seed, 13 | load_best_configs, 14 | show_occupied_memory, 15 | ) 16 | from models import build_model 17 | from datasets.lc_sampler import ( 18 | setup_training_dataloder, 19 | setup_training_data, 20 | ) 21 | from models.finetune import linear_probing_minibatch, finetune 22 | 23 | import warnings 24 | 25 | warnings.filterwarnings("ignore") 26 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) 27 | 28 | 29 | def evaluate( 30 | model, 31 | graph, feats, labels, 32 | split_idx, 33 | lr_f, weight_decay_f, max_epoch_f, 34 | linear_prob=True, 35 | device=0, 36 | batch_size=256, 37 | logger=None, ego_graph_nodes=None, 38 | label_rate=1.0, 39 | full_graph_forward=False, 40 | shuffle=True, 41 | ): 42 | logging.info("Using `lc` for evaluation...") 43 | num_train, num_val, num_test = [split_idx[k].shape[0] for k in ["train", "valid", "test"]] 44 | print(num_train,num_val,num_test) 45 | 46 | train_g_idx = np.arange(0, num_train) 47 | val_g_idx = np.arange(num_train, num_train+num_val) 48 | test_g_idx = np.arange(num_train+num_val, num_train+num_val+num_test) 49 | 50 | train_ego_graph_nodes = [ego_graph_nodes[i] for i in train_g_idx] 51 | val_ego_graph_nodes = [ego_graph_nodes[i] for i in val_g_idx] 52 | test_ego_graph_nodes = [ego_graph_nodes[i] for i in test_g_idx] 53 | 54 | train_lbls, val_lbls, test_lbls = labels[train_g_idx], labels[val_g_idx], labels[test_g_idx] 55 | 56 | # labels = [train_lbls, val_lbls, test_lbls] 57 | assert len(train_ego_graph_nodes) == len(train_lbls) 58 | assert len(val_ego_graph_nodes) == len(val_lbls) 59 | assert len(test_ego_graph_nodes) == len(test_lbls) 60 | 61 | print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}") 62 | logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---") 63 | 64 | 65 | if linear_prob: 66 | result = linear_probing_minibatch(model, graph, feats, [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes], [train_lbls, val_lbls, test_lbls], lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, batch_size=batch_size, device=device, shuffle=shuffle) 67 | else: 68 | max_epoch_f = max_epoch_f // 2 69 | 70 | if label_rate < 1.0: 71 | rand_idx = np.arange(len(train_ego_graph_nodes)) 72 | np.random.shuffle(rand_idx) 73 | rand_idx = rand_idx[:int(label_rate * len(train_ego_graph_nodes))] 74 | train_ego_graph_nodes = [train_ego_graph_nodes[i] for i in rand_idx] 75 | train_lbls = train_lbls[rand_idx] 76 | 77 | logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---") 78 | 79 | # train_lbls = (all_train_lbls, train_lbls) 80 | result = finetune( 81 | model, graph, feats, 82 | [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes], 83 | [train_lbls, val_lbls, test_lbls], 84 | split_idx=split_idx, 85 | lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, use_scheduler=True, batch_size=batch_size, device=device, logger=logger, full_graph_forward=full_graph_forward, 86 | ) 87 | return result 88 | 89 | 90 | def pretrain(model, feats, graph, ego_graph_nodes, max_epoch, device, use_scheduler, lr, weight_decay, batch_size=512, sampling_method="lc", optimizer="adam", drop_edge_rate=0): 91 | logging.info("start training..") 92 | 93 | model = model.to(device) 94 | optimizer = create_optimizer(optimizer, model, lr, weight_decay) 95 | 96 | dataloader = setup_training_dataloder( 97 | sampling_method, ego_graph_nodes, graph, feats, batch_size=batch_size, drop_edge_rate=drop_edge_rate) 98 | 99 | logging.info(f"After creating dataloader: Memory: {show_occupied_memory():.2f} MB") 100 | if use_scheduler and max_epoch > 0: 101 | logging.info("Use scheduler") 102 | scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5 103 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler) 104 | else: 105 | scheduler = None 106 | 107 | for epoch in range(max_epoch): 108 | epoch_iter = tqdm(dataloader) 109 | losses = [] 110 | # assert (graph.in_degrees() > 0).all(), "after loading" 111 | 112 | for batch_g in epoch_iter: 113 | model.train() 114 | if drop_edge_rate > 0: 115 | batch_g, targets, _, node_idx, drop_g1, drop_g2 = batch_g 116 | batch_g = batch_g.to(device) 117 | drop_g1 = drop_g1.to(device) 118 | drop_g2 = drop_g2.to(device) 119 | x = batch_g.ndata.pop("feat") 120 | loss = model(batch_g, x, targets, epoch, drop_g1, drop_g2) 121 | else: 122 | batch_g, targets, _, node_idx = batch_g 123 | batch_g = batch_g.to(device) 124 | x = batch_g.ndata.pop("feat") 125 | loss = model(batch_g, x, targets, epoch) 126 | 127 | optimizer.zero_grad() 128 | loss.backward() 129 | torch.nn.utils.clip_grad_norm_(model.parameters(), 3) 130 | optimizer.step() 131 | 132 | epoch_iter.set_description(f"train_loss: {loss.item():.4f}, Memory: {show_occupied_memory():.2f} MB") 133 | losses.append(loss.item()) 134 | 135 | if scheduler is not None: 136 | scheduler.step() 137 | 138 | torch.save(model.state_dict(), os.path.join(model_dir, model_name)) 139 | 140 | print(f"# Epoch {epoch} | train_loss: {np.mean(losses):.4f}, Memory: {show_occupied_memory():.2f} MB") 141 | 142 | return model 143 | 144 | 145 | if __name__ == "__main__": 146 | args = build_args() 147 | if args.use_cfg: 148 | args = load_best_configs(args) 149 | 150 | if args.device < 0: 151 | device = "cpu" 152 | else: 153 | device = "cuda:{}".format(args.device) 154 | seeds = args.seeds 155 | dataset_name = args.dataset 156 | max_epoch = args.max_epoch 157 | max_epoch_f = args.max_epoch_f 158 | num_hidden = args.num_hidden 159 | num_layers = args.num_layers 160 | encoder_type = args.encoder 161 | decoder_type = args.decoder 162 | encoder = args.encoder 163 | decoder = args.decoder 164 | num_hidden = args.num_hidden 165 | drop_edge_rate = args.drop_edge_rate 166 | 167 | optim_type = args.optimizer 168 | loss_fn = args.loss_fn 169 | 170 | lr = args.lr 171 | weight_decay = args.weight_decay 172 | lr_f = args.lr_f 173 | weight_decay_f = args.weight_decay_f 174 | linear_prob = args.linear_prob 175 | load_model = args.load_model 176 | no_pretrain = args.no_pretrain 177 | logs = args.logging 178 | use_scheduler = args.scheduler 179 | batch_size = args.batch_size 180 | batch_size_f = args.batch_size_f 181 | sampling_method = args.sampling_method 182 | ego_graph_file_path = args.ego_graph_file_path 183 | data_dir = args.data_dir 184 | 185 | n_procs = torch.cuda.device_count() 186 | optimizer_type = args.optimizer 187 | label_rate = args.label_rate 188 | lam = args.lam 189 | full_graph_forward = hasattr(args, "full_graph_forward") and args.full_graph_forward and not linear_prob 190 | 191 | model_dir = "checkpoints" 192 | os.makedirs(model_dir, exist_ok=True) 193 | 194 | set_random_seed(0) 195 | print(args) 196 | 197 | logging.info(f"Before loading data, occupied memory: {show_occupied_memory():.2f} MB") # in MB 198 | feats, graph, labels, split_idx, ego_graph_nodes = setup_training_data(dataset_name, data_dir, ego_graph_file_path) 199 | if dataset_name == "ogbn-papers100M": 200 | pretrain_ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2] + ego_graph_nodes[3] 201 | else: 202 | pretrain_ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2] 203 | ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2] # * merge train/val/test = all 204 | 205 | logging.info(f"After loading data, occupied memory: {show_occupied_memory():.2f} MB") # in MB 206 | 207 | args.num_features = feats.shape[1] 208 | 209 | if logs: 210 | logger = WandbLogger(log_path=f"{dataset_name}_loss_{loss_fn}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}", project="GraphMAE2", args=args) 211 | else: 212 | logger = None 213 | model_name = f"{encoder}_{decoder}_{num_hidden}_{num_layers}_{dataset_name}_{args.mask_rate}_{num_hidden}_checkpoint.pt" 214 | 215 | model = build_model(args) 216 | 217 | if not args.no_pretrain: 218 | # ------------- pretraining starts ---------------- 219 | if not load_model: 220 | logging.info("---- start pretraining ----") 221 | model = pretrain(model, feats, graph, pretrain_ego_graph_nodes, max_epoch=max_epoch, device=device, use_scheduler=use_scheduler, lr=lr, 222 | weight_decay=weight_decay, batch_size=batch_size, drop_edge_rate=drop_edge_rate, 223 | sampling_method=sampling_method, optimizer=optimizer_type) 224 | 225 | model = model.cpu() 226 | logging.info(f"saving model to {model_dir}/{model_name}...") 227 | torch.save(model.state_dict(), os.path.join(model_dir, model_name)) 228 | # ------------- pretraining ends ---------------- 229 | 230 | if load_model: 231 | model.load_state_dict(torch.load(os.path.join(args.checkpoint_path))) 232 | logging.info(f"Loading Model from {args.checkpoint_path}...") 233 | else: 234 | logging.info("--- no pretrain ---") 235 | 236 | model = model.to(device) 237 | model.eval() 238 | 239 | logging.info("---- start finetuning / evaluation ----") 240 | 241 | final_accs = [] 242 | for i,_ in enumerate(seeds): 243 | print(f"####### Run seed {seeds[i]}") 244 | set_random_seed(seeds[i]) 245 | eval_model = build_model(args) 246 | eval_model.load_state_dict(model.state_dict()) 247 | eval_model.to(device) 248 | 249 | print(f"features size : {feats.shape[1]}") 250 | logging.info("start evaluation...") 251 | final_acc = evaluate( 252 | eval_model, graph, feats, labels, split_idx, 253 | lr_f, weight_decay_f, max_epoch_f, 254 | device=device, 255 | batch_size=batch_size_f, 256 | ego_graph_nodes=ego_graph_nodes, 257 | linear_prob=linear_prob, 258 | label_rate=label_rate, 259 | full_graph_forward=full_graph_forward, 260 | shuffle=False if dataset_name == "ogbn-papers100M" else True 261 | ) 262 | 263 | final_accs.append(float(final_acc)) 264 | 265 | print(f"Run {seeds[i]} | TestAcc: {final_acc:.4f}") 266 | 267 | print(f"# final_acc: {np.mean(final_accs):.4f}, std: {np.std(final_accs):.4f}") 268 | 269 | if logger is not None: 270 | logger.finish() 271 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .edcoder import PreModel 2 | 3 | 4 | def build_model(args): 5 | num_heads = args.num_heads 6 | num_out_heads = args.num_out_heads 7 | num_hidden = args.num_hidden 8 | num_layers = args.num_layers 9 | residual = args.residual 10 | attn_drop = args.attn_drop 11 | in_drop = args.in_drop 12 | norm = args.norm 13 | negative_slope = args.negative_slope 14 | encoder_type = args.encoder 15 | decoder_type = args.decoder 16 | mask_rate = args.mask_rate 17 | remask_rate = args.remask_rate 18 | mask_method = args.mask_method 19 | drop_edge_rate = args.drop_edge_rate 20 | 21 | activation = args.activation 22 | loss_fn = args.loss_fn 23 | alpha_l = args.alpha_l 24 | 25 | num_features = args.num_features 26 | num_dec_layers = args.num_dec_layers 27 | num_remasking = args.num_remasking 28 | lam = args.lam 29 | delayed_ema_epoch = args.delayed_ema_epoch 30 | replace_rate = args.replace_rate 31 | remask_method = args.remask_method 32 | momentum = args.momentum 33 | zero_init = args.dataset in ("cora", "pubmed", "citeseer") 34 | 35 | model = PreModel( 36 | in_dim=num_features, 37 | num_hidden=num_hidden, 38 | num_layers=num_layers, 39 | num_dec_layers=num_dec_layers, 40 | num_remasking=num_remasking, 41 | nhead=num_heads, 42 | nhead_out=num_out_heads, 43 | activation=activation, 44 | feat_drop=in_drop, 45 | attn_drop=attn_drop, 46 | negative_slope=negative_slope, 47 | residual=residual, 48 | encoder_type=encoder_type, 49 | decoder_type=decoder_type, 50 | mask_rate=mask_rate, 51 | remask_rate=remask_rate, 52 | mask_method=mask_method, 53 | norm=norm, 54 | loss_fn=loss_fn, 55 | drop_edge_rate=drop_edge_rate, 56 | alpha_l=alpha_l, 57 | lam=lam, 58 | delayed_ema_epoch=delayed_ema_epoch, 59 | replace_rate=replace_rate, 60 | remask_method=remask_method, 61 | momentum=momentum, 62 | zero_init=zero_init, 63 | ) 64 | return model 65 | -------------------------------------------------------------------------------- /models/edcoder.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | from typing import Optional 4 | import torch 5 | import torch.nn as nn 6 | from functools import partial 7 | 8 | from .gat import GAT 9 | 10 | from .loss_func import sce_loss 11 | 12 | 13 | def setup_module(m_type, enc_dec, in_dim, num_hidden, out_dim, num_layers, dropout, activation, residual, norm, nhead, nhead_out, attn_drop, negative_slope=0.2, concat_out=True, **kwargs) -> nn.Module: 14 | if m_type in ("gat", "tsgat"): 15 | mod = GAT( 16 | in_dim=in_dim, 17 | num_hidden=num_hidden, 18 | out_dim=out_dim, 19 | num_layers=num_layers, 20 | nhead=nhead, 21 | nhead_out=nhead_out, 22 | concat_out=concat_out, 23 | activation=activation, 24 | feat_drop=dropout, 25 | attn_drop=attn_drop, 26 | negative_slope=negative_slope, 27 | residual=residual, 28 | norm=norm, 29 | encoding=(enc_dec == "encoding"), 30 | **kwargs, 31 | ) 32 | elif m_type == "mlp": 33 | # * just for decoder 34 | mod = nn.Sequential( 35 | nn.Linear(in_dim, num_hidden * 2), 36 | nn.PReLU(), 37 | nn.Dropout(0.2), 38 | nn.Linear(num_hidden * 2, out_dim) 39 | ) 40 | elif m_type == "linear": 41 | mod = nn.Linear(in_dim, out_dim) 42 | else: 43 | raise NotImplementedError 44 | 45 | return mod 46 | 47 | class PreModel(nn.Module): 48 | def __init__( 49 | self, 50 | in_dim: int, 51 | num_hidden: int, 52 | num_layers: int, 53 | num_dec_layers: int, 54 | num_remasking: int, 55 | nhead: int, 56 | nhead_out: int, 57 | activation: str, 58 | feat_drop: float, 59 | attn_drop: float, 60 | negative_slope: float, 61 | residual: bool, 62 | norm: Optional[str], 63 | mask_rate: float = 0.3, 64 | remask_rate: float = 0.5, 65 | remask_method: str = "random", 66 | mask_method: str = "random", 67 | encoder_type: str = "gat", 68 | decoder_type: str = "gat", 69 | loss_fn: str = "byol", 70 | drop_edge_rate: float = 0.0, 71 | alpha_l: float = 2, 72 | lam: float = 1.0, 73 | delayed_ema_epoch: int = 0, 74 | momentum: float = 0.996, 75 | replace_rate: float = 0.0, 76 | zero_init: bool = False, 77 | ): 78 | super(PreModel, self).__init__() 79 | self._mask_rate = mask_rate 80 | self._remask_rate = remask_rate 81 | self._mask_method = mask_method 82 | self._alpha_l = alpha_l 83 | self._delayed_ema_epoch = delayed_ema_epoch 84 | 85 | self.num_remasking = num_remasking 86 | self._encoder_type = encoder_type 87 | self._decoder_type = decoder_type 88 | self._drop_edge_rate = drop_edge_rate 89 | self._output_hidden_size = num_hidden 90 | self._momentum = momentum 91 | self._replace_rate = replace_rate 92 | self._num_remasking = num_remasking 93 | self._remask_method = remask_method 94 | 95 | self._token_rate = 1 - self._replace_rate 96 | self._lam = lam 97 | 98 | assert num_hidden % nhead == 0 99 | assert num_hidden % nhead_out == 0 100 | if encoder_type in ("gat",): 101 | enc_num_hidden = num_hidden // nhead 102 | enc_nhead = nhead 103 | else: 104 | enc_num_hidden = num_hidden 105 | enc_nhead = 1 106 | 107 | dec_in_dim = num_hidden 108 | dec_num_hidden = num_hidden // nhead if decoder_type in ("gat",) else num_hidden 109 | 110 | # build encoder 111 | self.encoder = setup_module( 112 | m_type=encoder_type, 113 | enc_dec="encoding", 114 | in_dim=in_dim, 115 | num_hidden=enc_num_hidden, 116 | out_dim=enc_num_hidden, 117 | num_layers=num_layers, 118 | nhead=enc_nhead, 119 | nhead_out=enc_nhead, 120 | concat_out=True, 121 | activation=activation, 122 | dropout=feat_drop, 123 | attn_drop=attn_drop, 124 | negative_slope=negative_slope, 125 | residual=residual, 126 | norm=norm, 127 | ) 128 | 129 | self.decoder = setup_module( 130 | m_type=decoder_type, 131 | enc_dec="decoding", 132 | in_dim=dec_in_dim, 133 | num_hidden=dec_num_hidden, 134 | out_dim=in_dim, 135 | nhead_out=nhead_out, 136 | num_layers=num_dec_layers, 137 | nhead=nhead, 138 | activation=activation, 139 | dropout=feat_drop, 140 | attn_drop=attn_drop, 141 | negative_slope=negative_slope, 142 | residual=residual, 143 | norm=norm, 144 | concat_out=True, 145 | ) 146 | 147 | self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim)) 148 | self.dec_mask_token = nn.Parameter(torch.zeros(1, num_hidden)) 149 | 150 | self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False) 151 | 152 | if not zero_init: 153 | self.reset_parameters_for_token() 154 | 155 | 156 | # * setup loss function 157 | self.criterion = self.setup_loss_fn(loss_fn, alpha_l) 158 | 159 | self.projector = nn.Sequential( 160 | nn.Linear(num_hidden, 256), 161 | nn.PReLU(), 162 | nn.Linear(256, num_hidden), 163 | ) 164 | self.projector_ema = nn.Sequential( 165 | nn.Linear(num_hidden, 256), 166 | nn.PReLU(), 167 | nn.Linear(256, num_hidden), 168 | ) 169 | self.predictor = nn.Sequential( 170 | nn.PReLU(), 171 | nn.Linear(num_hidden, num_hidden) 172 | ) 173 | 174 | self.encoder_ema = setup_module( 175 | m_type=encoder_type, 176 | enc_dec="encoding", 177 | in_dim=in_dim, 178 | num_hidden=enc_num_hidden, 179 | out_dim=enc_num_hidden, 180 | num_layers=num_layers, 181 | nhead=enc_nhead, 182 | nhead_out=enc_nhead, 183 | concat_out=True, 184 | activation=activation, 185 | dropout=feat_drop, 186 | attn_drop=attn_drop, 187 | negative_slope=negative_slope, 188 | residual=residual, 189 | norm=norm, 190 | ) 191 | self.encoder_ema.load_state_dict(self.encoder.state_dict()) 192 | self.projector_ema.load_state_dict(self.projector.state_dict()) 193 | 194 | for p in self.encoder_ema.parameters(): 195 | p.requires_grad = False 196 | p.detach_() 197 | for p in self.projector_ema.parameters(): 198 | p.requires_grad = False 199 | p.detach_() 200 | 201 | self.print_num_parameters() 202 | 203 | def print_num_parameters(self): 204 | num_encoder_params = [p.numel() for p in self.encoder.parameters() if p.requires_grad] 205 | num_decoder_params = [p.numel() for p in self.decoder.parameters() if p.requires_grad] 206 | num_params = [p.numel() for p in self.parameters() if p.requires_grad] 207 | 208 | print(f"num_encoder_params: {sum(num_encoder_params)}, num_decoder_params: {sum(num_decoder_params)}, num_params_in_total: {sum(num_params)}") 209 | 210 | def reset_parameters_for_token(self): 211 | nn.init.xavier_normal_(self.enc_mask_token) 212 | nn.init.xavier_normal_(self.dec_mask_token) 213 | nn.init.xavier_normal_(self.encoder_to_decoder.weight, gain=1.414) 214 | 215 | @property 216 | def output_hidden_dim(self): 217 | return self._output_hidden_size 218 | 219 | def setup_loss_fn(self, loss_fn, alpha_l): 220 | if loss_fn == "mse": 221 | print(f"=== Use mse_loss ===") 222 | criterion = nn.MSELoss() 223 | elif loss_fn == "sce": 224 | print(f"=== Use sce_loss and alpha_l={alpha_l} ===") 225 | criterion = partial(sce_loss, alpha=alpha_l) 226 | else: 227 | raise NotImplementedError 228 | return criterion 229 | 230 | def forward(self, g, x, targets=None, epoch=0, drop_g1=None, drop_g2=None): # ---- attribute reconstruction ---- 231 | loss = self.mask_attr_prediction(g, x, targets, epoch, drop_g1, drop_g2) 232 | 233 | return loss 234 | 235 | def mask_attr_prediction(self, g, x, targets, epoch, drop_g1=None, drop_g2=None): 236 | pre_use_g, use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise(g, x, self._mask_rate) 237 | use_g = drop_g1 if drop_g1 is not None else g 238 | 239 | enc_rep = self.encoder(use_g, use_x,) 240 | 241 | with torch.no_grad(): 242 | drop_g2 = drop_g2 if drop_g2 is not None else g 243 | latent_target = self.encoder_ema(drop_g2, x,) 244 | if targets is not None: 245 | latent_target = self.projector_ema(latent_target[targets]) 246 | else: 247 | latent_target = self.projector_ema(latent_target[keep_nodes]) 248 | 249 | if targets is not None: 250 | latent_pred = self.projector(enc_rep[targets]) 251 | latent_pred = self.predictor(latent_pred) 252 | loss_latent = sce_loss(latent_pred, latent_target, 1) 253 | else: 254 | latent_pred = self.projector(enc_rep[keep_nodes]) 255 | latent_pred = self.predictor(latent_pred) 256 | loss_latent = sce_loss(latent_pred, latent_target, 1) 257 | 258 | # ---- attribute reconstruction ---- 259 | origin_rep = self.encoder_to_decoder(enc_rep) 260 | 261 | loss_rec_all = 0 262 | if self._remask_method == "random": 263 | for i in range(self._num_remasking): 264 | rep = origin_rep.clone() 265 | rep, remask_nodes, rekeep_nodes = self.random_remask(use_g, rep, self._remask_rate) 266 | recon = self.decoder(pre_use_g, rep) 267 | 268 | x_init = x[mask_nodes] 269 | x_rec = recon[mask_nodes] 270 | loss_rec = self.criterion(x_init, x_rec) 271 | loss_rec_all += loss_rec 272 | loss_rec = loss_rec_all 273 | elif self._remask_method == "fixed": 274 | rep = self.fixed_remask(g, origin_rep, mask_nodes) 275 | x_rec = self.decoder(pre_use_g, rep)[mask_nodes] 276 | x_init = x[mask_nodes] 277 | loss_rec = self.criterion(x_init, x_rec) 278 | else: 279 | raise NotImplementedError 280 | 281 | loss = loss_rec + self._lam * loss_latent 282 | 283 | if epoch >= self._delayed_ema_epoch: 284 | self.ema_update() 285 | return loss 286 | 287 | def ema_update(self): 288 | def update(student, teacher): 289 | with torch.no_grad(): 290 | # m = momentum_schedule[it] # momentum parameter 291 | m = self._momentum 292 | for param_q, param_k in zip(student.parameters(), teacher.parameters()): 293 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 294 | update(self.encoder, self.encoder_ema) 295 | update(self.projector, self.projector_ema) 296 | 297 | def embed(self, g, x): 298 | rep = self.encoder(g, x) 299 | return rep 300 | 301 | def get_encoder(self): 302 | #self.encoder.reset_classifier(out_size) 303 | return self.encoder 304 | 305 | def reset_encoder(self, out_size): 306 | self.encoder.reset_classifier(out_size) 307 | 308 | @property 309 | def enc_params(self): 310 | return self.encoder.parameters() 311 | 312 | @property 313 | def dec_params(self): 314 | return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()]) 315 | 316 | def output_grad(self): 317 | grad_dict = {} 318 | for n, p in self.named_parameters(): 319 | if p.grad is not None: 320 | grad_dict[n] = p.grad.abs().mean().item() 321 | return grad_dict 322 | 323 | def encoding_mask_noise(self, g, x, mask_rate=0.3): 324 | num_nodes = g.num_nodes() 325 | perm = torch.randperm(num_nodes, device=x.device) 326 | num_mask_nodes = int(mask_rate * num_nodes) 327 | 328 | # exclude isolated nodes 329 | # isolated_nodes = torch.where(g.in_degrees() <= 1)[0] 330 | # mask_nodes = perm[: num_mask_nodes] 331 | # mask_nodes = torch.index_fill(torch.full((num_nodes,), False, device=device), 0, mask_nodes, True) 332 | # mask_nodes[isolated_nodes] = False 333 | # keep_nodes = torch.where(~mask_nodes)[0] 334 | # mask_nodes = torch.where(mask_nodes)[0] 335 | # num_mask_nodes = mask_nodes.shape[0] 336 | 337 | # random masking 338 | num_mask_nodes = int(mask_rate * num_nodes) 339 | mask_nodes = perm[: num_mask_nodes] 340 | keep_nodes = perm[num_mask_nodes: ] 341 | 342 | out_x = x.clone() 343 | token_nodes = mask_nodes 344 | out_x[mask_nodes] = 0.0 345 | 346 | out_x[token_nodes] += self.enc_mask_token 347 | use_g = g.clone() 348 | 349 | return use_g, out_x, (mask_nodes, keep_nodes) 350 | 351 | def random_remask(self,g,rep,remask_rate=0.5): 352 | 353 | num_nodes = g.num_nodes() 354 | perm = torch.randperm(num_nodes, device=rep.device) 355 | num_remask_nodes = int(remask_rate * num_nodes) 356 | remask_nodes = perm[: num_remask_nodes] 357 | rekeep_nodes = perm[num_remask_nodes: ] 358 | 359 | rep = rep.clone() 360 | rep[remask_nodes] = 0 361 | rep[remask_nodes] += self.dec_mask_token 362 | 363 | return rep, remask_nodes, rekeep_nodes 364 | 365 | def fixed_remask(self, g, rep, masked_nodes): 366 | rep[masked_nodes] = 0 367 | return rep -------------------------------------------------------------------------------- /models/finetune.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | 11 | from datasets.lc_sampler import setup_eval_dataloder, setup_finetune_dataloder, LinearProbingDataLoader 12 | from utils import accuracy, set_random_seed, show_occupied_memory, get_current_lr 13 | 14 | import wandb 15 | 16 | 17 | def linear_probing_minibatch( 18 | model, graph, 19 | feats, ego_graph_nodes, labels, 20 | lr_f, weight_decay_f, max_epoch_f, 21 | device, batch_size=-1, shuffle=True): 22 | logging.info("-- Linear Probing in downstream tasks ---") 23 | train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes = ego_graph_nodes 24 | num_train, num_val = len(train_ego_graph_nodes), len(val_ego_graph_nodes) 25 | train_lbls, val_lbls, test_lbls = labels 26 | # if dataset_name in ["ogbn-papers100M", "mag-scholar-f", "mag-scholar-c","ogbn-arxiv","ogbn-products"]: 27 | # if dataset_name in ["ogbn-papers100M", "mag-scholar-f", "mag-scholar-c", "ogbn-arxiv", "ogbn-products"]: 28 | eval_loader = setup_eval_dataloder("lc", graph, feats, train_ego_graph_nodes+val_ego_graph_nodes+test_ego_graph_nodes, 512) 29 | 30 | with torch.no_grad(): 31 | model.eval() 32 | embeddings = [] 33 | 34 | for batch in tqdm(eval_loader, desc="Infering..."): 35 | batch_g, targets, _, node_idx = batch 36 | batch_g = batch_g.to(device) 37 | x = batch_g.ndata.pop("feat") 38 | targets = targets.to(device) 39 | 40 | batch_emb = model.embed(batch_g, x)[targets] 41 | embeddings.append(batch_emb.cpu()) 42 | embeddings = torch.cat(embeddings, dim=0) 43 | 44 | train_emb, val_emb, test_emb = embeddings[:num_train], embeddings[num_train:num_train+num_val], embeddings[num_train+num_val:] 45 | 46 | batch_size = 5120 47 | acc = [] 48 | seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 49 | for i,_ in enumerate(seeds): 50 | print(f"####### Run seed {seeds[i]} for LinearProbing...") 51 | set_random_seed(seeds[i]) 52 | print(f"training sample:{len(train_emb)}") 53 | test_acc = node_classification_linear_probing( 54 | (train_emb, val_emb, test_emb), 55 | (train_lbls, val_lbls, test_lbls), 56 | lr_f, weight_decay_f, max_epoch_f, device, batch_size=batch_size, shuffle=shuffle) 57 | acc.append(test_acc) 58 | 59 | print(f"# final_acc: {np.mean(acc):.4f}, std: {np.std(acc):.4f}") 60 | 61 | return np.mean(acc) 62 | 63 | 64 | 65 | class LogisticRegression(nn.Module): 66 | def __init__(self, num_dim, num_class): 67 | super().__init__() 68 | self.linear = nn.Linear(num_dim, num_class) 69 | 70 | def forward(self, g, x, *args): 71 | logits = self.linear(x) 72 | return logits 73 | 74 | 75 | def node_classification_linear_probing(embeddings, labels, lr, weight_decay, max_epoch, device, mute=False, batch_size=-1, shuffle=True): 76 | criterion = torch.nn.CrossEntropyLoss() 77 | 78 | train_emb, val_emb, test_emb = embeddings 79 | train_label, val_label, test_label = labels 80 | train_label = train_label.to(torch.long) 81 | val_label = val_label.to(torch.long) 82 | test_label = test_label.to(torch.long) 83 | 84 | best_val_acc = 0 85 | best_val_epoch = 0 86 | best_model = None 87 | 88 | if not mute: 89 | epoch_iter = tqdm(range(max_epoch)) 90 | else: 91 | epoch_iter = range(max_epoch) 92 | 93 | encoder = LogisticRegression(train_emb.shape[1], int(train_label.max().item() + 1)) 94 | encoder = encoder.to(device) 95 | optimizer = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=weight_decay) 96 | 97 | if batch_size > 0: 98 | train_loader = LinearProbingDataLoader(np.arange(len(train_emb)), train_emb, train_label, batch_size=batch_size, num_workers=4, persistent_workers=True, shuffle=shuffle) 99 | # train_loader = DataLoader(np.arange(len(train_emb)), batch_size=batch_size, shuffle=False) 100 | val_loader = LinearProbingDataLoader(np.arange(len(val_emb)), val_emb, val_label, batch_size=batch_size, num_workers=4, persistent_workers=True,shuffle=False) 101 | test_loader = LinearProbingDataLoader(np.arange(len(test_emb)), test_emb, test_label, batch_size=batch_size, num_workers=4, persistent_workers=True,shuffle=False) 102 | else: 103 | train_loader = [np.arange(len(train_emb))] 104 | val_loader = [np.arange(len(val_emb))] 105 | test_loader = [np.arange(len(test_emb))] 106 | 107 | def eval_forward(loader, _label): 108 | pred_all = [] 109 | for batch_x, _ in loader: 110 | batch_x = batch_x.to(device) 111 | pred = encoder(None, batch_x) 112 | pred_all.append(pred.cpu()) 113 | pred = torch.cat(pred_all, dim=0) 114 | acc = accuracy(pred, _label) 115 | return acc 116 | 117 | for epoch in epoch_iter: 118 | encoder.train() 119 | 120 | for batch_x, batch_label in train_loader: 121 | batch_x = batch_x.to(device) 122 | batch_label = batch_label.to(device) 123 | pred = encoder(None, batch_x) 124 | loss = criterion(pred, batch_label) 125 | optimizer.zero_grad() 126 | loss.backward() 127 | # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3) 128 | optimizer.step() 129 | 130 | with torch.no_grad(): 131 | encoder.eval() 132 | val_acc = eval_forward(val_loader, val_label) 133 | 134 | if val_acc >= best_val_acc: 135 | best_val_acc = val_acc 136 | best_val_epoch = epoch 137 | best_model = copy.deepcopy(encoder) 138 | 139 | if not mute: 140 | epoch_iter.set_description(f"# Epoch: {epoch}, train_loss:{loss.item(): .4f}, val_acc:{val_acc:.4f}") 141 | 142 | best_model.eval() 143 | encoder = best_model 144 | with torch.no_grad(): 145 | test_acc = eval_forward(test_loader, test_label) 146 | if mute: 147 | print(f"# IGNORE: --- TestAcc: {test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ") 148 | else: 149 | print(f"--- TestAcc: {test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ") 150 | 151 | return test_acc 152 | 153 | 154 | def finetune( 155 | model, 156 | graph, 157 | feats, 158 | ego_graph_nodes, 159 | labels, 160 | split_idx, 161 | lr_f, weight_decay_f, max_epoch_f, 162 | use_scheduler, batch_size, 163 | device, 164 | logger=None, 165 | full_graph_forward=False, 166 | ): 167 | logging.info("-- Finetuning in downstream tasks ---") 168 | train_egs, val_egs, test_egs = ego_graph_nodes 169 | print(f"num of egos:{len(train_egs)},{len(val_egs)},{len(test_egs)}") 170 | 171 | print(graph.num_nodes()) 172 | 173 | train_nid = split_idx["train"].numpy() 174 | val_nid = split_idx["valid"].numpy() 175 | test_nid = split_idx["test"].numpy() 176 | 177 | train_lbls, val_lbls, test_lbls = [x.long() for x in labels] 178 | print(f"num of labels:{len(train_lbls)},{len(val_lbls)},{len(test_lbls)}") 179 | 180 | num_classes = max(max(train_lbls.max().item(), val_lbls.max().item()), test_lbls.max().item()) + 1 181 | 182 | model = model.get_encoder() 183 | model.reset_classifier(int(num_classes)) 184 | model = model.to(device) 185 | criterion = torch.nn.CrossEntropyLoss() 186 | 187 | train_loader = setup_finetune_dataloder("lc", graph, feats, train_egs, train_lbls, batch_size=batch_size, shuffle=True) 188 | val_loader = setup_finetune_dataloder("lc", graph, feats, val_egs, val_lbls, batch_size=batch_size, shuffle=False) 189 | test_loader = setup_finetune_dataloder("lc", graph, feats, test_egs, test_lbls, batch_size=batch_size, shuffle=False) 190 | 191 | #optimizer = torch.optim.Adam(model.parameters(), lr=lr_f, weight_decay=weight_decay_f) 192 | 193 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr_f, weight_decay=weight_decay_f) 194 | 195 | if use_scheduler and max_epoch_f > 0: 196 | logging.info("Use schedular") 197 | warmup_epochs = int(max_epoch_f * 0.1) 198 | # scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch_f) ) * 0.5 199 | scheduler = lambda epoch: epoch / warmup_epochs if epoch < warmup_epochs else ( 1 + np.cos((epoch - warmup_epochs) * np.pi / (max_epoch_f - warmup_epochs))) * 0.5 200 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler) 201 | else: 202 | scheduler = None 203 | 204 | def eval_with_lc(model, loader): 205 | pred_counts = [] 206 | model.eval() 207 | epoch_iter = tqdm(loader) 208 | with torch.no_grad(): 209 | for batch in epoch_iter: 210 | batch_g, targets, batch_lbls, node_idx = batch 211 | batch_g = batch_g.to(device) 212 | batch_lbls = batch_lbls.to(device) 213 | x = batch_g.ndata.pop("feat") 214 | 215 | prediction = model(batch_g, x) 216 | prediction = prediction[targets] 217 | pred_counts.append((prediction.argmax(1) == batch_lbls)) 218 | pred_counts = torch.cat(pred_counts) 219 | acc = pred_counts.float().sum() / pred_counts.shape[0] 220 | return acc 221 | 222 | def eval_full_prop(model, g, nfeat, val_nid, test_nid, batch_size, device): 223 | model.eval() 224 | 225 | with torch.no_grad(): 226 | pred = model.inference(g, nfeat, batch_size, device) 227 | model.train() 228 | 229 | return accuracy(pred[val_nid], val_lbls.cpu()), accuracy(pred[test_nid], test_lbls.cpu()) 230 | 231 | best_val_acc = 0 232 | best_model = None 233 | best_epoch = 0 234 | test_acc = 0 235 | early_stop_cnt = 0 236 | 237 | for epoch in range(max_epoch_f): 238 | if epoch == 0: 239 | scheduler.step() 240 | continue 241 | if early_stop_cnt >= 10: 242 | break 243 | epoch_iter = tqdm(train_loader) 244 | losses = [] 245 | model.train() 246 | 247 | for batch_g, targets, batch_lbls, node_idx in epoch_iter: 248 | batch_g = batch_g.to(device) 249 | targets = targets.to(device) 250 | batch_lbls = batch_lbls.to(device) 251 | x = batch_g.ndata.pop("feat") 252 | 253 | prediction = model(batch_g, x) 254 | prediction = prediction[targets] 255 | loss = criterion(prediction, batch_lbls) 256 | 257 | optimizer.zero_grad() 258 | loss.backward() 259 | 260 | torch.nn.utils.clip_grad_norm_(model.parameters(), 5) 261 | optimizer.step() 262 | 263 | metrics = {"finetune_loss": loss} 264 | wandb.log(metrics) 265 | 266 | if logger is not None: 267 | logger.log(metrics) 268 | 269 | epoch_iter.set_description(f"Finetuning | train_loss: {loss.item():.4f}, Memory: {show_occupied_memory():.2f} MB") 270 | losses.append(loss.item()) 271 | 272 | if scheduler is not None: 273 | scheduler.step() 274 | 275 | if not full_graph_forward: 276 | if epoch > 0: 277 | val_acc = eval_with_lc(model, val_loader) 278 | _test_acc = 0 279 | else: 280 | if epoch > 0 and epoch % 1 == 0: 281 | val_acc, _test_acc = eval_full_prop(model, graph, feats, val_nid, test_nid, 10000, device) 282 | model = model.to(device) 283 | 284 | print('val Acc {:.4f}'.format(val_acc)) 285 | if val_acc > best_val_acc: 286 | best_model = copy.deepcopy(model) 287 | best_val_acc = val_acc 288 | test_acc = _test_acc 289 | best_epoch = epoch 290 | early_stop_cnt = 0 291 | else: 292 | early_stop_cnt += 1 293 | 294 | if not full_graph_forward: 295 | print("val Acc {:.4f}, Best Val Acc {:.4f}".format(val_acc, best_val_acc)) 296 | else: 297 | print("Val Acc {:.4f}, Best Val Acc {:.4f} Test Acc {:.4f}".format(val_acc, best_val_acc, test_acc)) 298 | 299 | metrics = {"epoch_val_acc": val_acc, 300 | "test_acc": test_acc, 301 | "epoch": epoch, 302 | "lr_f": get_current_lr(optimizer)} 303 | 304 | wandb.log(metrics) 305 | if logger is not None: 306 | logger.log(metrics) 307 | print(f"# Finetuning - Epoch {epoch} | train_loss: {np.mean(losses):.4f}, ValAcc: {val_acc:.4f}, TestAcc: {test_acc:.4f}, Memory: {show_occupied_memory():.2f} MB") 308 | 309 | model = best_model 310 | if not full_graph_forward: 311 | test_acc = eval_with_lc(test_loader) 312 | 313 | print(f"Finetune | TestAcc: {test_acc:.4f} from Epoch {best_epoch}") 314 | return test_acc 315 | 316 | 317 | def linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob=True, mute=False): 318 | model.eval() 319 | with torch.no_grad(): 320 | x = model.embed(graph.to(device), x.to(device)) 321 | in_feat = x.shape[1] 322 | encoder = LogisticRegression(in_feat, num_classes) 323 | 324 | num_finetune_params = [p.numel() for p in encoder.parameters() if p.requires_grad] 325 | if not mute: 326 | print(f"num parameters for finetuning: {sum(num_finetune_params)}") 327 | 328 | encoder.to(device) 329 | optimizer_f = torch.optim.Adam(encoder.parameters(), lr=lr_f, weight_decay=weight_decay_f) 330 | final_acc, estp_acc = _linear_probing_full_batch(encoder, graph, x, optimizer_f, max_epoch_f, device, mute) 331 | return final_acc, estp_acc 332 | 333 | 334 | def _linear_probing_full_batch(model, graph, feat, optimizer, max_epoch, device, mute=False): 335 | criterion = torch.nn.CrossEntropyLoss() 336 | 337 | graph = graph.to(device) 338 | x = feat.to(device) 339 | 340 | train_mask = graph.ndata["train_mask"] 341 | val_mask = graph.ndata["val_mask"] 342 | test_mask = graph.ndata["test_mask"] 343 | labels = graph.ndata["label"] 344 | 345 | best_val_acc = 0 346 | best_val_epoch = 0 347 | best_model = None 348 | 349 | if not mute: 350 | epoch_iter = tqdm(range(max_epoch)) 351 | else: 352 | epoch_iter = range(max_epoch) 353 | 354 | for epoch in epoch_iter: 355 | model.train() 356 | out = model(graph, x) 357 | loss = criterion(out[train_mask], labels[train_mask]) 358 | optimizer.zero_grad() 359 | loss.backward() 360 | # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3) 361 | optimizer.step() 362 | 363 | with torch.no_grad(): 364 | model.eval() 365 | pred = model(graph, x) 366 | val_acc = accuracy(pred[val_mask], labels[val_mask]) 367 | val_loss = criterion(pred[val_mask], labels[val_mask]) 368 | test_acc = accuracy(pred[test_mask], labels[test_mask]) 369 | test_loss = criterion(pred[test_mask], labels[test_mask]) 370 | 371 | if val_acc >= best_val_acc: 372 | best_val_acc = val_acc 373 | best_val_epoch = epoch 374 | best_model = copy.deepcopy(model) 375 | 376 | if not mute: 377 | epoch_iter.set_description(f"# Epoch: {epoch}, train_loss:{loss.item(): .4f}, val_loss:{val_loss.item(): .4f}, val_acc:{val_acc}, test_loss:{test_loss.item(): .4f}, test_acc:{test_acc: .4f}") 378 | 379 | best_model.eval() 380 | with torch.no_grad(): 381 | pred = best_model(graph, x) 382 | estp_test_acc = accuracy(pred[test_mask], labels[test_mask]) 383 | if mute: 384 | print(f"# IGNORE: --- TestAcc: {test_acc:.4f}, early-stopping-TestAcc: {estp_test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ") 385 | else: 386 | print(f"--- TestAcc: {test_acc:.4f}, early-stopping-TestAcc: {estp_test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ") 387 | 388 | return test_acc, estp_test_acc 389 | -------------------------------------------------------------------------------- /models/gat.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import dgl 7 | import dgl.function as fn 8 | from dgl.ops import edge_softmax 9 | from dgl.utils import expand_as_pair 10 | 11 | 12 | from utils import create_activation, create_norm 13 | 14 | 15 | class GAT(nn.Module): 16 | def __init__(self, 17 | in_dim, 18 | num_hidden, 19 | out_dim, 20 | num_layers, 21 | nhead, 22 | nhead_out, 23 | activation, 24 | feat_drop, 25 | attn_drop, 26 | negative_slope, 27 | residual, 28 | norm, 29 | concat_out=False, 30 | encoding=False, 31 | ): 32 | super(GAT, self).__init__() 33 | self.out_dim = out_dim 34 | self.num_heads = nhead 35 | self.num_heads_out = nhead_out 36 | self.num_hidden = num_hidden 37 | self.num_layers = num_layers 38 | self.gat_layers = nn.ModuleList() 39 | self.activation = activation 40 | self.concat_out = concat_out 41 | 42 | last_activation = create_activation(activation) if encoding else None 43 | last_residual = (encoding and residual) 44 | last_norm = norm if encoding else None 45 | 46 | hidden_in = in_dim 47 | hidden_out = out_dim 48 | 49 | if num_layers == 1: 50 | self.gat_layers.append(GATConv( 51 | hidden_in, hidden_out, nhead_out, 52 | feat_drop, attn_drop, negative_slope, last_residual, norm=last_norm, concat_out=concat_out)) 53 | else: 54 | # input projection (no residual) 55 | self.gat_layers.append(GATConv( 56 | hidden_in, num_hidden, nhead, 57 | feat_drop, attn_drop, negative_slope, residual, create_activation(activation), norm=norm, concat_out=concat_out)) 58 | # hidden layers 59 | 60 | for l in range(1, num_layers - 1): 61 | # due to multi-head, the in_dim = num_hidden * num_heads 62 | self.gat_layers.append(GATConv( 63 | num_hidden * nhead, num_hidden, nhead, 64 | feat_drop, attn_drop, negative_slope, residual, create_activation(activation), norm=norm, concat_out=concat_out)) 65 | 66 | # output projection 67 | self.gat_layers.append(GATConv( 68 | num_hidden * nhead, hidden_out, nhead_out, 69 | feat_drop, attn_drop, negative_slope, last_residual, activation=last_activation, norm=last_norm, concat_out=concat_out)) 70 | self.head = nn.Identity() 71 | 72 | def forward(self, g, inputs): 73 | h = inputs 74 | 75 | for l in range(self.num_layers): 76 | h = self.gat_layers[l](g, h) 77 | 78 | if self.head is not None: 79 | return self.head(h) 80 | else: 81 | return h 82 | 83 | def inference(self, g, x, batch_size, device, emb=False): 84 | """ 85 | Inference with the GAT model on full neighbors (i.e. without neighbor sampling). 86 | g : the entire graph. 87 | x : the input of entire node set. 88 | The inference code is written in a fashion that it could handle any number of nodes and 89 | layers. 90 | """ 91 | num_heads = self.num_heads 92 | num_heads_out = self.num_heads_out 93 | for l, layer in enumerate(self.gat_layers): 94 | if l < self.num_layers - 1: 95 | y = torch.zeros(g.num_nodes(), self.num_hidden * num_heads if l != len(self.gat_layers) - 1 else self.num_classes) 96 | else: 97 | if emb == False: 98 | y = torch.zeros(g.num_nodes(), self.num_hidden if l != len(self.gat_layers) - 1 else self.num_classes) 99 | else: 100 | y = torch.zeros(g.num_nodes(), self.out_dim * num_heads_out) 101 | sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) 102 | dataloader = dgl.dataloading.DataLoader( 103 | g, 104 | torch.arange(g.num_nodes()), 105 | sampler, 106 | batch_size=batch_size, 107 | shuffle=False, 108 | drop_last=False, 109 | num_workers=8) 110 | 111 | for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): 112 | block = blocks[0].int().to(device) 113 | h = x[input_nodes].to(device) 114 | if l < self.num_layers - 1: 115 | h = layer(block, h) 116 | else: 117 | h = layer(block, h) 118 | 119 | if l == len(self.gat_layers) - 1 and (emb == False): 120 | h = self.head(h) 121 | y[output_nodes] = h.cpu() 122 | x = y 123 | return y 124 | 125 | def reset_classifier(self, num_classes): 126 | self.num_classes = num_classes 127 | self.is_pretraining = False 128 | self.head = nn.Linear(self.num_heads * self.out_dim, num_classes) 129 | 130 | 131 | 132 | class GATConv(nn.Module): 133 | def __init__(self, 134 | in_feats, 135 | out_feats, 136 | num_heads, 137 | feat_drop=0., 138 | attn_drop=0., 139 | negative_slope=0.2, 140 | residual=False, 141 | activation=None, 142 | allow_zero_in_degree=False, 143 | bias=True, 144 | norm=None, 145 | concat_out=True): 146 | super(GATConv, self).__init__() 147 | self._num_heads = num_heads 148 | self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) 149 | self._out_feats = out_feats 150 | self._allow_zero_in_degree = allow_zero_in_degree 151 | self._concat_out = concat_out 152 | 153 | if isinstance(in_feats, tuple): 154 | self.fc_src = nn.Linear( 155 | self._in_src_feats, out_feats * num_heads, bias=False) 156 | self.fc_dst = nn.Linear( 157 | self._in_dst_feats, out_feats * num_heads, bias=False) 158 | else: 159 | self.fc = nn.Linear( 160 | self._in_src_feats, out_feats * num_heads, bias=False) 161 | self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) 162 | self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) 163 | self.feat_drop = nn.Dropout(feat_drop) 164 | self.attn_drop = nn.Dropout(attn_drop) 165 | self.leaky_relu = nn.LeakyReLU(negative_slope) 166 | if bias: 167 | self.bias = nn.Parameter(torch.FloatTensor(size=(num_heads * out_feats,))) 168 | else: 169 | self.register_buffer('bias', None) 170 | if residual: 171 | if self._in_dst_feats != out_feats * num_heads: 172 | self.res_fc = nn.Linear( 173 | self._in_dst_feats, num_heads * out_feats, bias=False) 174 | else: 175 | self.res_fc = None 176 | else: 177 | self.register_buffer('res_fc', None) 178 | self.reset_parameters() 179 | self.activation = activation 180 | 181 | self.norm = norm 182 | if norm is not None: 183 | self.norm = create_norm(norm)(num_heads * out_feats) 184 | self.set_allow_zero_in_degree(False) 185 | 186 | def reset_parameters(self): 187 | """ 188 | 189 | Description 190 | ----------- 191 | Reinitialize learnable parameters. 192 | 193 | Note 194 | ---- 195 | The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. 196 | The attention weights are using xavier initialization method. 197 | """ 198 | gain = nn.init.calculate_gain('relu') 199 | if hasattr(self, 'fc'): 200 | nn.init.xavier_normal_(self.fc.weight, gain=gain) 201 | else: 202 | nn.init.xavier_normal_(self.fc_src.weight, gain=gain) 203 | nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) 204 | nn.init.xavier_normal_(self.attn_l, gain=gain) 205 | nn.init.xavier_normal_(self.attn_r, gain=gain) 206 | if self.bias is not None: 207 | nn.init.constant_(self.bias, 0) 208 | if isinstance(self.res_fc, nn.Linear): 209 | nn.init.xavier_normal_(self.res_fc.weight, gain=gain) 210 | 211 | def set_allow_zero_in_degree(self, set_value): 212 | self._allow_zero_in_degree = set_value 213 | 214 | def forward(self, graph, feat, get_attention=False): 215 | with graph.local_scope(): 216 | if not self._allow_zero_in_degree: 217 | if (graph.in_degrees() == 0).any(): 218 | raise RuntimeError('There are 0-in-degree nodes in the graph, ' 219 | 'output for those nodes will be invalid. ' 220 | 'This is harmful for some applications, ' 221 | 'causing silent performance regression. ' 222 | 'Adding self-loop on the input graph by ' 223 | 'calling `g = dgl.add_self_loop(g)` will resolve ' 224 | 'the issue. Setting ``allow_zero_in_degree`` ' 225 | 'to be `True` when constructing this module will ' 226 | 'suppress the check and let the code run.') 227 | 228 | if isinstance(feat, tuple): 229 | src_prefix_shape = feat[0].shape[:-1] 230 | dst_prefix_shape = feat[1].shape[:-1] 231 | h_src = self.feat_drop(feat[0]) 232 | # h_dst = self.feat_drop(feat[1]) 233 | h_dst = feat[1] 234 | 235 | if not hasattr(self, 'fc_src'): 236 | feat_src = self.fc(h_src).view( 237 | *src_prefix_shape, self._num_heads, self._out_feats) 238 | feat_dst = self.fc(h_dst).view( 239 | *dst_prefix_shape, self._num_heads, self._out_feats) 240 | else: 241 | feat_src = self.fc_src(h_src).view( 242 | *src_prefix_shape, self._num_heads, self._out_feats) 243 | feat_dst = self.fc_dst(h_dst).view( 244 | *dst_prefix_shape, self._num_heads, self._out_feats) 245 | else: 246 | src_prefix_shape = dst_prefix_shape = feat.shape[:-1] 247 | h_src = h_dst = self.feat_drop(feat) 248 | feat_src = feat_dst = self.fc(h_src).view( 249 | *src_prefix_shape, self._num_heads, self._out_feats) 250 | if graph.is_block: 251 | feat_dst = feat_src[:graph.number_of_dst_nodes()] 252 | h_dst = h_dst[:graph.number_of_dst_nodes()] 253 | dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:] 254 | # NOTE: GAT paper uses "first concatenation then linear projection" 255 | # to compute attention scores, while ours is "first projection then 256 | # addition", the two approaches are mathematically equivalent: 257 | # We decompose the weight vector a mentioned in the paper into 258 | # [a_l || a_r], then 259 | # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j 260 | # Our implementation is much efficient because we do not need to 261 | # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, 262 | # addition could be optimized with DGL's built-in function u_add_v, 263 | # which further speeds up computation and saves memory footprint. 264 | el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) 265 | er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) 266 | graph.srcdata.update({'ft': feat_src, 'el': el}) 267 | graph.dstdata.update({'er': er}) 268 | # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. 269 | graph.apply_edges(fn.u_add_v('el', 'er', 'e')) 270 | e = self.leaky_relu(graph.edata.pop('e')) 271 | # e[e == 0] = -1e3 272 | # e = graph.edata.pop('e') 273 | # compute softmax 274 | graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) 275 | # message passing 276 | graph.update_all(fn.u_mul_e('ft', 'a', 'm'), 277 | fn.sum('m', 'ft')) 278 | rst = graph.dstdata['ft'] 279 | 280 | # bias 281 | if self.bias is not None: 282 | rst = rst + self.bias.view( 283 | *((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats) 284 | 285 | # residual 286 | if self.res_fc is not None: 287 | # Use -1 rather than self._num_heads to handle broadcasting 288 | resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self._out_feats) 289 | rst = rst + resval 290 | 291 | if self._concat_out: 292 | rst = rst.flatten(1) 293 | else: 294 | rst = torch.mean(rst, dim=1) 295 | 296 | if self.norm is not None: 297 | rst = self.norm(rst) 298 | 299 | # activation 300 | if self.activation: 301 | rst = self.activation(rst) 302 | 303 | if get_attention: 304 | return rst, graph.edata['a'] 305 | else: 306 | return rst 307 | -------------------------------------------------------------------------------- /models/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl 6 | import dgl.function as fn 7 | from dgl.utils import expand_as_pair 8 | 9 | from utils import create_activation, create_norm 10 | 11 | 12 | class GCN(nn.Module): 13 | def __init__(self, 14 | in_dim, 15 | num_hidden, 16 | out_dim, 17 | num_layers, 18 | dropout, 19 | activation, 20 | residual, 21 | norm, 22 | encoding=False 23 | ): 24 | super(GCN, self).__init__() 25 | self.out_dim = out_dim 26 | self.num_layers = num_layers 27 | self.gcn_layers = nn.ModuleList() 28 | self.activation = activation 29 | self.dropout = dropout 30 | 31 | last_activation = create_activation(activation) if encoding else None 32 | last_residual = encoding and residual 33 | last_norm = norm if encoding else None 34 | 35 | if num_layers == 1: 36 | self.gcn_layers.append(GraphConv( 37 | in_dim, out_dim, residual=last_residual, norm=last_norm, activation=last_activation)) 38 | else: 39 | # input projection (no residual) 40 | self.gcn_layers.append(GraphConv( 41 | in_dim, num_hidden, residual=residual, norm=norm, activation=create_activation(activation))) 42 | # hidden layers 43 | for l in range(1, num_layers - 1): 44 | # due to multi-head, the in_dim = num_hidden * num_heads 45 | self.gcn_layers.append(GraphConv( 46 | num_hidden, num_hidden, residual=residual, norm=norm, activation=create_activation(activation))) 47 | # output projection 48 | self.gcn_layers.append(GraphConv( 49 | num_hidden, out_dim, residual=last_residual, activation=last_activation, norm=last_norm)) 50 | 51 | # if norm is not None: 52 | # self.norms = nn.ModuleList([ 53 | # norm(num_hidden) 54 | # for _ in range(num_layers - 1) 55 | # ]) 56 | # if not encoding: 57 | # self.norms.append(norm(out_dim)) 58 | # else: 59 | # self.norms = None 60 | self.norms = None 61 | self.head = nn.Identity() 62 | 63 | def forward(self, g, inputs, return_hidden=False): 64 | h = inputs 65 | hidden_list = [] 66 | for l in range(self.num_layers): 67 | h = F.dropout(h, p=self.dropout, training=self.training) 68 | h = self.gcn_layers[l](g, h) 69 | if self.norms is not None and l != self.num_layers - 1: 70 | h = self.norms[l](h) 71 | hidden_list.append(h) 72 | # output projection 73 | if self.norms is not None and len(self.norms) == self.num_layers: 74 | h = self.norms[-1](h) 75 | if return_hidden: 76 | return self.head(h), hidden_list 77 | else: 78 | return self.head(h) 79 | 80 | def reset_classifier(self, num_classes): 81 | self.head = nn.Linear(self.out_dim, num_classes) 82 | 83 | 84 | class GraphConv(nn.Module): 85 | def __init__(self, 86 | in_dim, 87 | out_dim, 88 | norm=None, 89 | activation=None, 90 | residual=True, 91 | ): 92 | super().__init__() 93 | self._in_feats = in_dim 94 | self._out_feats = out_dim 95 | 96 | self.fc = nn.Linear(in_dim, out_dim) 97 | 98 | if residual: 99 | if self._in_feats != self._out_feats: 100 | self.res_fc = nn.Linear( 101 | self._in_feats, self._out_feats, bias=False) 102 | print("! Linear Residual !") 103 | else: 104 | print("Identity Residual ") 105 | self.res_fc = nn.Identity() 106 | else: 107 | self.register_buffer('res_fc', None) 108 | 109 | # if norm == "batchnorm": 110 | # self.norm = nn.BatchNorm1d(out_dim) 111 | # elif norm == "layernorm": 112 | # self.norm = nn.LayerNorm(out_dim) 113 | # else: 114 | # self.norm = None 115 | 116 | self.norm = norm 117 | if norm is not None: 118 | self.norm = create_norm(norm)(out_dim) 119 | self._activation = activation 120 | 121 | self.reset_parameters() 122 | 123 | def reset_parameters(self): 124 | self.fc.reset_parameters() 125 | 126 | def forward(self, graph, feat): 127 | with graph.local_scope(): 128 | aggregate_fn = fn.copy_src('h', 'm') 129 | # if edge_weight is not None: 130 | # assert edge_weight.shape[0] == graph.number_of_edges() 131 | # graph.edata['_edge_weight'] = edge_weight 132 | # aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm') 133 | 134 | # (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite. 135 | feat_src, feat_dst = expand_as_pair(feat, graph) 136 | # if self._norm in ['left', 'both']: 137 | degs = graph.out_degrees().float().clamp(min=1) 138 | norm = torch.pow(degs, -0.5) 139 | shp = norm.shape + (1,) * (feat_src.dim() - 1) 140 | norm = torch.reshape(norm, shp) 141 | feat_src = feat_src * norm 142 | 143 | # if self._in_feats > self._out_feats: 144 | # # mult W first to reduce the feature size for aggregation. 145 | # # if weight is not None: 146 | # # feat_src = th.matmul(feat_src, weight) 147 | # graph.srcdata['h'] = feat_src 148 | # graph.update_all(aggregate_fn, fn.sum(msg='m', out='h')) 149 | # rst = graph.dstdata['h'] 150 | # else: 151 | # aggregate first then mult W 152 | graph.srcdata['h'] = feat_src 153 | graph.update_all(aggregate_fn, fn.sum(msg='m', out='h')) 154 | rst = graph.dstdata['h'] 155 | 156 | rst = self.fc(rst) 157 | 158 | # if self._norm in ['right', 'both']: 159 | degs = graph.in_degrees().float().clamp(min=1) 160 | norm = torch.pow(degs, -0.5) 161 | shp = norm.shape + (1,) * (feat_dst.dim() - 1) 162 | norm = torch.reshape(norm, shp) 163 | rst = rst * norm 164 | 165 | if self.res_fc is not None: 166 | rst = rst + self.res_fc(feat_dst) 167 | 168 | if self.norm is not None: 169 | rst = self.norm(rst) 170 | 171 | if self._activation is not None: 172 | rst = self._activation(rst) 173 | 174 | return rst 175 | -------------------------------------------------------------------------------- /models/loss_func.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def auc_pair_loss(x, y, z): 9 | x = F.normalize(x, p=2, dim=-1) 10 | y = F.normalize(y, p=2, dim=-1) 11 | z = F.normalize(z, p=2, dim=-1) 12 | 13 | sim = (x * y).sum(dim=-1) 14 | dissim = (x * z).sum(dim=-1) 15 | loss = (1 - sim + dissim).mean() 16 | # loss = (1 - sim).mean() 17 | return loss 18 | 19 | 20 | def sce_loss(x, y, alpha=3): 21 | x = F.normalize(x, p=2, dim=-1) 22 | y = F.normalize(y, p=2, dim=-1) 23 | 24 | # loss = - (x * y).sum(dim=-1) 25 | # loss = (x_h - y_h).norm(dim=1).pow(alpha) 26 | 27 | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) 28 | 29 | loss = loss.mean() 30 | return loss 31 | 32 | 33 | class DINOLoss(nn.Module): 34 | def __init__(self, out_dim, warmup_teacher_temp, teacher_temp, 35 | warmup_teacher_temp_epochs, nepochs, student_temp=0.1, 36 | center_momentum=0.9): 37 | super().__init__() 38 | self.student_temp = student_temp 39 | self.center_momentum = center_momentum 40 | self.register_buffer("center", torch.zeros(1, out_dim)) 41 | # we apply a warm up for the teacher temperature because 42 | # a too high temperature makes the training instable at the beginning 43 | self.teacher_temp_schedule = np.concatenate(( 44 | np.linspace(warmup_teacher_temp, 45 | teacher_temp, warmup_teacher_temp_epochs), 46 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp 47 | )) 48 | 49 | def forward(self, student_output, teacher_output, epoch): 50 | """ 51 | Cross-entropy between softmax outputs of the teacher and student networks. 52 | """ 53 | student_out = student_output / self.student_temp 54 | 55 | # teacher centering and sharpening 56 | temp = self.teacher_temp_schedule[epoch] 57 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) 58 | teacher_out = teacher_out.detach() 59 | 60 | loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1) 61 | loss = loss.mean() 62 | self.update_center(teacher_output) 63 | return loss 64 | 65 | # total_loss = 0 66 | # n_loss_terms = 0 67 | # for iq, q in enumerate(teacher_out): 68 | # for v in range(len(student_out)): 69 | # if v == iq: 70 | # # we skip cases where student and teacher operate on the same view 71 | # continue 72 | # loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) 73 | # total_loss += loss.mean() 74 | # n_loss_terms += 1 75 | # total_loss /= n_loss_terms 76 | # self.update_center(teacher_output) 77 | # return total_loss 78 | 79 | @torch.no_grad() 80 | def update_center(self, teacher_output): 81 | """ 82 | Update center used for teacher output. 83 | """ 84 | batch_center = torch.mean(teacher_output, dim=0, keepdim=True) 85 | 86 | # ema update 87 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 88 | 89 | 90 | class MLPHead(nn.Module): 91 | def __init__(self, hidden_size, out_dim, num_layers=2, bottleneck_dim=256): 92 | super().__init__() 93 | self._num_layers = num_layers 94 | self.mlp = nn.ModuleList() 95 | for i in range(num_layers): 96 | if i == num_layers - 1: 97 | self.mlp.append( 98 | nn.Linear(hidden_size, bottleneck_dim) 99 | ) 100 | else: 101 | self.mlp.append(nn.Linear(hidden_size, hidden_size)) 102 | # self.mlp.append(nn.LayerNorm(hidden_size)) 103 | self.mlp.append(nn.PReLU()) 104 | 105 | self.apply(self._init_weights) 106 | # self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False) 107 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 108 | self.last_layer.weight_g.data.fill_(1) 109 | # self.last_layer.weight_g.requires_grad = False 110 | # self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False) 111 | 112 | def _init_weights(self, m): 113 | if isinstance(m, nn.Linear): 114 | trunc_normal_(m.weight, std=.02) 115 | if isinstance(m, nn.Linear) and m.bias is not None: 116 | nn.init.constant_(m.bias, 0) 117 | 118 | def forward(self, x): 119 | num_layers = len(self.mlp) 120 | for i, layer in enumerate(self.mlp): 121 | x = layer(x) 122 | 123 | x = nn.functional.normalize(x, dim=-1, p=2) 124 | x = self.last_layer(x) 125 | return x 126 | 127 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 128 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 129 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 130 | def norm_cdf(x): 131 | # Computes standard normal cumulative distribution function 132 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 133 | 134 | if (mean < a - 2 * std) or (mean > b + 2 * std): 135 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 136 | "The distribution of values may be incorrect.", 137 | stacklevel=2) 138 | 139 | with torch.no_grad(): 140 | # Values are generated by using a truncated uniform distribution and 141 | # then using the inverse CDF for the normal distribution. 142 | # Get upper and lower cdf values 143 | l = norm_cdf((a - mean) / std) 144 | u = norm_cdf((b - mean) / std) 145 | 146 | # Uniformly fill tensor with values from [l, u], then translate to 147 | # [2l-1, 2u-1]. 148 | tensor.uniform_(2 * l - 1, 2 * u - 1) 149 | 150 | # Use inverse cdf transform for normal distribution to get truncated 151 | # standard normal 152 | tensor.erfinv_() 153 | 154 | # Transform to proper mean, std 155 | tensor.mul_(std * math.sqrt(2.)) 156 | tensor.add_(mean) 157 | 158 | # Clamp to ensure it's in the proper range 159 | tensor.clamp_(min=a, max=b) 160 | return tensor 161 | 162 | 163 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 164 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml 2 | tqdm 3 | tensorboardX 4 | scikit-learn 5 | ogb 6 | torch 7 | dgl 8 | -------------------------------------------------------------------------------- /run_fullbatch.sh: -------------------------------------------------------------------------------- 1 | dataset=$1 2 | device=$2 3 | 4 | [ -z "${dataset}" ] && dataset="cora" 5 | [ -z "${device}" ] && device=0 6 | 7 | CUDA_VISIBLE_DEVICES=$device \ 8 | python main_full_batch.py \ 9 | --device 0 \ 10 | --dataset $dataset \ 11 | --mask_method "random" \ 12 | --remask_method "fixed" \ 13 | --mask_rate 0.5 \ 14 | --in_drop 0.2 \ 15 | --attn_drop 0.1 \ 16 | --num_layers 2 \ 17 | --num_dec_layers 1 \ 18 | --num_hidden 256 \ 19 | --num_heads 4 \ 20 | --num_out_heads 1 \ 21 | --encoder "gat" \ 22 | --decoder "gat" \ 23 | --max_epoch 1000 \ 24 | --max_epoch_f 300 \ 25 | --lr 0.001 \ 26 | --weight_decay 0.04 \ 27 | --lr_f 0.005 \ 28 | --weight_decay_f 1e-4 \ 29 | --activation "prelu" \ 30 | --loss_fn "sce" \ 31 | --alpha_l 3 \ 32 | --scheduler \ 33 | --seeds 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 \ 34 | --lam 0.5 \ 35 | --linear_prob \ 36 | --use_cfg 37 | -------------------------------------------------------------------------------- /run_minibatch.sh: -------------------------------------------------------------------------------- 1 | dataset=$1 2 | device=$2 3 | 4 | [ -z "${dataset}" ] && dataset="ogbn-arxiv" 5 | [ -z "${device}" ] && device=0 6 | 7 | CUDA_VISIBLE_DEVICES=$device \ 8 | python main_large.py \ 9 | --device 0 \ 10 | --dataset $dataset \ 11 | --mask_type "mask" \ 12 | --mask_rate 0.5 \ 13 | --remask_rate 0.5 \ 14 | --num_remasking 3 \ 15 | --in_drop 0.2 \ 16 | --attn_drop 0.2 \ 17 | --num_layers 4 \ 18 | --num_dec_layers 1 \ 19 | --num_hidden 1024 \ 20 | --num_heads 4 \ 21 | --num_out_heads 1 \ 22 | --encoder "gat" \ 23 | --decoder "gat" \ 24 | --max_epoch 60 \ 25 | --max_epoch_f 1000 \ 26 | --lr 0.002 \ 27 | --weight_decay 0.04 \ 28 | --lr_f 0.005 \ 29 | --weight_decay_f 1e-4 \ 30 | --activation "prelu" \ 31 | --optimizer "adamw" \ 32 | --drop_edge_rate 0.5 \ 33 | --loss_fn "sce" \ 34 | --alpha_l 4 \ 35 | --mask_method "random" \ 36 | --scheduler \ 37 | --batch_size 512 \ 38 | --batch_size_f 256 \ 39 | --seeds 0 \ 40 | --residual \ 41 | --norm "layernorm" \ 42 | --sampling_method "lc" \ 43 | --label_rate 1.0 \ 44 | --lam 1.0 \ 45 | --momentum 0.996 \ 46 | --linear_prob \ 47 | --use_cfg \ 48 | --ego_graph_file_path "./lc_ego_graphs/${dataset}-lc-ego-graphs-256.pt" \ 49 | --data_dir "./dataset" \ 50 | # --logging 51 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import psutil 5 | import yaml 6 | import logging 7 | from functools import partial 8 | from tensorboardX import SummaryWriter 9 | import wandb 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from torch import optim as optim 15 | 16 | 17 | import dgl 18 | import dgl.function as fn 19 | from sklearn.decomposition import PCA 20 | from sklearn.manifold import TSNE 21 | import matplotlib.pyplot as plt 22 | 23 | 24 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) 25 | 26 | 27 | def accuracy(y_pred, y_true): 28 | y_true = y_true.squeeze().long() 29 | preds = y_pred.max(1)[1].type_as(y_true) 30 | correct = preds.eq(y_true).double() 31 | correct = correct.sum().item() 32 | return correct / len(y_true) 33 | 34 | 35 | def set_random_seed(seed): 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed(seed) 40 | torch.cuda.manual_seed_all(seed) 41 | torch.backends.cudnn.determinstic = True 42 | 43 | 44 | def get_current_lr(optimizer): 45 | return optimizer.state_dict()["param_groups"][0]["lr"] 46 | 47 | 48 | def build_args(): 49 | parser = argparse.ArgumentParser(description="GAT") 50 | parser.add_argument("--seeds", type=int, nargs="+", default=[0]) 51 | parser.add_argument("--dataset", type=str, default="cora") 52 | parser.add_argument("--device", type=int, default=0) 53 | parser.add_argument("--max_epoch", type=int, default=500, 54 | help="number of training epochs") 55 | parser.add_argument("--warmup_steps", type=int, default=-1) 56 | 57 | parser.add_argument("--num_heads", type=int, default=4, 58 | help="number of hidden attention heads") 59 | parser.add_argument("--num_out_heads", type=int, default=1, 60 | help="number of output attention heads") 61 | parser.add_argument("--num_layers", type=int, default=2, 62 | help="number of hidden layers") 63 | parser.add_argument("--num_dec_layers", type=int, default=1) 64 | parser.add_argument("--num_remasking", type=int, default=3) 65 | parser.add_argument("--num_hidden", type=int, default=512, 66 | help="number of hidden units") 67 | parser.add_argument("--residual", action="store_true", default=False, 68 | help="use residual connection") 69 | parser.add_argument("--in_drop", type=float, default=.2, 70 | help="input feature dropout") 71 | parser.add_argument("--attn_drop", type=float, default=.1, 72 | help="attention dropout") 73 | parser.add_argument("--norm", type=str, default=None) 74 | parser.add_argument("--lr", type=float, default=0.001, 75 | help="learning rate") 76 | parser.add_argument("--weight_decay", type=float, default=0, 77 | help="weight decay") 78 | parser.add_argument("--negative_slope", type=float, default=0.2, 79 | help="the negative slope of leaky relu") 80 | parser.add_argument("--activation", type=str, default="prelu") 81 | parser.add_argument("--mask_rate", type=float, default=0.5) 82 | parser.add_argument("--remask_rate", type=float, default=0.5) 83 | parser.add_argument("--remask_method", type=str, default="random") 84 | parser.add_argument("--mask_type", type=str, default="mask", 85 | help="`mask` or `drop`") 86 | parser.add_argument("--mask_method", type=str, default="random") 87 | parser.add_argument("--drop_edge_rate", type=float, default=0.0) 88 | parser.add_argument("--drop_edge_rate_f", type=float, default=0.0) 89 | 90 | parser.add_argument("--encoder", type=str, default="gat") 91 | parser.add_argument("--decoder", type=str, default="gat") 92 | parser.add_argument("--loss_fn", type=str, default="sce") 93 | parser.add_argument("--alpha_l", type=float, default=2) 94 | parser.add_argument("--optimizer", type=str, default="adam") 95 | 96 | parser.add_argument("--max_epoch_f", type=int, default=300) 97 | parser.add_argument("--lr_f", type=float, default=0.01) 98 | parser.add_argument("--weight_decay_f", type=float, default=0.0) 99 | parser.add_argument("--linear_prob", action="store_true", default=False) 100 | 101 | 102 | parser.add_argument("--no_pretrain", action="store_true") 103 | parser.add_argument("--load_model", action="store_true") 104 | parser.add_argument("--checkpoint_path", type=str, default=None) 105 | parser.add_argument("--use_cfg", action="store_true") 106 | parser.add_argument("--logging", action="store_true") 107 | parser.add_argument("--scheduler", action="store_true", default=False) 108 | 109 | parser.add_argument("--batch_size", type=int, default=256) 110 | parser.add_argument("--batch_size_f", type=int, default=128) 111 | parser.add_argument("--sampling_method", type=str, default="saint", help="sampling method, `lc` or `saint`") 112 | 113 | parser.add_argument("--label_rate", type=float, default=1.0) 114 | parser.add_argument("--ego_graph_file_path", type=str, default=None) 115 | parser.add_argument("--data_dir", type=str, default="data") 116 | 117 | parser.add_argument("--lam", type=float, default=1.0) 118 | parser.add_argument("--full_graph_forward", action="store_true", default=False) 119 | parser.add_argument("--delayed_ema_epoch", type=int, default=0) 120 | parser.add_argument("--replace_rate", type=float, default=0.0) 121 | parser.add_argument("--momentum", type=float, default=0.996) 122 | 123 | args = parser.parse_args() 124 | return args 125 | 126 | def create_activation(name): 127 | if name == "relu": 128 | return nn.ReLU() 129 | elif name == "gelu": 130 | return nn.GELU() 131 | elif name == "prelu": 132 | return nn.PReLU() 133 | elif name == "selu": 134 | return nn.SELU() 135 | elif name == "elu": 136 | return nn.ELU() 137 | elif name == "silu": 138 | return nn.SiLU() 139 | elif name is None: 140 | return nn.Identity() 141 | else: 142 | raise NotImplementedError(f"{name} is not implemented.") 143 | 144 | 145 | def identity_norm(x): 146 | def func(x): 147 | return x 148 | return func 149 | 150 | def create_norm(name): 151 | if name == "layernorm": 152 | return nn.LayerNorm 153 | elif name == "batchnorm": 154 | return nn.BatchNorm1d 155 | elif name == "identity": 156 | return identity_norm 157 | else: 158 | # print("Identity norm") 159 | return None 160 | 161 | 162 | def create_optimizer(opt, model, lr, weight_decay, get_num_layer=None, get_layer_scale=None): 163 | opt_lower = opt.lower() 164 | parameters = model.parameters() 165 | opt_args = dict(lr=lr, weight_decay=weight_decay) 166 | 167 | opt_split = opt_lower.split("_") 168 | opt_lower = opt_split[-1] 169 | 170 | if opt_lower == "adam": 171 | optimizer = optim.Adam(parameters, **opt_args) 172 | elif opt_lower == "adamw": 173 | optimizer = optim.AdamW(parameters, **opt_args) 174 | elif opt_lower == "adadelta": 175 | optimizer = optim.Adadelta(parameters, **opt_args) 176 | elif opt_lower == "sgd": 177 | opt_args["momentum"] = 0.9 178 | return optim.SGD(parameters, **opt_args) 179 | else: 180 | raise NotImplementedError("Invalid optimizer") 181 | 182 | return optimizer 183 | 184 | 185 | def show_occupied_memory(): 186 | process = psutil.Process(os.getpid()) 187 | return process.memory_info().rss / 1024**2 188 | 189 | 190 | # ------------------- 191 | def mask_edge(graph, mask_prob): 192 | E = graph.num_edges() 193 | 194 | mask_rates = torch.ones(E) * mask_prob 195 | masks = torch.bernoulli(1 - mask_rates) 196 | mask_idx = masks.nonzero().squeeze(1) 197 | return mask_idx 198 | 199 | 200 | def drop_edge(graph, drop_rate, return_edges=False): 201 | if drop_rate <= 0: 202 | return graph 203 | 204 | graph = graph.remove_self_loop() 205 | 206 | n_node = graph.num_nodes() 207 | edge_mask = mask_edge(graph, drop_rate) 208 | src, dst = graph.edges() 209 | 210 | nsrc = src[edge_mask] 211 | ndst = dst[edge_mask] 212 | 213 | ng = dgl.graph((nsrc, ndst), num_nodes=n_node) 214 | ng = ng.add_self_loop() 215 | 216 | return ng 217 | 218 | 219 | def visualize(x, y, method="tsne"): 220 | if torch.is_tensor(x): 221 | x = x.cpu().numpy() 222 | 223 | if torch.is_tensor(y): 224 | y = y.cpu().numpy() 225 | 226 | if method == "tsne": 227 | func = TSNE(n_components=2) 228 | else: 229 | func = PCA(n_components=2) 230 | out = func.fit_transform(x) 231 | plt.scatter(out[:, 0], out[:, 1], c=y) 232 | plt.savefig("vis.png") 233 | 234 | 235 | def load_best_configs(args): 236 | dataset_name = args.dataset 237 | config_path = os.path.join("configs", f"{dataset_name}.yaml") 238 | with open(config_path, "r") as f: 239 | configs = yaml.load(f, yaml.FullLoader) 240 | 241 | for k, v in configs.items(): 242 | if "lr" in k or "weight_decay" in k: 243 | v = float(v) 244 | setattr(args, k, v) 245 | logging.info(f"----- Using best configs from {config_path} -----") 246 | 247 | return args 248 | 249 | 250 | 251 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 252 | warmup_schedule = np.array([]) 253 | warmup_iters = warmup_epochs * niter_per_ep 254 | if warmup_epochs > 0: 255 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 256 | 257 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 258 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 259 | 260 | scheduler = np.concatenate((warmup_schedule, schedule)) 261 | assert len(scheduler) == epochs * niter_per_ep 262 | return scheduler 263 | 264 | 265 | 266 | # ------ logging ------ 267 | 268 | class TBLogger(object): 269 | def __init__(self, log_path="./logging_data", name="run"): 270 | super(TBLogger, self).__init__() 271 | 272 | if not os.path.exists(log_path): 273 | os.makedirs(log_path, exist_ok=True) 274 | 275 | self.last_step = 0 276 | self.log_path = log_path 277 | raw_name = os.path.join(log_path, name) 278 | name = raw_name 279 | for i in range(1000): 280 | name = raw_name + str(f"_{i}") 281 | if not os.path.exists(name): 282 | break 283 | self.writer = SummaryWriter(logdir=name) 284 | 285 | def note(self, metrics, step=None): 286 | if step is None: 287 | step = self.last_step 288 | for key, value in metrics.items(): 289 | self.writer.add_scalar(key, value, step) 290 | self.last_step = step 291 | 292 | def finish(self): 293 | self.writer.close() 294 | 295 | 296 | class WandbLogger(object): 297 | def __init__(self, log_path, project, args): 298 | self.log_path = log_path 299 | self.project = project 300 | self.args = args 301 | self.last_step = 0 302 | self.project = project 303 | self.start() 304 | 305 | def start(self): 306 | self.run = wandb.init(config=self.args, project=self.project) 307 | 308 | def log(self, metrics, step=None): 309 | if not hasattr(self, "run"): 310 | self.start() 311 | if step is None: 312 | step = self.last_step 313 | self.run.log(metrics) 314 | self.last_step = step 315 | 316 | def finish(self): 317 | self.run.finish() 318 | --------------------------------------------------------------------------------