├── .gitignore
├── LICENSE
├── README.md
├── asserts
└── overview.png
├── configs
├── citeseer.yaml
├── cora.yaml
├── mag-scholar-f.yaml
├── ogbn-arxiv.yaml
├── ogbn-papers100M.yaml
├── ogbn-products.yaml
└── pubmed.yaml
├── datasets
├── __init__.py
├── data_proc.py
├── lc_sampler.py
├── localclustering.py
└── saint_sampler.py
├── main_full_batch.py
├── main_large.py
├── models
├── __init__.py
├── edcoder.py
├── finetune.py
├── gat.py
├── gcn.py
└── loss_func.py
├── requirements.txt
├── run_fullbatch.sh
├── run_minibatch.sh
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 THUDM
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
GraphMAE2: A Decoding-Enhanced Masked Self-Supervised
2 | Graph Learner
3 |
4 | Implementation for WWW'23 paper: [GraphMAE2: A Decoding-Enhanced Masked Self-Supervised
5 | Graph Learner](https://arxiv.org/abs/2304.04779).
6 |
7 |
8 | [GraphMAE] The predecessor of this work: [GraphMAE: Self-Supervised Masked Graph Autoencoders](https://arxiv.org/abs/2205.10803) can be found [here](https://github.com/THUDM/GraphMAE).
9 |
10 | ❗ Update
11 |
12 | [2023-04-19] We have made **checkpoints** of pre-trained models on different datasets available - feel free to download them from [Google Drive](https://drive.google.com/drive/folders/1GiuP0PtIZaYlJWIrjvu73ZQCJGr6kGkh).
13 |
14 | Dependencies
15 |
16 | * Python >= 3.7
17 | * [Pytorch](https://pytorch.org/) >= 1.9.0
18 | * pyyaml == 5.4.1
19 |
20 |
21 | Quick Start
22 |
23 | For quick start, you could run the scripts:
24 |
25 | **Node classification**
26 |
27 | ```bash
28 | sh run_minibatch.sh # for mini batch node classification
29 | # example: sh run_minibatch.sh ogbn-arxiv 0
30 | sh run_fullbatch.sh # for full batch node classification
31 | # example: sh run_fullbatch.sh cora 0
32 |
33 | # Or you could run the code manually:
34 | # for mini batch node classification
35 | python main_large.py --dataset ogbn-arxiv --encoder gat --decoder gat --seed 0 --device 0
36 | # for full batch node classification
37 | python main_full_batch.py --dataset cora --encoder gat --decoder gat --seed 0 --device 0
38 | ```
39 |
40 | Supported datasets:
41 |
42 | * mini batch node classification: `ogbn-arxiv`, `ogbn-products`, `mag-scholar-f`, `ogbn-papers100M`
43 | * full batch node classification: `cora`, `citeseer`, `pubmed`
44 |
45 | Run the scripts provided or add `--use_cfg` in command to reproduce the reported results.
46 |
47 | **For Large scale graphs**
48 | Before starting mini-batch training, you'll need to generate local clusters if you want to use local-clustering for training. By default, the program will load dataset from `./data` and save the generated local clusters to `./lc_ego_graphs`. To generate a local cluster, you should first install [localclustering](https://github.com/kfoynt/LocalGraphClustering) and then run the following command:
49 |
50 | ```
51 | python ./datasets/localclustering.py --dataset --data_dir
52 | ```
53 | And we also provide the pre-generated local clusters which can be downloaded [here](https://cloud.tsinghua.edu.cn/d/64f859f389ca43eda472/) and then put into `lc_ego_graphs` for usage.
54 |
55 |
56 |
57 | Datasets
58 |
59 | During the code's execution, the OGB and small-scale datasets (Cora, Citeseer, and PubMed) will be downloaded automatically. For the MAG-SCHOLAR dataset, you can download the raw data from [here](https://figshare.com/articles/dataset/mag_scholar/12696653) or use our processed version, which can be found [here](https://cloud.tsinghua.edu.cn/d/776e73d84d47454c958d/) (the four feature files have to be merged in to a `feature_f.npy`). Once you have the dataset, place it into the `./data/mag_scholar_f` folder for later usage. The folder should contain the following files:
60 | ```
61 | - mag_scholar_f
62 | |--- edge_index_f.npy
63 | |--- split_idx_f.pt
64 | |--- feature_f.npy
65 | |--- label_f.npy
66 | ```
67 |
68 | Soon, we will provide [SAINTSampler](https://arxiv.org/abs/1907.04931) as the baseline.
69 |
70 |
71 | Experimental Results
72 |
73 | Experimental results of node classification on large-scale datasets (Accuracy, %):
74 |
75 | | | Ogbn-arxiv | Ogbn-products | Mag-Scholar-F | Ogbn-papers100M |
76 | | ------------------ | ------------ | ------------ | ------------ | -------------- |
77 | | MLP | 55.50±0.23 | 61.06±0.08 | 39.11±0.21 | 47.24±0.31 |
78 | | SGC | 66.92±0.08 | 74.87±0.25 | 54.68±0.23 | 63.29±0.19 |
79 | | Random-Init | 68.14±0.02 | 74.04±0.06 | 56.57±0.03 | 61.55±0.12 |
80 | | CCA-SSG | 68.57±0.02 | 75.27±0.05 | 51.55±0.03 | 55.67±0.15 |
81 | | GRACE | 69.34±0.01 | 79.47±0.59 | 57.39±0.02 | 61.21±0.12 |
82 | | BGRL | 70.51±0.03 | 78.59±0.02 | 57.57±0.01 | 62.18±0.15 |
83 | | GGD | - | 75.70±0.40 | - | 63.50±0.50 |
84 | | GraphMAE | 71.03±0.02 | 78.89±0.01 | 58.75±0.03 | 62.54±0.09 |
85 | | **GraphMAE2** | **71.89±0.03** | **81.59±0.02** | **59.24±0.01** | **64.89±0.04** |
86 |
87 |
88 |
89 | Citing
90 |
91 | If you find this work is helpful to your research, please consider citing our paper:
92 |
93 | ```
94 | @inproceedings{hou2023graphmae2,
95 | title={GraphMAE2: A Decoding-Enhanced Masked Self-Supervised Graph Learner},
96 | author={Zhenyu Hou, Yufei He, Yukuo Cen, Xiao Liu, Yuxiao Dong, Evgeny Kharlamov, Jie Tang},
97 | booktitle={Proceedings of the ACM Web Conference 2023 (WWW’23)},
98 | year={2023}
99 | }
100 | ```
101 |
--------------------------------------------------------------------------------
/asserts/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/GraphMAE2/bc473e8b750d6e90b7cba1c6bb20d858f617d90c/asserts/overview.png
--------------------------------------------------------------------------------
/configs/citeseer.yaml:
--------------------------------------------------------------------------------
1 | lr: 0.0005 # 0.0005
2 | lr_f: 0.025
3 | num_hidden: 1024
4 | num_heads: 4
5 | num_out_heads: 1
6 | num_layers: 2
7 | weight_decay: 1e-4
8 | weight_decay_f: 1e-2
9 | max_epoch: 500
10 | max_epoch_f: 500
11 | mask_rate: 0.5
12 | num_layers: 2
13 | encoder: gat
14 | decoder: gat
15 | activation: prelu
16 | attn_drop: 0.1
17 | linear_prob: True
18 | in_drop: 0.2
19 | loss_fn: sce
20 | drop_edge_rate: 0.0
21 | optimizer: adam
22 | replace_rate: 0.0
23 | alpha_l: 1
24 | scheduler: True
25 | remask_method: fixed
26 | momentum: 1
27 | lam: 0.1
--------------------------------------------------------------------------------
/configs/cora.yaml:
--------------------------------------------------------------------------------
1 | lr: 0.001
2 | lr_f: 0.025
3 | num_hidden: 1024
4 | num_heads: 8
5 | num_out_heads: 1
6 | num_layers: 2
7 | weight_decay: 2e-4
8 | weight_decay_f: 1e-4
9 | max_epoch: 2000
10 | max_epoch_f: 300
11 | mask_rate: 0.5
12 | num_layers: 2
13 | encoder: gat
14 | decoder: gat
15 | activation: prelu
16 | attn_drop: 0.1
17 | linear_prob: True
18 | in_drop: 0.2
19 | loss_fn: sce
20 | drop_edge_rate: 0.0
21 | optimizer: adam
22 | replace_rate: 0.15
23 | alpha_l: 3
24 | scheduler: True
25 | remask_method: fixed
26 | momentum: 0
27 | lam: 0.15
28 |
--------------------------------------------------------------------------------
/configs/mag-scholar-f.yaml:
--------------------------------------------------------------------------------
1 | lr: 0.001
2 | lr_f: 0.001
3 | num_hidden: 1024
4 | num_heads: 8
5 | num_out_heads: 1
6 | num_layers: 4
7 | weight_decay: 0.04
8 | weight_decay_f: 0
9 | max_epoch: 10
10 | max_epoch_f: 1000
11 | batch_size: 512
12 | batch_size_f: 256
13 | mask_rate: 0.5
14 | num_layers: 4
15 | encoder: gat
16 | decoder: gat
17 | activation: prelu
18 | attn_drop: 0.2
19 | linear_prob: True
20 | in_drop: 0.2
21 | loss_fn: sce
22 | drop_edge_rate: 0.5
23 | optimizer: adamw
24 | alpha_l: 2
25 | scheduler: True
26 | remask_method: random
27 | momentum: 0.996
28 | lam: 0.1
29 | delayed_ema_epoch: 0
30 | num_remasking: 3
--------------------------------------------------------------------------------
/configs/ogbn-arxiv.yaml:
--------------------------------------------------------------------------------
1 | lr: 0.0025
2 | lr_f: 0.005
3 | num_hidden: 1024
4 | num_heads: 8
5 | num_out_heads: 1
6 | num_layers: 4
7 | weight_decay: 0.06
8 | weight_decay_f: 1e-4
9 | max_epoch: 60
10 | max_epoch_f: 1000
11 | batch_size: 512
12 | batch_size_f: 256
13 | mask_rate: 0.5
14 | num_layers: 4
15 | encoder: gat
16 | decoder: gat
17 | activation: prelu
18 | attn_drop: 0.1
19 | linear_prob: True
20 | in_drop: 0.2
21 | loss_fn: sce
22 | drop_edge_rate: 0.5
23 | optimizer: adamw
24 | alpha_l: 6
25 | scheduler: True
26 | remask_method: random
27 | momentum: 0.996
28 | lam: 10.0
29 | delayed_ema_epoch: 40
30 | num_remasking: 3
--------------------------------------------------------------------------------
/configs/ogbn-papers100M.yaml:
--------------------------------------------------------------------------------
1 | lr: 0.001
2 | lr_f: 0.001
3 | num_hidden: 1024
4 | num_heads: 4
5 | num_out_heads: 1
6 | num_layers: 4
7 | weight_decay: 0.05
8 | weight_decay_f: 0
9 | max_epoch: 10
10 | max_epoch_f: 1000
11 | batch_size: 512
12 | batch_size_f: 256
13 | mask_rate: 0.5
14 | num_layers: 4
15 | encoder: gat
16 | decoder: gat
17 | activation: prelu
18 | attn_drop: 0.2
19 | linear_prob: True
20 | in_drop: 0.2
21 | loss_fn: sce
22 | drop_edge_rate: 0.5
23 | optimizer: adamw
24 | alpha_l: 2
25 | scheduler: True
26 | remask_method: random
27 | momentum: 0.996
28 | lam: 10.0
29 | delayed_ema_epoch: 0
30 | num_remasking: 3
--------------------------------------------------------------------------------
/configs/ogbn-products.yaml:
--------------------------------------------------------------------------------
1 | lr: 0.002
2 | lr_f: 0.001
3 | num_hidden: 1024
4 | num_heads: 4
5 | num_out_heads: 1
6 | num_layers: 4
7 | weight_decay: 0.04
8 | weight_decay_f: 0
9 | max_epoch: 20
10 | max_epoch_f: 1000
11 | batch_size: 512
12 | batch_size_f: 256
13 | mask_rate: 0.5
14 | num_layers: 4
15 | encoder: gat
16 | decoder: gat
17 | activation: prelu
18 | attn_drop: 0.2
19 | linear_prob: True
20 | in_drop: 0.2
21 | loss_fn: sce
22 | drop_edge_rate: 0.5
23 | optimizer: adamw
24 | alpha_l: 3
25 | scheduler: True
26 | remask_method: random
27 | momentum: 0.996
28 | lam: 5.0
29 | delayed_ema_epoch: 0
30 | num_remasking: 3
--------------------------------------------------------------------------------
/configs/pubmed.yaml:
--------------------------------------------------------------------------------
1 | lr: 0.005
2 | lr_f: 0.025
3 | num_hidden: 512
4 | num_heads: 2
5 | num_out_heads: 1
6 | num_layers: 2
7 | weight_decay: 1e-5
8 | weight_decay_f: 5e-4
9 | max_epoch: 2000
10 | max_epoch_f: 500
11 | mask_rate: 0.9
12 | num_layers: 2
13 | encoder: gat
14 | decoder: gat
15 | activation: prelu
16 | attn_drop: 0.1
17 | linear_prob: True
18 | in_drop: 0.2
19 | loss_fn: sce
20 | drop_edge_rate: 0.0
21 | optimizer: adam
22 | replace_rate: 0.0
23 | alpha_l: 4
24 | scheduler: True
25 | remask_method: fixed
26 | momentum: 0.995
27 | lam: 1
28 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/GraphMAE2/bc473e8b750d6e90b7cba1c6bb20d858f617d90c/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/data_proc.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 |
5 | import dgl
6 | from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
7 | from ogb.nodeproppred import DglNodePropPredDataset
8 |
9 | from sklearn.preprocessing import StandardScaler
10 |
11 |
12 | GRAPH_DICT = {
13 | "cora": CoraGraphDataset,
14 | "citeseer": CiteseerGraphDataset,
15 | "pubmed": PubmedGraphDataset,
16 | "ogbn-arxiv": DglNodePropPredDataset,
17 | }
18 |
19 | def load_small_dataset(dataset_name):
20 | assert dataset_name in GRAPH_DICT, f"Unknow dataset: {dataset_name}."
21 | if dataset_name.startswith("ogbn"):
22 | dataset = GRAPH_DICT[dataset_name](dataset_name)
23 | else:
24 | dataset = GRAPH_DICT[dataset_name]()
25 |
26 | if dataset_name == "ogbn-arxiv":
27 | graph, labels = dataset[0]
28 | num_nodes = graph.num_nodes()
29 |
30 | split_idx = dataset.get_idx_split()
31 | train_idx, val_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
32 | graph = preprocess(graph)
33 |
34 | if not torch.is_tensor(train_idx):
35 | train_idx = torch.as_tensor(train_idx)
36 | val_idx = torch.as_tensor(val_idx)
37 | test_idx = torch.as_tensor(test_idx)
38 |
39 | feat = graph.ndata["feat"]
40 | feat = scale_feats(feat)
41 | graph.ndata["feat"] = feat
42 |
43 | train_mask = torch.full((num_nodes,), False).index_fill_(0, train_idx, True)
44 | val_mask = torch.full((num_nodes,), False).index_fill_(0, val_idx, True)
45 | test_mask = torch.full((num_nodes,), False).index_fill_(0, test_idx, True)
46 | graph.ndata["label"] = labels.view(-1)
47 | graph.ndata["train_mask"], graph.ndata["val_mask"], graph.ndata["test_mask"] = train_mask, val_mask, test_mask
48 | else:
49 | graph = dataset[0]
50 | graph = graph.remove_self_loop()
51 | graph = graph.add_self_loop()
52 | num_features = graph.ndata["feat"].shape[1]
53 | num_classes = dataset.num_classes
54 | return graph, (num_features, num_classes)
55 |
56 | def preprocess(graph):
57 | # make bidirected
58 | if "feat" in graph.ndata:
59 | feat = graph.ndata["feat"]
60 | else:
61 | feat = None
62 | src, dst = graph.all_edges()
63 | # graph.add_edges(dst, src)
64 | graph = dgl.to_bidirected(graph)
65 | if feat is not None:
66 | graph.ndata["feat"] = feat
67 |
68 | # add self-loop
69 | graph = graph.remove_self_loop().add_self_loop()
70 | # graph.create_formats_()
71 | return graph
72 |
73 |
74 | def scale_feats(x):
75 | logging.info("### scaling features ###")
76 | scaler = StandardScaler()
77 | feats = x.numpy()
78 | scaler.fit(feats)
79 | feats = torch.from_numpy(scaler.transform(feats)).float()
80 | return feats
81 |
--------------------------------------------------------------------------------
/datasets/lc_sampler.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import dgl
6 | from ogb.nodeproppred import DglNodePropPredDataset
7 |
8 | from .data_proc import preprocess, scale_feats
9 | from utils import mask_edge
10 |
11 | import logging
12 | import torch.multiprocessing
13 | from torch.utils.data import DataLoader
14 | from tqdm import tqdm
15 | torch.multiprocessing.set_sharing_strategy('file_system')
16 |
17 |
18 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
19 |
20 | # def collect_topk_ppr(graph, nodes, topk, alpha, epsilon):
21 | # if torch.is_tensor(nodes):
22 | # nodes = nodes.numpy()
23 | # row, col = graph.edges()
24 | # row = row.numpy()
25 | # col = col.numpy()
26 | # num_nodes = graph.num_nodes()
27 |
28 | # neighbors = build_topk_ppr((row, col), alpha, epsilon, nodes, topk, num_nodes=num_nodes)
29 | # return neighbors
30 |
31 | # ---------------------------------------------------------------------------------------------------------------------
32 |
33 |
34 | def load_dataset(data_dir, dataset_name):
35 | if dataset_name.startswith("ogbn"):
36 | dataset = DglNodePropPredDataset(dataset_name, root=os.path.join(data_dir, "dataset"))
37 | graph, label = dataset[0]
38 |
39 | if "year" in graph.ndata:
40 | del graph.ndata["year"]
41 | if not graph.is_multigraph:
42 | logging.info("--- to undirected graph ---")
43 | graph = preprocess(graph)
44 | graph = graph.remove_self_loop().add_self_loop()
45 |
46 | split_idx = dataset.get_idx_split()
47 | label = label.view(-1)
48 |
49 | feats = graph.ndata.pop("feat")
50 | if dataset_name in ("ogbn-arxiv","ogbn-papers100M"):
51 | feats = scale_feats(feats)
52 | elif dataset_name == "mag-scholar-f":
53 | edge_index = np.load(os.path.join(data_dir, dataset_name, "edge_index_f.npy"))
54 | feats = torch.from_numpy(np.load(os.path.join(data_dir, "feature_f.npy"))).float()
55 |
56 | graph = dgl.DGLGraph((edge_index[0], edge_index[1]))
57 |
58 | graph = dgl.remove_self_loop(graph)
59 | graph = dgl.add_self_loop(graph)
60 |
61 | label = torch.from_numpy(np.load(os.path.join(data_dir, "label_f.npy"))).to(torch.long)
62 | split_idx = torch.load(os.path.join(data_dir, "split_idx_f.pt"))
63 |
64 | # graph.ndata["feat"] = feats
65 | # graph.ndata["label"] = label
66 |
67 | return feats, graph, label, split_idx
68 |
69 | class LinearProbingDataLoader(DataLoader):
70 | def __init__(self, idx, feats, labels=None, **kwargs):
71 | self.labels = labels
72 | self.feats = feats
73 |
74 | kwargs["collate_fn"] = self.__collate_fn__
75 | super().__init__(dataset=idx, **kwargs)
76 |
77 | def __collate_fn__(self, batch_idx):
78 | feats = self.feats[batch_idx]
79 | label = self.labels[batch_idx]
80 |
81 | return feats, label
82 |
83 | class OnlineLCLoader(DataLoader):
84 | def __init__(self, root_nodes, graph, feats, labels=None, drop_edge_rate=0, **kwargs):
85 | self.graph = graph
86 | self.labels = labels
87 | self._drop_edge_rate = drop_edge_rate
88 | self.ego_graph_nodes = root_nodes
89 | self.feats = feats
90 |
91 | dataset = np.arange(len(root_nodes))
92 | kwargs["collate_fn"] = self.__collate_fn__
93 | super().__init__(dataset, **kwargs)
94 |
95 | def drop_edge(self, g):
96 | if self._drop_edge_rate <= 0:
97 | return g, g
98 |
99 | g = g.remove_self_loop()
100 | mask_index1 = mask_edge(g, self._drop_edge_rate)
101 | mask_index2 = mask_edge(g, self._drop_edge_rate)
102 | g1 = dgl.remove_edges(g, mask_index1).add_self_loop()
103 | g2 = dgl.remove_edges(g, mask_index2).add_self_loop()
104 | return g1, g2
105 |
106 | def __collate_fn__(self, batch_idx):
107 | ego_nodes = [self.ego_graph_nodes[i] for i in batch_idx]
108 | subgs = [self.graph.subgraph(ego_nodes[i]) for i in range(len(ego_nodes))]
109 | sg = dgl.batch(subgs)
110 |
111 | nodes = torch.from_numpy(np.concatenate(ego_nodes)).long()
112 | num_nodes = [x.shape[0] for x in ego_nodes]
113 | cum_num_nodes = np.cumsum([0] + num_nodes)[:-1]
114 |
115 | if self._drop_edge_rate > 0:
116 | drop_g1, drop_g2 = self.drop_edge(sg)
117 |
118 | sg = sg.remove_self_loop().add_self_loop()
119 | sg.ndata["feat"] = self.feats[nodes]
120 | targets = torch.from_numpy(cum_num_nodes)
121 |
122 | if self.labels != None:
123 | label = self.labels[batch_idx]
124 | else:
125 | label = None
126 |
127 | if self._drop_edge_rate > 0:
128 | return sg, targets, label, nodes, drop_g1, drop_g2
129 | else:
130 | return sg, targets, label, nodes
131 |
132 |
133 | def setup_training_data(dataset_name, data_dir, ego_graphs_file_path):
134 | feats, graph, labels, split_idx = load_dataset(data_dir, dataset_name)
135 |
136 | train_lbls = labels[split_idx["train"]]
137 | val_lbls = labels[split_idx["valid"]]
138 | test_lbls = labels[split_idx["test"]]
139 |
140 | labels = torch.cat([train_lbls, val_lbls, test_lbls])
141 |
142 | os.makedirs(os.path.dirname(ego_graphs_file_path), exist_ok=True)
143 |
144 | if not os.path.exists(ego_graphs_file_path):
145 | raise FileNotFoundError(f"{ego_graphs_file_path} doesn't exist")
146 | else:
147 | nodes = torch.load(ego_graphs_file_path)
148 |
149 | return feats, graph, labels, split_idx, nodes
150 |
151 |
152 | def setup_training_dataloder(loader_type, training_nodes, graph, feats, batch_size, drop_edge_rate=0, pretrain_clustergcn=False, cluster_iter_data=None):
153 | num_workers = 8
154 |
155 | if loader_type == "lc":
156 | assert training_nodes is not None
157 | else:
158 | raise NotImplementedError(f"{loader_type} is not implemented yet")
159 |
160 | # print(" -------- drop edge rate: {} --------".format(drop_edge_rate))
161 | dataloader = OnlineLCLoader(training_nodes, graph, feats=feats, drop_edge_rate=drop_edge_rate, batch_size=batch_size, shuffle=True, drop_last=False, persistent_workers=True, num_workers=num_workers)
162 | return dataloader
163 |
164 |
165 | def setup_eval_dataloder(loader_type, graph, feats, ego_graph_nodes=None, batch_size=128, shuffle=False):
166 | num_workers = 8
167 | if loader_type == "lc":
168 | assert ego_graph_nodes is not None
169 | else:
170 | raise NotImplementedError(f"{loader_type} is not implemented yet")
171 |
172 | dataloader = OnlineLCLoader(ego_graph_nodes, graph, feats, batch_size=batch_size, shuffle=shuffle, drop_last=False, persistent_workers=True, num_workers=num_workers)
173 | return dataloader
174 |
175 |
176 | def setup_finetune_dataloder(loader_type, graph, feats, ego_graph_nodes, labels, batch_size, shuffle=False):
177 | num_workers = 8
178 |
179 | if loader_type == "lc":
180 | assert ego_graph_nodes is not None
181 | else:
182 | raise NotImplementedError(f"{loader_type} is not implemented yet")
183 |
184 | dataloader = OnlineLCLoader(ego_graph_nodes, graph, feats, labels=labels, feats=feats, batch_size=batch_size, shuffle=shuffle, drop_last=False, num_workers=num_workers, persistent_workers=True)
185 |
186 | return dataloader
187 |
--------------------------------------------------------------------------------
/datasets/localclustering.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import namedtuple
3 | import multiprocessing
4 | import os
5 |
6 | import numpy as np
7 | from localgraphclustering import *
8 | from scipy.sparse import csr_matrix
9 | from ogb.nodeproppred import DglNodePropPredDataset
10 | import torch
11 | import logging
12 |
13 | import dgl
14 | from dgl.data import load_data
15 |
16 |
17 | def my_sweep_cut(g, node):
18 | vol_sum = 0.0
19 | in_edge = 0.0
20 | conds = np.zeros_like(node, dtype=np.float32)
21 | for i in range(len(node)):
22 | idx = node[i]
23 | vol_sum += g.d[idx]
24 | denominator = min(vol_sum, g.vol_G - vol_sum)
25 | if denominator == 0.0:
26 | denominator = 1.0
27 | in_edge += 2*sum([g.adjacency_matrix[idx,prev] for prev in node[:i+1]])
28 | cut = vol_sum - in_edge
29 | conds[i] = cut/denominator
30 | return conds
31 |
32 |
33 | def calc_local_clustering(args):
34 | i, log_steps, num_iter, ego_size, method = args
35 | if i % log_steps == 0:
36 | print(i)
37 | node, ppr = approximate_PageRank(graphlocal, [i], iterations=num_iter, method=method, normalize=False)
38 | d_inv = graphlocal.dn[node]
39 | d_inv[d_inv > 1.0] = 1.0
40 | ppr_d_inv = ppr * d_inv
41 | output = list(zip(node, ppr_d_inv))[:ego_size]
42 | node, ppr_d_inv = zip(*sorted(output, key=lambda x: x[1], reverse=True))
43 | assert node[0] == i
44 | node = np.array(node, dtype=np.int32)
45 | conds = my_sweep_cut(graphlocal, node)
46 | return node, conds
47 |
48 |
49 | def step1_local_clustering(data, name, idx_split, ego_size=128, num_iter=1000, log_steps=10000, num_workers=16, method='acl', save_dir=None):
50 | if save_dir is None:
51 | save_path = f"{name}-lc-ego-graphs-{ego_size}.pt"
52 | else:
53 | if not os.path.exists(save_dir):
54 | os.makedirs(save_dir, exist_ok=True)
55 |
56 | save_path = os.path.join(save_dir, f"{name}-lc-ego-graphs-{ego_size}.pt")
57 |
58 | N = data.num_nodes()
59 | edge_index = data.edges()
60 | edge_index = (edge_index[0].numpy(), edge_index[1].numpy())
61 | adj = csr_matrix((np.ones(edge_index[0].shape[0]), edge_index), shape=(N, N))
62 |
63 | global graphlocal
64 | graphlocal = GraphLocal.from_sparse_adjacency(adj)
65 | print('graphlocal generated')
66 |
67 | train_idx = idx_split["train"].cpu().numpy()
68 | valid_idx = idx_split["valid"].cpu().numpy()
69 | test_idx = idx_split["test"].cpu().numpy()
70 |
71 | with multiprocessing.Pool(num_workers) as pool:
72 | ego_graphs_train, conds_train = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in train_idx], chunksize=512))
73 |
74 | with multiprocessing.Pool(num_workers) as pool:
75 | ego_graphs_valid, conds_valid = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in valid_idx], chunksize=512))
76 |
77 | with multiprocessing.Pool(num_workers) as pool:
78 | ego_graphs_test, conds_test = zip(*pool.imap(calc_local_clustering, [(i, log_steps, num_iter, ego_size, method) for i in test_idx], chunksize=512))
79 |
80 | ego_graphs = []
81 | conds = []
82 | ego_graphs.extend(ego_graphs_train)
83 | ego_graphs.extend(ego_graphs_valid)
84 | ego_graphs.extend(ego_graphs_test)
85 | conds.extend(conds_train)
86 | conds.extend(conds_valid)
87 | conds.extend(conds_test)
88 |
89 | ego_graphs = [ego_graphs_train, ego_graphs_valid, ego_graphs_test]
90 |
91 | torch.save(ego_graphs, save_path)
92 |
93 |
94 | def preprocess(graph):
95 | # make bidirected
96 | if "feat" in graph.ndata:
97 | feat = graph.ndata["feat"]
98 | else:
99 | feat = None
100 | # src, dst = graph.all_edges()
101 | # graph.add_edges(dst, src)
102 | graph = dgl.to_bidirected(graph)
103 | if feat is not None:
104 | graph.ndata["feat"] = feat
105 |
106 | # add self-loop
107 | graph = graph.remove_self_loop().add_self_loop()
108 | # graph.create_formats_()
109 | return graph
110 |
111 |
112 | def load_dataset(data_dir, dataset_name):
113 | if dataset_name.startswith("ogbn"):
114 | dataset = DglNodePropPredDataset(dataset_name, root=os.path.join(data_dir, "dataset"))
115 | graph, label = dataset[0]
116 |
117 | if "year" in graph.ndata:
118 | del graph.ndata["year"]
119 | if not graph.is_multigraph:
120 | graph = preprocess(graph)
121 | # graph = graph.remove_self_loop().add_self_loop()
122 |
123 | split_idx = dataset.get_idx_split()
124 | label = label.view(-1)
125 |
126 | elif dataset_name == "mag-scholar-f":
127 | edge_index = np.load(os.path.join(data_dir, dataset_name, "edge_index_f.npy"))
128 | print(len(edge_index[0]))
129 | graph = dgl.DGLGraph((edge_index[0], edge_index[1]))
130 | print(graph)
131 | num_nodes = graph.num_nodes()
132 | assert num_nodes == 12403930
133 | split_idx = torch.load(os.path.join(data_dir, dataset_name, "split_idx_f.pt"))
134 | else:
135 | raise NotImplementedError
136 |
137 | return graph, split_idx
138 |
139 |
140 | if __name__ == "__main__":
141 | parser = argparse.ArgumentParser(description='LCGNN (Preprocessing)')
142 | parser.add_argument('--dataset', type=str, default='flickr')
143 | parser.add_argument("--data_dir", type=str, default="data")
144 | parser.add_argument("--save_dir", type=str, default="lc_ego_graphs")
145 | parser.add_argument('--ego_size', type=int, default=256)
146 | parser.add_argument('--num_iter', type=int, default=1000)
147 | parser.add_argument('--log_steps', type=int, default=10000)
148 | parser.add_argument('--seed', type=int, default=0)
149 | parser.add_argument('--method', type=str, default='acl')
150 | parser.add_argument('--num_workers', type=int, default=16)
151 | args = parser.parse_args()
152 | print(args)
153 |
154 | np.random.seed(args.seed)
155 |
156 | graph, split_idx = load_dataset(args.data_dir, args.dataset)
157 | step1_local_clustering(graph, args.dataset, split_idx, args.ego_size, args.num_iter, args.log_steps, args.num_workers, args.method, args.save_dir)
158 |
--------------------------------------------------------------------------------
/datasets/saint_sampler.py:
--------------------------------------------------------------------------------
1 | """This file is modified from https://github.com/dmlc/dgl/blob/master/python/dgl/dataloading/graphsaint.py"""
2 |
3 | import os
4 | import time
5 | import math
6 | import torch as th
7 | from torch.utils.data import DataLoader
8 | import random
9 | import numpy as np
10 | import dgl.function as fn
11 | import dgl
12 | from dgl.sampling import random_walk, pack_traces
13 | from tqdm import tqdm
14 |
15 |
16 | class SAINTSampler:
17 | """
18 | Description
19 | -----------
20 | SAINTSampler implements the sampler described in GraphSAINT. This sampler implements offline sampling in
21 | pre-sampling phase as well as fully offline sampling, fully online sampling in training phase.
22 | Users can conveniently set param 'online' of the sampler to choose different modes.
23 |
24 | Parameters
25 | ----------
26 | node_budget : int
27 | the expected number of nodes in each subgraph, which is specifically explained in the paper. Actually this
28 | param specifies the times of sampling nodes from the original graph with replacement. The meaning of edge_budget
29 | is similar to the node_budget.
30 | dn : str
31 | name of dataset.
32 | g : DGLGraph
33 | the full graph.
34 | train_nid : list
35 | ids of training nodes.
36 | num_workers_sampler : int
37 | number of processes to sample subgraphs in pre-sampling procedure using torch.dataloader.
38 | num_subg_sampler : int, optional
39 | the max number of subgraphs sampled in pre-sampling phase for computing normalization coefficients in the beginning.
40 | Actually this param is used as ``__len__`` of sampler in pre-sampling phase.
41 | Please make sure that num_subg_sampler is greater than batch_size_sampler so that we can sample enough subgraphs.
42 | Defaults: 10000
43 | batch_size_sampler : int, optional
44 | the number of subgraphs sampled by each process concurrently in pre-sampling phase.
45 | Defaults: 200
46 | online : bool, optional
47 | If `True`, we employ online sampling in training phase. Otherwise employing offline sampling.
48 | Defaults: True
49 | num_subg : int, optional
50 | the expected number of sampled subgraphs in pre-sampling phase.
51 | It is actually the 'N' in the original paper. Note that this param is different from the num_subg_sampler.
52 | This param is just used to control the number of pre-sampled subgraphs.
53 | Defaults: 50
54 | """
55 |
56 | def __init__(self, node_budget, dn, g, num_workers_sampler, num_subg_sampler=10000,
57 | batch_size_sampler=200, online=True, num_subg=50, full=True):
58 | self.g = g.cpu()
59 | self.node_budget = node_budget
60 | # self.train_g: dgl.graph = g.subgraph(train_nid)
61 | self.train_g = g
62 | self.dn, self.num_subg = dn, num_subg
63 | self.node_counter = th.zeros((self.train_g.num_nodes(),))
64 | self.edge_counter = th.zeros((self.train_g.num_edges(),))
65 | self.prob = None
66 | self.num_subg_sampler = num_subg_sampler
67 | self.batch_size_sampler = batch_size_sampler
68 | self.num_workers_sampler = num_workers_sampler
69 | self.train = False
70 | self.online = online
71 | self.full = full
72 |
73 | assert self.num_subg_sampler >= self.batch_size_sampler, "num_subg_sampler should be greater than batch_size_sampler"
74 | graph_fn, norm_fn = self.__generate_fn__()
75 |
76 | if os.path.exists(graph_fn):
77 | self.subgraphs = np.load(graph_fn, allow_pickle=True)
78 | # aggr_norm, loss_norm = np.load(norm_fn, allow_pickle=True)
79 | else:
80 | os.makedirs('./subgraphs/', exist_ok=True)
81 |
82 | self.subgraphs = []
83 | self.N, sampled_nodes = 0, 0
84 | # N: the number of pre-sampled subgraphs
85 |
86 | # Employ parallelism to speed up the sampling procedure
87 | loader = DataLoader(self, batch_size=self.batch_size_sampler, shuffle=True,
88 | num_workers=self.num_workers_sampler, collate_fn=self.__collate_fn__, drop_last=False)
89 |
90 | t = time.perf_counter()
91 | for num_nodes, subgraphs_nids, subgraphs_eids in tqdm(loader, desc="preprocessing saint subgraphs"):
92 |
93 | self.subgraphs.extend(subgraphs_nids)
94 | sampled_nodes += num_nodes
95 |
96 | _subgraphs, _node_counts = np.unique(np.concatenate(subgraphs_nids), return_counts=True)
97 | sampled_nodes_idx = th.from_numpy(_subgraphs)
98 | _node_counts = th.from_numpy(_node_counts)
99 | self.node_counter[sampled_nodes_idx] += _node_counts
100 |
101 | _subgraphs_eids, _edge_counts = np.unique(np.concatenate(subgraphs_eids), return_counts=True)
102 | sampled_edges_idx = th.from_numpy(_subgraphs_eids)
103 | _edge_counts = th.from_numpy(_edge_counts)
104 | self.edge_counter[sampled_edges_idx] += _edge_counts
105 |
106 | self.N += len(subgraphs_nids) # number of subgraphs
107 |
108 | # print("sampled_nodes: ", sampled_nodes, " --> ", self.train_g.num_nodes(), num_subg)
109 | if sampled_nodes > self.train_g.num_nodes() * num_subg:
110 | break
111 |
112 | print(f'Sampling time: [{time.perf_counter() - t:.2f}s]')
113 | np.save(graph_fn, self.subgraphs)
114 |
115 | # t = time.perf_counter()
116 | # aggr_norm, loss_norm = self.__compute_norm__()
117 | # print(f'Normalization time: [{time.perf_counter() - t:.2f}s]')
118 | # np.save(norm_fn, (aggr_norm, loss_norm))
119 |
120 | # self.train_g.ndata['l_n'] = th.Tensor(loss_norm)
121 | # self.train_g.edata['w'] = th.Tensor(aggr_norm)
122 | # self.__compute_degree_norm() # basically normalizing adjacent matrix
123 |
124 | random.shuffle(self.subgraphs)
125 | self.__clear__()
126 | print("The number of subgraphs is: ", len(self.subgraphs))
127 |
128 | self.train = True
129 |
130 | def __len__(self):
131 | if self.train is False:
132 | return self.num_subg_sampler
133 | else:
134 | if self.full:
135 | return len(self.subgraphs)
136 | else:
137 | return math.ceil(self.train_g.num_nodes() / self.node_budget)
138 |
139 | def __getitem__(self, idx):
140 | # Only when sampling subgraphs in training procedure and need to utilize sampled subgraphs and we still
141 | # have sampled subgraphs we can fetch a subgraph from sampled subgraphs
142 | if self.train:
143 | if self.online:
144 | subgraph = self.__sample__()
145 | return dgl.node_subgraph(self.train_g, subgraph)
146 | else:
147 | return dgl.node_subgraph(self.train_g, self.subgraphs[idx])
148 | else:
149 | subgraph_nids = self.__sample__()
150 | num_nodes = len(subgraph_nids)
151 | subgraph_eids = dgl.node_subgraph(self.train_g, subgraph_nids).edata[dgl.EID]
152 | return num_nodes, subgraph_nids, subgraph_eids
153 |
154 | def __collate_fn__(self, batch):
155 | if self.train: # sample only one graph each epoch, batch_size in training phase in 1
156 | return batch[0]
157 | else:
158 | sum_num_nodes = 0
159 | subgraphs_nids_list = []
160 | subgraphs_eids_list = []
161 | for num_nodes, subgraph_nids, subgraph_eids in batch:
162 | sum_num_nodes += num_nodes
163 | subgraphs_nids_list.append(subgraph_nids)
164 | subgraphs_eids_list.append(subgraph_eids)
165 | return sum_num_nodes, subgraphs_nids_list, subgraphs_eids_list
166 |
167 | def __clear__(self):
168 | self.prob = None
169 | self.node_counter = None
170 | self.edge_counter = None
171 | self.g = None
172 |
173 | def __generate_fn__(self):
174 | raise NotImplementedError
175 |
176 | def __compute_norm__(self):
177 |
178 | self.node_counter[self.node_counter == 0] = 1
179 | self.edge_counter[self.edge_counter == 0] = 1
180 |
181 | loss_norm = self.N / self.node_counter / self.train_g.num_nodes()
182 |
183 | self.train_g.ndata['n_c'] = self.node_counter
184 | self.train_g.edata['e_c'] = self.edge_counter
185 | self.train_g.apply_edges(fn.v_div_e('n_c', 'e_c', 'a_n'))
186 | aggr_norm = self.train_g.edata.pop('a_n')
187 |
188 | self.train_g.ndata.pop('n_c')
189 | self.train_g.edata.pop('e_c')
190 |
191 | return aggr_norm.numpy(), loss_norm.numpy()
192 |
193 | def __compute_degree_norm(self):
194 |
195 | self.train_g.ndata['train_D_norm'] = 1. / self.train_g.in_degrees().float().clamp(min=1).unsqueeze(1)
196 | self.g.ndata['full_D_norm'] = 1. / self.g.in_degrees().float().clamp(min=1).unsqueeze(1)
197 |
198 | def __sample__(self):
199 | raise NotImplementedError
200 |
201 | def get_generated_subgraph_nodes(self):
202 | return self.subgraphs
203 |
204 |
205 |
206 | class SAINTRandomWalkLoader(SAINTSampler):
207 | """
208 | Description
209 | -----------
210 | GraphSAINT with random walk sampler
211 |
212 | Parameters
213 | ----------
214 | num_roots : int
215 | the number of roots to generate random walks.
216 | length : int
217 | the length of each random walk.
218 |
219 | """
220 | def __init__(self, feats, num_roots, length, **kwargs):
221 | self.num_roots, self.length = num_roots, length
222 | self.feats = feats
223 | super(SAINTRandomWalkLoader, self).__init__(node_budget=num_roots * length, **kwargs)
224 |
225 | def __generate_fn__(self):
226 | graph_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}.npy'.format(self.dn, self.num_roots,
227 | self.length, self.num_subg))
228 | norm_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}_norm.npy'.format(self.dn, self.num_roots,
229 | self.length, self.num_subg))
230 | return graph_fn, norm_fn
231 |
232 | def __sample__(self):
233 | sampled_roots = th.randint(0, self.train_g.num_nodes(), (self.num_roots,))
234 | traces, types = random_walk(self.train_g, nodes=sampled_roots, length=self.length)
235 | sampled_nodes, _, _, _ = pack_traces(traces, types)
236 | sampled_nodes = sampled_nodes.unique()
237 | return sampled_nodes.numpy()
238 |
239 | def __collate_fn__(self, batch):
240 | if self.train: # sample only one graph each epoch, batch_size in training phase in 1
241 | subg = batch[0]
242 | node_ids = subg.ndata["_ID"]
243 | feats = self.feats[node_ids]
244 | return feats, subg
245 | else:
246 | sum_num_nodes = 0
247 | subgraphs_nids_list = []
248 | subgraphs_eids_list = []
249 | for num_nodes, subgraph_nids, subgraph_eids in batch:
250 | sum_num_nodes += num_nodes
251 | subgraphs_nids_list.append(subgraph_nids)
252 | subgraphs_eids_list.append(subgraph_eids)
253 | return sum_num_nodes, subgraphs_nids_list, subgraphs_eids_list
254 |
255 |
256 | def build_saint_dataloader(feats, graph, dataset_name, online=False, **kwargs):
257 | num_nodes = graph.num_nodes()
258 | num_subg_sampler = 20000 # the max num of subgraphs
259 | num_subg = 50 # the max times a node be sampled
260 | full = True
261 | batch_size_sampler = 200 # batch_size of nodes
262 |
263 | num_roots = max(num_nodes // 100, 2000)
264 | # num_roots = 2000 # starting node for each multiple random walk
265 | length = 4
266 | num_workers = 4
267 |
268 | params = {
269 | 'dn': dataset_name, 'g': graph, 'num_workers_sampler': 4,
270 | 'num_subg_sampler': num_subg_sampler,
271 | 'batch_size_sampler': batch_size_sampler,
272 | 'online': online, 'num_subg': num_subg,
273 | 'full': full
274 | }
275 |
276 | saint_sampler = SAINTRandomWalkLoader(feats, num_roots, length, **params)
277 | loader = DataLoader(saint_sampler, collate_fn=saint_sampler.__collate_fn__, batch_size=1, **kwargs)
278 | return loader
279 |
280 |
281 | def get_saint_subgraphs(graph, dataset_name, online=False):
282 | num_nodes = graph.num_nodes()
283 | num_subg_sampler = 20000 # the max num of subgraphs
284 | num_subg = 50 # the max times a node be sampled
285 | full = True
286 | batch_size_sampler = 200 # batch_size of nodes
287 |
288 | num_roots = max(num_nodes // 100, 2000)
289 | # num_roots = 2000 # starting node for each multiple random walk
290 | length = 4
291 | num_workers = 4
292 |
293 | params = {
294 | 'dn': dataset_name, 'g': graph, 'num_workers_sampler': 4,
295 | 'num_subg_sampler': num_subg_sampler,
296 | 'batch_size_sampler': batch_size_sampler,
297 | 'online': online, 'num_subg': num_subg,
298 | 'full': full
299 | }
300 | saint_sampler = SAINTRandomWalkLoader(None, num_roots, length, **params)
301 | subgraph_nodes = saint_sampler.get_generated_subgraph_nodes()
302 | return subgraph_nodes
303 |
--------------------------------------------------------------------------------
/main_full_batch.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | from tqdm import tqdm
4 | import torch
5 |
6 | from utils import (
7 | build_args,
8 | create_optimizer,
9 | set_random_seed,
10 | TBLogger,
11 | get_current_lr,
12 | load_best_configs,
13 | )
14 | from datasets.data_proc import load_small_dataset
15 | from models.finetune import linear_probing_full_batch
16 | from models import build_model
17 |
18 |
19 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
20 |
21 |
22 | def pretrain(model, graph, feat, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None):
23 | logging.info("start training..")
24 | graph = graph.to(device)
25 | x = feat.to(device)
26 |
27 | target_nodes = torch.arange(x.shape[0], device=x.device, dtype=torch.long)
28 | epoch_iter = tqdm(range(max_epoch))
29 |
30 | for epoch in epoch_iter:
31 | model.train()
32 |
33 | loss = model(graph, x, targets=target_nodes)
34 |
35 | loss_dict = {"loss": loss.item()}
36 | optimizer.zero_grad()
37 | loss.backward()
38 | optimizer.step()
39 | if scheduler is not None:
40 | scheduler.step()
41 |
42 | epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}")
43 | if logger is not None:
44 | loss_dict["lr"] = get_current_lr(optimizer)
45 | logger.note(loss_dict, step=epoch)
46 |
47 | if (epoch + 1) % 200 == 0:
48 | linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob, mute=True)
49 |
50 | return model
51 |
52 |
53 | def main(args):
54 | device = args.device if args.device >= 0 else "cpu"
55 | seeds = args.seeds
56 | dataset_name = args.dataset
57 | max_epoch = args.max_epoch
58 | max_epoch_f = args.max_epoch_f
59 | num_hidden = args.num_hidden
60 | num_layers = args.num_layers
61 | encoder_type = args.encoder
62 | decoder_type = args.decoder
63 | replace_rate = args.replace_rate
64 |
65 | optim_type = args.optimizer
66 | loss_fn = args.loss_fn
67 |
68 | lr = args.lr
69 | weight_decay = args.weight_decay
70 | lr_f = args.lr_f
71 | weight_decay_f = args.weight_decay_f
72 | linear_prob = args.linear_prob
73 | load_model = args.load_model
74 | logs = args.logging
75 | use_scheduler = args.scheduler
76 |
77 | graph, (num_features, num_classes) = load_small_dataset(dataset_name)
78 | args.num_features = num_features
79 |
80 | acc_list = []
81 | estp_acc_list = []
82 | for i, seed in enumerate(seeds):
83 | print(f"####### Run {i} for seed {seed}")
84 | set_random_seed(seed)
85 |
86 | if logs:
87 | logger = TBLogger(name=f"{dataset_name}_loss_{loss_fn}_rpr_{replace_rate}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}")
88 | else:
89 | logger = None
90 |
91 | model = build_model(args)
92 | model.to(device)
93 | optimizer = create_optimizer(optim_type, model, lr, weight_decay)
94 |
95 | if use_scheduler:
96 | logging.info("Use schedular")
97 | scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
98 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
99 | else:
100 | scheduler = None
101 |
102 | x = graph.ndata["feat"]
103 | if not load_model:
104 | model = pretrain(model, graph, x, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger)
105 | model = model.cpu()
106 |
107 | if load_model:
108 | logging.info("Loading Model ... ")
109 | model.load_state_dict(torch.load("checkpoint.pt"))
110 |
111 | model = model.to(device)
112 | model.eval()
113 |
114 | final_acc, estp_acc = linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob)
115 | acc_list.append(final_acc)
116 | estp_acc_list.append(estp_acc)
117 |
118 | if logger is not None:
119 | logger.finish()
120 |
121 | final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list)
122 | estp_acc, estp_acc_std = np.mean(estp_acc_list), np.std(estp_acc_list)
123 | print(f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}")
124 | print(f"# early-stopping_acc: {estp_acc:.4f}±{estp_acc_std:.4f}")
125 |
126 |
127 | # Press the green button in the gutter to run the script.
128 | if __name__ == "__main__":
129 | args = build_args()
130 | if args.use_cfg:
131 | args = load_best_configs(args)
132 | print(args)
133 | main(args)
134 |
--------------------------------------------------------------------------------
/main_large.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import numpy as np
4 | from tqdm import tqdm
5 |
6 | import torch
7 |
8 | from utils import (
9 | WandbLogger,
10 | build_args,
11 | create_optimizer,
12 | set_random_seed,
13 | load_best_configs,
14 | show_occupied_memory,
15 | )
16 | from models import build_model
17 | from datasets.lc_sampler import (
18 | setup_training_dataloder,
19 | setup_training_data,
20 | )
21 | from models.finetune import linear_probing_minibatch, finetune
22 |
23 | import warnings
24 |
25 | warnings.filterwarnings("ignore")
26 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
27 |
28 |
29 | def evaluate(
30 | model,
31 | graph, feats, labels,
32 | split_idx,
33 | lr_f, weight_decay_f, max_epoch_f,
34 | linear_prob=True,
35 | device=0,
36 | batch_size=256,
37 | logger=None, ego_graph_nodes=None,
38 | label_rate=1.0,
39 | full_graph_forward=False,
40 | shuffle=True,
41 | ):
42 | logging.info("Using `lc` for evaluation...")
43 | num_train, num_val, num_test = [split_idx[k].shape[0] for k in ["train", "valid", "test"]]
44 | print(num_train,num_val,num_test)
45 |
46 | train_g_idx = np.arange(0, num_train)
47 | val_g_idx = np.arange(num_train, num_train+num_val)
48 | test_g_idx = np.arange(num_train+num_val, num_train+num_val+num_test)
49 |
50 | train_ego_graph_nodes = [ego_graph_nodes[i] for i in train_g_idx]
51 | val_ego_graph_nodes = [ego_graph_nodes[i] for i in val_g_idx]
52 | test_ego_graph_nodes = [ego_graph_nodes[i] for i in test_g_idx]
53 |
54 | train_lbls, val_lbls, test_lbls = labels[train_g_idx], labels[val_g_idx], labels[test_g_idx]
55 |
56 | # labels = [train_lbls, val_lbls, test_lbls]
57 | assert len(train_ego_graph_nodes) == len(train_lbls)
58 | assert len(val_ego_graph_nodes) == len(val_lbls)
59 | assert len(test_ego_graph_nodes) == len(test_lbls)
60 |
61 | print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}")
62 | logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---")
63 |
64 |
65 | if linear_prob:
66 | result = linear_probing_minibatch(model, graph, feats, [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes], [train_lbls, val_lbls, test_lbls], lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, batch_size=batch_size, device=device, shuffle=shuffle)
67 | else:
68 | max_epoch_f = max_epoch_f // 2
69 |
70 | if label_rate < 1.0:
71 | rand_idx = np.arange(len(train_ego_graph_nodes))
72 | np.random.shuffle(rand_idx)
73 | rand_idx = rand_idx[:int(label_rate * len(train_ego_graph_nodes))]
74 | train_ego_graph_nodes = [train_ego_graph_nodes[i] for i in rand_idx]
75 | train_lbls = train_lbls[rand_idx]
76 |
77 | logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---")
78 |
79 | # train_lbls = (all_train_lbls, train_lbls)
80 | result = finetune(
81 | model, graph, feats,
82 | [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes],
83 | [train_lbls, val_lbls, test_lbls],
84 | split_idx=split_idx,
85 | lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, use_scheduler=True, batch_size=batch_size, device=device, logger=logger, full_graph_forward=full_graph_forward,
86 | )
87 | return result
88 |
89 |
90 | def pretrain(model, feats, graph, ego_graph_nodes, max_epoch, device, use_scheduler, lr, weight_decay, batch_size=512, sampling_method="lc", optimizer="adam", drop_edge_rate=0):
91 | logging.info("start training..")
92 |
93 | model = model.to(device)
94 | optimizer = create_optimizer(optimizer, model, lr, weight_decay)
95 |
96 | dataloader = setup_training_dataloder(
97 | sampling_method, ego_graph_nodes, graph, feats, batch_size=batch_size, drop_edge_rate=drop_edge_rate)
98 |
99 | logging.info(f"After creating dataloader: Memory: {show_occupied_memory():.2f} MB")
100 | if use_scheduler and max_epoch > 0:
101 | logging.info("Use scheduler")
102 | scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
103 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
104 | else:
105 | scheduler = None
106 |
107 | for epoch in range(max_epoch):
108 | epoch_iter = tqdm(dataloader)
109 | losses = []
110 | # assert (graph.in_degrees() > 0).all(), "after loading"
111 |
112 | for batch_g in epoch_iter:
113 | model.train()
114 | if drop_edge_rate > 0:
115 | batch_g, targets, _, node_idx, drop_g1, drop_g2 = batch_g
116 | batch_g = batch_g.to(device)
117 | drop_g1 = drop_g1.to(device)
118 | drop_g2 = drop_g2.to(device)
119 | x = batch_g.ndata.pop("feat")
120 | loss = model(batch_g, x, targets, epoch, drop_g1, drop_g2)
121 | else:
122 | batch_g, targets, _, node_idx = batch_g
123 | batch_g = batch_g.to(device)
124 | x = batch_g.ndata.pop("feat")
125 | loss = model(batch_g, x, targets, epoch)
126 |
127 | optimizer.zero_grad()
128 | loss.backward()
129 | torch.nn.utils.clip_grad_norm_(model.parameters(), 3)
130 | optimizer.step()
131 |
132 | epoch_iter.set_description(f"train_loss: {loss.item():.4f}, Memory: {show_occupied_memory():.2f} MB")
133 | losses.append(loss.item())
134 |
135 | if scheduler is not None:
136 | scheduler.step()
137 |
138 | torch.save(model.state_dict(), os.path.join(model_dir, model_name))
139 |
140 | print(f"# Epoch {epoch} | train_loss: {np.mean(losses):.4f}, Memory: {show_occupied_memory():.2f} MB")
141 |
142 | return model
143 |
144 |
145 | if __name__ == "__main__":
146 | args = build_args()
147 | if args.use_cfg:
148 | args = load_best_configs(args)
149 |
150 | if args.device < 0:
151 | device = "cpu"
152 | else:
153 | device = "cuda:{}".format(args.device)
154 | seeds = args.seeds
155 | dataset_name = args.dataset
156 | max_epoch = args.max_epoch
157 | max_epoch_f = args.max_epoch_f
158 | num_hidden = args.num_hidden
159 | num_layers = args.num_layers
160 | encoder_type = args.encoder
161 | decoder_type = args.decoder
162 | encoder = args.encoder
163 | decoder = args.decoder
164 | num_hidden = args.num_hidden
165 | drop_edge_rate = args.drop_edge_rate
166 |
167 | optim_type = args.optimizer
168 | loss_fn = args.loss_fn
169 |
170 | lr = args.lr
171 | weight_decay = args.weight_decay
172 | lr_f = args.lr_f
173 | weight_decay_f = args.weight_decay_f
174 | linear_prob = args.linear_prob
175 | load_model = args.load_model
176 | no_pretrain = args.no_pretrain
177 | logs = args.logging
178 | use_scheduler = args.scheduler
179 | batch_size = args.batch_size
180 | batch_size_f = args.batch_size_f
181 | sampling_method = args.sampling_method
182 | ego_graph_file_path = args.ego_graph_file_path
183 | data_dir = args.data_dir
184 |
185 | n_procs = torch.cuda.device_count()
186 | optimizer_type = args.optimizer
187 | label_rate = args.label_rate
188 | lam = args.lam
189 | full_graph_forward = hasattr(args, "full_graph_forward") and args.full_graph_forward and not linear_prob
190 |
191 | model_dir = "checkpoints"
192 | os.makedirs(model_dir, exist_ok=True)
193 |
194 | set_random_seed(0)
195 | print(args)
196 |
197 | logging.info(f"Before loading data, occupied memory: {show_occupied_memory():.2f} MB") # in MB
198 | feats, graph, labels, split_idx, ego_graph_nodes = setup_training_data(dataset_name, data_dir, ego_graph_file_path)
199 | if dataset_name == "ogbn-papers100M":
200 | pretrain_ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2] + ego_graph_nodes[3]
201 | else:
202 | pretrain_ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2]
203 | ego_graph_nodes = ego_graph_nodes[0] + ego_graph_nodes[1] + ego_graph_nodes[2] # * merge train/val/test = all
204 |
205 | logging.info(f"After loading data, occupied memory: {show_occupied_memory():.2f} MB") # in MB
206 |
207 | args.num_features = feats.shape[1]
208 |
209 | if logs:
210 | logger = WandbLogger(log_path=f"{dataset_name}_loss_{loss_fn}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}", project="GraphMAE2", args=args)
211 | else:
212 | logger = None
213 | model_name = f"{encoder}_{decoder}_{num_hidden}_{num_layers}_{dataset_name}_{args.mask_rate}_{num_hidden}_checkpoint.pt"
214 |
215 | model = build_model(args)
216 |
217 | if not args.no_pretrain:
218 | # ------------- pretraining starts ----------------
219 | if not load_model:
220 | logging.info("---- start pretraining ----")
221 | model = pretrain(model, feats, graph, pretrain_ego_graph_nodes, max_epoch=max_epoch, device=device, use_scheduler=use_scheduler, lr=lr,
222 | weight_decay=weight_decay, batch_size=batch_size, drop_edge_rate=drop_edge_rate,
223 | sampling_method=sampling_method, optimizer=optimizer_type)
224 |
225 | model = model.cpu()
226 | logging.info(f"saving model to {model_dir}/{model_name}...")
227 | torch.save(model.state_dict(), os.path.join(model_dir, model_name))
228 | # ------------- pretraining ends ----------------
229 |
230 | if load_model:
231 | model.load_state_dict(torch.load(os.path.join(args.checkpoint_path)))
232 | logging.info(f"Loading Model from {args.checkpoint_path}...")
233 | else:
234 | logging.info("--- no pretrain ---")
235 |
236 | model = model.to(device)
237 | model.eval()
238 |
239 | logging.info("---- start finetuning / evaluation ----")
240 |
241 | final_accs = []
242 | for i,_ in enumerate(seeds):
243 | print(f"####### Run seed {seeds[i]}")
244 | set_random_seed(seeds[i])
245 | eval_model = build_model(args)
246 | eval_model.load_state_dict(model.state_dict())
247 | eval_model.to(device)
248 |
249 | print(f"features size : {feats.shape[1]}")
250 | logging.info("start evaluation...")
251 | final_acc = evaluate(
252 | eval_model, graph, feats, labels, split_idx,
253 | lr_f, weight_decay_f, max_epoch_f,
254 | device=device,
255 | batch_size=batch_size_f,
256 | ego_graph_nodes=ego_graph_nodes,
257 | linear_prob=linear_prob,
258 | label_rate=label_rate,
259 | full_graph_forward=full_graph_forward,
260 | shuffle=False if dataset_name == "ogbn-papers100M" else True
261 | )
262 |
263 | final_accs.append(float(final_acc))
264 |
265 | print(f"Run {seeds[i]} | TestAcc: {final_acc:.4f}")
266 |
267 | print(f"# final_acc: {np.mean(final_accs):.4f}, std: {np.std(final_accs):.4f}")
268 |
269 | if logger is not None:
270 | logger.finish()
271 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .edcoder import PreModel
2 |
3 |
4 | def build_model(args):
5 | num_heads = args.num_heads
6 | num_out_heads = args.num_out_heads
7 | num_hidden = args.num_hidden
8 | num_layers = args.num_layers
9 | residual = args.residual
10 | attn_drop = args.attn_drop
11 | in_drop = args.in_drop
12 | norm = args.norm
13 | negative_slope = args.negative_slope
14 | encoder_type = args.encoder
15 | decoder_type = args.decoder
16 | mask_rate = args.mask_rate
17 | remask_rate = args.remask_rate
18 | mask_method = args.mask_method
19 | drop_edge_rate = args.drop_edge_rate
20 |
21 | activation = args.activation
22 | loss_fn = args.loss_fn
23 | alpha_l = args.alpha_l
24 |
25 | num_features = args.num_features
26 | num_dec_layers = args.num_dec_layers
27 | num_remasking = args.num_remasking
28 | lam = args.lam
29 | delayed_ema_epoch = args.delayed_ema_epoch
30 | replace_rate = args.replace_rate
31 | remask_method = args.remask_method
32 | momentum = args.momentum
33 | zero_init = args.dataset in ("cora", "pubmed", "citeseer")
34 |
35 | model = PreModel(
36 | in_dim=num_features,
37 | num_hidden=num_hidden,
38 | num_layers=num_layers,
39 | num_dec_layers=num_dec_layers,
40 | num_remasking=num_remasking,
41 | nhead=num_heads,
42 | nhead_out=num_out_heads,
43 | activation=activation,
44 | feat_drop=in_drop,
45 | attn_drop=attn_drop,
46 | negative_slope=negative_slope,
47 | residual=residual,
48 | encoder_type=encoder_type,
49 | decoder_type=decoder_type,
50 | mask_rate=mask_rate,
51 | remask_rate=remask_rate,
52 | mask_method=mask_method,
53 | norm=norm,
54 | loss_fn=loss_fn,
55 | drop_edge_rate=drop_edge_rate,
56 | alpha_l=alpha_l,
57 | lam=lam,
58 | delayed_ema_epoch=delayed_ema_epoch,
59 | replace_rate=replace_rate,
60 | remask_method=remask_method,
61 | momentum=momentum,
62 | zero_init=zero_init,
63 | )
64 | return model
65 |
--------------------------------------------------------------------------------
/models/edcoder.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 |
3 | from typing import Optional
4 | import torch
5 | import torch.nn as nn
6 | from functools import partial
7 |
8 | from .gat import GAT
9 |
10 | from .loss_func import sce_loss
11 |
12 |
13 | def setup_module(m_type, enc_dec, in_dim, num_hidden, out_dim, num_layers, dropout, activation, residual, norm, nhead, nhead_out, attn_drop, negative_slope=0.2, concat_out=True, **kwargs) -> nn.Module:
14 | if m_type in ("gat", "tsgat"):
15 | mod = GAT(
16 | in_dim=in_dim,
17 | num_hidden=num_hidden,
18 | out_dim=out_dim,
19 | num_layers=num_layers,
20 | nhead=nhead,
21 | nhead_out=nhead_out,
22 | concat_out=concat_out,
23 | activation=activation,
24 | feat_drop=dropout,
25 | attn_drop=attn_drop,
26 | negative_slope=negative_slope,
27 | residual=residual,
28 | norm=norm,
29 | encoding=(enc_dec == "encoding"),
30 | **kwargs,
31 | )
32 | elif m_type == "mlp":
33 | # * just for decoder
34 | mod = nn.Sequential(
35 | nn.Linear(in_dim, num_hidden * 2),
36 | nn.PReLU(),
37 | nn.Dropout(0.2),
38 | nn.Linear(num_hidden * 2, out_dim)
39 | )
40 | elif m_type == "linear":
41 | mod = nn.Linear(in_dim, out_dim)
42 | else:
43 | raise NotImplementedError
44 |
45 | return mod
46 |
47 | class PreModel(nn.Module):
48 | def __init__(
49 | self,
50 | in_dim: int,
51 | num_hidden: int,
52 | num_layers: int,
53 | num_dec_layers: int,
54 | num_remasking: int,
55 | nhead: int,
56 | nhead_out: int,
57 | activation: str,
58 | feat_drop: float,
59 | attn_drop: float,
60 | negative_slope: float,
61 | residual: bool,
62 | norm: Optional[str],
63 | mask_rate: float = 0.3,
64 | remask_rate: float = 0.5,
65 | remask_method: str = "random",
66 | mask_method: str = "random",
67 | encoder_type: str = "gat",
68 | decoder_type: str = "gat",
69 | loss_fn: str = "byol",
70 | drop_edge_rate: float = 0.0,
71 | alpha_l: float = 2,
72 | lam: float = 1.0,
73 | delayed_ema_epoch: int = 0,
74 | momentum: float = 0.996,
75 | replace_rate: float = 0.0,
76 | zero_init: bool = False,
77 | ):
78 | super(PreModel, self).__init__()
79 | self._mask_rate = mask_rate
80 | self._remask_rate = remask_rate
81 | self._mask_method = mask_method
82 | self._alpha_l = alpha_l
83 | self._delayed_ema_epoch = delayed_ema_epoch
84 |
85 | self.num_remasking = num_remasking
86 | self._encoder_type = encoder_type
87 | self._decoder_type = decoder_type
88 | self._drop_edge_rate = drop_edge_rate
89 | self._output_hidden_size = num_hidden
90 | self._momentum = momentum
91 | self._replace_rate = replace_rate
92 | self._num_remasking = num_remasking
93 | self._remask_method = remask_method
94 |
95 | self._token_rate = 1 - self._replace_rate
96 | self._lam = lam
97 |
98 | assert num_hidden % nhead == 0
99 | assert num_hidden % nhead_out == 0
100 | if encoder_type in ("gat",):
101 | enc_num_hidden = num_hidden // nhead
102 | enc_nhead = nhead
103 | else:
104 | enc_num_hidden = num_hidden
105 | enc_nhead = 1
106 |
107 | dec_in_dim = num_hidden
108 | dec_num_hidden = num_hidden // nhead if decoder_type in ("gat",) else num_hidden
109 |
110 | # build encoder
111 | self.encoder = setup_module(
112 | m_type=encoder_type,
113 | enc_dec="encoding",
114 | in_dim=in_dim,
115 | num_hidden=enc_num_hidden,
116 | out_dim=enc_num_hidden,
117 | num_layers=num_layers,
118 | nhead=enc_nhead,
119 | nhead_out=enc_nhead,
120 | concat_out=True,
121 | activation=activation,
122 | dropout=feat_drop,
123 | attn_drop=attn_drop,
124 | negative_slope=negative_slope,
125 | residual=residual,
126 | norm=norm,
127 | )
128 |
129 | self.decoder = setup_module(
130 | m_type=decoder_type,
131 | enc_dec="decoding",
132 | in_dim=dec_in_dim,
133 | num_hidden=dec_num_hidden,
134 | out_dim=in_dim,
135 | nhead_out=nhead_out,
136 | num_layers=num_dec_layers,
137 | nhead=nhead,
138 | activation=activation,
139 | dropout=feat_drop,
140 | attn_drop=attn_drop,
141 | negative_slope=negative_slope,
142 | residual=residual,
143 | norm=norm,
144 | concat_out=True,
145 | )
146 |
147 | self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim))
148 | self.dec_mask_token = nn.Parameter(torch.zeros(1, num_hidden))
149 |
150 | self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False)
151 |
152 | if not zero_init:
153 | self.reset_parameters_for_token()
154 |
155 |
156 | # * setup loss function
157 | self.criterion = self.setup_loss_fn(loss_fn, alpha_l)
158 |
159 | self.projector = nn.Sequential(
160 | nn.Linear(num_hidden, 256),
161 | nn.PReLU(),
162 | nn.Linear(256, num_hidden),
163 | )
164 | self.projector_ema = nn.Sequential(
165 | nn.Linear(num_hidden, 256),
166 | nn.PReLU(),
167 | nn.Linear(256, num_hidden),
168 | )
169 | self.predictor = nn.Sequential(
170 | nn.PReLU(),
171 | nn.Linear(num_hidden, num_hidden)
172 | )
173 |
174 | self.encoder_ema = setup_module(
175 | m_type=encoder_type,
176 | enc_dec="encoding",
177 | in_dim=in_dim,
178 | num_hidden=enc_num_hidden,
179 | out_dim=enc_num_hidden,
180 | num_layers=num_layers,
181 | nhead=enc_nhead,
182 | nhead_out=enc_nhead,
183 | concat_out=True,
184 | activation=activation,
185 | dropout=feat_drop,
186 | attn_drop=attn_drop,
187 | negative_slope=negative_slope,
188 | residual=residual,
189 | norm=norm,
190 | )
191 | self.encoder_ema.load_state_dict(self.encoder.state_dict())
192 | self.projector_ema.load_state_dict(self.projector.state_dict())
193 |
194 | for p in self.encoder_ema.parameters():
195 | p.requires_grad = False
196 | p.detach_()
197 | for p in self.projector_ema.parameters():
198 | p.requires_grad = False
199 | p.detach_()
200 |
201 | self.print_num_parameters()
202 |
203 | def print_num_parameters(self):
204 | num_encoder_params = [p.numel() for p in self.encoder.parameters() if p.requires_grad]
205 | num_decoder_params = [p.numel() for p in self.decoder.parameters() if p.requires_grad]
206 | num_params = [p.numel() for p in self.parameters() if p.requires_grad]
207 |
208 | print(f"num_encoder_params: {sum(num_encoder_params)}, num_decoder_params: {sum(num_decoder_params)}, num_params_in_total: {sum(num_params)}")
209 |
210 | def reset_parameters_for_token(self):
211 | nn.init.xavier_normal_(self.enc_mask_token)
212 | nn.init.xavier_normal_(self.dec_mask_token)
213 | nn.init.xavier_normal_(self.encoder_to_decoder.weight, gain=1.414)
214 |
215 | @property
216 | def output_hidden_dim(self):
217 | return self._output_hidden_size
218 |
219 | def setup_loss_fn(self, loss_fn, alpha_l):
220 | if loss_fn == "mse":
221 | print(f"=== Use mse_loss ===")
222 | criterion = nn.MSELoss()
223 | elif loss_fn == "sce":
224 | print(f"=== Use sce_loss and alpha_l={alpha_l} ===")
225 | criterion = partial(sce_loss, alpha=alpha_l)
226 | else:
227 | raise NotImplementedError
228 | return criterion
229 |
230 | def forward(self, g, x, targets=None, epoch=0, drop_g1=None, drop_g2=None): # ---- attribute reconstruction ----
231 | loss = self.mask_attr_prediction(g, x, targets, epoch, drop_g1, drop_g2)
232 |
233 | return loss
234 |
235 | def mask_attr_prediction(self, g, x, targets, epoch, drop_g1=None, drop_g2=None):
236 | pre_use_g, use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise(g, x, self._mask_rate)
237 | use_g = drop_g1 if drop_g1 is not None else g
238 |
239 | enc_rep = self.encoder(use_g, use_x,)
240 |
241 | with torch.no_grad():
242 | drop_g2 = drop_g2 if drop_g2 is not None else g
243 | latent_target = self.encoder_ema(drop_g2, x,)
244 | if targets is not None:
245 | latent_target = self.projector_ema(latent_target[targets])
246 | else:
247 | latent_target = self.projector_ema(latent_target[keep_nodes])
248 |
249 | if targets is not None:
250 | latent_pred = self.projector(enc_rep[targets])
251 | latent_pred = self.predictor(latent_pred)
252 | loss_latent = sce_loss(latent_pred, latent_target, 1)
253 | else:
254 | latent_pred = self.projector(enc_rep[keep_nodes])
255 | latent_pred = self.predictor(latent_pred)
256 | loss_latent = sce_loss(latent_pred, latent_target, 1)
257 |
258 | # ---- attribute reconstruction ----
259 | origin_rep = self.encoder_to_decoder(enc_rep)
260 |
261 | loss_rec_all = 0
262 | if self._remask_method == "random":
263 | for i in range(self._num_remasking):
264 | rep = origin_rep.clone()
265 | rep, remask_nodes, rekeep_nodes = self.random_remask(use_g, rep, self._remask_rate)
266 | recon = self.decoder(pre_use_g, rep)
267 |
268 | x_init = x[mask_nodes]
269 | x_rec = recon[mask_nodes]
270 | loss_rec = self.criterion(x_init, x_rec)
271 | loss_rec_all += loss_rec
272 | loss_rec = loss_rec_all
273 | elif self._remask_method == "fixed":
274 | rep = self.fixed_remask(g, origin_rep, mask_nodes)
275 | x_rec = self.decoder(pre_use_g, rep)[mask_nodes]
276 | x_init = x[mask_nodes]
277 | loss_rec = self.criterion(x_init, x_rec)
278 | else:
279 | raise NotImplementedError
280 |
281 | loss = loss_rec + self._lam * loss_latent
282 |
283 | if epoch >= self._delayed_ema_epoch:
284 | self.ema_update()
285 | return loss
286 |
287 | def ema_update(self):
288 | def update(student, teacher):
289 | with torch.no_grad():
290 | # m = momentum_schedule[it] # momentum parameter
291 | m = self._momentum
292 | for param_q, param_k in zip(student.parameters(), teacher.parameters()):
293 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
294 | update(self.encoder, self.encoder_ema)
295 | update(self.projector, self.projector_ema)
296 |
297 | def embed(self, g, x):
298 | rep = self.encoder(g, x)
299 | return rep
300 |
301 | def get_encoder(self):
302 | #self.encoder.reset_classifier(out_size)
303 | return self.encoder
304 |
305 | def reset_encoder(self, out_size):
306 | self.encoder.reset_classifier(out_size)
307 |
308 | @property
309 | def enc_params(self):
310 | return self.encoder.parameters()
311 |
312 | @property
313 | def dec_params(self):
314 | return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])
315 |
316 | def output_grad(self):
317 | grad_dict = {}
318 | for n, p in self.named_parameters():
319 | if p.grad is not None:
320 | grad_dict[n] = p.grad.abs().mean().item()
321 | return grad_dict
322 |
323 | def encoding_mask_noise(self, g, x, mask_rate=0.3):
324 | num_nodes = g.num_nodes()
325 | perm = torch.randperm(num_nodes, device=x.device)
326 | num_mask_nodes = int(mask_rate * num_nodes)
327 |
328 | # exclude isolated nodes
329 | # isolated_nodes = torch.where(g.in_degrees() <= 1)[0]
330 | # mask_nodes = perm[: num_mask_nodes]
331 | # mask_nodes = torch.index_fill(torch.full((num_nodes,), False, device=device), 0, mask_nodes, True)
332 | # mask_nodes[isolated_nodes] = False
333 | # keep_nodes = torch.where(~mask_nodes)[0]
334 | # mask_nodes = torch.where(mask_nodes)[0]
335 | # num_mask_nodes = mask_nodes.shape[0]
336 |
337 | # random masking
338 | num_mask_nodes = int(mask_rate * num_nodes)
339 | mask_nodes = perm[: num_mask_nodes]
340 | keep_nodes = perm[num_mask_nodes: ]
341 |
342 | out_x = x.clone()
343 | token_nodes = mask_nodes
344 | out_x[mask_nodes] = 0.0
345 |
346 | out_x[token_nodes] += self.enc_mask_token
347 | use_g = g.clone()
348 |
349 | return use_g, out_x, (mask_nodes, keep_nodes)
350 |
351 | def random_remask(self,g,rep,remask_rate=0.5):
352 |
353 | num_nodes = g.num_nodes()
354 | perm = torch.randperm(num_nodes, device=rep.device)
355 | num_remask_nodes = int(remask_rate * num_nodes)
356 | remask_nodes = perm[: num_remask_nodes]
357 | rekeep_nodes = perm[num_remask_nodes: ]
358 |
359 | rep = rep.clone()
360 | rep[remask_nodes] = 0
361 | rep[remask_nodes] += self.dec_mask_token
362 |
363 | return rep, remask_nodes, rekeep_nodes
364 |
365 | def fixed_remask(self, g, rep, masked_nodes):
366 | rep[masked_nodes] = 0
367 | return rep
--------------------------------------------------------------------------------
/models/finetune.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 |
4 | from tqdm import tqdm
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 |
11 | from datasets.lc_sampler import setup_eval_dataloder, setup_finetune_dataloder, LinearProbingDataLoader
12 | from utils import accuracy, set_random_seed, show_occupied_memory, get_current_lr
13 |
14 | import wandb
15 |
16 |
17 | def linear_probing_minibatch(
18 | model, graph,
19 | feats, ego_graph_nodes, labels,
20 | lr_f, weight_decay_f, max_epoch_f,
21 | device, batch_size=-1, shuffle=True):
22 | logging.info("-- Linear Probing in downstream tasks ---")
23 | train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes = ego_graph_nodes
24 | num_train, num_val = len(train_ego_graph_nodes), len(val_ego_graph_nodes)
25 | train_lbls, val_lbls, test_lbls = labels
26 | # if dataset_name in ["ogbn-papers100M", "mag-scholar-f", "mag-scholar-c","ogbn-arxiv","ogbn-products"]:
27 | # if dataset_name in ["ogbn-papers100M", "mag-scholar-f", "mag-scholar-c", "ogbn-arxiv", "ogbn-products"]:
28 | eval_loader = setup_eval_dataloder("lc", graph, feats, train_ego_graph_nodes+val_ego_graph_nodes+test_ego_graph_nodes, 512)
29 |
30 | with torch.no_grad():
31 | model.eval()
32 | embeddings = []
33 |
34 | for batch in tqdm(eval_loader, desc="Infering..."):
35 | batch_g, targets, _, node_idx = batch
36 | batch_g = batch_g.to(device)
37 | x = batch_g.ndata.pop("feat")
38 | targets = targets.to(device)
39 |
40 | batch_emb = model.embed(batch_g, x)[targets]
41 | embeddings.append(batch_emb.cpu())
42 | embeddings = torch.cat(embeddings, dim=0)
43 |
44 | train_emb, val_emb, test_emb = embeddings[:num_train], embeddings[num_train:num_train+num_val], embeddings[num_train+num_val:]
45 |
46 | batch_size = 5120
47 | acc = []
48 | seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
49 | for i,_ in enumerate(seeds):
50 | print(f"####### Run seed {seeds[i]} for LinearProbing...")
51 | set_random_seed(seeds[i])
52 | print(f"training sample:{len(train_emb)}")
53 | test_acc = node_classification_linear_probing(
54 | (train_emb, val_emb, test_emb),
55 | (train_lbls, val_lbls, test_lbls),
56 | lr_f, weight_decay_f, max_epoch_f, device, batch_size=batch_size, shuffle=shuffle)
57 | acc.append(test_acc)
58 |
59 | print(f"# final_acc: {np.mean(acc):.4f}, std: {np.std(acc):.4f}")
60 |
61 | return np.mean(acc)
62 |
63 |
64 |
65 | class LogisticRegression(nn.Module):
66 | def __init__(self, num_dim, num_class):
67 | super().__init__()
68 | self.linear = nn.Linear(num_dim, num_class)
69 |
70 | def forward(self, g, x, *args):
71 | logits = self.linear(x)
72 | return logits
73 |
74 |
75 | def node_classification_linear_probing(embeddings, labels, lr, weight_decay, max_epoch, device, mute=False, batch_size=-1, shuffle=True):
76 | criterion = torch.nn.CrossEntropyLoss()
77 |
78 | train_emb, val_emb, test_emb = embeddings
79 | train_label, val_label, test_label = labels
80 | train_label = train_label.to(torch.long)
81 | val_label = val_label.to(torch.long)
82 | test_label = test_label.to(torch.long)
83 |
84 | best_val_acc = 0
85 | best_val_epoch = 0
86 | best_model = None
87 |
88 | if not mute:
89 | epoch_iter = tqdm(range(max_epoch))
90 | else:
91 | epoch_iter = range(max_epoch)
92 |
93 | encoder = LogisticRegression(train_emb.shape[1], int(train_label.max().item() + 1))
94 | encoder = encoder.to(device)
95 | optimizer = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=weight_decay)
96 |
97 | if batch_size > 0:
98 | train_loader = LinearProbingDataLoader(np.arange(len(train_emb)), train_emb, train_label, batch_size=batch_size, num_workers=4, persistent_workers=True, shuffle=shuffle)
99 | # train_loader = DataLoader(np.arange(len(train_emb)), batch_size=batch_size, shuffle=False)
100 | val_loader = LinearProbingDataLoader(np.arange(len(val_emb)), val_emb, val_label, batch_size=batch_size, num_workers=4, persistent_workers=True,shuffle=False)
101 | test_loader = LinearProbingDataLoader(np.arange(len(test_emb)), test_emb, test_label, batch_size=batch_size, num_workers=4, persistent_workers=True,shuffle=False)
102 | else:
103 | train_loader = [np.arange(len(train_emb))]
104 | val_loader = [np.arange(len(val_emb))]
105 | test_loader = [np.arange(len(test_emb))]
106 |
107 | def eval_forward(loader, _label):
108 | pred_all = []
109 | for batch_x, _ in loader:
110 | batch_x = batch_x.to(device)
111 | pred = encoder(None, batch_x)
112 | pred_all.append(pred.cpu())
113 | pred = torch.cat(pred_all, dim=0)
114 | acc = accuracy(pred, _label)
115 | return acc
116 |
117 | for epoch in epoch_iter:
118 | encoder.train()
119 |
120 | for batch_x, batch_label in train_loader:
121 | batch_x = batch_x.to(device)
122 | batch_label = batch_label.to(device)
123 | pred = encoder(None, batch_x)
124 | loss = criterion(pred, batch_label)
125 | optimizer.zero_grad()
126 | loss.backward()
127 | # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3)
128 | optimizer.step()
129 |
130 | with torch.no_grad():
131 | encoder.eval()
132 | val_acc = eval_forward(val_loader, val_label)
133 |
134 | if val_acc >= best_val_acc:
135 | best_val_acc = val_acc
136 | best_val_epoch = epoch
137 | best_model = copy.deepcopy(encoder)
138 |
139 | if not mute:
140 | epoch_iter.set_description(f"# Epoch: {epoch}, train_loss:{loss.item(): .4f}, val_acc:{val_acc:.4f}")
141 |
142 | best_model.eval()
143 | encoder = best_model
144 | with torch.no_grad():
145 | test_acc = eval_forward(test_loader, test_label)
146 | if mute:
147 | print(f"# IGNORE: --- TestAcc: {test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ")
148 | else:
149 | print(f"--- TestAcc: {test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ")
150 |
151 | return test_acc
152 |
153 |
154 | def finetune(
155 | model,
156 | graph,
157 | feats,
158 | ego_graph_nodes,
159 | labels,
160 | split_idx,
161 | lr_f, weight_decay_f, max_epoch_f,
162 | use_scheduler, batch_size,
163 | device,
164 | logger=None,
165 | full_graph_forward=False,
166 | ):
167 | logging.info("-- Finetuning in downstream tasks ---")
168 | train_egs, val_egs, test_egs = ego_graph_nodes
169 | print(f"num of egos:{len(train_egs)},{len(val_egs)},{len(test_egs)}")
170 |
171 | print(graph.num_nodes())
172 |
173 | train_nid = split_idx["train"].numpy()
174 | val_nid = split_idx["valid"].numpy()
175 | test_nid = split_idx["test"].numpy()
176 |
177 | train_lbls, val_lbls, test_lbls = [x.long() for x in labels]
178 | print(f"num of labels:{len(train_lbls)},{len(val_lbls)},{len(test_lbls)}")
179 |
180 | num_classes = max(max(train_lbls.max().item(), val_lbls.max().item()), test_lbls.max().item()) + 1
181 |
182 | model = model.get_encoder()
183 | model.reset_classifier(int(num_classes))
184 | model = model.to(device)
185 | criterion = torch.nn.CrossEntropyLoss()
186 |
187 | train_loader = setup_finetune_dataloder("lc", graph, feats, train_egs, train_lbls, batch_size=batch_size, shuffle=True)
188 | val_loader = setup_finetune_dataloder("lc", graph, feats, val_egs, val_lbls, batch_size=batch_size, shuffle=False)
189 | test_loader = setup_finetune_dataloder("lc", graph, feats, test_egs, test_lbls, batch_size=batch_size, shuffle=False)
190 |
191 | #optimizer = torch.optim.Adam(model.parameters(), lr=lr_f, weight_decay=weight_decay_f)
192 |
193 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr_f, weight_decay=weight_decay_f)
194 |
195 | if use_scheduler and max_epoch_f > 0:
196 | logging.info("Use schedular")
197 | warmup_epochs = int(max_epoch_f * 0.1)
198 | # scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch_f) ) * 0.5
199 | scheduler = lambda epoch: epoch / warmup_epochs if epoch < warmup_epochs else ( 1 + np.cos((epoch - warmup_epochs) * np.pi / (max_epoch_f - warmup_epochs))) * 0.5
200 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
201 | else:
202 | scheduler = None
203 |
204 | def eval_with_lc(model, loader):
205 | pred_counts = []
206 | model.eval()
207 | epoch_iter = tqdm(loader)
208 | with torch.no_grad():
209 | for batch in epoch_iter:
210 | batch_g, targets, batch_lbls, node_idx = batch
211 | batch_g = batch_g.to(device)
212 | batch_lbls = batch_lbls.to(device)
213 | x = batch_g.ndata.pop("feat")
214 |
215 | prediction = model(batch_g, x)
216 | prediction = prediction[targets]
217 | pred_counts.append((prediction.argmax(1) == batch_lbls))
218 | pred_counts = torch.cat(pred_counts)
219 | acc = pred_counts.float().sum() / pred_counts.shape[0]
220 | return acc
221 |
222 | def eval_full_prop(model, g, nfeat, val_nid, test_nid, batch_size, device):
223 | model.eval()
224 |
225 | with torch.no_grad():
226 | pred = model.inference(g, nfeat, batch_size, device)
227 | model.train()
228 |
229 | return accuracy(pred[val_nid], val_lbls.cpu()), accuracy(pred[test_nid], test_lbls.cpu())
230 |
231 | best_val_acc = 0
232 | best_model = None
233 | best_epoch = 0
234 | test_acc = 0
235 | early_stop_cnt = 0
236 |
237 | for epoch in range(max_epoch_f):
238 | if epoch == 0:
239 | scheduler.step()
240 | continue
241 | if early_stop_cnt >= 10:
242 | break
243 | epoch_iter = tqdm(train_loader)
244 | losses = []
245 | model.train()
246 |
247 | for batch_g, targets, batch_lbls, node_idx in epoch_iter:
248 | batch_g = batch_g.to(device)
249 | targets = targets.to(device)
250 | batch_lbls = batch_lbls.to(device)
251 | x = batch_g.ndata.pop("feat")
252 |
253 | prediction = model(batch_g, x)
254 | prediction = prediction[targets]
255 | loss = criterion(prediction, batch_lbls)
256 |
257 | optimizer.zero_grad()
258 | loss.backward()
259 |
260 | torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
261 | optimizer.step()
262 |
263 | metrics = {"finetune_loss": loss}
264 | wandb.log(metrics)
265 |
266 | if logger is not None:
267 | logger.log(metrics)
268 |
269 | epoch_iter.set_description(f"Finetuning | train_loss: {loss.item():.4f}, Memory: {show_occupied_memory():.2f} MB")
270 | losses.append(loss.item())
271 |
272 | if scheduler is not None:
273 | scheduler.step()
274 |
275 | if not full_graph_forward:
276 | if epoch > 0:
277 | val_acc = eval_with_lc(model, val_loader)
278 | _test_acc = 0
279 | else:
280 | if epoch > 0 and epoch % 1 == 0:
281 | val_acc, _test_acc = eval_full_prop(model, graph, feats, val_nid, test_nid, 10000, device)
282 | model = model.to(device)
283 |
284 | print('val Acc {:.4f}'.format(val_acc))
285 | if val_acc > best_val_acc:
286 | best_model = copy.deepcopy(model)
287 | best_val_acc = val_acc
288 | test_acc = _test_acc
289 | best_epoch = epoch
290 | early_stop_cnt = 0
291 | else:
292 | early_stop_cnt += 1
293 |
294 | if not full_graph_forward:
295 | print("val Acc {:.4f}, Best Val Acc {:.4f}".format(val_acc, best_val_acc))
296 | else:
297 | print("Val Acc {:.4f}, Best Val Acc {:.4f} Test Acc {:.4f}".format(val_acc, best_val_acc, test_acc))
298 |
299 | metrics = {"epoch_val_acc": val_acc,
300 | "test_acc": test_acc,
301 | "epoch": epoch,
302 | "lr_f": get_current_lr(optimizer)}
303 |
304 | wandb.log(metrics)
305 | if logger is not None:
306 | logger.log(metrics)
307 | print(f"# Finetuning - Epoch {epoch} | train_loss: {np.mean(losses):.4f}, ValAcc: {val_acc:.4f}, TestAcc: {test_acc:.4f}, Memory: {show_occupied_memory():.2f} MB")
308 |
309 | model = best_model
310 | if not full_graph_forward:
311 | test_acc = eval_with_lc(test_loader)
312 |
313 | print(f"Finetune | TestAcc: {test_acc:.4f} from Epoch {best_epoch}")
314 | return test_acc
315 |
316 |
317 | def linear_probing_full_batch(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob=True, mute=False):
318 | model.eval()
319 | with torch.no_grad():
320 | x = model.embed(graph.to(device), x.to(device))
321 | in_feat = x.shape[1]
322 | encoder = LogisticRegression(in_feat, num_classes)
323 |
324 | num_finetune_params = [p.numel() for p in encoder.parameters() if p.requires_grad]
325 | if not mute:
326 | print(f"num parameters for finetuning: {sum(num_finetune_params)}")
327 |
328 | encoder.to(device)
329 | optimizer_f = torch.optim.Adam(encoder.parameters(), lr=lr_f, weight_decay=weight_decay_f)
330 | final_acc, estp_acc = _linear_probing_full_batch(encoder, graph, x, optimizer_f, max_epoch_f, device, mute)
331 | return final_acc, estp_acc
332 |
333 |
334 | def _linear_probing_full_batch(model, graph, feat, optimizer, max_epoch, device, mute=False):
335 | criterion = torch.nn.CrossEntropyLoss()
336 |
337 | graph = graph.to(device)
338 | x = feat.to(device)
339 |
340 | train_mask = graph.ndata["train_mask"]
341 | val_mask = graph.ndata["val_mask"]
342 | test_mask = graph.ndata["test_mask"]
343 | labels = graph.ndata["label"]
344 |
345 | best_val_acc = 0
346 | best_val_epoch = 0
347 | best_model = None
348 |
349 | if not mute:
350 | epoch_iter = tqdm(range(max_epoch))
351 | else:
352 | epoch_iter = range(max_epoch)
353 |
354 | for epoch in epoch_iter:
355 | model.train()
356 | out = model(graph, x)
357 | loss = criterion(out[train_mask], labels[train_mask])
358 | optimizer.zero_grad()
359 | loss.backward()
360 | # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3)
361 | optimizer.step()
362 |
363 | with torch.no_grad():
364 | model.eval()
365 | pred = model(graph, x)
366 | val_acc = accuracy(pred[val_mask], labels[val_mask])
367 | val_loss = criterion(pred[val_mask], labels[val_mask])
368 | test_acc = accuracy(pred[test_mask], labels[test_mask])
369 | test_loss = criterion(pred[test_mask], labels[test_mask])
370 |
371 | if val_acc >= best_val_acc:
372 | best_val_acc = val_acc
373 | best_val_epoch = epoch
374 | best_model = copy.deepcopy(model)
375 |
376 | if not mute:
377 | epoch_iter.set_description(f"# Epoch: {epoch}, train_loss:{loss.item(): .4f}, val_loss:{val_loss.item(): .4f}, val_acc:{val_acc}, test_loss:{test_loss.item(): .4f}, test_acc:{test_acc: .4f}")
378 |
379 | best_model.eval()
380 | with torch.no_grad():
381 | pred = best_model(graph, x)
382 | estp_test_acc = accuracy(pred[test_mask], labels[test_mask])
383 | if mute:
384 | print(f"# IGNORE: --- TestAcc: {test_acc:.4f}, early-stopping-TestAcc: {estp_test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ")
385 | else:
386 | print(f"--- TestAcc: {test_acc:.4f}, early-stopping-TestAcc: {estp_test_acc:.4f}, Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch} --- ")
387 |
388 | return test_acc, estp_test_acc
389 |
--------------------------------------------------------------------------------
/models/gat.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | import dgl
7 | import dgl.function as fn
8 | from dgl.ops import edge_softmax
9 | from dgl.utils import expand_as_pair
10 |
11 |
12 | from utils import create_activation, create_norm
13 |
14 |
15 | class GAT(nn.Module):
16 | def __init__(self,
17 | in_dim,
18 | num_hidden,
19 | out_dim,
20 | num_layers,
21 | nhead,
22 | nhead_out,
23 | activation,
24 | feat_drop,
25 | attn_drop,
26 | negative_slope,
27 | residual,
28 | norm,
29 | concat_out=False,
30 | encoding=False,
31 | ):
32 | super(GAT, self).__init__()
33 | self.out_dim = out_dim
34 | self.num_heads = nhead
35 | self.num_heads_out = nhead_out
36 | self.num_hidden = num_hidden
37 | self.num_layers = num_layers
38 | self.gat_layers = nn.ModuleList()
39 | self.activation = activation
40 | self.concat_out = concat_out
41 |
42 | last_activation = create_activation(activation) if encoding else None
43 | last_residual = (encoding and residual)
44 | last_norm = norm if encoding else None
45 |
46 | hidden_in = in_dim
47 | hidden_out = out_dim
48 |
49 | if num_layers == 1:
50 | self.gat_layers.append(GATConv(
51 | hidden_in, hidden_out, nhead_out,
52 | feat_drop, attn_drop, negative_slope, last_residual, norm=last_norm, concat_out=concat_out))
53 | else:
54 | # input projection (no residual)
55 | self.gat_layers.append(GATConv(
56 | hidden_in, num_hidden, nhead,
57 | feat_drop, attn_drop, negative_slope, residual, create_activation(activation), norm=norm, concat_out=concat_out))
58 | # hidden layers
59 |
60 | for l in range(1, num_layers - 1):
61 | # due to multi-head, the in_dim = num_hidden * num_heads
62 | self.gat_layers.append(GATConv(
63 | num_hidden * nhead, num_hidden, nhead,
64 | feat_drop, attn_drop, negative_slope, residual, create_activation(activation), norm=norm, concat_out=concat_out))
65 |
66 | # output projection
67 | self.gat_layers.append(GATConv(
68 | num_hidden * nhead, hidden_out, nhead_out,
69 | feat_drop, attn_drop, negative_slope, last_residual, activation=last_activation, norm=last_norm, concat_out=concat_out))
70 | self.head = nn.Identity()
71 |
72 | def forward(self, g, inputs):
73 | h = inputs
74 |
75 | for l in range(self.num_layers):
76 | h = self.gat_layers[l](g, h)
77 |
78 | if self.head is not None:
79 | return self.head(h)
80 | else:
81 | return h
82 |
83 | def inference(self, g, x, batch_size, device, emb=False):
84 | """
85 | Inference with the GAT model on full neighbors (i.e. without neighbor sampling).
86 | g : the entire graph.
87 | x : the input of entire node set.
88 | The inference code is written in a fashion that it could handle any number of nodes and
89 | layers.
90 | """
91 | num_heads = self.num_heads
92 | num_heads_out = self.num_heads_out
93 | for l, layer in enumerate(self.gat_layers):
94 | if l < self.num_layers - 1:
95 | y = torch.zeros(g.num_nodes(), self.num_hidden * num_heads if l != len(self.gat_layers) - 1 else self.num_classes)
96 | else:
97 | if emb == False:
98 | y = torch.zeros(g.num_nodes(), self.num_hidden if l != len(self.gat_layers) - 1 else self.num_classes)
99 | else:
100 | y = torch.zeros(g.num_nodes(), self.out_dim * num_heads_out)
101 | sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
102 | dataloader = dgl.dataloading.DataLoader(
103 | g,
104 | torch.arange(g.num_nodes()),
105 | sampler,
106 | batch_size=batch_size,
107 | shuffle=False,
108 | drop_last=False,
109 | num_workers=8)
110 |
111 | for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
112 | block = blocks[0].int().to(device)
113 | h = x[input_nodes].to(device)
114 | if l < self.num_layers - 1:
115 | h = layer(block, h)
116 | else:
117 | h = layer(block, h)
118 |
119 | if l == len(self.gat_layers) - 1 and (emb == False):
120 | h = self.head(h)
121 | y[output_nodes] = h.cpu()
122 | x = y
123 | return y
124 |
125 | def reset_classifier(self, num_classes):
126 | self.num_classes = num_classes
127 | self.is_pretraining = False
128 | self.head = nn.Linear(self.num_heads * self.out_dim, num_classes)
129 |
130 |
131 |
132 | class GATConv(nn.Module):
133 | def __init__(self,
134 | in_feats,
135 | out_feats,
136 | num_heads,
137 | feat_drop=0.,
138 | attn_drop=0.,
139 | negative_slope=0.2,
140 | residual=False,
141 | activation=None,
142 | allow_zero_in_degree=False,
143 | bias=True,
144 | norm=None,
145 | concat_out=True):
146 | super(GATConv, self).__init__()
147 | self._num_heads = num_heads
148 | self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
149 | self._out_feats = out_feats
150 | self._allow_zero_in_degree = allow_zero_in_degree
151 | self._concat_out = concat_out
152 |
153 | if isinstance(in_feats, tuple):
154 | self.fc_src = nn.Linear(
155 | self._in_src_feats, out_feats * num_heads, bias=False)
156 | self.fc_dst = nn.Linear(
157 | self._in_dst_feats, out_feats * num_heads, bias=False)
158 | else:
159 | self.fc = nn.Linear(
160 | self._in_src_feats, out_feats * num_heads, bias=False)
161 | self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
162 | self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
163 | self.feat_drop = nn.Dropout(feat_drop)
164 | self.attn_drop = nn.Dropout(attn_drop)
165 | self.leaky_relu = nn.LeakyReLU(negative_slope)
166 | if bias:
167 | self.bias = nn.Parameter(torch.FloatTensor(size=(num_heads * out_feats,)))
168 | else:
169 | self.register_buffer('bias', None)
170 | if residual:
171 | if self._in_dst_feats != out_feats * num_heads:
172 | self.res_fc = nn.Linear(
173 | self._in_dst_feats, num_heads * out_feats, bias=False)
174 | else:
175 | self.res_fc = None
176 | else:
177 | self.register_buffer('res_fc', None)
178 | self.reset_parameters()
179 | self.activation = activation
180 |
181 | self.norm = norm
182 | if norm is not None:
183 | self.norm = create_norm(norm)(num_heads * out_feats)
184 | self.set_allow_zero_in_degree(False)
185 |
186 | def reset_parameters(self):
187 | """
188 |
189 | Description
190 | -----------
191 | Reinitialize learnable parameters.
192 |
193 | Note
194 | ----
195 | The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
196 | The attention weights are using xavier initialization method.
197 | """
198 | gain = nn.init.calculate_gain('relu')
199 | if hasattr(self, 'fc'):
200 | nn.init.xavier_normal_(self.fc.weight, gain=gain)
201 | else:
202 | nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
203 | nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
204 | nn.init.xavier_normal_(self.attn_l, gain=gain)
205 | nn.init.xavier_normal_(self.attn_r, gain=gain)
206 | if self.bias is not None:
207 | nn.init.constant_(self.bias, 0)
208 | if isinstance(self.res_fc, nn.Linear):
209 | nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
210 |
211 | def set_allow_zero_in_degree(self, set_value):
212 | self._allow_zero_in_degree = set_value
213 |
214 | def forward(self, graph, feat, get_attention=False):
215 | with graph.local_scope():
216 | if not self._allow_zero_in_degree:
217 | if (graph.in_degrees() == 0).any():
218 | raise RuntimeError('There are 0-in-degree nodes in the graph, '
219 | 'output for those nodes will be invalid. '
220 | 'This is harmful for some applications, '
221 | 'causing silent performance regression. '
222 | 'Adding self-loop on the input graph by '
223 | 'calling `g = dgl.add_self_loop(g)` will resolve '
224 | 'the issue. Setting ``allow_zero_in_degree`` '
225 | 'to be `True` when constructing this module will '
226 | 'suppress the check and let the code run.')
227 |
228 | if isinstance(feat, tuple):
229 | src_prefix_shape = feat[0].shape[:-1]
230 | dst_prefix_shape = feat[1].shape[:-1]
231 | h_src = self.feat_drop(feat[0])
232 | # h_dst = self.feat_drop(feat[1])
233 | h_dst = feat[1]
234 |
235 | if not hasattr(self, 'fc_src'):
236 | feat_src = self.fc(h_src).view(
237 | *src_prefix_shape, self._num_heads, self._out_feats)
238 | feat_dst = self.fc(h_dst).view(
239 | *dst_prefix_shape, self._num_heads, self._out_feats)
240 | else:
241 | feat_src = self.fc_src(h_src).view(
242 | *src_prefix_shape, self._num_heads, self._out_feats)
243 | feat_dst = self.fc_dst(h_dst).view(
244 | *dst_prefix_shape, self._num_heads, self._out_feats)
245 | else:
246 | src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
247 | h_src = h_dst = self.feat_drop(feat)
248 | feat_src = feat_dst = self.fc(h_src).view(
249 | *src_prefix_shape, self._num_heads, self._out_feats)
250 | if graph.is_block:
251 | feat_dst = feat_src[:graph.number_of_dst_nodes()]
252 | h_dst = h_dst[:graph.number_of_dst_nodes()]
253 | dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
254 | # NOTE: GAT paper uses "first concatenation then linear projection"
255 | # to compute attention scores, while ours is "first projection then
256 | # addition", the two approaches are mathematically equivalent:
257 | # We decompose the weight vector a mentioned in the paper into
258 | # [a_l || a_r], then
259 | # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
260 | # Our implementation is much efficient because we do not need to
261 | # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
262 | # addition could be optimized with DGL's built-in function u_add_v,
263 | # which further speeds up computation and saves memory footprint.
264 | el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
265 | er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
266 | graph.srcdata.update({'ft': feat_src, 'el': el})
267 | graph.dstdata.update({'er': er})
268 | # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
269 | graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
270 | e = self.leaky_relu(graph.edata.pop('e'))
271 | # e[e == 0] = -1e3
272 | # e = graph.edata.pop('e')
273 | # compute softmax
274 | graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
275 | # message passing
276 | graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
277 | fn.sum('m', 'ft'))
278 | rst = graph.dstdata['ft']
279 |
280 | # bias
281 | if self.bias is not None:
282 | rst = rst + self.bias.view(
283 | *((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats)
284 |
285 | # residual
286 | if self.res_fc is not None:
287 | # Use -1 rather than self._num_heads to handle broadcasting
288 | resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self._out_feats)
289 | rst = rst + resval
290 |
291 | if self._concat_out:
292 | rst = rst.flatten(1)
293 | else:
294 | rst = torch.mean(rst, dim=1)
295 |
296 | if self.norm is not None:
297 | rst = self.norm(rst)
298 |
299 | # activation
300 | if self.activation:
301 | rst = self.activation(rst)
302 |
303 | if get_attention:
304 | return rst, graph.edata['a']
305 | else:
306 | return rst
307 |
--------------------------------------------------------------------------------
/models/gcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl
6 | import dgl.function as fn
7 | from dgl.utils import expand_as_pair
8 |
9 | from utils import create_activation, create_norm
10 |
11 |
12 | class GCN(nn.Module):
13 | def __init__(self,
14 | in_dim,
15 | num_hidden,
16 | out_dim,
17 | num_layers,
18 | dropout,
19 | activation,
20 | residual,
21 | norm,
22 | encoding=False
23 | ):
24 | super(GCN, self).__init__()
25 | self.out_dim = out_dim
26 | self.num_layers = num_layers
27 | self.gcn_layers = nn.ModuleList()
28 | self.activation = activation
29 | self.dropout = dropout
30 |
31 | last_activation = create_activation(activation) if encoding else None
32 | last_residual = encoding and residual
33 | last_norm = norm if encoding else None
34 |
35 | if num_layers == 1:
36 | self.gcn_layers.append(GraphConv(
37 | in_dim, out_dim, residual=last_residual, norm=last_norm, activation=last_activation))
38 | else:
39 | # input projection (no residual)
40 | self.gcn_layers.append(GraphConv(
41 | in_dim, num_hidden, residual=residual, norm=norm, activation=create_activation(activation)))
42 | # hidden layers
43 | for l in range(1, num_layers - 1):
44 | # due to multi-head, the in_dim = num_hidden * num_heads
45 | self.gcn_layers.append(GraphConv(
46 | num_hidden, num_hidden, residual=residual, norm=norm, activation=create_activation(activation)))
47 | # output projection
48 | self.gcn_layers.append(GraphConv(
49 | num_hidden, out_dim, residual=last_residual, activation=last_activation, norm=last_norm))
50 |
51 | # if norm is not None:
52 | # self.norms = nn.ModuleList([
53 | # norm(num_hidden)
54 | # for _ in range(num_layers - 1)
55 | # ])
56 | # if not encoding:
57 | # self.norms.append(norm(out_dim))
58 | # else:
59 | # self.norms = None
60 | self.norms = None
61 | self.head = nn.Identity()
62 |
63 | def forward(self, g, inputs, return_hidden=False):
64 | h = inputs
65 | hidden_list = []
66 | for l in range(self.num_layers):
67 | h = F.dropout(h, p=self.dropout, training=self.training)
68 | h = self.gcn_layers[l](g, h)
69 | if self.norms is not None and l != self.num_layers - 1:
70 | h = self.norms[l](h)
71 | hidden_list.append(h)
72 | # output projection
73 | if self.norms is not None and len(self.norms) == self.num_layers:
74 | h = self.norms[-1](h)
75 | if return_hidden:
76 | return self.head(h), hidden_list
77 | else:
78 | return self.head(h)
79 |
80 | def reset_classifier(self, num_classes):
81 | self.head = nn.Linear(self.out_dim, num_classes)
82 |
83 |
84 | class GraphConv(nn.Module):
85 | def __init__(self,
86 | in_dim,
87 | out_dim,
88 | norm=None,
89 | activation=None,
90 | residual=True,
91 | ):
92 | super().__init__()
93 | self._in_feats = in_dim
94 | self._out_feats = out_dim
95 |
96 | self.fc = nn.Linear(in_dim, out_dim)
97 |
98 | if residual:
99 | if self._in_feats != self._out_feats:
100 | self.res_fc = nn.Linear(
101 | self._in_feats, self._out_feats, bias=False)
102 | print("! Linear Residual !")
103 | else:
104 | print("Identity Residual ")
105 | self.res_fc = nn.Identity()
106 | else:
107 | self.register_buffer('res_fc', None)
108 |
109 | # if norm == "batchnorm":
110 | # self.norm = nn.BatchNorm1d(out_dim)
111 | # elif norm == "layernorm":
112 | # self.norm = nn.LayerNorm(out_dim)
113 | # else:
114 | # self.norm = None
115 |
116 | self.norm = norm
117 | if norm is not None:
118 | self.norm = create_norm(norm)(out_dim)
119 | self._activation = activation
120 |
121 | self.reset_parameters()
122 |
123 | def reset_parameters(self):
124 | self.fc.reset_parameters()
125 |
126 | def forward(self, graph, feat):
127 | with graph.local_scope():
128 | aggregate_fn = fn.copy_src('h', 'm')
129 | # if edge_weight is not None:
130 | # assert edge_weight.shape[0] == graph.number_of_edges()
131 | # graph.edata['_edge_weight'] = edge_weight
132 | # aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
133 |
134 | # (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
135 | feat_src, feat_dst = expand_as_pair(feat, graph)
136 | # if self._norm in ['left', 'both']:
137 | degs = graph.out_degrees().float().clamp(min=1)
138 | norm = torch.pow(degs, -0.5)
139 | shp = norm.shape + (1,) * (feat_src.dim() - 1)
140 | norm = torch.reshape(norm, shp)
141 | feat_src = feat_src * norm
142 |
143 | # if self._in_feats > self._out_feats:
144 | # # mult W first to reduce the feature size for aggregation.
145 | # # if weight is not None:
146 | # # feat_src = th.matmul(feat_src, weight)
147 | # graph.srcdata['h'] = feat_src
148 | # graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
149 | # rst = graph.dstdata['h']
150 | # else:
151 | # aggregate first then mult W
152 | graph.srcdata['h'] = feat_src
153 | graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
154 | rst = graph.dstdata['h']
155 |
156 | rst = self.fc(rst)
157 |
158 | # if self._norm in ['right', 'both']:
159 | degs = graph.in_degrees().float().clamp(min=1)
160 | norm = torch.pow(degs, -0.5)
161 | shp = norm.shape + (1,) * (feat_dst.dim() - 1)
162 | norm = torch.reshape(norm, shp)
163 | rst = rst * norm
164 |
165 | if self.res_fc is not None:
166 | rst = rst + self.res_fc(feat_dst)
167 |
168 | if self.norm is not None:
169 | rst = self.norm(rst)
170 |
171 | if self._activation is not None:
172 | rst = self._activation(rst)
173 |
174 | return rst
175 |
--------------------------------------------------------------------------------
/models/loss_func.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def auc_pair_loss(x, y, z):
9 | x = F.normalize(x, p=2, dim=-1)
10 | y = F.normalize(y, p=2, dim=-1)
11 | z = F.normalize(z, p=2, dim=-1)
12 |
13 | sim = (x * y).sum(dim=-1)
14 | dissim = (x * z).sum(dim=-1)
15 | loss = (1 - sim + dissim).mean()
16 | # loss = (1 - sim).mean()
17 | return loss
18 |
19 |
20 | def sce_loss(x, y, alpha=3):
21 | x = F.normalize(x, p=2, dim=-1)
22 | y = F.normalize(y, p=2, dim=-1)
23 |
24 | # loss = - (x * y).sum(dim=-1)
25 | # loss = (x_h - y_h).norm(dim=1).pow(alpha)
26 |
27 | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
28 |
29 | loss = loss.mean()
30 | return loss
31 |
32 |
33 | class DINOLoss(nn.Module):
34 | def __init__(self, out_dim, warmup_teacher_temp, teacher_temp,
35 | warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
36 | center_momentum=0.9):
37 | super().__init__()
38 | self.student_temp = student_temp
39 | self.center_momentum = center_momentum
40 | self.register_buffer("center", torch.zeros(1, out_dim))
41 | # we apply a warm up for the teacher temperature because
42 | # a too high temperature makes the training instable at the beginning
43 | self.teacher_temp_schedule = np.concatenate((
44 | np.linspace(warmup_teacher_temp,
45 | teacher_temp, warmup_teacher_temp_epochs),
46 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
47 | ))
48 |
49 | def forward(self, student_output, teacher_output, epoch):
50 | """
51 | Cross-entropy between softmax outputs of the teacher and student networks.
52 | """
53 | student_out = student_output / self.student_temp
54 |
55 | # teacher centering and sharpening
56 | temp = self.teacher_temp_schedule[epoch]
57 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
58 | teacher_out = teacher_out.detach()
59 |
60 | loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1)
61 | loss = loss.mean()
62 | self.update_center(teacher_output)
63 | return loss
64 |
65 | # total_loss = 0
66 | # n_loss_terms = 0
67 | # for iq, q in enumerate(teacher_out):
68 | # for v in range(len(student_out)):
69 | # if v == iq:
70 | # # we skip cases where student and teacher operate on the same view
71 | # continue
72 | # loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
73 | # total_loss += loss.mean()
74 | # n_loss_terms += 1
75 | # total_loss /= n_loss_terms
76 | # self.update_center(teacher_output)
77 | # return total_loss
78 |
79 | @torch.no_grad()
80 | def update_center(self, teacher_output):
81 | """
82 | Update center used for teacher output.
83 | """
84 | batch_center = torch.mean(teacher_output, dim=0, keepdim=True)
85 |
86 | # ema update
87 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
88 |
89 |
90 | class MLPHead(nn.Module):
91 | def __init__(self, hidden_size, out_dim, num_layers=2, bottleneck_dim=256):
92 | super().__init__()
93 | self._num_layers = num_layers
94 | self.mlp = nn.ModuleList()
95 | for i in range(num_layers):
96 | if i == num_layers - 1:
97 | self.mlp.append(
98 | nn.Linear(hidden_size, bottleneck_dim)
99 | )
100 | else:
101 | self.mlp.append(nn.Linear(hidden_size, hidden_size))
102 | # self.mlp.append(nn.LayerNorm(hidden_size))
103 | self.mlp.append(nn.PReLU())
104 |
105 | self.apply(self._init_weights)
106 | # self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False)
107 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
108 | self.last_layer.weight_g.data.fill_(1)
109 | # self.last_layer.weight_g.requires_grad = False
110 | # self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False)
111 |
112 | def _init_weights(self, m):
113 | if isinstance(m, nn.Linear):
114 | trunc_normal_(m.weight, std=.02)
115 | if isinstance(m, nn.Linear) and m.bias is not None:
116 | nn.init.constant_(m.bias, 0)
117 |
118 | def forward(self, x):
119 | num_layers = len(self.mlp)
120 | for i, layer in enumerate(self.mlp):
121 | x = layer(x)
122 |
123 | x = nn.functional.normalize(x, dim=-1, p=2)
124 | x = self.last_layer(x)
125 | return x
126 |
127 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
128 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
129 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
130 | def norm_cdf(x):
131 | # Computes standard normal cumulative distribution function
132 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
133 |
134 | if (mean < a - 2 * std) or (mean > b + 2 * std):
135 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
136 | "The distribution of values may be incorrect.",
137 | stacklevel=2)
138 |
139 | with torch.no_grad():
140 | # Values are generated by using a truncated uniform distribution and
141 | # then using the inverse CDF for the normal distribution.
142 | # Get upper and lower cdf values
143 | l = norm_cdf((a - mean) / std)
144 | u = norm_cdf((b - mean) / std)
145 |
146 | # Uniformly fill tensor with values from [l, u], then translate to
147 | # [2l-1, 2u-1].
148 | tensor.uniform_(2 * l - 1, 2 * u - 1)
149 |
150 | # Use inverse cdf transform for normal distribution to get truncated
151 | # standard normal
152 | tensor.erfinv_()
153 |
154 | # Transform to proper mean, std
155 | tensor.mul_(std * math.sqrt(2.))
156 | tensor.add_(mean)
157 |
158 | # Clamp to ensure it's in the proper range
159 | tensor.clamp_(min=a, max=b)
160 | return tensor
161 |
162 |
163 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
164 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyyaml
2 | tqdm
3 | tensorboardX
4 | scikit-learn
5 | ogb
6 | torch
7 | dgl
8 |
--------------------------------------------------------------------------------
/run_fullbatch.sh:
--------------------------------------------------------------------------------
1 | dataset=$1
2 | device=$2
3 |
4 | [ -z "${dataset}" ] && dataset="cora"
5 | [ -z "${device}" ] && device=0
6 |
7 | CUDA_VISIBLE_DEVICES=$device \
8 | python main_full_batch.py \
9 | --device 0 \
10 | --dataset $dataset \
11 | --mask_method "random" \
12 | --remask_method "fixed" \
13 | --mask_rate 0.5 \
14 | --in_drop 0.2 \
15 | --attn_drop 0.1 \
16 | --num_layers 2 \
17 | --num_dec_layers 1 \
18 | --num_hidden 256 \
19 | --num_heads 4 \
20 | --num_out_heads 1 \
21 | --encoder "gat" \
22 | --decoder "gat" \
23 | --max_epoch 1000 \
24 | --max_epoch_f 300 \
25 | --lr 0.001 \
26 | --weight_decay 0.04 \
27 | --lr_f 0.005 \
28 | --weight_decay_f 1e-4 \
29 | --activation "prelu" \
30 | --loss_fn "sce" \
31 | --alpha_l 3 \
32 | --scheduler \
33 | --seeds 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 \
34 | --lam 0.5 \
35 | --linear_prob \
36 | --use_cfg
37 |
--------------------------------------------------------------------------------
/run_minibatch.sh:
--------------------------------------------------------------------------------
1 | dataset=$1
2 | device=$2
3 |
4 | [ -z "${dataset}" ] && dataset="ogbn-arxiv"
5 | [ -z "${device}" ] && device=0
6 |
7 | CUDA_VISIBLE_DEVICES=$device \
8 | python main_large.py \
9 | --device 0 \
10 | --dataset $dataset \
11 | --mask_type "mask" \
12 | --mask_rate 0.5 \
13 | --remask_rate 0.5 \
14 | --num_remasking 3 \
15 | --in_drop 0.2 \
16 | --attn_drop 0.2 \
17 | --num_layers 4 \
18 | --num_dec_layers 1 \
19 | --num_hidden 1024 \
20 | --num_heads 4 \
21 | --num_out_heads 1 \
22 | --encoder "gat" \
23 | --decoder "gat" \
24 | --max_epoch 60 \
25 | --max_epoch_f 1000 \
26 | --lr 0.002 \
27 | --weight_decay 0.04 \
28 | --lr_f 0.005 \
29 | --weight_decay_f 1e-4 \
30 | --activation "prelu" \
31 | --optimizer "adamw" \
32 | --drop_edge_rate 0.5 \
33 | --loss_fn "sce" \
34 | --alpha_l 4 \
35 | --mask_method "random" \
36 | --scheduler \
37 | --batch_size 512 \
38 | --batch_size_f 256 \
39 | --seeds 0 \
40 | --residual \
41 | --norm "layernorm" \
42 | --sampling_method "lc" \
43 | --label_rate 1.0 \
44 | --lam 1.0 \
45 | --momentum 0.996 \
46 | --linear_prob \
47 | --use_cfg \
48 | --ego_graph_file_path "./lc_ego_graphs/${dataset}-lc-ego-graphs-256.pt" \
49 | --data_dir "./dataset" \
50 | # --logging
51 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import random
4 | import psutil
5 | import yaml
6 | import logging
7 | from functools import partial
8 | from tensorboardX import SummaryWriter
9 | import wandb
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | from torch import optim as optim
15 |
16 |
17 | import dgl
18 | import dgl.function as fn
19 | from sklearn.decomposition import PCA
20 | from sklearn.manifold import TSNE
21 | import matplotlib.pyplot as plt
22 |
23 |
24 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
25 |
26 |
27 | def accuracy(y_pred, y_true):
28 | y_true = y_true.squeeze().long()
29 | preds = y_pred.max(1)[1].type_as(y_true)
30 | correct = preds.eq(y_true).double()
31 | correct = correct.sum().item()
32 | return correct / len(y_true)
33 |
34 |
35 | def set_random_seed(seed):
36 | random.seed(seed)
37 | np.random.seed(seed)
38 | torch.manual_seed(seed)
39 | torch.cuda.manual_seed(seed)
40 | torch.cuda.manual_seed_all(seed)
41 | torch.backends.cudnn.determinstic = True
42 |
43 |
44 | def get_current_lr(optimizer):
45 | return optimizer.state_dict()["param_groups"][0]["lr"]
46 |
47 |
48 | def build_args():
49 | parser = argparse.ArgumentParser(description="GAT")
50 | parser.add_argument("--seeds", type=int, nargs="+", default=[0])
51 | parser.add_argument("--dataset", type=str, default="cora")
52 | parser.add_argument("--device", type=int, default=0)
53 | parser.add_argument("--max_epoch", type=int, default=500,
54 | help="number of training epochs")
55 | parser.add_argument("--warmup_steps", type=int, default=-1)
56 |
57 | parser.add_argument("--num_heads", type=int, default=4,
58 | help="number of hidden attention heads")
59 | parser.add_argument("--num_out_heads", type=int, default=1,
60 | help="number of output attention heads")
61 | parser.add_argument("--num_layers", type=int, default=2,
62 | help="number of hidden layers")
63 | parser.add_argument("--num_dec_layers", type=int, default=1)
64 | parser.add_argument("--num_remasking", type=int, default=3)
65 | parser.add_argument("--num_hidden", type=int, default=512,
66 | help="number of hidden units")
67 | parser.add_argument("--residual", action="store_true", default=False,
68 | help="use residual connection")
69 | parser.add_argument("--in_drop", type=float, default=.2,
70 | help="input feature dropout")
71 | parser.add_argument("--attn_drop", type=float, default=.1,
72 | help="attention dropout")
73 | parser.add_argument("--norm", type=str, default=None)
74 | parser.add_argument("--lr", type=float, default=0.001,
75 | help="learning rate")
76 | parser.add_argument("--weight_decay", type=float, default=0,
77 | help="weight decay")
78 | parser.add_argument("--negative_slope", type=float, default=0.2,
79 | help="the negative slope of leaky relu")
80 | parser.add_argument("--activation", type=str, default="prelu")
81 | parser.add_argument("--mask_rate", type=float, default=0.5)
82 | parser.add_argument("--remask_rate", type=float, default=0.5)
83 | parser.add_argument("--remask_method", type=str, default="random")
84 | parser.add_argument("--mask_type", type=str, default="mask",
85 | help="`mask` or `drop`")
86 | parser.add_argument("--mask_method", type=str, default="random")
87 | parser.add_argument("--drop_edge_rate", type=float, default=0.0)
88 | parser.add_argument("--drop_edge_rate_f", type=float, default=0.0)
89 |
90 | parser.add_argument("--encoder", type=str, default="gat")
91 | parser.add_argument("--decoder", type=str, default="gat")
92 | parser.add_argument("--loss_fn", type=str, default="sce")
93 | parser.add_argument("--alpha_l", type=float, default=2)
94 | parser.add_argument("--optimizer", type=str, default="adam")
95 |
96 | parser.add_argument("--max_epoch_f", type=int, default=300)
97 | parser.add_argument("--lr_f", type=float, default=0.01)
98 | parser.add_argument("--weight_decay_f", type=float, default=0.0)
99 | parser.add_argument("--linear_prob", action="store_true", default=False)
100 |
101 |
102 | parser.add_argument("--no_pretrain", action="store_true")
103 | parser.add_argument("--load_model", action="store_true")
104 | parser.add_argument("--checkpoint_path", type=str, default=None)
105 | parser.add_argument("--use_cfg", action="store_true")
106 | parser.add_argument("--logging", action="store_true")
107 | parser.add_argument("--scheduler", action="store_true", default=False)
108 |
109 | parser.add_argument("--batch_size", type=int, default=256)
110 | parser.add_argument("--batch_size_f", type=int, default=128)
111 | parser.add_argument("--sampling_method", type=str, default="saint", help="sampling method, `lc` or `saint`")
112 |
113 | parser.add_argument("--label_rate", type=float, default=1.0)
114 | parser.add_argument("--ego_graph_file_path", type=str, default=None)
115 | parser.add_argument("--data_dir", type=str, default="data")
116 |
117 | parser.add_argument("--lam", type=float, default=1.0)
118 | parser.add_argument("--full_graph_forward", action="store_true", default=False)
119 | parser.add_argument("--delayed_ema_epoch", type=int, default=0)
120 | parser.add_argument("--replace_rate", type=float, default=0.0)
121 | parser.add_argument("--momentum", type=float, default=0.996)
122 |
123 | args = parser.parse_args()
124 | return args
125 |
126 | def create_activation(name):
127 | if name == "relu":
128 | return nn.ReLU()
129 | elif name == "gelu":
130 | return nn.GELU()
131 | elif name == "prelu":
132 | return nn.PReLU()
133 | elif name == "selu":
134 | return nn.SELU()
135 | elif name == "elu":
136 | return nn.ELU()
137 | elif name == "silu":
138 | return nn.SiLU()
139 | elif name is None:
140 | return nn.Identity()
141 | else:
142 | raise NotImplementedError(f"{name} is not implemented.")
143 |
144 |
145 | def identity_norm(x):
146 | def func(x):
147 | return x
148 | return func
149 |
150 | def create_norm(name):
151 | if name == "layernorm":
152 | return nn.LayerNorm
153 | elif name == "batchnorm":
154 | return nn.BatchNorm1d
155 | elif name == "identity":
156 | return identity_norm
157 | else:
158 | # print("Identity norm")
159 | return None
160 |
161 |
162 | def create_optimizer(opt, model, lr, weight_decay, get_num_layer=None, get_layer_scale=None):
163 | opt_lower = opt.lower()
164 | parameters = model.parameters()
165 | opt_args = dict(lr=lr, weight_decay=weight_decay)
166 |
167 | opt_split = opt_lower.split("_")
168 | opt_lower = opt_split[-1]
169 |
170 | if opt_lower == "adam":
171 | optimizer = optim.Adam(parameters, **opt_args)
172 | elif opt_lower == "adamw":
173 | optimizer = optim.AdamW(parameters, **opt_args)
174 | elif opt_lower == "adadelta":
175 | optimizer = optim.Adadelta(parameters, **opt_args)
176 | elif opt_lower == "sgd":
177 | opt_args["momentum"] = 0.9
178 | return optim.SGD(parameters, **opt_args)
179 | else:
180 | raise NotImplementedError("Invalid optimizer")
181 |
182 | return optimizer
183 |
184 |
185 | def show_occupied_memory():
186 | process = psutil.Process(os.getpid())
187 | return process.memory_info().rss / 1024**2
188 |
189 |
190 | # -------------------
191 | def mask_edge(graph, mask_prob):
192 | E = graph.num_edges()
193 |
194 | mask_rates = torch.ones(E) * mask_prob
195 | masks = torch.bernoulli(1 - mask_rates)
196 | mask_idx = masks.nonzero().squeeze(1)
197 | return mask_idx
198 |
199 |
200 | def drop_edge(graph, drop_rate, return_edges=False):
201 | if drop_rate <= 0:
202 | return graph
203 |
204 | graph = graph.remove_self_loop()
205 |
206 | n_node = graph.num_nodes()
207 | edge_mask = mask_edge(graph, drop_rate)
208 | src, dst = graph.edges()
209 |
210 | nsrc = src[edge_mask]
211 | ndst = dst[edge_mask]
212 |
213 | ng = dgl.graph((nsrc, ndst), num_nodes=n_node)
214 | ng = ng.add_self_loop()
215 |
216 | return ng
217 |
218 |
219 | def visualize(x, y, method="tsne"):
220 | if torch.is_tensor(x):
221 | x = x.cpu().numpy()
222 |
223 | if torch.is_tensor(y):
224 | y = y.cpu().numpy()
225 |
226 | if method == "tsne":
227 | func = TSNE(n_components=2)
228 | else:
229 | func = PCA(n_components=2)
230 | out = func.fit_transform(x)
231 | plt.scatter(out[:, 0], out[:, 1], c=y)
232 | plt.savefig("vis.png")
233 |
234 |
235 | def load_best_configs(args):
236 | dataset_name = args.dataset
237 | config_path = os.path.join("configs", f"{dataset_name}.yaml")
238 | with open(config_path, "r") as f:
239 | configs = yaml.load(f, yaml.FullLoader)
240 |
241 | for k, v in configs.items():
242 | if "lr" in k or "weight_decay" in k:
243 | v = float(v)
244 | setattr(args, k, v)
245 | logging.info(f"----- Using best configs from {config_path} -----")
246 |
247 | return args
248 |
249 |
250 |
251 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
252 | warmup_schedule = np.array([])
253 | warmup_iters = warmup_epochs * niter_per_ep
254 | if warmup_epochs > 0:
255 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
256 |
257 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
258 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
259 |
260 | scheduler = np.concatenate((warmup_schedule, schedule))
261 | assert len(scheduler) == epochs * niter_per_ep
262 | return scheduler
263 |
264 |
265 |
266 | # ------ logging ------
267 |
268 | class TBLogger(object):
269 | def __init__(self, log_path="./logging_data", name="run"):
270 | super(TBLogger, self).__init__()
271 |
272 | if not os.path.exists(log_path):
273 | os.makedirs(log_path, exist_ok=True)
274 |
275 | self.last_step = 0
276 | self.log_path = log_path
277 | raw_name = os.path.join(log_path, name)
278 | name = raw_name
279 | for i in range(1000):
280 | name = raw_name + str(f"_{i}")
281 | if not os.path.exists(name):
282 | break
283 | self.writer = SummaryWriter(logdir=name)
284 |
285 | def note(self, metrics, step=None):
286 | if step is None:
287 | step = self.last_step
288 | for key, value in metrics.items():
289 | self.writer.add_scalar(key, value, step)
290 | self.last_step = step
291 |
292 | def finish(self):
293 | self.writer.close()
294 |
295 |
296 | class WandbLogger(object):
297 | def __init__(self, log_path, project, args):
298 | self.log_path = log_path
299 | self.project = project
300 | self.args = args
301 | self.last_step = 0
302 | self.project = project
303 | self.start()
304 |
305 | def start(self):
306 | self.run = wandb.init(config=self.args, project=self.project)
307 |
308 | def log(self, metrics, step=None):
309 | if not hasattr(self, "run"):
310 | self.start()
311 | if step is None:
312 | step = self.last_step
313 | self.run.log(metrics)
314 | self.last_step = step
315 |
316 | def finish(self):
317 | self.run.finish()
318 |
--------------------------------------------------------------------------------