├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── configuration ├── afgrl_cs.yml ├── afgrl_photo.yml ├── afgrl_wikics.yml ├── bgrl_cs.yml ├── bgrl_photo.yml ├── bgrl_wikics.yml ├── dgi_amazon.yml ├── dgi_coauthor.yml ├── dgi_wikics.yml ├── gca_amazon.yml ├── gca_coauthor.yml ├── gca_wikics.yml ├── graphcl.yml ├── graphcl_amazon.yml ├── graphcl_coauthor.yml ├── graphcl_cora.yml ├── graphcl_molecure_imdb_b.yml ├── graphcl_molecure_imdb_m.yml ├── graphcl_molecure_mutag.yml ├── graphcl_wikics.yml ├── graphmae_imdb_b.yml ├── graphmae_imdb_m.yml ├── graphmae_mutag.yml ├── graphregcl_amazon.yml ├── graphregcl_coauthor.yml ├── graphregcl_cora.yml ├── graphregcl_wikics.yml ├── infograph_imdb_b.yml ├── infograph_imdb_m.yml ├── infograph_mutag.yml ├── merit.yml ├── mvgrl_amazon.yml ├── mvgrl_coauthor.yml ├── mvgrl_cora.yml ├── mvgrl_wikics.yml ├── sugrl_amazon.yml ├── sugrl_coauthor.yml └── sugrl_wikics.yml ├── examples ├── example_BGRL.py ├── example_afgrl.py ├── example_dgi.py ├── example_gca.py ├── example_graphcl.py ├── example_graphcl_molecure.py ├── example_graphmae.py ├── example_infograph.py ├── example_merit.py ├── example_mvgrl.py ├── example_regcl.py └── example_sugrl.py ├── figures └── overview.png ├── requirements.txt ├── setup.py └── src ├── __init__.py ├── augment ├── __init__.py ├── augment_afgrl.py ├── augment_subgraph.py ├── base.py ├── collections.py ├── concate_aug.py ├── diff_matrix.py ├── echo.py ├── gca_augments.py ├── negative.py ├── positive.py ├── random_drop_edge.py ├── random_drop_node.py ├── random_mask.py ├── random_mask_channel.py ├── shuffle_node.py └── sum_emb.py ├── config ├── __init__.py └── config.py ├── data └── data_non_contrast.py ├── evaluation ├── __init__.py ├── base.py ├── k_means.py ├── logistic_regression.py ├── rf_regression.py ├── sim_search.py ├── svc_regression.py └── visulization_embeddings.py ├── loader ├── __init__.py ├── augment_data_loader.py └── loader.py ├── losses ├── __init__.py ├── infograph_losses.py ├── negative_mi.py ├── sce_loss.py ├── sig_loss.py └── simple_losses.py ├── methods ├── __init__.py ├── adgcl.py ├── afgrl.py ├── base.py ├── bgrl.py ├── dgi.py ├── dgi_bk.py ├── gca.py ├── graphcl.py ├── graphmae.py ├── hdmi.py ├── heco.py ├── infograph.py ├── merit.py ├── mvgrl.py ├── regcl.py ├── sugrl.py ├── utils.py └── utils │ ├── __init__.py │ ├── ema.py │ ├── gnn.py │ └── readout.py ├── similarity_functions ├── __init__.py ├── bilinear.py └── simple_functions.py ├── trainer ├── __init__.py ├── base.py ├── gca_trainer.py ├── non_contrast_trainer.py ├── simple_trainer.py └── utils │ ├── __init__.py │ └── early_stop.py ├── transforms ├── __init__.py ├── base.py ├── edge2adj.py ├── gcn_norm.py └── normalize_features.py ├── typing.py └── utils ├── __init__.py ├── add_adj.py ├── bilinear_sim.py └── create_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # pycharm 132 | .idea 133 | __pychache__ 134 | 135 | # check points 136 | *ckpt/ 137 | *ckpt*/ 138 | ckpt*/ 139 | .ckpt 140 | 141 | # others 142 | results/ 143 | .vscode/ 144 | pyg_data/ 145 | 146 | # local datasets downloaded from torch_geometric 147 | DBLP_data 148 | Aminer_data 149 | datasets/acm/processed/data.pt 150 | datasets/acm/processed/pre_filter.pt 151 | datasets/acm/processed/pre_transform.pt 152 | datasets/aminer/processed/data.pt 153 | datasets/aminer/processed/pre_filter.pt 154 | datasets/aminer/processed/pre_transform.pt 155 | datasets/freebase_movies/processed/data.pt 156 | datasets/freebase_movies/processed/pre_filter.pt 157 | datasets/freebase_movies/processed/pre_transform.pt 158 | datasets/dblp 159 | test.ipynb 160 | *.pkl 161 | 162 | datasets/ 163 | wikics/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use an pytorch 2.3.1 cuda 12.1 python 3.10.14 image as a parent image 2 | FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime 3 | 4 | # Install torch_geometric (https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) 5 | RUN pip install torch_geometric 6 | RUN pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html 7 | # Successfully installed aiohappyeyeballs-2.4.4 aiohttp-3.11.11 aiosignal-1.3.2 async-timeout-5.0.1 frozenlist-1.5.0 multidict-6.1.0 propcache-0.2.1 pyparsing-3.2.0 torch_geometric-2.6.1 yarl-1.18.3 8 | # Successfully installed pyg_lib-0.4.0+pt23cu121 scipy-1.14.1 torch_cluster-1.6.3+pt23cu121 torch_scatter-2.1.2+pt23cu121 torch_sparse-0.6.18+pt23cu121 torch_spline_conv-1.2.2+pt23cu121 9 | 10 | # Install faiss 11 | RUN pip install faiss-gpu 12 | # Successfully installed faiss-gpu-1.7.2 13 | 14 | # Install deep graph library (https://www.dgl.ai/pages/start.html) 15 | RUN pip install dgl -f https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html 16 | # Successfully installed annotated-types-0.7.0 dgl-2.4.0+cu121 pandas-2.2.3 pydantic-2.10.4 pydantic-core-2.27.2 python-dateutil-2.9.0.post0 typing-extensions-4.12.1 tzdata-2024.2 17 | 18 | # other common libraries 19 | RUN pip install scikit-learn matplotlib seaborn yacs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 iDEA-iSAIL Lab @UIUC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyG-SSL: A Graph Self-Supervised Learning Toolkit 2 | 3 | **[Documentation](https://pygssl-tutorial.readthedocs.io/en/latest/index.html)** | **[Paper](https://arxiv.org/abs/2412.21151)** 4 | 5 | PyG-SSL is a Python library built upon [PyTorch](https://pytorch.org) and [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/). It integrates state-of-the-art algorithms for Self-Supervised Learning on Graphs and offers simple APIs to call the algorithms from a variety of high-quality research papers. 6 | 7 | **Node-level Algorithms for General Graphs**: [DGI](https://arxiv.org/pdf/1809.10341.pdf), [GraphCL](https://proceedings.nips.cc/paper/2020/file/3fe230348e9a12c13120749e3f9fa4cd-Paper.pdf), 8 | [MVGRL](https://proceedings.mlr.press/v119/hassani20a/hassani20a.pdf), [GCA](https://arxiv.org/abs/2010.14945), 9 | [SUGRL](https://ojs.aaai.org/index.php/AAAI/article/view/20748), [ReGCL](https://ojs.aaai.org/index.php/AAAI/article/view/28698), 10 | [BGRL](https://arxiv.org/abs/2102.06514), [AFGRL](https://arxiv.org/abs/2112.02472) 11 | 12 | **Algorithms for Heterogeneous/Multiplex/Multiview Graphs**: [HeCo](https://arxiv.org/pdf/2105.09111.pdf) 13 | 14 | **Graph-level Algorithms for Graph-level Representation Learning or Molecular Graphs**: [InfoGraph](https://openreview.net/pdf?id=r1lfF2NYvH), [GraphCL](https://proceedings.nips.cc/paper/2020/file/3fe230348e9a12c13120749e3f9fa4cd-Paper.pdf), [GraphMAE](https://arxiv.org/pdf/2205.10803.pdf) 15 | 16 | 17 | ---------------------------------------------------------- 18 | ![Overview](figures/overview.png) 19 | Generally, PyG-SSL have four components: Configuration, Method, Trainer and Evaluator. The user only needs to specify the configuration and write a simple script which calls the other components within this framework. The library will automatically run the method, train the model, and evaluate the results. The user can also customize the method, trainer, and evaluator to fit their own needs. We provide example configurations in the configuration folder , and example scripts in the root folder. The user can easily modify these configurations and scripts to run their own experiments. 20 | 21 | ## Installation 22 | While we provide requirements.txt for installing the dependencies with pip, sometimes pip is not able to resolve the dependencies correctly. We recommend using the following ways to build the environment. 23 | ### Directly use the source 24 | One way to use the source code is to clone the repository and install the dependencies. We recommend to create a new conda environment to prevent conflicts with other packages. 25 | ```bash 26 | conda create -n pyg-ssl python=3.9 27 | pip install torch_geometric 28 | pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html 29 | pip install faiss-gpu 30 | pip install dgl -f https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html 31 | pip install scikit-learn matplotlib seaborn yacs 32 | ``` 33 | Then you are free to go with some basic examples. In the root directory of this repository, you can run the following command to test the installation. 34 | ```bash 35 | cp examples/example_dgi.py . 36 | python example_dgi.py 37 | ``` 38 | 39 | ### Use our docker image 40 | We provide a docker image at https://hub.docker.com/r/violet24k/pyg-ssl for users to run the code easily. Our docker image is based on the official PyTorch docker image with pytorch 2.3.1, cuda 12.1 and python 3.10.14. 41 | 42 | With Docker installed on your machine, pull our image and start a container by 43 | ```bash 44 | docker pull violet24k/pyg-ssl 45 | docker run --gpus all -it -v .:/workspace violet24k/pyg-ssl bash 46 | ``` 47 | 48 | ### Build from source 49 | We provide setup.py for users to build the package from source. You can run the following command to build the package. 50 | ```bash 51 | conda create -n pyg-ssl python=3.9 52 | pip install -e . # -e for editable mode so that you can taylor the code to your needs 53 | 54 | 55 | # test the installation 56 | python 57 | >>> import pyg_ssl 58 | >>> print(pyg_ssl.__version__) 59 | ``` 60 | Then you can run the example scripts in the root directory of this repository. 61 | 62 | 64 | 65 | 66 | ## Cite 67 | If you find this repository useful in your research, please consider citing the following paper: 68 | 69 | ``` 70 | @article{zheng2024pyg, 71 | title={PyG-SSL: A Graph Self-Supervised Learning Toolkit}, 72 | author={Zheng, Lecheng and Jing, Baoyu and Li, Zihao and Zeng, Zhichen and Wei, Tianxin and Ai, Mengting and He, Xinrui and Liu, Lihui and Fu, Dongqi and You, Jiaxuan and others}, 73 | journal={arXiv preprint arXiv:2412.21151}, 74 | year={2024} 75 | } 76 | ``` 77 | 78 | ``` 79 | @inproceedings{DBLP:conf/kdd/ZhengJLTH24, 80 | author = {Lecheng Zheng and 81 | Baoyu Jing and 82 | Zihao Li and 83 | Hanghang Tong and 84 | Jingrui He}, 85 | editor = {Ricardo Baeza{-}Yates and 86 | Francesco Bonchi}, 87 | title = {Heterogeneous Contrastive Learning for Foundation Models and Beyond}, 88 | booktitle = {Proceedings of the 30th {ACM} {SIGKDD} Conference on Knowledge Discovery 89 | and Data Mining, {KDD} 2024, Barcelona, Spain, August 25-29, 2024}, 90 | pages = {6666--6676}, 91 | publisher = {{ACM}}, 92 | year = {2024}, 93 | url = {https://doi.org/10.1145/3637528.3671454}, 94 | doi = {10.1145/3637528.3671454}, 95 | timestamp = {Sun, 08 Sep 2024 16:05:58 +0200}, 96 | biburl = {https://dblp.org/rec/conf/kdd/ZhengJLTH24.bib}, 97 | bibsource = {dblp computer science bibliography, https://dblp.org} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /configuration/afgrl_cs.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | topk: 8 9 | num_centroids: 100 10 | num_kmeans: 5 11 | clus_num_iters: 20 12 | optim: 13 | base_lr: 0.00005 14 | name : adam 15 | max_epoch : 5000 16 | patience : 1000 17 | weight_decay : 0.00001 18 | moving_average_decay: 0.9 19 | dataset: 20 | name : cs 21 | root : pyg_data 22 | output: 23 | verbose : True 24 | save_dir : './saved_model' 25 | interval : 100 26 | -------------------------------------------------------------------------------- /configuration/afgrl_photo.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | topk: 4 9 | num_centroids: 100 10 | num_kmeans: 5 11 | clus_num_iters: 20 12 | optim: 13 | base_lr: 0.0001 14 | name : adam 15 | max_epoch : 5000 16 | patience : 1000 17 | weight_decay : 0.00001 18 | moving_average_decay: 0.9 19 | dataset: 20 | name : photo 21 | root : pyg_data 22 | output: 23 | verbose : True 24 | save_dir : './saved_model' 25 | interval : 100 26 | -------------------------------------------------------------------------------- /configuration/afgrl_wikics.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | topk: 4 9 | num_centroids: 100 10 | num_kmeans: 5 11 | clus_num_iters: 20 12 | optim: 13 | base_lr: 0.0001 14 | name : adam 15 | max_epoch : 5000 16 | patience : 1000 17 | weight_decay : 0.00001 18 | moving_average_decay: 0.9 19 | dataset: 20 | name : wikics 21 | root : pyg_data 22 | output: 23 | verbose : True 24 | save_dir : './saved_model' 25 | interval : 100 26 | -------------------------------------------------------------------------------- /configuration/bgrl_cs.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 1 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | aug_mask_1: 0.1 9 | aug_mask_2: 0.2 10 | aug_edge_1: 0.4 11 | aug_edge_2: 0.1 12 | optim: 13 | base_lr: 0.0001 14 | name : adam 15 | max_epoch : 4000 16 | patience : 1000 17 | weight_decay : 0.00001 18 | moving_average_decay: 0.99 19 | dataset: 20 | name : cs 21 | root : pyg_data 22 | output: 23 | verbose : True 24 | save_dir : './saved_model' 25 | interval : 100 26 | -------------------------------------------------------------------------------- /configuration/bgrl_photo.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | aug_mask_1: 0.1 9 | aug_mask_2: 0.2 10 | aug_edge_1: 0.4 11 | aug_edge_2: 0.1 12 | optim: 13 | base_lr: 0.0001 14 | name : adam 15 | max_epoch : 4000 16 | patience : 1000 17 | weight_decay : 0.00001 18 | moving_average_decay: 0.99 19 | dataset: 20 | name : photo 21 | root : pyg_data 22 | output: 23 | verbose : True 24 | save_dir : './saved_model' 25 | interval : 100 26 | -------------------------------------------------------------------------------- /configuration/bgrl_wikics.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 1 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | aug_mask_1: 0.1 9 | aug_mask_2: 0.2 10 | aug_edge_1: 0.4 11 | aug_edge_2: 0.1 12 | optim: 13 | base_lr: 0.0001 14 | name : adam 15 | max_epoch : 4000 16 | patience : 1000 17 | weight_decay : 0.00001 18 | moving_average_decay: 0.99 19 | dataset: 20 | name : wikics 21 | root : pyg_data 22 | output: 23 | verbose : True 24 | save_dir : './saved_model' 25 | interval : 100 26 | -------------------------------------------------------------------------------- /configuration/dgi_amazon.yml: -------------------------------------------------------------------------------- 1 | dataset: "photo" 2 | root: "pyg_data" 3 | black_list: [4,5,6] 4 | lr: 0.01 5 | out_heads: 1 6 | task_type: Node_Transductive 7 | val_interval: 1 8 | num_hidden_features: 8 9 | to_undirected_at_neg: True 10 | epochs: 1500 11 | patience: 100 12 | w_loss1: 100 13 | w_loss2: 100 14 | w_loss3: 1 15 | margin1: 0.9 16 | margin2: 0.9 17 | dim: 128 18 | cfg: [512, 128] 19 | NewATop: 0 20 | dropout: 0.1 21 | NN: 1 22 | num1: 200 23 | wd: 0.0 24 | weight_decay: 0.0001 25 | in_channels: 745 -------------------------------------------------------------------------------- /configuration/dgi_coauthor.yml: -------------------------------------------------------------------------------- 1 | dataset: "cs" 2 | root: "pyg_data" 3 | black_list: [4,5,6] 4 | lr: 0.01 5 | out_heads: 1 6 | task_type: Node_Transductive 7 | val_interval: 1 8 | num_hidden_features: 8 9 | epochs: 500 10 | to_undirected_at_neg: True 11 | epochs: 1000 12 | patience: 100 13 | w_loss1: 100 14 | w_loss2: 100 15 | w_loss3: 1 16 | margin1: 0.9 17 | margin2: 0.9 18 | dim: 128 19 | cfg: [512, 128] 20 | NewATop: 0 21 | dropout: 0.1 22 | NN: 1 23 | num1: 200 24 | wd: 0.0 25 | weight_decay: 0.0001 26 | in_channels: 6805 -------------------------------------------------------------------------------- /configuration/dgi_wikics.yml: -------------------------------------------------------------------------------- 1 | dataset: "wikics" 2 | root: "pyg_data" 3 | black_list: [4,5,6] 4 | lr: 0.001 5 | out_heads: 1 6 | task_type: Node_Transductive 7 | val_interval: 1 8 | num_hidden_features: 8 9 | to_undirected_at_neg: True 10 | epochs: 2000 11 | patience: 100 12 | w_loss1: 100 13 | w_loss2: 100 14 | w_loss3: 2 15 | margin1: 0.4 16 | margin2: 0.6 17 | dim: 128 18 | cfg: [512, 128] 19 | NewATop: 0 20 | dropout: 0.1 21 | NN: 5 22 | num1: 300 23 | wd: 0.0 24 | weight_decay: 0.0001 25 | in_channels: 300 -------------------------------------------------------------------------------- /configuration/gca_amazon.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | out_channels: 256 5 | activation: prelu 6 | hidden_channels: 256 7 | mode: 4 8 | num_layers: 2 9 | num_proj_hidden: 32 10 | tau: 0.4 11 | drop_edge_rate_1: 0.2 12 | drop_edge_rate_2: 0.3 13 | drop_feature_rate_1: 0.1 14 | drop_feature_rate_2: 0.1 15 | optim: 16 | lr: 0.0001 17 | name : adam 18 | max_epoch : 50 19 | patience : 30 20 | weight_decay : 0.0005 21 | dataset: 22 | name : Amazon-Photo 23 | root : pyg_data 24 | output: 25 | verbose : True 26 | save_dir : './saved_model' 27 | interval : 100 28 | classifier: 29 | base_lr: 0.01 30 | name : adam 31 | max_epoch : 1500 32 | patience : 100 33 | run : 10 34 | weight_decay : 0.0 -------------------------------------------------------------------------------- /configuration/gca_coauthor.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | out_channels: 256 5 | activation: prelu 6 | hidden_channels: 256 7 | mode: 4 8 | num_layers: 2 9 | num_proj_hidden: 32 10 | tau: 0.4 11 | drop_edge_rate_1: 0.2 12 | drop_edge_rate_2: 0.3 13 | drop_feature_rate_1: 0.1 14 | drop_feature_rate_2: 0.1 15 | optim: 16 | lr: 0.0001 17 | name : adam 18 | max_epoch : 50 19 | patience : 30 20 | weight_decay : 0.0005 21 | dataset: 22 | name : Coauthor-CS 23 | root : pyg_data 24 | output: 25 | verbose : True 26 | save_dir : './saved_model' 27 | interval : 100 28 | classifier: 29 | base_lr: 0.01 30 | name : adam 31 | max_epoch : 1500 32 | patience : 100 33 | run : 10 34 | weight_decay : 0.0 -------------------------------------------------------------------------------- /configuration/gca_wikics.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | out_channels: 256 5 | activation: prelu 6 | hidden_channels: 256 7 | mode: 4 8 | num_layers: 2 9 | num_proj_hidden: 32 10 | tau: 0.4 11 | drop_edge_rate_1: 0.2 12 | drop_edge_rate_2: 0.3 13 | drop_feature_rate_1: 0.1 14 | drop_feature_rate_2: 0.1 15 | optim: 16 | lr: 0.0001 17 | name : adam 18 | max_epoch : 50 19 | patience : 30 20 | weight_decay : 0.0005 21 | dataset: 22 | name : WikiCS 23 | root : pyg_data 24 | output: 25 | verbose : True 26 | save_dir : './saved_model' 27 | interval : 100 28 | classifier: 29 | base_lr: 0.01 30 | name : adam 31 | max_epoch : 1500 32 | patience : 100 33 | run : 10 34 | weight_decay : 0.0 -------------------------------------------------------------------------------- /configuration/graphcl.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 3703 7 | hidden_channels: 512 8 | aug_type: 'mask' 9 | optim: 10 | base_lr: 0.001 11 | name : adam 12 | max_epoch : 3000 13 | patience : 1000 14 | weight_decay : 0.0005 15 | dataset: 16 | name : citeseer 17 | root : pyg_data 18 | output: 19 | verbose : True 20 | save_dir : './saved_model' 21 | interval : 100 22 | aaa: 23 | aaaa: 0 24 | bbbb: 0 25 | -------------------------------------------------------------------------------- /configuration/graphcl_amazon.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 3 6 | in_channels: 745 7 | hidden_channels: 256 8 | aug_type: 'mask' 9 | optim: 10 | base_lr: 0.003 11 | name : adam 12 | max_epoch : 3000 13 | patience : 200 14 | weight_decay : 0.0005 15 | run: 1 16 | classifier: 17 | base_lr: 0.001 18 | name : adam 19 | max_epoch : 30000 20 | patience : 500 21 | run : 10 22 | weight_decay : 0.0005 23 | dataset: 24 | name : Amazon 25 | root : pyg_data 26 | output: 27 | verbose : True 28 | save_dir : './saved_model' 29 | interval : 100 30 | -------------------------------------------------------------------------------- /configuration/graphcl_coauthor.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 6805 7 | hidden_channels: 256 8 | aug_type: 'mask' 9 | optim: 10 | base_lr: 0.001 11 | name : adam 12 | max_epoch : 3000 13 | patience : 200 14 | weight_decay : 0.0005 15 | classifier: 16 | base_lr: 0.001 17 | name : adam 18 | max_epoch : 30000 19 | patience : 500 20 | run : 10 21 | weight_decay : 0.0005 22 | dataset: 23 | name : coauthor 24 | root : pyg_data 25 | output: 26 | verbose : True 27 | save_dir : './saved_model' 28 | interval : 100 29 | aaa: 30 | aaaa: 0 31 | bbbb: 0 32 | -------------------------------------------------------------------------------- /configuration/graphcl_cora.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | aug_type: 'mask' 9 | optim: 10 | base_lr: 0.001 11 | name : adam 12 | max_epoch : 3000 13 | patience : 1000 14 | weight_decay : 0.0005 15 | dataset: 16 | name : cora 17 | root : pyg_data 18 | output: 19 | verbose : True 20 | save_dir : './saved_model' 21 | interval : 100 22 | aaa: 23 | aaaa: 0 24 | bbbb: 0 25 | -------------------------------------------------------------------------------- /configuration/graphcl_molecure_imdb_b.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | torch_seed: 0 4 | model: 5 | dropout: 0.5 6 | n_layers: 3 7 | in_channels: 3703 8 | hidden_channels: 32 9 | aug_type: 'mask' 10 | optim: 11 | base_lr: 0.0001 12 | name : adam 13 | max_epoch : 1 14 | patience : 10 15 | weight_decay : 0.0001 16 | classifier: 17 | base_lr: 0.01 18 | name : adam 19 | max_epoch : 300 20 | patience : 100 21 | run : 10 22 | weight_decay : 0.0005 23 | dataset: 24 | name : IMDB-BINARY 25 | root : pyg_data 26 | batch_size: 128 27 | -------------------------------------------------------------------------------- /configuration/graphcl_molecure_imdb_m.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | torch_seed: 1 4 | model: 5 | dropout: 0.5 6 | n_layers: 3 7 | in_channels: 3703 8 | hidden_channels: 32 9 | aug_type: 'mask' 10 | optim: 11 | base_lr: 0.00015 12 | name : adam 13 | max_epoch : 20 14 | patience : 10 15 | weight_decay : 0.000 16 | classifier: 17 | base_lr: 0.01 18 | name : adam 19 | max_epoch : 300 20 | patience : 100 21 | weight_decay : 0.000 22 | dataset: 23 | name : IMDB-MULTI 24 | root : pyg_data 25 | batch_size: 128 26 | -------------------------------------------------------------------------------- /configuration/graphcl_molecure_mutag.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | torch_seed: 0 4 | model: 5 | dropout: 0.5 6 | n_layers: 3 7 | in_channels: 3703 8 | hidden_channels: 32 9 | aug_type: 'mask' 10 | optim: 11 | base_lr: 0.001 12 | name : adam 13 | max_epoch : 30 14 | patience : 10 15 | weight_decay : 0.0 16 | classifier: 17 | base_lr: 0.01 18 | name : adam 19 | max_epoch : 200 20 | patience : 100 21 | run : 10 22 | weight_decay : 0.00 23 | dataset: 24 | name : MUTAG 25 | root : pyg_data 26 | batch_size: 128 27 | 28 | -------------------------------------------------------------------------------- /configuration/graphcl_wikics.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 300 7 | hidden_channels: 256 8 | aug_type: 'mask' 9 | optim: 10 | base_lr: 0.001 11 | name : adam 12 | max_epoch : 3000 13 | patience : 200 14 | weight_decay : 0.0005 15 | classifier: 16 | base_lr: 0.001 17 | name : adam 18 | max_epoch : 3000 19 | patience : 1000 20 | run : 10 21 | weight_decay : 0.0005 22 | dataset: 23 | name : WikiCS 24 | root : pyg_data 25 | output: 26 | verbose : True 27 | save_dir : './saved_model' 28 | interval : 100 29 | aaa: 30 | aaaa: 0 31 | bbbb: 0 32 | -------------------------------------------------------------------------------- /configuration/graphmae_imdb_b.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | torch_seed: 0 4 | model: 5 | dropout: 0.5 6 | encoder_layers: 3 7 | in_channels: 3703 8 | hidden_channels: 512 9 | aug_type: 'mask' 10 | num_hidden: 32 11 | decoder_layers: 2 12 | lr: 0.0005 13 | weight_decay: 0 14 | weight_decay_f: 0 15 | max_epoch: 60 16 | max_epoch_f: 500 17 | mask_rate: 0.5 18 | num_layers: 2 19 | encoder_type: gin 20 | decoder_type: gin 21 | activation: prelu 22 | loss_fn: sce 23 | scheduler: False 24 | pooling: sum 25 | batch_size: 32 26 | alpha_l: 1 27 | replace_rate: 0.0 28 | drop_edge_rate: 0.0 29 | norm: batchnorm 30 | in_drop: 0.2 31 | attn_drop: 0.1 32 | optim: 33 | base_lr: 0.00015 34 | name : adam 35 | max_epoch : 100 36 | patience : 10 37 | weight_decay : 0.000 38 | classifier: 39 | base_lr: 0.005 40 | name : adam 41 | max_epoch : 1000 42 | patience : 150 43 | run : 10 44 | weight_decay : 0.00 45 | loss_type : 'mse' 46 | dataset: 47 | name : IMDB-BINARY 48 | root : pyg_data 49 | batch_size: 32 50 | 51 | -------------------------------------------------------------------------------- /configuration/graphmae_imdb_m.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | torch_seed: 0 4 | model: 5 | dropout: 0.5 6 | encoder_layers: 3 7 | in_channels: 3703 8 | hidden_channels: 512 9 | aug_type: 'mask' 10 | num_hidden: 32 11 | decoder_layers: 2 12 | lr: 0.0005 13 | weight_decay: 0 14 | weight_decay_f: 0 15 | max_epoch: 60 16 | max_epoch_f: 500 17 | mask_rate: 0.5 18 | num_layers: 2 19 | encoder_type: gin 20 | decoder_type: gin 21 | activation: prelu 22 | loss_fn: sce 23 | scheduler: False 24 | pooling: sum 25 | batch_size: 32 26 | alpha_l: 1 27 | replace_rate: 0.0 28 | drop_edge_rate: 0.0 29 | norm: batchnorm 30 | in_drop: 0.2 31 | attn_drop: 0.1 32 | optim: 33 | base_lr: 0.00015 34 | name : adam 35 | max_epoch : 60 36 | patience : 10 37 | weight_decay : 0.000 38 | classifier: 39 | base_lr: 0.005 40 | name : adam 41 | max_epoch : 2000 42 | patience : 150 43 | run : 10 44 | weight_decay : 0.00 45 | loss_type : 'mse' 46 | dataset: 47 | name : IMDB-MULTI 48 | root : pyg_data 49 | batch_size: 32 50 | 51 | -------------------------------------------------------------------------------- /configuration/graphmae_mutag.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | torch_seed: 0 4 | model: 5 | dropout: 0.5 6 | encoder_layers: 3 7 | in_channels: 3703 8 | hidden_channels: 32 9 | aug_type: 'mask' 10 | num_hidden: 32 11 | decoder_layers: 2 12 | lr: 0.0005 13 | weight_decay: 0.00 14 | mask_rate: 0.75 15 | drop_edge_rate: 0.0 16 | max_epoch: 20 17 | encoder_type: gin 18 | decoder_type: gin 19 | activation: prelu 20 | loss_fn: sce 21 | scheduler: False 22 | pooling: sum 23 | batch_size: 64 24 | alpha_l: 2 25 | replace_rate: 0.1 26 | norm: batchnorm 27 | in_drop: 0.2 28 | attn_drop: 0.1 29 | optim: 30 | base_lr: 0.00015 31 | name : adam 32 | max_epoch : 50 33 | patience : 10 34 | weight_decay : 0.0 35 | classifier: 36 | base_lr: 0.005 37 | name : adam 38 | max_epoch : 1000 39 | patience : 150 40 | run : 10 41 | weight_decay : 0.00 42 | loss_type : 'mse' 43 | dataset: 44 | name : mutag 45 | root : pyg_data 46 | batch_size: 128 47 | 48 | -------------------------------------------------------------------------------- /configuration/graphregcl_amazon.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | mode: 4 9 | num_layers: 2 10 | num_hidden: 512 11 | num_proj_hidden: 512 12 | optim: 13 | lr: 0.001 14 | lr2: 0.01 15 | tau: 0.5 16 | dfr1: 0.1 17 | dfr2: 0.1 18 | der1: 0.1 19 | der2: 0.1 20 | lv: 1 21 | cutway: 2 22 | cutrate: 1.0 23 | wd: 0 24 | wd2: 0.0004 25 | name : adam 26 | max_epoch : 500 27 | patience : 10 28 | weight_decay : 0.0005 29 | dataset: 30 | name : photo 31 | root : pyg_data 32 | output: 33 | verbose : True 34 | save_dir : './saved_model' 35 | interval : 100 36 | classifier: 37 | base_lr: 0.01 38 | name : adam 39 | max_epoch : 300 40 | patience : 1000 41 | run : 10 42 | weight_decay : 0.0 -------------------------------------------------------------------------------- /configuration/graphregcl_coauthor.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | mode: 4 9 | num_layers: 2 10 | num_hidden: 512 11 | num_proj_hidden: 512 12 | optim: 13 | lr: 0.0001 14 | lr2: 0.01 15 | tau: 0.5 16 | dfr1: 0.1 17 | dfr2: 0.1 18 | der1: 0.1 19 | der2: 0.1 20 | lv: 1 21 | cutway: 2 22 | cutrate: 1.0 23 | wd: 0 24 | wd2: 0.0004 25 | name : adam 26 | max_epoch : 10 27 | patience : 30 28 | weight_decay : 0.0005 29 | dataset: 30 | name : cs 31 | root : pyg_data 32 | output: 33 | verbose : True 34 | save_dir : './saved_model' 35 | interval : 100 36 | classifier: 37 | base_lr: 0.001 38 | name : adam 39 | max_epoch : 2000 40 | patience : 1000 41 | run : 10 42 | weight_decay : 0.0 -------------------------------------------------------------------------------- /configuration/graphregcl_cora.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | mode: 4 9 | num_layers: 2 10 | num_hidden: 512 11 | num_proj_hidden: 512 12 | optim: 13 | lr: 0.0005 14 | lr2: 0.01 15 | tau: 0.2 16 | dfr1: 0.4 17 | dfr2: 0.3 18 | der1: 0.0 19 | der2: 0.4 20 | lv: 1 21 | cutway: 2 22 | cutrate: 1.0 23 | wd: 0 24 | wd2: 0.0004 25 | name : adam 26 | max_epoch : 50 27 | patience : 1000 28 | weight_decay : 0.0005 29 | dataset: 30 | name : cora 31 | root : pyg_data 32 | output: 33 | verbose : True 34 | save_dir : './saved_model' 35 | interval : 100 36 | aaa: 37 | aaaa: 0 38 | bbbb: 0 39 | classifier: 40 | base_lr: 0.001 41 | name : adam 42 | max_epoch : 3000 43 | patience : 1000 44 | run : 10 45 | weight_decay : 0.0005 -------------------------------------------------------------------------------- /configuration/graphregcl_wikics.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | mode: 4 9 | num_layers: 2 10 | num_hidden: 512 11 | num_proj_hidden: 512 12 | optim: 13 | lr: 0.0001 14 | lr2: 0.01 15 | tau: 0.5 16 | dfr1: 0.1 17 | dfr2: 0.1 18 | der1: 0.1 19 | der2: 0.1 20 | lv: 1 21 | cutway: 2 22 | cutrate: 1.0 23 | wd: 0 24 | wd2: 0.0004 25 | name : adam 26 | max_epoch : 50 27 | patience : 30 28 | weight_decay : 0.0005 29 | dataset: 30 | name : WikiCS 31 | root : pyg_data 32 | output: 33 | verbose : True 34 | save_dir : './saved_model' 35 | interval : 100 36 | classifier: 37 | base_lr: 0.01 38 | name : adam 39 | max_epoch : 300 40 | patience : 1000 41 | run : 10 42 | weight_decay : 0.0 -------------------------------------------------------------------------------- /configuration/infograph_imdb_b.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | torch_seed: 0 4 | model: 5 | dropout: 0.5 6 | n_layers: 3 7 | in_channels: 3703 8 | hidden_channels: 32 9 | aug_type: 'mask' 10 | optim: 11 | base_lr: 0.001 12 | name : adam 13 | max_epoch : 10 14 | patience : 10 15 | weight_decay : 0.000 16 | classifier: 17 | base_lr: 0.01 18 | name : adam 19 | max_epoch : 300 20 | patience : 100 21 | run : 10 22 | weight_decay : 0.0005 23 | dataset: 24 | name : IMDB-B 25 | root : pyg_data 26 | batch_size: 128 27 | -------------------------------------------------------------------------------- /configuration/infograph_imdb_m.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | torch_seed: 1 4 | model: 5 | dropout: 0.5 6 | n_layers: 3 7 | in_channels: 3703 8 | hidden_channels: 32 9 | aug_type: 'mask' 10 | optim: 11 | base_lr: 0.00015 12 | name : adam 13 | max_epoch : 10 14 | patience : 10 15 | weight_decay : 0.000 16 | classifier: 17 | base_lr: 0.01 18 | name : adam 19 | max_epoch : 300 20 | patience : 100 21 | weight_decay : 0.000 22 | dataset: 23 | name : IMDB-M 24 | root : pyg_data 25 | batch_size: 128 26 | -------------------------------------------------------------------------------- /configuration/infograph_mutag.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | torch_seed: 0 4 | model: 5 | dropout: 0.5 6 | n_layers: 3 7 | in_channels: 3703 8 | hidden_channels: 32 9 | aug_type: 'mask' 10 | optim: 11 | base_lr: 0.001 12 | name : adam 13 | max_epoch : 20 14 | patience : 10 15 | weight_decay : 0.0 16 | classifier: 17 | base_lr: 0.01 18 | name : adam 19 | max_epoch : 200 20 | patience : 100 21 | run : 10 22 | weight_decay : 0.00 23 | dataset: 24 | name : mutag 25 | root : pyg_data 26 | batch_size: 128 27 | 28 | -------------------------------------------------------------------------------- /configuration/merit.yml: -------------------------------------------------------------------------------- 1 | dataset: "wikics" 2 | torch_seed: 0 3 | drop_edge: 0.4 4 | drop_feat1: 0.4 5 | drop_feat2: 0.4 6 | projection_size: 512 7 | prediction_size: 512 8 | prediction_hidden_size: 4096 9 | projection_hidden_size: 4096 10 | beta: 0.6 11 | momentum: 0.8 12 | alpha: 0.05 13 | sample_size: 2000 -------------------------------------------------------------------------------- /configuration/mvgrl_amazon.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 3 6 | in_channels: 745 7 | hidden_channels: 512 8 | alpha: 0.2 9 | t: 5 10 | batch_size: 8 11 | aug_type: 'ppr' 12 | optim: 13 | base_lr: 0.0005 14 | name : adam 15 | max_epoch : 3000 16 | patience : 200 17 | weight_decay : 0.0005 18 | run: 1 19 | classifier: 20 | base_lr: 0.001 21 | name : adam 22 | max_epoch : 30000 23 | patience : 500 24 | run : 10 25 | weight_decay : 0.0005 26 | dataset: 27 | name : Amazon 28 | root : pyg_data 29 | output: 30 | verbose : True 31 | save_dir : './saved_model' 32 | interval : 100 33 | -------------------------------------------------------------------------------- /configuration/mvgrl_coauthor.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 6805 7 | hidden_channels: 512 8 | alpha: 0.2 9 | t: 5 10 | batch_size: 8 11 | aug_type: 'heat' 12 | optim: 13 | base_lr: 0.001 14 | name : adam 15 | max_epoch : 3000 16 | patience : 200 17 | weight_decay : 0.0005 18 | classifier: 19 | base_lr: 0.001 20 | name : adam 21 | max_epoch : 30000 22 | patience : 500 23 | run : 10 24 | weight_decay : 0.0005 25 | dataset: 26 | name : coauthor 27 | root : pyg_data 28 | output: 29 | verbose : True 30 | save_dir : './saved_model' 31 | interval : 100 32 | aaa: 33 | aaaa: 0 34 | bbbb: 0 35 | -------------------------------------------------------------------------------- /configuration/mvgrl_cora.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 1433 7 | hidden_channels: 512 8 | alpha: 0.2 9 | t: 5 10 | batch_size: 8 11 | aug_type: 'ppr' 12 | optim: 13 | base_lr: 0.001 14 | name : adam 15 | max_epoch : 2000 16 | patience : 50 17 | weight_decay : 0.0005 18 | dataset: 19 | name : cora 20 | root : pyg_data 21 | output: 22 | verbose : True 23 | save_dir : './saved_model' 24 | interval : 100 25 | aaa: 26 | aaaa: 0 27 | bbbb: 0 28 | -------------------------------------------------------------------------------- /configuration/mvgrl_wikics.yml: -------------------------------------------------------------------------------- 1 | use_cuda : True 2 | gpu_idx : 0 3 | model: 4 | dropout: 0.5 5 | n_layers: 2 6 | in_channels: 300 7 | hidden_channels: 256 8 | alpha: 0.2 9 | t: 5 10 | batch_size: 8 11 | aug_type: 'ppr' 12 | optim: 13 | base_lr: 0.0001 14 | name : adam 15 | max_epoch : 3000 16 | patience : 200 17 | weight_decay : 0.0005 18 | classifier: 19 | base_lr: 0.001 20 | name : adam 21 | max_epoch : 3000 22 | patience : 1000 23 | run : 10 24 | weight_decay : 0.0005 25 | dataset: 26 | name : WikiCS 27 | root : pyg_data 28 | output: 29 | verbose : True 30 | save_dir : './saved_model' 31 | interval : 100 32 | aaa: 33 | aaaa: 0 34 | bbbb: 0 35 | -------------------------------------------------------------------------------- /configuration/sugrl_amazon.yml: -------------------------------------------------------------------------------- 1 | dataset: "photo" 2 | 3 | black_list: [4,5,6] 4 | lr: 0.01 5 | out_heads: 1 6 | task_type: Node_Transductive 7 | val_interval: 1 8 | num_hidden_features: 8 9 | epochs: 500 10 | to_undirected_at_neg: True 11 | epochs: 1000 12 | w_loss1: 100 13 | w_loss2: 100 14 | w_loss3: 1 15 | margin1: 0.9 16 | margin2: 0.9 17 | dim: 128 18 | cfg: [512, 128] 19 | NewATop: 0 20 | dropout: 0.1 21 | NN: 1 22 | num1: 200 23 | wd: 0.0 24 | weight_decay: 0.0001 -------------------------------------------------------------------------------- /configuration/sugrl_coauthor.yml: -------------------------------------------------------------------------------- 1 | dataset: "coauthor" 2 | 3 | black_list: [4,5,6] 4 | lr: 0.01 5 | out_heads: 1 6 | task_type: Node_Transductive 7 | val_interval: 1 8 | num_hidden_features: 8 9 | epochs: 500 10 | to_undirected_at_neg: True 11 | epochs: 1000 12 | w_loss1: 100 13 | w_loss2: 100 14 | w_loss3: 1 15 | margin1: 0.9 16 | margin2: 0.9 17 | dim: 128 18 | cfg: [512, 128] 19 | NewATop: 0 20 | dropout: 0.1 21 | NN: 1 22 | num1: 200 23 | wd: 0.0 24 | weight_decay: 0.0001 -------------------------------------------------------------------------------- /configuration/sugrl_wikics.yml: -------------------------------------------------------------------------------- 1 | dataset: "wikics" 2 | 3 | black_list: [4,5,6] 4 | lr: 0.0001 5 | out_heads: 1 6 | task_type: Node_Transductive 7 | val_interval: 1 8 | num_hidden_features: 8 9 | epochs: 500 10 | to_undirected_at_neg: True 11 | epochs: 1000 12 | w_loss1: 100 13 | w_loss2: 100 14 | w_loss3: 2 15 | margin1: 0.4 16 | margin2: 0.6 17 | dim: 128 18 | cfg: [512, 128] 19 | NewATop: 0 20 | dropout: 0.1 21 | NN: 5 22 | num1: 300 23 | wd: 0.0 24 | weight_decay: 0.0001 -------------------------------------------------------------------------------- /examples/example_BGRL.py: -------------------------------------------------------------------------------- 1 | from src.augment import RandomDropEdge, AugmentorList, AugmentorDict, RandomMaskChannel 2 | from src.methods import BGRL, BGRLEncoder 3 | from src.trainer import NonContrastTrainer 4 | from torch_geometric.loader import DataLoader 5 | from src.evaluation import LogisticRegression 6 | from src.data.data_non_contrast import Dataset 7 | import torch, copy 8 | import numpy as np 9 | from src.config import load_yaml 10 | 11 | 12 | torch.manual_seed(0) 13 | np.random.seed(0) 14 | torch.cuda.manual_seed_all(0) 15 | config = load_yaml('./configuration/bgrl_photo.yml') 16 | device = torch.device("cuda:{}".format(config.gpu_idx) if torch.cuda.is_available() and config.use_cuda else "cpu") 17 | # WikiCS, cora, citeseer, pubmed, photo, computers, cs, and physics 18 | data_name = config.dataset.name 19 | root = config.dataset.root 20 | 21 | dataset = Dataset(root=root, name=data_name) 22 | if not hasattr(dataset, "adj_t"): 23 | data = dataset.data 24 | dataset.data.adj_t = torch.sparse_coo_tensor(data.edge_index, torch.ones_like(data.edge_index[0]), [data.x.shape[0], data.x.shape[0]]).coalesce() 25 | data_loader = DataLoader(dataset) 26 | 27 | # Augmentation 28 | augment_1 = AugmentorList([RandomDropEdge(config.model.aug_edge_1), RandomMaskChannel(config.model.aug_mask_1)]) 29 | augment_2 = AugmentorList([RandomDropEdge(config.model.aug_edge_2), RandomMaskChannel(config.model.aug_mask_2)]) 30 | augment = AugmentorDict({"augment_1":augment_1, "augment_2":augment_2}) 31 | 32 | # ------------------- Method ----------------- 33 | if data_name=="cora": 34 | student_encoder = BGRLEncoder(in_channel=dataset.x.shape[1], hidden_channels=[2048]) 35 | elif data_name=="photo": 36 | student_encoder = BGRLEncoder(in_channel=dataset.x.shape[1], hidden_channels=[512, 256]) 37 | elif data_name=="wikics": 38 | student_encoder = BGRLEncoder(in_channel=dataset.x.shape[1], hidden_channels=[512, 256]) 39 | elif data_name=="cs": 40 | student_encoder = BGRLEncoder(in_channel=dataset.x.shape[1], hidden_channels=[256]) 41 | 42 | teacher_encoder = copy.deepcopy(student_encoder) 43 | 44 | method = BGRL(student_encoder=student_encoder, teacher_encoder = teacher_encoder, data_augment=augment) 45 | 46 | # ------------------ Trainer -------------------- 47 | trainer = NonContrastTrainer(method=method, data_loader=data_loader, device=device, use_ema=True, \ 48 | moving_average_decay=config.optim.moving_average_decay, lr=config.optim.base_lr, 49 | weight_decay=config.optim.weight_decay, dataset=dataset, n_epochs=config.optim.max_epoch,\ 50 | patience=config.optim.patience) 51 | trainer.train() 52 | 53 | # ------------------ Evaluator ------------------- 54 | method.eval() 55 | data_pyg = dataset.data.to(method.device) 56 | embs = method.get_embs(data_pyg, data_pyg.edge_index).detach() 57 | 58 | lg = LogisticRegression(lr=0.01, weight_decay=0, max_iter=100, n_run=20, device=device) 59 | lg(embs=embs, dataset=data_pyg) 60 | -------------------------------------------------------------------------------- /examples/example_afgrl.py: -------------------------------------------------------------------------------- 1 | from src.augment import NeighborSearch_AFGRL 2 | from src.methods import AFGRLEncoder, AFGRL 3 | from src.trainer import NonContrastTrainer 4 | from torch_geometric.loader import DataLoader 5 | from torch_geometric.datasets import Planetoid, Amazon, WikiCS 6 | from src.evaluation import LogisticRegression 7 | import torch, copy 8 | import torch_geometric 9 | from src.data.data_non_contrast import Dataset 10 | import numpy as np 11 | from src.config import load_yaml 12 | 13 | 14 | torch.manual_seed(0) 15 | np.random.seed(0) 16 | torch.cuda.manual_seed_all(0) 17 | config = load_yaml('./configuration/afgrl_cs.yml') 18 | device = torch.device("cuda:{}".format(config.gpu_idx) if torch.cuda.is_available() and config.use_cuda else "cpu") 19 | 20 | # WikiCS, cora, citeseer, pubmed, photo, computers, cs, and physics 21 | data_name = config.dataset.name 22 | root = config.dataset.root 23 | 24 | dataset = Dataset(root=root, name=data_name) 25 | if not hasattr(dataset, "adj_t"): 26 | data = dataset.data 27 | dataset.data.adj_t = torch.sparse_coo_tensor(data.edge_index, torch.ones_like(data.edge_index[0]), [data.x.shape[0], data.x.shape[0]]).coalesce() 28 | data_loader = DataLoader(dataset) 29 | data = dataset.data 30 | 31 | adj_ori_sparse = torch.sparse_coo_tensor(data.edge_index, torch.ones_like(data.edge_index[0]), [data.x.shape[0], data.x.shape[0]]).to(device).coalesce() 32 | 33 | # Augmentation 34 | augment = NeighborSearch_AFGRL(device=device, num_centroids=config.model.num_centroids, num_kmeans=config.model.num_kmeans, clus_num_iters=config.model.clus_num_iters) 35 | 36 | # ------------------- Method ----------------- 37 | if data_name=="cora": 38 | student_encoder = AFGRLEncoder(in_channel=dataset.x.shape[1], hidden_channels=[2048]) 39 | elif data_name=="photo": 40 | student_encoder = AFGRLEncoder(in_channel=dataset.x.shape[1], hidden_channels=[512, 512]) 41 | elif data_name=="wikics": 42 | student_encoder = AFGRLEncoder(in_channel=dataset.x.shape[1], hidden_channels=[512, 256]) 43 | elif data_name=="cs": 44 | student_encoder = AFGRLEncoder(in_channel=dataset.x.shape[1], hidden_channels=[512, 256]) 45 | teacher_encoder = copy.deepcopy(student_encoder) 46 | 47 | method = AFGRL(student_encoder=student_encoder, teacher_encoder = teacher_encoder, data_augment=augment, adj_ori = adj_ori_sparse, topk=config.model.topk) 48 | 49 | 50 | # ------------------ Trainer -------------------- 51 | trainer = NonContrastTrainer(method=method, data_loader=data_loader, device=device, use_ema=True, 52 | moving_average_decay=config.optim.moving_average_decay, lr=config.optim.base_lr, 53 | weight_decay=config.optim.weight_decay, n_epochs=config.optim.max_epoch, dataset=dataset, 54 | patience=config.optim.patience) 55 | trainer.train() 56 | 57 | 58 | # ------------------ Evaluator ------------------- 59 | data_pyg = dataset.data.to(method.device) 60 | embs = method.get_embs(data_pyg).detach() 61 | 62 | lg = LogisticRegression(lr=0.01, weight_decay=0, max_iter=100, n_run=20, device=device) 63 | lg(embs=embs, dataset=data_pyg) -------------------------------------------------------------------------------- /examples/example_dgi.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.loader import DataLoader 2 | from torch_geometric.datasets import Planetoid 3 | from src.transforms import NormalizeFeatures, GCNNorm, Edge2Adj, Compose 4 | from src.methods import DGI, DGIEncoder 5 | from src.trainer import SimpleTrainer 6 | 7 | from src.evaluation import LogisticRegression 8 | 9 | 10 | # -------------------- Data -------------------- 11 | pre_transforms = Compose([NormalizeFeatures(ord=1), Edge2Adj(norm=GCNNorm(add_self_loops=1))]) 12 | dataset = Planetoid(root="pyg_data", name="cora", pre_transform=pre_transforms) 13 | data_loader = DataLoader(dataset) 14 | 15 | 16 | encoder = DGIEncoder(in_channels=1433, hidden_channels=512) 17 | method = DGI(encoder=encoder, hidden_channels=512) 18 | 19 | # ------------------ Trainer -------------------- 20 | trainer = SimpleTrainer(method=method, data_loader=data_loader, device="cuda:0") 21 | trainer.train() 22 | 23 | # ------------------ Evaluator ------------------- 24 | data_pyg = dataset.data.to(method.device) 25 | embs = method.get_embs(data_pyg) 26 | 27 | lg = LogisticRegression(lr=0.01, weight_decay=0, max_iter=100, n_run=50, device="cuda") 28 | lg(embs=embs, dataset=data_pyg) 29 | -------------------------------------------------------------------------------- /examples/example_gca.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from src.config import load_yaml 4 | 5 | import torch_geometric.transforms as T 6 | from torch_geometric.loader import DataLoader 7 | 8 | from src.methods.gca import GCA_Encoder, GRACE 9 | from torch_geometric.datasets import WikiCS, Amazon, Coauthor 10 | from src.augment.gca_augments import get_activation, get_base_model 11 | from src.trainer.gca_trainer import GCATrainer 12 | 13 | from src.evaluation import LogisticRegression 14 | from src.utils import create_data 15 | 16 | dataset_name = 'Amazon-Photo' 17 | dataset = WikiCS(root='pyg_data', transform=T.NormalizeFeatures()) 18 | data_loader = DataLoader(dataset) 19 | if dataset_name == 'Amazon-Photo': 20 | dataset = Amazon(root='pyg_data', name='photo', transform=T.NormalizeFeatures()) 21 | data_loader = DataLoader(dataset) 22 | config = load_yaml('./configuration/gca_amazon.yml') 23 | if dataset_name == 'Coauthor-CS': 24 | dataset = Coauthor(root='pyg_data', name='cs', transform=T.NormalizeFeatures()) 25 | data_loader = DataLoader(dataset) 26 | config = load_yaml('./configuration/gca_coauthor.yml') 27 | 28 | 29 | # ---> load model 30 | encoder = GCA_Encoder(in_channels=dataset.num_features, out_channels=config.model.out_channels, activation=get_activation(config.model.activation), base_model=get_base_model('GCNConv'), k=2) 31 | method = GRACE(encoder=encoder, loss_function=None, num_hidden=config.model.hidden_channels, num_proj_hidden=config.model.num_proj_hidden, tau=config.model.tau) 32 | 33 | # ---> train 34 | trainer = GCATrainer(method=method, data_loader=data_loader, drop_scheme='degree', dataset_name='WikiCS', device="cuda:{}".format(config.gpu_idx), n_epochs=1000) 35 | trainer.train() 36 | 37 | # ---> evaluation 38 | data = dataset.data.to(method.device) 39 | embs = method.get_embs(data) 40 | lg = LogisticRegression(lr=config.classifier.base_lr, weight_decay=0, max_iter=100, n_run=config.classifier.run, device="cuda:{}".format(config.gpu_idx)) 41 | if dataset_name == 'WikiCS': 42 | data.train_mask = torch.transpose(data.train_mask, 0, 1) 43 | data.val_mask = torch.transpose(data.val_mask, 0, 1) 44 | if dataset_name == 'Amazon-Photo' or dataset_name == 'Coauthor-CS': 45 | data = create_data.create_masks(data.cpu()) 46 | lg(embs=embs, dataset=data) -------------------------------------------------------------------------------- /examples/example_graphcl.py: -------------------------------------------------------------------------------- 1 | from src.augment import * 2 | from src.methods import GraphCL, GraphCLEncoder 3 | from src.trainer import SimpleTrainer 4 | from torch_geometric.loader import DataLoader 5 | from src.evaluation import LogisticRegression 6 | import torch 7 | from src.config import load_yaml 8 | from src.utils.create_data import create_masks 9 | from src.utils.add_adj import add_adj_t 10 | from src.data.data_non_contrast import Dataset 11 | 12 | # load the configuration file 13 | config = load_yaml('./configuration/graphcl_wikics.yml') 14 | torch.manual_seed(config.torch_seed) 15 | device = torch.device("cuda:{}".format(config.gpu_idx) if torch.cuda.is_available() and config.use_cuda else "cpu") 16 | 17 | data_name = config.dataset.name 18 | root = config.dataset.root 19 | dataset = Dataset(root=root, name=data_name) 20 | if config.dataset.name in ['Amazon', 'WikiCS', 'coauthor']: 21 | dataset.data = create_masks(dataset.data, config.dataset.name) 22 | dataset = add_adj_t(dataset) 23 | data_loader = DataLoader(dataset) 24 | 25 | # Augmentation 26 | aug_type = config.model.aug_type 27 | if aug_type == 'edge': 28 | augment_neg = AugmentorList([RandomDropEdge()]) 29 | elif aug_type == 'mask': 30 | augment_neg = AugmentorList([RandomMask()]) 31 | elif aug_type == 'node': 32 | augment_neg = AugmentorList([RandomDropNode()]) 33 | elif aug_type == 'subgraph': 34 | augment_neg = AugmentorList([AugmentSubgraph()]) 35 | else: 36 | assert 'unrecognized augmentation method' 37 | # 38 | # ------------------- Method ----------------- 39 | encoder = GraphCLEncoder(in_channels=config.model.in_channels, hidden_channels=config.model.hidden_channels) 40 | method = GraphCL(encoder=encoder, corruption=augment_neg, hidden_channels=config.model.hidden_channels) 41 | method.augment_type = aug_type 42 | 43 | 44 | # ------------------ Trainer -------------------- 45 | trainer = SimpleTrainer(method=method, data_loader=data_loader, 46 | device=device, n_epochs=config.optim.max_epoch, 47 | patience=config.optim.patience) 48 | trainer.train() 49 | 50 | 51 | # ------------------ Evaluator ------------------- 52 | method.eval() 53 | # renew dataset to get back train_mask and value_mask from .T 54 | dataset = Dataset(root=root, name=data_name) 55 | dataset = add_adj_t(dataset) 56 | data_pyg = dataset.data.to(method.device) 57 | embs = method.get_embs(data_pyg) 58 | lg = LogisticRegression(lr=config.classifier.base_lr, weight_decay=config.classifier.weight_decay, 59 | max_iter=config.classifier.max_epoch, n_run=1, device=device) 60 | lg(embs=embs, dataset=data_pyg) 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /examples/example_graphcl_molecure.py: -------------------------------------------------------------------------------- 1 | from src.augment import RandomMask, RandomDropEdge, RandomDropNode, AugmentSubgraph, AugmentorList 2 | from src.methods.infograph import InfoGraph, Encoder 3 | from src.trainer import SimpleTrainer 4 | from torch_geometric.loader import DataLoader 5 | from torch_geometric.datasets import TUDataset 6 | from src.evaluation import LogisticRegression 7 | from src.config import load_yaml 8 | from torch_geometric.nn import GINConv 9 | import torch 10 | import os 11 | import numpy as np 12 | 13 | 14 | torch.manual_seed(0) 15 | config = load_yaml('./configuration/infograph_mutag.yml') 16 | torch.manual_seed(config.torch_seed) 17 | np.random.seed(config.torch_seed) 18 | device = torch.device("cuda:{}".format(config.gpu_idx) if torch.cuda.is_available() and config.use_cuda else "cpu") 19 | 20 | # data 21 | # -------------------- Data -------------------- 22 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), config.dataset.root) 23 | if config.dataset.name in ['IMDB-B', 'IMDB-M', 'mutag', 'COLLAB', 'PROTEINS']: 24 | dataset = TUDataset(path, name=config.dataset.name.upper()) 25 | else: 26 | raise NotImplementedError 27 | data_loader = DataLoader(dataset, batch_size=config.dataset.batch_size) 28 | 29 | in_channels = max(dataset.num_features, 1) 30 | 31 | ### to handle molecular graph, graphcl use infograph as the backbone. ### 32 | # ------------------- Method ----------------- 33 | # 34 | encoder = Encoder(in_channels=in_channels, hidden_channels=config.model.hidden_channels, 35 | num_layers=config.model.n_layers, GNN=GINConv) 36 | method = InfoGraph(encoder=encoder, hidden_channels=config.model.hidden_channels, num_layers=config.model.n_layers, 37 | prior=False) 38 | 39 | # Augmentation 40 | aug_type = 'mask' 41 | if aug_type == 'edge': 42 | augment_neg = RandomDropEdge() 43 | elif aug_type == 'mask': 44 | augment_neg = RandomMask() 45 | elif aug_type == 'node': 46 | augment_neg = RandomDropNode() 47 | elif aug_type == 'subgraph': 48 | augment_neg = AugmentSubgraph() 49 | else: 50 | assert False 51 | augment_neg = AugmentorList([RandomDropEdge(), RandomMask()]) 52 | 53 | 54 | # ------------------ Trainer -------------------- 55 | trainer = SimpleTrainer(method=method, data_loader=data_loader, device="cuda:0") 56 | trainer.train() 57 | 58 | 59 | # ------------------ Evaluator ------------------- 60 | method.eval() 61 | data_pyg = dataset.data.to(method.device) 62 | y, embs = method.get_embs(data_loader) 63 | 64 | data_pyg.x = embs 65 | lg = LogisticRegression(lr=config.classifier.base_lr, weight_decay=config.classifier.weight_decay, 66 | max_iter=config.classifier.max_epoch, n_run=1, device=device) 67 | lg(embs=embs, dataset=data_pyg) 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /examples/example_graphmae.py: -------------------------------------------------------------------------------- 1 | from src.methods.graphmae import GraphMAE, EncoderDecoder, load_graph_classification_dataset, setup_loss_fn, collate_fn 2 | from src.trainer import SimpleTrainer 3 | from dgl.dataloading import GraphDataLoader 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 6 | import torch 7 | import numpy as np 8 | from src.config import load_yaml 9 | import os 10 | 11 | config = load_yaml('./configuration/graphmae_imdb_m.yml') 12 | torch.manual_seed(config.torch_seed) 13 | np.random.seed(config.torch_seed) 14 | device = torch.device("cuda:{}".format(config.gpu_idx) if torch.cuda.is_available() and config.use_cuda else "cpu") 15 | 16 | current_folder = os.path.abspath('') 17 | print(current_folder) 18 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), config.dataset.root, config.dataset.name) 19 | # -------------------- Data -------------------- 20 | dataset, num_features = load_graph_classification_dataset(config.dataset.name, raw_dir=path) 21 | train_idx = torch.arange(len(dataset)) 22 | train_sampler = SubsetRandomSampler(train_idx) 23 | eval_loader = GraphDataLoader(dataset, collate_fn=collate_fn, batch_size=config.dataset.batch_size, shuffle=False) 24 | in_channels = max(num_features, 1) 25 | 26 | # ------------------- Method ----------------- 27 | pooling = 'mean' 28 | if pooling == "mean": 29 | pooler = AvgPooling() 30 | elif pooling == "max": 31 | pooler = MaxPooling() 32 | elif pooling == "sum": 33 | pooler = SumPooling() 34 | else: 35 | raise NotImplementedError 36 | encoder = EncoderDecoder(GNN=config.model.encoder_type, enc_dec="encoding", in_channels=in_channels, 37 | hidden_channels=config.model.hidden_channels, num_layers=config.model.encoder_layers) 38 | decoder = EncoderDecoder(GNN=config.model.decoder_type, enc_dec="decoding", in_channels=config.model.hidden_channels, 39 | hidden_channels=in_channels, num_layers=config.model.decoder_layers) 40 | loss_function = setup_loss_fn(config.model.loss_fn, alpha_l=config.model.alpha_l) 41 | method = GraphMAE(encoder=encoder, decoder=decoder, hidden_channels=512, argument=config, loss_function=loss_function) 42 | method.device = device 43 | 44 | # ------------------ Trainer -------------------- 45 | trainer = SimpleTrainer(method=method, data_loader=dataset, device=device, n_epochs=config.optim.max_epoch, 46 | lr=config.optim.base_lr) 47 | trainer.train() 48 | 49 | # ------------------ Evaluation ------------------- 50 | method.evaluation(pooler, eval_loader) 51 | 52 | -------------------------------------------------------------------------------- /examples/example_infograph.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.loader import DataLoader 2 | from src.methods.infograph import InfoGraph, Encoder 3 | from src.trainer import SimpleTrainer 4 | from src.evaluation import LogisticRegression 5 | from torch_geometric.datasets import TUDataset 6 | import os 7 | from torch_geometric.nn import GINConv 8 | from src.config import load_yaml 9 | import torch 10 | import numpy as np 11 | 12 | 13 | config = load_yaml('./configuration/infograph_mutag.yml') 14 | torch.manual_seed(config.torch_seed) 15 | np.random.seed(config.torch_seed) 16 | device = torch.device("cuda:{}".format(config.gpu_idx) if torch.cuda.is_available() and config.use_cuda else "cpu") 17 | 18 | # -------------------- Data -------------------- 19 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), config.dataset.root) 20 | if config.dataset.name in ['IMDB-B', 'IMDB-M', 'mutag', 'COLLAB', 'PROTEINS']: 21 | dataset = TUDataset(path, name=config.dataset.name.upper()) 22 | else: 23 | raise NotImplementedError 24 | 25 | data_loader = DataLoader(dataset, batch_size=config.dataset.batch_size) 26 | 27 | in_channels = max(dataset.num_features, 1) 28 | 29 | # ------------------- Method ----------------- 30 | # 31 | encoder = Encoder(in_channels=in_channels, hidden_channels=config.model.hidden_channels, 32 | num_layers=config.model.n_layers, GNN=GINConv) 33 | method = InfoGraph(encoder=encoder, hidden_channels=config.model.hidden_channels, num_layers=config.model.n_layers, 34 | prior=False) 35 | 36 | # # # ------------------ Trainer -------------------- 37 | trainer = SimpleTrainer(method=method, data_loader=data_loader, device=device, n_epochs=config.optim.max_epoch, 38 | lr=config.optim.base_lr) 39 | trainer.train() 40 | 41 | # ------------------ Evaluator ------------------- 42 | method.eval() 43 | data_pyg = dataset.data.to(method.device) 44 | y, embs = method.get_embs(data_loader) 45 | 46 | data_pyg.x = embs 47 | lg = LogisticRegression(lr=config.classifier.base_lr, weight_decay=config.classifier.weight_decay, 48 | max_iter=config.classifier.max_epoch, n_run=1, device=device) 49 | lg(embs=embs, dataset=data_pyg) 50 | -------------------------------------------------------------------------------- /examples/example_merit.py: -------------------------------------------------------------------------------- 1 | from src.methods import GCN, Merit 2 | from src.trainer import SimpleTrainer 3 | from torch_geometric.loader import DataLoader 4 | import torch_geometric.transforms as T 5 | from src.transforms import NormalizeFeatures, GCNNorm, Edge2Adj, Compose 6 | from torch_geometric.datasets import Amazon, WikiCS,Coauthor 7 | from src.utils.create_data import create_masks 8 | from src.evaluation import LogisticRegression 9 | import torch 10 | import yaml 11 | from src.utils.add_adj import add_adj_t 12 | from sklearn.impute import SimpleImputer 13 | 14 | 15 | torch.manual_seed(0) 16 | config = yaml.safe_load(open("./configuration/merit.yml", 'r', encoding='utf-8').read()) 17 | # data 18 | pre_transforms = Compose([NormalizeFeatures(ord=1), Edge2Adj(norm=GCNNorm(add_self_loops=1))]) 19 | data_name = config['dataset'] 20 | 21 | if data_name=="photo": #91.4101 22 | dataset = Amazon(root="pyg_data", name="photo", pre_transform=pre_transforms) 23 | elif data_name=="coauthor": # 92.0973 24 | dataset = Coauthor(root="pyg_data", name='cs', transform=pre_transforms) 25 | elif data_name=="wikics": #82.0109 26 | dataset = WikiCS(root="pyg_data", transform=T.NormalizeFeatures()) 27 | dataset = add_adj_t(dataset) 28 | nan_mask = torch.isnan(dataset[0].x) 29 | imputer = SimpleImputer() 30 | dataset[0].x = torch.tensor(imputer.fit_transform(dataset[0].x)) 31 | 32 | data_loader = DataLoader(dataset) 33 | data = dataset.data 34 | 35 | 36 | # ------------------- Method ----------------- 37 | encoder = GCN(in_ft=data.x.shape[1], out_ft=512, projection_hidden_size=config["projection_hidden_size"], 38 | projection_size=config["projection_size"]) 39 | method = Merit(encoder=encoder, data = data, config=config,device="cuda:0",is_sparse=True) 40 | 41 | # ------------------ Trainer -------------------- 42 | trainer = SimpleTrainer(method=method, data_loader=data_loader, device="cuda:0") 43 | trainer.train() 44 | 45 | 46 | # ------------------ Evaluator ------------------- 47 | data_pyg = dataset.data.to(method.device) 48 | embs = method.get_embs(data_pyg, data_pyg.adj_t).detach() 49 | 50 | lg = LogisticRegression(lr=0.01, weight_decay=0, max_iter=2000, n_run=50, device="cuda") 51 | create_masks(data=data_pyg.cpu()) 52 | lg(embs=embs, dataset=data_pyg) 53 | -------------------------------------------------------------------------------- /examples/example_mvgrl.py: -------------------------------------------------------------------------------- 1 | from src.augment import ComputePPR, ComputeHeat 2 | from src.methods import MVGRL, MVGRLEncoder 3 | from src.trainer import SimpleTrainer 4 | from torch_geometric.loader import DataLoader 5 | from src.transforms import NormalizeFeatures, GCNNorm, Edge2Adj, Compose 6 | from torch_geometric.datasets import Planetoid, Amazon, WikiCS, Coauthor 7 | from src.evaluation import LogisticRegression 8 | import torch 9 | from src.config import load_yaml 10 | from src.utils.create_data import create_masks 11 | from src.utils.add_adj import add_adj_t 12 | 13 | # load the configuration file 14 | config = load_yaml('./configuration/mvgrl_cora.yml') 15 | torch.manual_seed(config.torch_seed) 16 | device = torch.device("cuda:{}".format(config.gpu_idx) if torch.cuda.is_available() and config.use_cuda else "cuda") 17 | 18 | 19 | # -------------------- Data -------------------- 20 | if config.dataset.name == 'cora': 21 | pre_transforms = Compose([NormalizeFeatures(ord=1), Edge2Adj(norm=GCNNorm(add_self_loops=1))]) 22 | dataset = Planetoid(root='pyg_data', name='cora', pre_transform=pre_transforms) 23 | elif config.dataset.name == 'Amazon': 24 | pre_transforms = NormalizeFeatures(ord=1) 25 | dataset = Amazon(root='pyg_data', name='Photo', pre_transform=pre_transforms) 26 | elif config.dataset.name == 'WikiCS': 27 | pre_transforms = NormalizeFeatures(ord=1) 28 | dataset = WikiCS(root='pyg_data', pre_transform=pre_transforms) 29 | elif config.dataset.name == 'coauthor': 30 | pre_transforms = NormalizeFeatures(ord=1) 31 | dataset = Coauthor(root='pyg_data', name='CS', pre_transform=pre_transforms) 32 | else: 33 | raise 'please specify the correct dataset root' 34 | if config.dataset.name in ['Amazon', 'WikiCS', 'coauthor']: 35 | dataset.data = create_masks(dataset.data, config.dataset.name) 36 | dataset = add_adj_t(dataset) 37 | data_loader = DataLoader(dataset, batch_size=config.model.batch_size) 38 | 39 | 40 | # ------------------- Method ----------------- 41 | encoder = MVGRLEncoder(in_channels=config.model.in_channels, hidden_channels=config.model.hidden_channels) 42 | method = MVGRL(encoder=encoder, diff=ComputeHeat(t = config.model.t) if config.model.aug_type == 'heat' else ComputePPR(alpha = config.model.alpha), hidden_channels=config.model.hidden_channels) 43 | 44 | 45 | # ------------------ Trainer -------------------- 46 | trainer = SimpleTrainer(method=method, data_loader=data_loader, device=device, n_epochs=config.optim.max_epoch, patience=config.optim.patience) 47 | trainer.train() 48 | 49 | 50 | # ------------------ Evaluator ------------------- 51 | data_pyg = dataset.data.to(method.device) 52 | data_neg = method.corrput(data_pyg).to(method.device) 53 | embs = method.get_embs(data_pyg, data_neg) 54 | 55 | lg = LogisticRegression(lr=0.01, weight_decay=0, max_iter=100, n_run=50, device=device) 56 | lg(embs=embs, dataset=data_pyg) 57 | 58 | -------------------------------------------------------------------------------- /examples/example_regcl.py: -------------------------------------------------------------------------------- 1 | from src.augment import * 2 | from src.methods.regcl import ReGCL, ReGCLEncoder 3 | from src.trainer import SimpleTrainer 4 | from torch_geometric.loader import DataLoader 5 | from src.transforms import NormalizeFeatures 6 | from src.evaluation import LogisticRegression 7 | import torch 8 | from src.config import load_yaml 9 | from src.utils.create_data import create_masks 10 | from src.utils.add_adj import add_adj_t 11 | from src.data.data_non_contrast import Dataset 12 | import argparse 13 | import torch.nn.functional as F 14 | from torch_geometric.utils import to_scipy_sparse_matrix 15 | from torch_geometric.utils import dropout_adj 16 | from torch_geometric.nn import GCNConv 17 | from src.evaluation import LogisticRegression 18 | 19 | # load the configuration file 20 | config = load_yaml('./configuration/graphregcl_amazon.yml') 21 | 22 | # assert config.gpu_idx in range(0, 8) 23 | torch.cuda.set_device(config.gpu_idx) 24 | 25 | learning_rate = config.optim.lr 26 | learning_rate2 = config.optim.lr2 27 | drop_edge_rate_1 = config.optim.der1 28 | drop_edge_rate_2 = config.optim.der2 29 | drop_feature_rate_1 = config.optim.dfr1 30 | drop_feature_rate_2 = config.optim.dfr2 31 | tau = config.optim.tau 32 | mode = config.model.mode 33 | nei_lv = config.optim.lv 34 | cutway = config.optim.cutway 35 | cutrate = config.optim.cutrate 36 | num_hidden = config.model.num_hidden 37 | num_proj_hidden = config.model.num_proj_hidden 38 | activation = F.relu 39 | base_model = GCNConv 40 | num_layers = config.model.num_layers 41 | num_epochs = config.optim.max_epoch 42 | weight_decay = config.optim.wd 43 | weight_decay2 = config.optim.wd2 44 | 45 | torch.manual_seed(config.torch_seed) 46 | device = torch.device("cuda:{}".format(config.gpu_idx) if torch.cuda.is_available() and config.use_cuda else "cpu") 47 | 48 | data_name = config.dataset.name 49 | root = config.dataset.root 50 | 51 | 52 | dataset = Dataset(root=root, name=data_name, transform=NormalizeFeatures()) 53 | 54 | data = dataset.data.to(device) 55 | data_loader = DataLoader(dataset, ) 56 | 57 | # 58 | # ------------------- Method ----------------- 59 | encoder = ReGCLEncoder(dataset.num_features, num_hidden, activation, mode, base_model=base_model, k=num_layers, cutway=cutway, cutrate=cutrate, tau=tau).to(device) 60 | method = ReGCL(config, encoder, num_hidden, num_proj_hidden, mode, tau).to(device) 61 | 62 | 63 | # ------------------ Trainer -------------------- 64 | trainer = SimpleTrainer(method=method, data_loader=data_loader, 65 | device=device, 66 | n_epochs=config.optim.max_epoch, 67 | patience=config.optim.patience) 68 | trainer.train() 69 | 70 | 71 | # ------------------ Evaluator ------------------- 72 | method.eval() 73 | data_pyg = dataset.data.to(method.device) 74 | y, embs = method.get_embs(data) 75 | 76 | lg = LogisticRegression(lr=config.classifier.base_lr, weight_decay=config.classifier.weight_decay, 77 | max_iter=config.classifier.max_epoch, 78 | n_run=1, device=device) 79 | lg(embs=embs, dataset=data_pyg) -------------------------------------------------------------------------------- /examples/example_sugrl.py: -------------------------------------------------------------------------------- 1 | from src.methods import SugrlMLP, SugrlGCN 2 | from src.methods import SUGRL 3 | from src.trainer import SimpleTrainer 4 | from torch_geometric.loader import DataLoader 5 | import torch_geometric.transforms as T 6 | from src.transforms import NormalizeFeatures, GCNNorm, Edge2Adj, Compose 7 | from torch_geometric.datasets import Planetoid, Amazon, WikiCS,Coauthor 8 | from src.utils.create_data import create_masks 9 | from src.evaluation import LogisticRegression 10 | import torch 11 | import yaml 12 | from src.utils.add_adj import add_adj_t 13 | from sklearn.impute import SimpleImputer 14 | 15 | torch.manual_seed(0) 16 | 17 | pre_transforms = Compose([NormalizeFeatures(ord=1), Edge2Adj(norm=GCNNorm(add_self_loops=1))]) 18 | 19 | # load the configuration file 20 | config = yaml.safe_load(open("./configuration/sugrl_amazon.yml", 'r', encoding='utf-8').read()) 21 | print(config) 22 | data_name = config['dataset'] 23 | 24 | if data_name=="cora": 25 | dataset = Planetoid(root="pyg_data", name="cora", pre_transform=pre_transforms) 26 | if data_name=="photo": #92.9267 27 | dataset = Amazon(root="pyg_data", name="photo", pre_transform=pre_transforms) 28 | elif data_name=="coauthor": # 92.0973 29 | dataset = Coauthor(root="pyg_data", name='cs', transform=pre_transforms) 30 | elif data_name=="wikics": #82.0109 31 | dataset = WikiCS(root="pyg_data", transform=T.NormalizeFeatures()) 32 | dataset = add_adj_t(dataset) 33 | nan_mask = torch.isnan(dataset[0].x) 34 | imputer = SimpleImputer() 35 | dataset[0].x = torch.tensor(imputer.fit_transform(dataset[0].x)) 36 | 37 | if not hasattr(dataset.data, 'adj_t'): 38 | dataset = add_adj_t(dataset) 39 | 40 | data_loader = DataLoader(dataset) 41 | data = dataset.data 42 | 43 | # ------------------- Method ----------------- 44 | encoder_1 = SugrlMLP(in_channels=data.x.shape[1]) 45 | encoder_2 = SugrlGCN(in_channels=data.x.shape[1]) 46 | method = SUGRL(encoder=[encoder_1,encoder_2],data = data, config=config,device="cuda:0") 47 | 48 | 49 | # ------------------ Trainer -------------------- 50 | trainer = SimpleTrainer(method=method, data_loader=data_loader, device="cuda:0") 51 | trainer.train() 52 | 53 | 54 | # ------------------ Evaluator ------------------- 55 | data_pyg = dataset.data.to(method.device) 56 | embs = method.get_embs(data_pyg.x, data_pyg.adj_t).detach() 57 | 58 | lg = LogisticRegression(lr=0.001, weight_decay=0, max_iter=3000, n_run=10, device="cuda") 59 | create_masks(data=data_pyg.cpu()) 60 | lg(embs=embs, dataset=data_pyg) -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iDEA-iSAIL-Lab-UIUC/pyg-ssl/66405e3817e774280d0d68ae7dc1e05d953dc935/figures/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.1+cu121 2 | torch_geometric>=2.6.0 3 | pyg_lib==0.4.0+pt23cu121 4 | torch_scatter==2.1.2+pt23cu121 5 | torch_sparse==0.6.18+pt23cu121 6 | torch_cluster==1.6.3+pt23cu121 7 | torch_spline_conv==1.2.2+pt23cu121 8 | faiss-gpu>=1.7.2 9 | dgl==2.4.0+cu121 10 | scikit-learn 11 | matplotlib 12 | seaborn 13 | yacs -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from pathlib import Path 3 | 4 | # Read the contents of the README file 5 | this_directory = Path(__file__).parent 6 | long_description = (this_directory / "README.md").read_text() 7 | 8 | setup( 9 | name="pyg-ssl", 10 | version="0.1.0", 11 | description="A PyTorch Geometric-based package for self-supervised learning on graphs", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | author="iDEA-iSAIL LAB @ UIUC", 15 | author_email="violet24k@outlook.com", 16 | url="https://github.com/iDEA-iSAIL-Lab-UIUC/pyg-ssl", 17 | packages=find_packages(where="src"), # Ensure packages are found in the src directory 18 | package_dir={"": "src"}, # Map the source directory to the package root 19 | classifiers=[ 20 | "Development Status :: 3 - Alpha", 21 | "Intended Audience :: Developers", 22 | "Intended Audience :: Science/Research", 23 | "License :: OSI Approved :: MIT License", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10", 27 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 28 | ], 29 | python_requires=">=3.9", 30 | install_requires=[ 31 | "torch==2.3.1", 32 | "torch_geometric==2.6.1", 33 | "faiss-gpu==1.7.2", 34 | "scikit-learn", 35 | "matplotlib", 36 | "seaborn", 37 | "yacs", 38 | ], 39 | dependency_links=[ 40 | "https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html#egg=dgl" 41 | ], 42 | extras_require={ 43 | "dev": [ 44 | "pytest", 45 | "flake8", 46 | "black", 47 | ], 48 | }, 49 | entry_points={ 50 | "console_scripts": [ 51 | "pyg-ssl-cli=pyg_ssl.cli:main", 52 | ], 53 | }, 54 | include_package_data=True, 55 | zip_safe=False, 56 | ) 57 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """ Next update. """ -------------------------------------------------------------------------------- /src/augment/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Augmentor, AugmentorList, AugmentorDict 2 | from .collections import * 3 | from .echo import Echo 4 | from .shuffle_node import ShuffleNode 5 | from .augment_subgraph import AugmentSubgraph 6 | from .random_drop_node import RandomDropNode 7 | from .random_mask import RandomMask 8 | from .random_mask_channel import RandomMaskChannel 9 | from .random_drop_edge import RandomDropEdge 10 | from .concate_aug import concate_aug_type 11 | from .diff_matrix import ComputeHeat, ComputePPR 12 | from .augment_afgrl import NeighborSearch_AFGRL 13 | from .sum_emb import SumEmb 14 | -------------------------------------------------------------------------------- /src/augment/augment_afgrl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | # from kmeans_pytorch import kmeans 4 | import faiss 5 | from .base import Augmentor 6 | class NeighborSearch_AFGRL(Augmentor): 7 | def __init__(self, device="cuda", num_centroids=100, num_kmeans=5, clus_num_iters=20): 8 | super(NeighborSearch_AFGRL, self).__init__() 9 | self.device = device 10 | self.num_centroids = num_centroids 11 | self.num_kmeans = num_kmeans 12 | self.clus_num_iters = clus_num_iters 13 | 14 | def __get_close_nei_in_back(self, indices, each_k_idx, cluster_labels, back_nei_idxs, k): 15 | # get which neighbors are close in the background set 16 | batch_labels = cluster_labels[each_k_idx][indices] 17 | top_cluster_labels = cluster_labels[each_k_idx][back_nei_idxs] 18 | batch_labels = self.repeat_1d_tensor(batch_labels, k) 19 | 20 | curr_close_nei = torch.eq(batch_labels, top_cluster_labels) 21 | return curr_close_nei 22 | 23 | def repeat_1d_tensor(self, t, num_reps): 24 | return t.unsqueeze(1).expand(-1, num_reps) 25 | 26 | def __call__(self, adj, student, teacher, top_k): 27 | n_data, d = student.shape 28 | similarity = torch.matmul(student, torch.transpose(teacher, 1, 0).detach()) 29 | similarity += torch.eye(n_data, device=self.device) * 10 30 | 31 | _, I_knn = similarity.topk(k=top_k, dim=1, largest=True, sorted=True) 32 | 33 | knn_neighbor = self.create_sparse(I_knn) 34 | locality = knn_neighbor * adj 35 | 36 | 37 | ncentroids = self.num_centroids 38 | niter = self.clus_num_iters 39 | tmp = torch.LongTensor(np.arange(n_data)).unsqueeze(-1).to(self.device) 40 | pred_labels = [] 41 | 42 | # import pdb; pdb.set_trace() 43 | for seed in range(self.num_kmeans): 44 | kmeans = faiss.Kmeans(d, ncentroids, niter=niter, gpu=False, seed=seed + 1234) 45 | kmeans.train(teacher.cpu().numpy()) 46 | _, I_kmeans = kmeans.index.search(teacher.cpu().numpy(), 1) 47 | # scikit kmeans. too slow 48 | # kmeans = MiniBatchKMeans(n_clusters=ncentroids, max_iter=niter, random_state=seed, n_init=10) 49 | # kmeans.fit(teacher.cpu().numpy()) 50 | 51 | # # Get the cluster assignments 52 | # I_kmeans = kmeans.predict(teacher.cpu().numpy()).reshape(-1, 1) 53 | 54 | # cluster_ids_x, cluster_centers = kmeans( 55 | # X=teacher, 56 | # num_clusters=ncentroids, 57 | # distance='euclidean', 58 | # device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 59 | # tol=1, 60 | # # num_iters=niter 61 | # ) 62 | 63 | # I_kmeans = cluster_ids_x.reshape(-1, 1) 64 | 65 | clust_labels = I_kmeans[:,0] 66 | 67 | pred_labels.append(clust_labels) 68 | 69 | pred_labels = np.stack(pred_labels, axis=0) 70 | cluster_labels = torch.from_numpy(pred_labels).long().to(self.device) 71 | 72 | all_close_nei_in_back = None 73 | with torch.no_grad(): 74 | for each_k_idx in range(self.num_kmeans): 75 | curr_close_nei = self.__get_close_nei_in_back(tmp.squeeze(-1), each_k_idx, cluster_labels, I_knn, I_knn.shape[1]) 76 | 77 | if all_close_nei_in_back is None: 78 | all_close_nei_in_back = curr_close_nei 79 | else: 80 | all_close_nei_in_back = all_close_nei_in_back | curr_close_nei 81 | 82 | all_close_nei_in_back = all_close_nei_in_back.to(self.device) 83 | 84 | globality = self.create_sparse_revised(I_knn, all_close_nei_in_back) 85 | 86 | pos_ = locality + globality 87 | 88 | return pos_.coalesce()._indices(), I_knn.shape[1] 89 | 90 | def create_sparse(self, I): 91 | 92 | similar = I.reshape(-1).tolist() 93 | index = np.repeat(range(I.shape[0]), I.shape[1]) 94 | 95 | assert len(similar) == len(index) 96 | # indices = torch.tensor([index, similar]).to(self.device) 97 | # result = torch.sparse_coo_tensor(indices, torch.ones_like(I.reshape(-1)), [I.shape[0], I.shape[0]], dtype=torch.float).to(self.device) 98 | 99 | indices_np = np.array([index, similar]) # Efficiently combine the arrays 100 | indices = torch.tensor(indices_np, device=self.device) # Convert to a PyTorch tensor on the target device 101 | 102 | # Create the sparse COO tensor 103 | result = torch.sparse_coo_tensor( 104 | indices, 105 | torch.ones_like(indices[0]), 106 | [I.shape[0], I.shape[0]], 107 | dtype=torch.float, 108 | device=self.device 109 | ) 110 | 111 | return result 112 | 113 | def create_sparse_revised(self, I, all_close_nei_in_back): 114 | n_data, k = I.shape[0], I.shape[1] 115 | 116 | index = [] 117 | similar = [] 118 | for j in range(I.shape[0]): 119 | for i in range(k): 120 | index.append(int(j)) 121 | similar.append(I[j][i].item()) 122 | 123 | index = torch.masked_select(torch.LongTensor(index).to(self.device), all_close_nei_in_back.reshape(-1)) 124 | similar = torch.masked_select(torch.LongTensor(similar).to(self.device), all_close_nei_in_back.reshape(-1)) 125 | 126 | assert len(similar) == len(index) 127 | indices = torch.tensor([index.cpu().numpy().tolist(), similar.cpu().numpy().tolist()]).to(self.device) 128 | result = torch.sparse_coo_tensor(indices, torch.ones(len(index)).to(self.device), [n_data, n_data], dtype=torch.float).to(self.device) 129 | 130 | return result 131 | -------------------------------------------------------------------------------- /src/augment/augment_subgraph.py: -------------------------------------------------------------------------------- 1 | from .base import Augmentor 2 | import copy 3 | import random 4 | import torch 5 | from src.augment.random_drop_node import delete_row_col 6 | # from torch_geometric.utils.convert import to_scipy_sparse_matrix 7 | 8 | class AugmentSubgraph(Augmentor): 9 | def __init__(self, is_x: bool = True, is_adj: bool = False, drop_percent=0.2): 10 | super().__init__() 11 | self.is_x = is_x 12 | self.is_adj = is_adj 13 | self.drop_percent = drop_percent 14 | 15 | def __call__(self, data, drop_percent=None): 16 | drop_percent = drop_percent if drop_percent else self.drop_percent 17 | data_tmp = copy.deepcopy(data) 18 | input_adj = data_tmp.adj_t.to_dense() 19 | # input_adj = from_scipy_sparse_matrix(data_tmp.edge_index) 20 | input_fea = data_tmp.x.squeeze(0) 21 | node_num = input_fea.shape[0] 22 | 23 | all_node_list = [i for i in range(node_num)] 24 | s_node_num = int(node_num * (1 - drop_percent)) 25 | center_node_id = random.randint(0, node_num - 1) 26 | sub_node_id_list = [center_node_id] 27 | all_neighbor_list = [] 28 | 29 | for i in range(s_node_num - 1): 30 | 31 | all_neighbor_list += torch.nonzero(input_adj[sub_node_id_list[i]], as_tuple=False).squeeze(1).tolist() 32 | 33 | all_neighbor_list = list(set(all_neighbor_list)) 34 | new_neighbor_list = [n for n in all_neighbor_list if not n in sub_node_id_list] 35 | if len(new_neighbor_list) != 0: 36 | new_node = random.sample(new_neighbor_list, 1)[0] 37 | sub_node_id_list.append(new_node) 38 | else: 39 | break 40 | 41 | drop_node_list = sorted([i for i in all_node_list if i not in sub_node_id_list]) 42 | 43 | aug_input_fea = delete_row_col(input_fea, drop_node_list, only_row=True) 44 | aug_input_adj = delete_row_col(input_adj, drop_node_list) 45 | 46 | data_tmp.x = aug_input_fea 47 | data_tmp.adj_t = aug_input_adj.to_sparse() 48 | 49 | return data_tmp 50 | -------------------------------------------------------------------------------- /src/augment/base.py: -------------------------------------------------------------------------------- 1 | from collections import UserDict 2 | from typing import List, Union 3 | 4 | 5 | class Augmentor: 6 | r"""Base class for augmentor.""" 7 | def __init__(self): 8 | pass 9 | 10 | def __call__(self, *args, **kwargs): 11 | raise NotImplementedError 12 | 13 | 14 | class AugmentorList: 15 | r"""A wrapper for a list of augmentors. Sequentially apply each augmentor to inputs.""" 16 | def __init__(self, augmentors: Union[Augmentor, List[Augmentor]]): 17 | if type(augmentors) == list: 18 | self.augmentors = augmentors 19 | else: 20 | self.augmentors = [augmentors] 21 | 22 | def __call__(self, inputs): 23 | for augmentor in self.augmentors: 24 | inputs = augmentor(inputs) 25 | return inputs 26 | 27 | # def append(self, item): 28 | # return self.augmentors.append(item) 29 | 30 | class AugmentorDict(UserDict): 31 | r"""Base class for augmentation. A dictionary of augmentors.""" 32 | def __setitem__(self, key, value): 33 | try: 34 | assert isinstance(value, Augmentor) or isinstance(value, AugmentorList) 35 | except AssertionError: 36 | raise "AssertionError. The value should be an instance of Augmentor or Augmentors." 37 | super().__setitem__(key, value) 38 | -------------------------------------------------------------------------------- /src/augment/collections.py: -------------------------------------------------------------------------------- 1 | """TODO: looks like this file is useless.""" 2 | 3 | from .base import * 4 | from .shuffle_node import ShuffleNode 5 | from .random_drop_edge import RandomDropEdge 6 | from .negative import ComputePPR, ComputeHeat 7 | from .shuffle_node import ShuffleNode 8 | 9 | __all__ = [ 10 | "augment_dgi", 11 | "augment_mvgrl_ppr", 12 | "augment_mvgrl_heat", 13 | "augment_bgrl_1", 14 | "augment_bgrl_2", 15 | ] 16 | 17 | # augment_dgi = DataShuffle(is_x=True) 18 | augment_dgi = ShuffleNode() 19 | augment_mvgrl_ppr = ComputePPR() 20 | augment_mvgrl_heat = ComputeHeat() 21 | augment_bgrl_1 = RandomDropEdge() 22 | augment_bgrl_2 = RandomDropEdge() 23 | -------------------------------------------------------------------------------- /src/augment/concate_aug.py: -------------------------------------------------------------------------------- 1 | from .base import Augmentor, AugmentorList, AugmentorDict 2 | from .collections import * 3 | from .echo import Echo 4 | from .shuffle_node import ShuffleNode 5 | from .augment_subgraph import AugmentSubgraph 6 | from .random_drop_node import RandomDropNode 7 | from .random_mask import RandomMask 8 | from .random_drop_edge import RandomDropEdge 9 | 10 | 11 | def concate_aug_type(augment_neg, aug): 12 | if augment_neg is None: 13 | if aug == 'edge': 14 | augment_neg = AugmentorList([RandomDropEdge()]) 15 | elif aug == 'mask': 16 | augment_neg = AugmentorList([RandomMask()]) 17 | elif aug == 'node': 18 | augment_neg = AugmentorList([RandomDropNode()]) 19 | elif aug == 'subgraph': 20 | augment_neg = AugmentorList([AugmentSubgraph()]) 21 | else: 22 | assert 'unrecognized augmentation method' 23 | else: 24 | if aug == 'edge': 25 | augment_neg.append(RandomDropEdge()) 26 | elif aug == 'mask': 27 | augment_neg.append(RandomMask()) 28 | elif aug == 'node': 29 | augment_neg.append(RandomDropNode()) 30 | elif aug == 'subgraph': 31 | augment_neg.append(AugmentSubgraph()) 32 | else: 33 | assert 'unrecognized augmentation method' 34 | return augment_neg 35 | -------------------------------------------------------------------------------- /src/augment/diff_matrix.py: -------------------------------------------------------------------------------- 1 | from .base import Augmentor 2 | import copy 3 | import torch 4 | from scipy.linalg import fractional_matrix_power 5 | from torch.linalg import inv 6 | 7 | class ComputePPR(Augmentor): 8 | def __init__(self, alpha=0.2, self_loop=True): 9 | super().__init__() 10 | self.alpha = alpha 11 | self.self_loop = self_loop 12 | 13 | def __call__(self, data): 14 | data_tmp = copy.deepcopy(data) 15 | a = data_tmp.adj_t.to_dense() 16 | device = a.device 17 | if self.self_loop: 18 | a = torch.eye(a.shape[0], device=device) + a 19 | d = torch.diag(torch.sum(a, 1)).cpu() 20 | dinv = torch.from_numpy(fractional_matrix_power(d, -0.5)).to(device) 21 | at = torch.matmul(torch.matmul(dinv, a), dinv) 22 | aug_adj = self.alpha * inv((torch.eye(a.shape[0], device=device) - (1 - self.alpha) * at)) 23 | data_tmp.adj_t = aug_adj.to_sparse() 24 | return data_tmp 25 | 26 | 27 | class ComputeHeat(Augmentor): 28 | def __init__(self, t=5, self_loop=True): 29 | super().__init__() 30 | self.t = t 31 | self.self_loop = self_loop 32 | 33 | def __call__(self, data): 34 | data_tmp = copy.deepcopy(data) 35 | a = data_tmp.adj_t.to_dense() 36 | device = a.device 37 | if self.self_loop: 38 | a = torch.eye(a.shape[0], device=device) + a 39 | d = torch.diag(torch.sum(a, 1)) 40 | aug_adj = torch.exp(self.t * (torch.matmul(a, inv(d)) - 1)) 41 | data_tmp.adj_t = aug_adj.to_sparse() 42 | return data_tmp 43 | -------------------------------------------------------------------------------- /src/augment/echo.py: -------------------------------------------------------------------------------- 1 | from .base import Augmentor 2 | 3 | 4 | class Echo(Augmentor): 5 | def __call__(self, inputs): 6 | return inputs 7 | -------------------------------------------------------------------------------- /src/augment/gca_augments.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import networkx as nx 3 | from torch_geometric.utils import degree, to_undirected, to_networkx 4 | from torch_scatter import scatter 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | from torch_geometric.nn import GCNConv, SGConv, SAGEConv, GATConv, GraphConv, GINConv 8 | 9 | def get_activation(name: str): 10 | activations = { 11 | 'relu': F.relu, 12 | 'hardtanh': F.hardtanh, 13 | 'elu': F.elu, 14 | 'leakyrelu': F.leaky_relu, 15 | 'prelu': torch.nn.PReLU(), 16 | 'rrelu': F.rrelu 17 | } 18 | 19 | return activations[name] 20 | 21 | 22 | def get_base_model(name: str): 23 | def gat_wrapper(in_channels, out_channels): 24 | return GATConv( 25 | in_channels=in_channels, 26 | out_channels=out_channels // 4, 27 | heads=4 28 | ) 29 | 30 | def gin_wrapper(in_channels, out_channels): 31 | mlp = nn.Sequential( 32 | nn.Linear(in_channels, 2 * out_channels), 33 | nn.ELU(), 34 | nn.Linear(2 * out_channels, out_channels) 35 | ) 36 | return GINConv(mlp) 37 | 38 | base_models = { 39 | 'GCNConv': GCNConv, 40 | 'SGConv': SGConv, 41 | 'SAGEConv': SAGEConv, 42 | 'GATConv': gat_wrapper, 43 | 'GraphConv': GraphConv, 44 | 'GINConv': gin_wrapper 45 | } 46 | 47 | return base_models[name] 48 | 49 | 50 | def compute_pr(edge_index, damp: float = 0.85, k: int = 10): 51 | num_nodes = edge_index.max().item() + 1 52 | deg_out = degree(edge_index[0]) 53 | x = torch.ones((num_nodes, )).to(edge_index.device).to(torch.float32) 54 | 55 | for i in range(k): 56 | edge_msg = x[edge_index[0]] / deg_out[edge_index[0]] 57 | agg_msg = scatter(edge_msg, edge_index[1], reduce='sum') 58 | 59 | x = (1 - damp) * x + damp * agg_msg 60 | 61 | return x 62 | 63 | 64 | def eigenvector_centrality(data): 65 | graph = to_networkx(data) 66 | x = nx.eigenvector_centrality_numpy(graph) 67 | x = [x[i] for i in range(data.num_nodes)] 68 | return torch.tensor(x, dtype=torch.float32).to(data.edge_index.device) 69 | 70 | 71 | def drop_feature(x, drop_prob): 72 | drop_mask = torch.empty((x.size(1),), dtype=torch.float32, device=x.device).uniform_(0, 1) < drop_prob 73 | x = x.clone() 74 | x[:, drop_mask] = 0 75 | 76 | return x 77 | 78 | 79 | def drop_feature_weighted(x, w, p: float, threshold: float = 0.7): 80 | w = w / w.mean() * p 81 | w = w.where(w < threshold, torch.ones_like(w) * threshold) 82 | drop_prob = w.repeat(x.size(0)).view(x.size(0), -1) 83 | 84 | drop_mask = torch.bernoulli(drop_prob).to(torch.bool) 85 | 86 | x = x.clone() 87 | x[drop_mask] = 0. 88 | 89 | return x 90 | 91 | 92 | def drop_feature_weighted_2(x, w, p: float, threshold: float = 0.7): 93 | w = w / w.mean() * p 94 | w = w.where(w < threshold, torch.ones_like(w) * threshold) 95 | drop_prob = w 96 | 97 | drop_mask = torch.bernoulli(drop_prob).to(torch.bool) 98 | 99 | x = x.clone() 100 | x[:, drop_mask] = 0. 101 | 102 | return x 103 | 104 | 105 | def feature_drop_weights(x, node_c): 106 | x = x.to(torch.bool).to(torch.float32) 107 | w = x.t() @ node_c 108 | w = w.log() 109 | s = (w.max() - w) / (w.max() - w.mean()) 110 | 111 | return s 112 | 113 | 114 | def feature_drop_weights_dense(x, node_c): 115 | x = x.abs() 116 | w = x.t() @ node_c 117 | w = w.log() 118 | s = (w.max() - w) / (w.max() - w.mean()) 119 | 120 | return s 121 | 122 | 123 | def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 1.): 124 | edge_weights = edge_weights / edge_weights.mean() * p 125 | edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold) 126 | sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool) 127 | 128 | return edge_index[:, sel_mask] 129 | 130 | 131 | def degree_drop_weights(edge_index): 132 | edge_index_ = to_undirected(edge_index) 133 | deg = degree(edge_index_[1]) 134 | deg_col = deg[edge_index[1]].to(torch.float32) 135 | s_col = torch.log(deg_col) 136 | weights = (s_col.max() - s_col) / (s_col.max() - s_col.mean()) 137 | 138 | return weights 139 | 140 | 141 | def pr_drop_weights(edge_index, aggr: str = 'sink', k: int = 10): 142 | pv = compute_pr(edge_index, k=k) 143 | pv_row = pv[edge_index[0]].to(torch.float32) 144 | pv_col = pv[edge_index[1]].to(torch.float32) 145 | s_row = torch.log(pv_row) 146 | s_col = torch.log(pv_col) 147 | if aggr == 'sink': 148 | s = s_col 149 | elif aggr == 'source': 150 | s = s_row 151 | elif aggr == 'mean': 152 | s = (s_col + s_row) * 0.5 153 | else: 154 | s = s_col 155 | weights = (s.max() - s) / (s.max() - s.mean()) 156 | 157 | return weights 158 | 159 | 160 | def evc_drop_weights(data): 161 | evc = eigenvector_centrality(data) 162 | evc = evc.where(evc > 0, torch.zeros_like(evc)) 163 | evc = evc + 1e-8 164 | s = evc.log() 165 | 166 | edge_index = data.edge_index 167 | s_row, s_col = s[edge_index[0]], s[edge_index[1]] 168 | s = s_col 169 | 170 | return (s.max() - s) / (s.max() - s.mean()) -------------------------------------------------------------------------------- /src/augment/negative.py: -------------------------------------------------------------------------------- 1 | """TODO: create a separate file for each class.""" 2 | import copy 3 | import random 4 | from scipy.linalg import fractional_matrix_power 5 | import torch 6 | from torch.linalg import inv 7 | from torch_geometric.data import Data 8 | from .base import Augmentor 9 | 10 | class ComputePPR(Augmentor): 11 | def __init__(self, alpha=0.2, self_loop=True): 12 | super().__init__() 13 | self.alpha = alpha 14 | self.self_loop = self_loop 15 | 16 | def __call__(self, data): 17 | data_tmp = copy.deepcopy(data) 18 | a = data_tmp.adj 19 | if self.self_loop: 20 | a = torch.eye(a.shape[0]) + a 21 | d = torch.diag(torch.sum(a, 1)) 22 | dinv = torch.from_numpy(fractional_matrix_power(d, -0.5)) 23 | at = torch.matmul(torch.matmul(dinv, a), dinv) 24 | data_tmp.adj = self.alpha * inv((torch.eye(a.shape[0]) - (1 - self.alpha) * at)) 25 | return data_tmp 26 | 27 | 28 | class ComputeHeat(Augmentor): 29 | def __init__(self, t=5, self_loop=True): 30 | super().__init__() 31 | self.t = t 32 | self.self_loop = self_loop 33 | 34 | def __call__(self, data): 35 | data_tmp = copy.deepcopy(data) 36 | a = data_tmp.adj 37 | if self.self_loop: 38 | a = torch.eye(a.shape[0]) + a 39 | d = torch.diag(torch.sum(a, 1)) 40 | data_tmp.adj = torch.exp(self.t * (torch.matmul(a, inv(d)) - 1)) 41 | return data_tmp 42 | -------------------------------------------------------------------------------- /src/augment/positive.py: -------------------------------------------------------------------------------- 1 | """TODO: remove this file.""" 2 | 3 | from .base import Augmentor 4 | 5 | 6 | class Echo(Augmentor): 7 | """ TODO: maybe this class is unnecessary. """ 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def apply(self, data): 12 | return data 13 | -------------------------------------------------------------------------------- /src/augment/random_drop_edge.py: -------------------------------------------------------------------------------- 1 | from .base import Augmentor 2 | import copy 3 | import random 4 | 5 | 6 | class RandomDropEdge(Augmentor): 7 | def __init__(self, is_x: bool = True, is_adj: bool = False, drop_percent=0.2): 8 | super().__init__() 9 | self.is_x = is_x 10 | self.is_adj = is_adj 11 | self.drop_percent = drop_percent 12 | 13 | def __call__(self, data, drop_percent=None): 14 | drop_percent = drop_percent if drop_percent else self.drop_percent 15 | data_tmp = copy.deepcopy(data) 16 | percent = self.drop_percent / 2 17 | row_idx, col_idx = data.x.nonzero().T 18 | 19 | index_list = [] 20 | for i in range(len(row_idx)): 21 | index_list.append((row_idx[i], col_idx[i])) 22 | 23 | edge_num = int(len(row_idx) / 2) # 9228 / 2 24 | add_drop_num = int(edge_num * percent / 2) 25 | aug_adj = copy.deepcopy(data_tmp.adj_t.to_dense()) 26 | 27 | edge_idx = [i for i in range(edge_num)] 28 | drop_idx = random.sample(edge_idx, add_drop_num) 29 | 30 | for i in drop_idx: 31 | aug_adj[index_list[i][0]][index_list[i][1]] = 0 32 | aug_adj[index_list[i][1]][index_list[i][0]] = 0 33 | 34 | node_num = data_tmp.x.shape[0] 35 | l = [(i, j) for i in range(node_num) for j in range(i)] 36 | add_list = random.sample(l, add_drop_num) 37 | 38 | for i in add_list: 39 | aug_adj[i[0]][i[1]] = 1 40 | aug_adj[i[1]][i[0]] = 1 41 | 42 | data_tmp.adj = aug_adj.to_sparse() 43 | return data_tmp 44 | -------------------------------------------------------------------------------- /src/augment/random_drop_node.py: -------------------------------------------------------------------------------- 1 | from .base import Augmentor 2 | import copy 3 | import random 4 | import torch 5 | 6 | 7 | class RandomDropNode(Augmentor): 8 | def __init__(self, is_x: bool = True, is_adj: bool = False, drop_percent=0.2): 9 | super().__init__() 10 | self.is_x = is_x 11 | self.is_adj = is_adj 12 | self.drop_percent = drop_percent 13 | 14 | def __call__(self, data, drop_percent=None): 15 | drop_percent = drop_percent if drop_percent else self.drop_percent 16 | data_tmp = copy.deepcopy(data) 17 | input_adj = torch.tensor(data_tmp.adj_t.to_dense().tolist()) 18 | input_fea = data_tmp.x.squeeze(0) 19 | 20 | node_num = input_fea.shape[0] 21 | drop_num = int(node_num * drop_percent) # number of drop nodes 22 | all_node_list = [i for i in range(node_num)] 23 | 24 | drop_node_list = sorted(random.sample(all_node_list, drop_num)) 25 | 26 | aug_input_fea = delete_row_col(input_fea, drop_node_list, only_row=True) 27 | aug_input_adj = delete_row_col(input_adj, drop_node_list) 28 | 29 | data_tmp.x = aug_input_fea 30 | data_tmp.adj_t = aug_input_adj.to_sparse() 31 | 32 | return data_tmp 33 | 34 | 35 | def delete_row_col(input_matrix, drop_list, only_row=False): 36 | remain_list = [i for i in range(input_matrix.shape[0]) if i not in drop_list] 37 | out = input_matrix[remain_list, :] 38 | if only_row: 39 | return out 40 | out = out[:, remain_list] 41 | return out -------------------------------------------------------------------------------- /src/augment/random_mask.py: -------------------------------------------------------------------------------- 1 | from .base import Augmentor 2 | import copy 3 | import random 4 | import torch 5 | 6 | 7 | class RandomMask(Augmentor): 8 | def __init__(self, is_x: bool = True, is_adj: bool = False, drop_percent=0.2): 9 | super().__init__() 10 | self.is_x = is_x 11 | self.is_adj = is_adj 12 | self.drop_percent = drop_percent 13 | 14 | def __call__(self, data, drop_percent=None): 15 | drop_percent = drop_percent if drop_percent else self.drop_percent 16 | data_tmp = copy.deepcopy(data) 17 | if self.is_x: 18 | mask_num = int(data_tmp.num_nodes * drop_percent) 19 | node_idx = [i for i in range(data_tmp.num_nodes)] 20 | mask_idx = random.sample(node_idx, mask_num) 21 | zeros = torch.zeros_like(data_tmp.x[0]) 22 | for j in mask_idx: 23 | data_tmp.x[j, :] = zeros 24 | if self.is_adj: 25 | raise NotImplementedError 26 | return data_tmp 27 | -------------------------------------------------------------------------------- /src/augment/random_mask_channel.py: -------------------------------------------------------------------------------- 1 | from .base import Augmentor 2 | import copy 3 | import random 4 | import torch 5 | 6 | 7 | class RandomMaskChannel(Augmentor): 8 | def __init__(self, is_x: bool = True, is_adj: bool = False, drop_percent=0.2): 9 | super().__init__() 10 | self.is_x = is_x 11 | self.is_adj = is_adj 12 | self.drop_percent = drop_percent 13 | 14 | def __call__(self, data, drop_percent=None): 15 | drop_percent = drop_percent if drop_percent else self.drop_percent 16 | data_tmp = copy.deepcopy(data) 17 | device = data_tmp.x.device 18 | if self.is_x: 19 | feat_mask = torch.FloatTensor(data_tmp.x.shape[1]).uniform_() > drop_percent 20 | feat_mask = feat_mask.to(device) 21 | data_tmp.x = feat_mask*data_tmp.x 22 | if self.is_adj: 23 | raise NotImplementedError 24 | return data_tmp 25 | -------------------------------------------------------------------------------- /src/augment/shuffle_node.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from torch_geometric.data import Data 5 | 6 | from .base import Augmentor 7 | 8 | 9 | class ShuffleNode(Augmentor): 10 | r"""Randomly shuffle nodes.""" 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def __call__(self, data: Data): 15 | data_tmp = copy.deepcopy(data) 16 | idx_tmp = torch.randperm(len(data.x)) 17 | data_tmp.x = data.x[idx_tmp] 18 | return data_tmp 19 | -------------------------------------------------------------------------------- /src/augment/sum_emb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.typing import Tensor 3 | 4 | from .base import Augmentor 5 | from typing import Callable, Optional 6 | 7 | 8 | class SumEmb(Augmentor): 9 | """ 10 | Get the summary for a batch of embedidngs. 11 | """ 12 | def __init__(self, 13 | pooler: str="avg", 14 | act: Optional[Callable]=None): 15 | super().__init__() 16 | assert pooler in ["avg", "max"], 'Pooler should be either "avg" or "max"' 17 | self.pooler = pooler 18 | self.act = act 19 | 20 | def __call__(self, 21 | x: Tensor, 22 | dim: int=1, 23 | keepdim: bool=False): 24 | if self.pooler == "avg": 25 | s = torch.mean(x, dim, keepdim=keepdim) 26 | else: 27 | s = torch.max(x, dim, keepdim=keepdim) 28 | 29 | if self.act: 30 | s = self.act(s) 31 | return s 32 | -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import load_yaml 2 | -------------------------------------------------------------------------------- /src/config/config.py: -------------------------------------------------------------------------------- 1 | # reference: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/graphgym/config.html 2 | 3 | import functools 4 | import inspect 5 | import logging 6 | import os 7 | import shutil 8 | import warnings 9 | from collections.abc import Iterable 10 | from dataclasses import asdict 11 | # from typing import Any 12 | from yacs.config import CfgNode as CN 13 | import torch_geometric.graphgym.register as register 14 | from torch_geometric.data.makedirs import makedirs 15 | 16 | # try: # Define global config object 17 | # from yacs.config import CfgNode as CN 18 | # cfg = CN() 19 | # except ImportError: 20 | # cfg = None 21 | # warnings.warn("Could not define global config object. Please install " 22 | # "'yacs' via 'pip install yacs' in order to use GraphGym") 23 | 24 | 25 | def set_cfg(cfg): 26 | r''' 27 | This function sets the default config value. 28 | 1) Note that for an experiment, only part of the arguments will be used 29 | The remaining unused arguments won't affect anything. 30 | So feel free to register any argument in graphgym.contrib.config 31 | 2) We support *at most* two levels of configs, e.g., cfg.dataset.name 32 | 33 | :return: configuration use by the experiment. 34 | ''' 35 | if cfg is None: 36 | return cfg 37 | 38 | # ----------------------------------------------------------------------- # 39 | # Basic options 40 | # ----------------------------------------------------------------------- # 41 | 42 | # Use cuda or not 43 | cfg.use_cuda = False 44 | 45 | cfg.gpu_idx = 0 46 | 47 | # Output directory 48 | cfg.out_dir = 'results' 49 | 50 | # Config name (in out_dir) 51 | cfg.cfg_dest = 'config.yaml' 52 | 53 | # Names of registered custom metric funcs to be used (use defaults if none) 54 | cfg.custom_metrics = [] 55 | 56 | # Numpy random seed 57 | cfg.seed = 0 58 | 59 | # Pytorch random seed 60 | cfg.torch_seed = 0 61 | 62 | # Print rounding 63 | cfg.round = 4 64 | 65 | # If visualize embedding. 66 | cfg.view_emb = False 67 | 68 | # ----------------------------------------------------------------------- # 69 | # Globally shared variables: 70 | # These variables will be set dynamically based on the input dataset 71 | # Do not directly set them here or in .yaml files 72 | # ----------------------------------------------------------------------- # 73 | 74 | cfg.share = CN(new_allowed=True) 75 | 76 | # Size of input dimension 77 | cfg.share.dim_in = 1 78 | 79 | # Size of out dimension, i.e., number of labels to be predicted 80 | cfg.share.dim_out = 1 81 | 82 | # Number of dataset splits: train/val/test 83 | cfg.share.num_splits = 1 84 | 85 | # ----------------------------------------------------------------------- # 86 | # Dataset options 87 | # ----------------------------------------------------------------------- # 88 | cfg.dataset = CN(new_allowed=True) 89 | 90 | # Name of the dataset 91 | cfg.dataset.name = 'Cora' 92 | 93 | # Dir to load the dataset. If the dataset is downloaded, this is the 94 | # cache dir 95 | cfg.dataset.dir = './datasets' 96 | 97 | # use pyg_data to load dataset 98 | cfg.dataset.root = 'pyg_data' 99 | 100 | # ----------------------------------------------------------------------- # 101 | # Model options 102 | # ----------------------------------------------------------------------- # 103 | cfg.model = CN(new_allowed=True) 104 | 105 | # Model type to use 106 | cfg.model.type = 'gnn' 107 | 108 | # Input channels 109 | cfg.model.in_channels = 1433 110 | 111 | # Hidden channels 112 | cfg.model.hidden_channels = 512 113 | 114 | # Number of layers in the Model 115 | cfg.model.num_layers = 1 116 | 117 | # Augmentation type 118 | cfg.aug_type = 'mask' 119 | 120 | # ----------------------------------------------------------------------- # 121 | # Optimizer options 122 | # ----------------------------------------------------------------------- # 123 | cfg.optim = CN(new_allowed=True) 124 | 125 | # optimizer: sgd, adam 126 | cfg.optim.optimizer = 'adam' 127 | 128 | # Base learning rate 129 | cfg.optim.base_lr = 0.001 130 | 131 | # L2 regularization 132 | cfg.optim.weight_decay = 5e-4 133 | 134 | # SGD momentum 135 | cfg.optim.momentum = 0.9 136 | 137 | # scheduler: none, steps, cos 138 | cfg.optim.scheduler = 'cos' 139 | 140 | # Steps for 'steps' policy (in epochs) 141 | cfg.optim.steps = [30, 60, 90] 142 | 143 | # Learning rate multiplier for 'steps' policy 144 | cfg.optim.lr_decay = 0.1 145 | 146 | # Maximal number of epochs 147 | cfg.optim.max_epoch = 1000 148 | 149 | # Number of runs 150 | cfg.optim.run = 5 151 | 152 | # Set user customized cfgs 153 | for func in register.config_dict.values(): 154 | func(cfg) 155 | return cfg 156 | 157 | 158 | def load_yaml(filename): 159 | try: # Define global config object 160 | cfg = CN(new_allowed=True) 161 | except ImportError: 162 | cfg = None 163 | warnings.warn("Could not define global config object. Please install " 164 | "'yacs' via 'pip install yacs' in order to use GraphGym") 165 | cfg = set_cfg(cfg) 166 | cfg.merge_from_file(filename) 167 | return cfg 168 | -------------------------------------------------------------------------------- /src/data/data_non_contrast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data, InMemoryDataset 3 | import torch_geometric.transforms as T 4 | from torch_geometric.utils import to_undirected 5 | import os.path as osp 6 | from src.utils.create_data import create_masks, create_dirs, decide_config 7 | 8 | def download_pyg_data(config): 9 | """ 10 | Downloads a dataset from the PyTorch Geometric library 11 | :param config: A dict containing info on the dataset to be downloaded 12 | :return: A tuple containing (root directory, dataset name, data directory) 13 | """ 14 | leaf_dir = config["kwargs"]["root"].split("/")[-1].strip() 15 | data_dir = osp.join(config["kwargs"]["root"], "" if config["name"] == leaf_dir else config["name"]) 16 | dst_path = osp.join(data_dir, "raw", "data.pt") 17 | if not osp.exists(dst_path): 18 | DatasetClass = config["class"] 19 | if config["name"] == "WikiCS": 20 | dataset = DatasetClass(data_dir, transform=T.NormalizeFeatures()) 21 | std, mean = torch.std_mean(dataset.data.x, dim=0, unbiased=False) 22 | dataset.data.x = (dataset.data.x - mean) / std 23 | dataset.data.edge_index = to_undirected(dataset.data.edge_index) 24 | else : 25 | dataset = DatasetClass(**config["kwargs"], transform=T.NormalizeFeatures()) 26 | create_masks(data=dataset.data) 27 | torch.save((dataset.data, dataset.slices), dst_path) 28 | 29 | return config["kwargs"]["root"], config["name"], data_dir 30 | 31 | 32 | def download_data(root, name): 33 | """ 34 | Download data from different repositories. Currently only PyTorch Geometric is supported 35 | :param root: The root directory of the dataset 36 | :param name: The name of the dataset 37 | :return: 38 | """ 39 | config = decide_config(root=root, name=name) 40 | if config["src"] == "pyg": 41 | return download_pyg_data(config) 42 | 43 | 44 | class Dataset(InMemoryDataset): 45 | 46 | """ 47 | A PyTorch InMemoryDataset to build multi-view dataset through graph data augmentation 48 | """ 49 | 50 | def __init__(self, root="data", name='cora', num_parts=1, final_parts=1, augumentation=None, transform=None, 51 | pre_transform=None): 52 | self.num_parts = num_parts 53 | self.final_parts = final_parts 54 | self.augumentation = augumentation 55 | self.root, self.name, self.data_dir = download_data(root=root, name=name) 56 | create_dirs(self.dirs) 57 | super().__init__(root=self.data_dir, transform=transform, pre_transform=pre_transform) 58 | path = osp.join(self.data_dir, "processed", self.processed_file_names[0]) 59 | self.data, self.slices = torch.load(path) 60 | 61 | @property 62 | def raw_file_names(self): 63 | return ["data.pt"] 64 | 65 | @property 66 | def processed_file_names(self): 67 | if self.num_parts == 1: 68 | return [f'byg.data.aug.pt'] 69 | else: 70 | return [f'byg.data.aug.ip.{self.num_parts}.fp.{self.final_parts}.pt'] 71 | 72 | @property 73 | def raw_dir(self): 74 | return osp.join(self.data_dir, "raw") 75 | 76 | @property 77 | def processed_dir(self): 78 | return osp.join(self.data_dir, "processed") 79 | 80 | @property 81 | def model_dir(self): 82 | return osp.join(self.data_dir, "model") 83 | 84 | @property 85 | def result_dir(self): 86 | return osp.join(self.data_dir, "result") 87 | 88 | @property 89 | def dirs(self): 90 | return [self.raw_dir, self.processed_dir, self.model_dir, self.result_dir] 91 | 92 | 93 | def process_full_batch_data(self, data): 94 | """ 95 | Augmented view data generation using the full-batch data. 96 | :param view1data: 97 | :return: 98 | """ 99 | print("Processing full batch data") 100 | 101 | data = Data(edge_index=data.edge_index, edge_attr= data.edge_attr, 102 | x = data.x, y = data.y, 103 | train_mask=data.train_mask, val_mask=data.val_mask, test_mask=data.test_mask, 104 | num_nodes=data.num_nodes) 105 | return [data] 106 | 107 | def download(self): 108 | pass 109 | 110 | def process(self): 111 | """ 112 | Process either a full batch or cluster data. 113 | :return: 114 | """ 115 | processed_path = osp.join(self.processed_dir, self.processed_file_names[0]) 116 | if not osp.exists(processed_path): 117 | path = osp.join(self.raw_dir, self.raw_file_names[0]) 118 | data, _ = torch.load(path) 119 | edge_attr = data.edge_attr 120 | edge_attr = torch.ones(data.edge_index.shape[1]) if edge_attr is None else edge_attr 121 | data.edge_attr = edge_attr 122 | data_list = self.process_full_batch_data(data) 123 | data, slices = self.collate(data_list) 124 | torch.save((data, slices), processed_path) -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .logistic_regression import LogisticRegression 2 | from .svc_regression import SVCRegression 3 | from .rf_regression import RandomForestClassifier 4 | from .k_means import K_Means 5 | from .sim_search import SimSearch 6 | from .visulization_embeddings import TSNEVisulization 7 | -------------------------------------------------------------------------------- /src/evaluation/base.py: -------------------------------------------------------------------------------- 1 | class BaseEvaluator: 2 | r"""Base class for Evaluator.""" 3 | def __call__(self, *args, **kwargs): 4 | raise NotImplementedError 5 | -------------------------------------------------------------------------------- /src/evaluation/k_means.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | from .base import BaseEvaluator 7 | from typing import Union 8 | from src.typing import Tensor 9 | from sklearn.metrics import pairwise 10 | 11 | from sklearn.model_selection import GridSearchCV 12 | from sklearn.metrics import accuracy_score 13 | from sklearn.cluster import KMeans 14 | from sklearn.metrics import normalized_mutual_info_score 15 | 16 | 17 | class K_Means(BaseEvaluator): 18 | def __init__(self, 19 | k: int = 8, 20 | average_method: str = "arithmetic", 21 | init = "k-means++", 22 | n_init = 10, 23 | max_iter: int = 300, 24 | tol: float = 1e-4, 25 | verbose: int = 0, 26 | random_state = None, 27 | copy_x: bool = True, 28 | algorithm: str = "auto", 29 | n_run: int = 50, 30 | device: Union[str, int] = "cuda") -> None: 31 | self.k = k 32 | self.average_method = average_method 33 | self.init = init 34 | self.n_init = n_init 35 | self.max_iter = max_iter 36 | self.tol = tol 37 | self.verbose = verbose 38 | self.random_state = random_state 39 | self.copy_x = copy_x 40 | self.algorithm = algorithm 41 | self.n_run = n_run 42 | self.device = device 43 | 44 | def __call__(self, embs, dataset): 45 | r""" 46 | TODO: maybe we need to return something. 47 | """ 48 | embs, labels = embs.detach().cpu().numpy(), dataset.y.detach().cpu().numpy() 49 | NMI_list = [] 50 | 51 | for n in range(self.n_run): 52 | 53 | s = self.single_run( 54 | embs=embs, labels=labels) 55 | NMI_list.append(s) 56 | 57 | mean = np.mean(NMI_list) 58 | std = np.std(NMI_list) 59 | 60 | print('Evaluate node classification results') 61 | print('\t[Clustering] NMI: {:.4f} | {:.4f}'.format(mean, std)) 62 | 63 | def single_run(self, embs, labels) -> (np.ndarray): 64 | 65 | estimator = KMeans(n_clusters=self.k, init=self.init, n_init=self.n_init, max_iter= 66 | self.max_iter, tol=self.tol, verbose=self.verbose, random_state=self.random_state, copy_x=self.copy_x, algorithm=self.algorithm) 67 | estimator.fit(embs) 68 | y_pred = estimator.predict(embs) 69 | s = normalized_mutual_info_score(labels, y_pred, average_method=self.average_method) 70 | 71 | return s 72 | -------------------------------------------------------------------------------- /src/evaluation/logistic_regression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | from .base import BaseEvaluator 7 | from typing import Union 8 | from src.typing import Tensor 9 | 10 | from tqdm import tqdm 11 | 12 | 13 | class LogisticRegression(BaseEvaluator): 14 | def __init__(self, 15 | lr: float = 0.01, 16 | weight_decay: float = 0., 17 | max_iter: int = 100, 18 | n_run: int = 50, 19 | device: Union[str, int] = "cuda") -> None: 20 | self.lr = lr 21 | self.weight_decay = weight_decay 22 | self.max_iter = max_iter 23 | self.n_run = n_run 24 | self.device = device 25 | 26 | def __call__(self, embs, dataset): 27 | r""" 28 | TODO: maybe we need to return something. 29 | """ 30 | val_accs, test_accs = [], [] 31 | 32 | for n in tqdm(range(self.n_run)): 33 | n_splits, train_mask, val_mask, test_mask = process_split(dataset=dataset) 34 | test_mask = dataset.test_mask 35 | for i in range(n_splits): 36 | train_msk = train_mask[i] 37 | val_msk = val_mask[i] 38 | # some datasets (e.g., wikics) with multiple splits only has one test mask 39 | test_msk = test_mask if type(test_mask) == torch.Tensor and len(test_mask.shape) == 1 else test_mask[i] 40 | val_acc, test_acc = self.single_run( 41 | embs=embs, labels=dataset.y, train_mask=train_msk, val_mask=val_msk, test_mask=test_msk) 42 | 43 | val_accs.append(val_acc * 100) 44 | test_accs.append(test_acc * 100) 45 | 46 | val_accs = np.stack(val_accs) 47 | test_accs = np.stack(test_accs) 48 | 49 | val_acc, val_std = val_accs.mean(), val_accs.std() 50 | test_acc, test_std = test_accs.mean(), test_accs.std() 51 | 52 | print('Evaluate node classification results') 53 | print('** Val: {:.4f} ({:.4f}) | Test: {:.4f} ({:.4f}) **'.format(val_acc, val_std, test_acc, test_std)) 54 | 55 | def single_run(self, embs, labels, train_mask, val_mask, test_mask): 56 | emb_dim, num_class = embs.shape[1], labels.unique().shape[0] 57 | 58 | embs, labels = embs.to(self.device), labels.to(self.device) 59 | classifier = LogisticRegressionClassifier(emb_dim, num_class).to(self.device) 60 | optimizer = torch.optim.Adam(classifier.parameters(), lr=self.lr, weight_decay=self.weight_decay) 61 | 62 | for iter in range(self.max_iter): 63 | classifier.train() 64 | logits, loss = classifier(embs[train_mask], labels[train_mask]) 65 | optimizer.zero_grad() 66 | loss.backward() 67 | optimizer.step() 68 | 69 | val_logits, _ = classifier(embs[val_mask], labels[val_mask]) 70 | test_logits, _ = classifier(embs[test_mask], labels[test_mask]) 71 | val_preds = torch.argmax(val_logits, dim=1) 72 | test_preds = torch.argmax(test_logits, dim=1) 73 | 74 | val_acc = (torch.sum(val_preds == labels[val_mask]).float() / 75 | labels[val_mask].shape[0]).detach().cpu().numpy() 76 | test_acc = (torch.sum(test_preds == labels[test_mask]).float() / 77 | labels[test_mask].shape[0]).detach().cpu().numpy() 78 | return val_acc, test_acc 79 | 80 | 81 | class LogisticRegressionClassifier(nn.Module): 82 | r"""This is the logistic regression classifier used by DGI. 83 | 84 | Args: 85 | num_dim (int): the size of embedding dimension. 86 | num_class (int): the number of classes. 87 | """ 88 | def __init__(self, num_dim: int, num_class: int): 89 | super().__init__() 90 | self.linear = nn.Linear(num_dim, num_class) 91 | torch.nn.init.xavier_uniform_(self.linear.weight.data) 92 | self.linear.bias.data.fill_(0.0) 93 | self.cross_entropy = nn.CrossEntropyLoss() 94 | 95 | def forward(self, x: Tensor, y: Tensor): 96 | r""" 97 | Args: 98 | x (Tensor): embeddings 99 | y (Tensor): labels 100 | """ 101 | logits = self.linear(x) 102 | loss = self.cross_entropy(logits, y) 103 | return logits, loss 104 | 105 | 106 | def process_split(dataset): 107 | r"""If the dataset only has one split, then add an additional dimension to the train_mask and val_mask. 108 | If the dataset has multiple splits, then return the number of splits, original train_mask and val_mask. 109 | """ 110 | if not hasattr(dataset, 'train_mask'): 111 | from sklearn.model_selection import StratifiedKFold 112 | n_splits = 10 113 | kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=None) 114 | dataset.train_mask = [] 115 | dataset.val_mask = [] 116 | dataset.test_mask = [] 117 | for train_index, test_index in kf.split(dataset.x.cpu(), dataset.y.cpu()): 118 | dataset.train_mask.append(torch.LongTensor(train_index)) 119 | dataset.test_mask.append(torch.LongTensor(test_index)) 120 | dataset.val_mask.append(torch.LongTensor(test_index)) 121 | n_splits = len(dataset.train_mask) 122 | return n_splits, dataset.train_mask, dataset.val_mask, dataset.test_mask 123 | train_mask = dataset.train_mask 124 | val_mask = dataset.val_mask 125 | test_mask = dataset.test_mask 126 | if type(train_mask) == torch.Tensor and len(train_mask.shape) > 1: 127 | n_splits = dataset.train_mask.shape[0] 128 | elif type(dataset.train_mask) == list and len(dataset.train_mask) > 1: 129 | n_splits = len(dataset.train_mask) 130 | else: 131 | n_splits = 1 132 | train_mask = torch.unsqueeze(dataset.train_mask, 0) 133 | val_mask = torch.unsqueeze(dataset.val_mask, 0) 134 | test_mask = torch.unsqueeze(dataset.test_mask, 0) 135 | return n_splits, train_mask, val_mask, test_mask 136 | -------------------------------------------------------------------------------- /src/evaluation/rf_regression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | from .base import BaseEvaluator 7 | from typing import Union 8 | from src.typing import Tensor 9 | 10 | from sklearn.model_selection import GridSearchCV 11 | from sklearn.metrics import accuracy_score 12 | from sklearn.svm import SVC, LinearSVC 13 | from sklearn.ensemble import RandomForestClassifier as rf_cl 14 | 15 | 16 | class RandomForestClassifier(BaseEvaluator): 17 | def __init__(self, 18 | search: bool = False, 19 | n_estimators: int = 1, 20 | criterion: str = 'gini', 21 | max_depth: int = None, 22 | min_samples_split: int = 2, 23 | min_samples_leaf: int = 1, 24 | min_weight_fraction_leaf: float = 0.0, 25 | max_features: Union[int, float, str, None] = 'auto', 26 | max_leaf_nodes: int = None, 27 | min_impurity_decrease: float = 0.0, 28 | bootstrap: bool = True, 29 | obb_score: bool = False, 30 | n_jobs: int = 1, 31 | random_state: Union[int, None] = None, 32 | verbose: int = 0, 33 | warm_start: bool = False, 34 | class_weight: dict = None, 35 | n_run: int = 50, 36 | device: Union[str, int] = "cuda") -> None: 37 | 38 | self.search = search 39 | self.n_estimators = n_estimators 40 | self.criterion = criterion 41 | self.max_depth = max_depth 42 | self.min_samples_split = min_samples_split 43 | self.min_samples_leaf = min_samples_leaf 44 | self.min_weight_fraction_leaf = min_weight_fraction_leaf 45 | self.max_features = max_features 46 | self.max_leaf_nodes = max_leaf_nodes 47 | self.min_impurity_decrease = min_impurity_decrease 48 | self.bootstrap = bootstrap 49 | self.class_weight = class_weight 50 | self.verbose = verbose 51 | self.obb_score = obb_score 52 | self.n_jobs = n_jobs 53 | self.warm_start = warm_start 54 | self.class_weight = class_weight 55 | self.random_state = random_state 56 | self.n_run = n_run 57 | self.device = device 58 | 59 | def __call__(self, embs, dataset): 60 | r""" 61 | TODO: maybe we need to return something. 62 | """ 63 | val_accs, test_accs = [], [] 64 | embs, labels = embs.detach().cpu().numpy(), dataset.y.detach().cpu().numpy() 65 | 66 | for n in range(self.n_run): 67 | n_splits, train_mask, val_mask, test_mask = process_split(dataset=dataset) 68 | test_mask = dataset.test_mask 69 | train_mask, val_mask, test_mask = train_mask.detach().cpu().numpy(), val_mask.detach().cpu().numpy(), test_mask.detach().cpu().numpy() 70 | for i in range(n_splits): 71 | train_msk = train_mask[i] 72 | val_msk = val_mask[i] 73 | # some datasets (e.g., wikics) with multiple splits only has one test mask 74 | test_msk = test_mask if len(test_mask.shape) == 1 else test_mask[i] 75 | 76 | val_acc, test_acc = self.single_run( 77 | embs=embs, labels=labels, train_mask=train_msk, val_mask=val_msk, test_mask=test_msk) 78 | 79 | val_accs.append(val_acc * 100) 80 | test_accs.append(test_acc * 100) 81 | 82 | val_accs = np.stack(val_accs) 83 | test_accs = np.stack(test_accs) 84 | 85 | val_acc, val_std = val_accs.mean(), val_accs.std() 86 | test_acc, test_std = test_accs.mean(), test_accs.std() 87 | 88 | print('Evaluate node classification results') 89 | print('** Val: {:.4f} ({:.4f}) | Test: {:.4f} ({:.4f}) **'.format(val_acc, val_std, test_acc, test_std)) 90 | 91 | def single_run(self, embs, labels, train_mask, val_mask, test_mask) -> (np.ndarray, np.ndarray): 92 | val_accs, test_accs = [], [] 93 | # emb_dim, num_class = embs.shape[1], labels.unique().shape[0] 94 | 95 | # embs, labels = embs.detach().cpu().numpy(), labels.detach().cpu().numpy() 96 | if self.search: 97 | params = {'C':[0.001, 0.01,0.1,1,10,100,1000]} 98 | classifier = GridSearchCV(rf_cl(n_estimators=self.n_estimators, criterion=self.criterion, max_depth=self.max_depth, min_samples_split=self.min_samples_split, min_samples_leaf=self.min_samples_leaf, min_weight_fraction_leaf=self.min_weight_fraction_leaf, 99 | max_features=self.max_features, max_leaf_nodes=self.max_leaf_nodes, min_impurity_decrease=self.min_impurity_decrease, bootstrap=self.bootstrap, oob_score=self.obb_score, 100 | n_jobs=self.n_jobs, random_state=self.random_state, verbose=self.verbose, warm_start=self.warm_start, class_weight=self.class_weight), params, cv=5, scoring='accuracy', verbose=0) 101 | else: 102 | classifier = rf_cl(n_estimators=self.n_estimators, criterion=self.criterion, max_depth=self.max_depth, min_samples_split=self.min_samples_split, min_samples_leaf=self.min_samples_leaf, min_weight_fraction_leaf=self.min_weight_fraction_leaf, 103 | max_features=self.max_features, max_leaf_nodes=self.max_leaf_nodes, min_impurity_decrease=self.min_impurity_decrease, bootstrap=self.bootstrap, oob_score=self.obb_score, 104 | n_jobs=self.n_jobs, random_state=self.random_state, verbose=self.verbose, warm_start=self.warm_start, class_weight=self.class_weight) 105 | classifier.fit(embs[train_mask], labels[train_mask]) 106 | 107 | val_accs.append(accuracy_score(labels[val_mask], classifier.predict(embs[val_mask]))) 108 | test_accs.append(accuracy_score(labels[test_mask], classifier.predict(embs[test_mask]))) 109 | 110 | return np.mean(val_accs), np.mean(test_accs) 111 | 112 | def process_split(dataset): 113 | r"""If the dataset only has one split, then add an additional dimension to the train_mask and val_mask. 114 | If the dataset has multiple splits, then return the number of splits, original train_mask and val_mask. 115 | """ 116 | train_mask = dataset.train_mask 117 | val_mask = dataset.val_mask 118 | test_mask = dataset.test_mask 119 | if len(dataset.train_mask.shape) > 1: 120 | n_splits = dataset.train_mask.shape[0] 121 | else: 122 | n_splits = 1 123 | train_mask = torch.unsqueeze(dataset.train_mask, 0) 124 | val_mask = torch.unsqueeze(dataset.val_mask, 0) 125 | test_mask = torch.unsqueeze(dataset.test_mask, 0) 126 | return n_splits, train_mask, val_mask, test_mask 127 | 128 | -------------------------------------------------------------------------------- /src/evaluation/sim_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | from .base import BaseEvaluator 7 | from typing import Union 8 | from src.typing import Tensor 9 | from sklearn.metrics import pairwise 10 | 11 | from sklearn.model_selection import GridSearchCV 12 | from sklearn.metrics import accuracy_score 13 | from sklearn.svm import SVC, LinearSVC 14 | 15 | 16 | class SimSearch(BaseEvaluator): 17 | def __init__(self, 18 | sim_list: list = [5, 10, 20, 50, 100], 19 | n_run: int = 50, 20 | device: Union[str, int] = "cuda") -> None: 21 | self.sim_list = sim_list 22 | self.n_run = n_run 23 | self.device = device 24 | 25 | def __call__(self, embs, dataset): 26 | r""" 27 | TODO: maybe we need to return something. 28 | """ 29 | embs, labels = embs.detach().cpu().numpy(), dataset.y.detach().cpu().numpy() 30 | st = self.single_run(embs=embs, labels=labels) 31 | st = ','.join(st) 32 | 33 | print('Evaluate node classification results') 34 | print("\t[Similarity]" + " : [{}]".format(st)) 35 | 36 | def single_run(self, embs, labels) -> (np.ndarray): 37 | numRows = embs.shape[0] 38 | 39 | cos_sim_array = pairwise.cosine_similarity(embs) - np.eye(numRows) 40 | st = [] 41 | 42 | for N in self.sim_list: 43 | indices = np.argsort(cos_sim_array, axis=1)[:, -N:] 44 | tmp = np.tile(labels, (numRows, 1)) 45 | selected_label = tmp[np.repeat(np.arange(numRows), N), indices.ravel()].reshape(numRows, N) 46 | original_label = np.repeat(labels, N).reshape(numRows,N) 47 | st.append(str(np.round(np.mean(np.sum((selected_label == original_label), 1) / N), 4))) 48 | 49 | return st -------------------------------------------------------------------------------- /src/evaluation/svc_regression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | from .base import BaseEvaluator 7 | from typing import Union 8 | from src.typing import Tensor 9 | 10 | from sklearn.model_selection import GridSearchCV 11 | from sklearn.metrics import accuracy_score 12 | from sklearn.svm import SVC, LinearSVC 13 | 14 | 15 | class SVCRegression(BaseEvaluator): 16 | def __init__(self, 17 | C: float = 1.0, 18 | search: bool = True, 19 | kernel: str = 'rbf', 20 | degree: int = 3, 21 | gamma: str='auto', 22 | coef0: float = 0.0, 23 | shrinking: bool = True, 24 | probability: bool = False, 25 | tol: float = 0.001, 26 | cache_size: int = 200, 27 | class_weight: dict = None, 28 | verbose: bool = False, 29 | max_iter: int = -1, 30 | decision_function_shape: str = 'ovr', 31 | random_state: int = None, 32 | n_run: int = 50, 33 | device: Union[str, int] = "cuda") -> None: 34 | self.C = C 35 | self.search = search 36 | self.kernel = kernel 37 | self.degree = degree 38 | self.gamma = gamma 39 | self.coef0 = coef0 40 | self.shirinking = shrinking 41 | self.probability = probability 42 | self.tol = tol 43 | self.cache_size = cache_size 44 | self.class_weight = class_weight 45 | self.verbose = verbose 46 | self.max_iter = max_iter 47 | self.decision_function_shape = decision_function_shape 48 | self.random_state = random_state 49 | self.n_run = n_run 50 | self.device = device 51 | 52 | def __call__(self, embs, dataset): 53 | r""" 54 | TODO: maybe we need to return something. 55 | """ 56 | val_accs, test_accs = [], [] 57 | embs, labels = embs.detach().cpu().numpy(), dataset.y.detach().cpu().numpy() 58 | 59 | for n in range(self.n_run): 60 | n_splits, train_mask, val_mask, test_mask = process_split(dataset=dataset) 61 | test_mask = dataset.test_mask 62 | train_mask, val_mask, test_mask = train_mask.detach().cpu().numpy(), val_mask.detach().cpu().numpy(), test_mask.detach().cpu().numpy() 63 | for i in range(n_splits): 64 | train_msk = train_mask[i] 65 | val_msk = val_mask[i] 66 | # some datasets (e.g., wikics) with multiple splits only has one test mask 67 | test_msk = test_mask if len(test_mask.shape) == 1 else test_mask[i] 68 | 69 | val_acc, test_acc = self.single_run( 70 | embs=embs, labels=labels, train_mask=train_msk, val_mask=val_msk, test_mask=test_msk) 71 | 72 | val_accs.append(val_acc * 100) 73 | test_accs.append(test_acc * 100) 74 | 75 | val_accs = np.stack(val_accs) 76 | test_accs = np.stack(test_accs) 77 | 78 | val_acc, val_std = val_accs.mean(), val_accs.std() 79 | test_acc, test_std = test_accs.mean(), test_accs.std() 80 | 81 | print('Evaluate node classification results') 82 | print('** Val: {:.4f} ({:.4f}) | Test: {:.4f} ({:.4f}) **'.format(val_acc, val_std, test_acc, test_std)) 83 | 84 | def single_run(self, embs, labels, train_mask, val_mask, test_mask) -> (np.ndarray, np.ndarray): 85 | val_accs, test_accs = [], [] 86 | # emb_dim, num_class = embs.shape[1], labels.unique().shape[0] 87 | 88 | # embs, labels = embs.detach().cpu().numpy(), labels.detach().cpu().numpy() 89 | if self.search: 90 | params = {'C':[0.001, 0.01,0.1,1,10,100,1000]} 91 | classifier = GridSearchCV(SVC(C=self.C, kernel=self.kernel, degree=self.degree, gamma=self.gamma, coef0=self.coef0, shrinking=self.shirinking, probability=self.probability, 92 | tol=self.tol, cache_size=self.cache_size, class_weight=self.class_weight, verbose=self.verbose, max_iter=self.max_iter, 93 | decision_function_shape=self.decision_function_shape,random_state=self.random_state), params, cv=5, scoring='accuracy', verbose=0) 94 | else: 95 | classifier = SVC(C=self.C, kernel=self.kernel, degree=self.degree, gamma=self.gamma, coef0=self.coef0, shrinking=self.shirinking, probability=self.probability, 96 | tol=self.tol, cache_size=self.cache_size, class_weight=self.class_weight, verbose=self.verbose, max_iter=self.max_iter, 97 | decision_function_shape=self.decision_function_shape,random_state=self.random_state) 98 | classifier.fit(embs[train_mask], labels[train_mask]) 99 | 100 | val_accs.append(accuracy_score(labels[val_mask], classifier.predict(embs[val_mask]))) 101 | test_accs.append(accuracy_score(labels[test_mask], classifier.predict(embs[test_mask]))) 102 | 103 | return np.mean(val_accs), np.mean(test_accs) 104 | 105 | def process_split(dataset): 106 | r"""If the dataset only has one split, then add an additional dimension to the train_mask and val_mask. 107 | If the dataset has multiple splits, then return the number of splits, original train_mask and val_mask. 108 | """ 109 | train_mask = dataset.train_mask 110 | val_mask = dataset.val_mask 111 | test_mask = dataset.test_mask 112 | if len(dataset.train_mask.shape) > 1: 113 | n_splits = dataset.train_mask.shape[0] 114 | else: 115 | n_splits = 1 116 | train_mask = torch.unsqueeze(dataset.train_mask, 0) 117 | val_mask = torch.unsqueeze(dataset.val_mask, 0) 118 | test_mask = torch.unsqueeze(dataset.test_mask, 0) 119 | return n_splits, train_mask, val_mask, test_mask 120 | 121 | -------------------------------------------------------------------------------- /src/evaluation/visulization_embeddings.py: -------------------------------------------------------------------------------- 1 | from .base import BaseEvaluator 2 | from typing import Union 3 | from src.typing import Tensor 4 | 5 | import pickle 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from sklearn.manifold import TSNE 9 | import seaborn as sns 10 | 11 | 12 | class TSNEVisulization(BaseEvaluator): 13 | def __init__(self, 14 | n_colors: int = 8, 15 | n_components: int =2, 16 | perplexity: float = 30.0, 17 | early_exaggeration: float = 12.0, 18 | learning_rate: Union[float, str] = "auto", 19 | n_iter: int = 1000, 20 | n_iter_without_progress: int = 300, 21 | min_grad_norm: float = 1e-7, 22 | metric: str = 'euclidean', 23 | metric_params: dict = None, 24 | init: Union[str, np.ndarray] = "pca", 25 | verbose: int = 0, 26 | random_state: int = None, 27 | method: str = 'barnes_hut', 28 | angle: float = 0.5, 29 | n_jobs: int = None, 30 | device: Union[str, int] = "cuda") -> None: 31 | 32 | self.n_components = n_components 33 | self.perplexity = perplexity 34 | self.early_exaggeration = early_exaggeration 35 | self.learning_rate = learning_rate 36 | self.n_iter = n_iter 37 | self.n_iter_without_progress = n_iter_without_progress 38 | self.min_grad_norm = min_grad_norm 39 | self.metric = metric 40 | self.metric_params = metric_params 41 | self.init = init 42 | self.verbose = verbose 43 | self.random_state = random_state 44 | self.method = method 45 | self.angle = angle 46 | self.n_jobs = n_jobs 47 | self.n_colors = n_colors 48 | self.device = device 49 | 50 | def __call__(self, embs, dataset): 51 | r""" 52 | TODO: maybe we need to return something. 53 | """ 54 | print(self.method) 55 | # tsne = TSNE(random_state=0, init='pca', n_components=3) 56 | tsne = TSNE(n_components=self.n_components, perplexity=self.perplexity, early_exaggeration=self.early_exaggeration, 57 | learning_rate=self.learning_rate, n_iter=self.n_iter, n_iter_without_progress=self.n_iter_without_progress, min_grad_norm=self.min_grad_norm, metric=self.metric, 58 | init=self.init, verbose=self.verbose, random_state=self.random_state, method=self.method, n_jobs=self.n_jobs) 59 | 60 | color = sns.hls_palette(self.n_colors) 61 | 62 | if isinstance(embs, Tensor): 63 | embs = embs.detach().cpu().numpy() 64 | labels = np.reshape(dataset.y.detach().cpu().numpy(), -1) 65 | else: 66 | labels = np.reshape(dataset.y, -1) 67 | 68 | 69 | 70 | tsne_results = tsne.fit_transform(embs) 71 | 72 | x = tsne_results[:, 0] 73 | y = tsne_results[:, 1] 74 | 75 | sns.set() 76 | fig = plt.figure() 77 | ax = fig.add_subplot(projection='3d') 78 | fig.set_size_inches(10, 10) 79 | 80 | for i in range(int(np.max(labels))+1): 81 | idx = labels == i 82 | x0 = x[idx] 83 | y0 = y[idx] 84 | if self.n_components > 2: 85 | z0 = tsne_results[:, 2][idx] 86 | ax.scatter3D(x0, y0, z0, color=color[i], marker=".", linewidths=1) 87 | else: 88 | plt.scatter(x0, y0, color=color[i], marker=".", linewidths=1) 89 | plt.axis("off") 90 | plt.tight_layout() 91 | plt.show() 92 | plt.savefig('test.png') 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /src/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import Loader, FullLoader 2 | from .augment_data_loader import AugmentDataLoader 3 | -------------------------------------------------------------------------------- /src/loader/augment_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class AugmentDataLoader: 5 | def __init__(self, batch_list, shuffle=True): 6 | self.batch_list = batch_list 7 | self.shuffle = shuffle 8 | 9 | self._nun_samples = len(self.batch_list) 10 | self._cnt = 0 11 | self._id_list = np.arange(0, self._nun_samples) 12 | 13 | def __len__(self): 14 | return self._nun_samples 15 | 16 | def __iter__(self): 17 | yield self.__next__() 18 | 19 | def __next__(self): 20 | self._cnt = (self._cnt + 1) % self._nun_samples 21 | if self._cnt == 0 and self.shuffle: 22 | np.random.shuffle(self._id_list) 23 | idx = self._id_list[self._cnt] 24 | return self.batch_list[idx] 25 | -------------------------------------------------------------------------------- /src/loader/loader.py: -------------------------------------------------------------------------------- 1 | BATCH_SIZE = 1 # for full loader 2 | 3 | 4 | class Loader: 5 | """ 6 | TODO: torch geomtric data structure. 7 | """ 8 | def __init__(self, batch_size: int, data, **kwargs): 9 | self.batch_size = batch_size 10 | self.data = data 11 | 12 | def __iter__(self, *args, **kwargs): 13 | raise NotImplementedError 14 | 15 | 16 | class FullLoader(Loader): 17 | """ 18 | r"Load entire graph each time." 19 | """ 20 | def __init__(self, data): 21 | super().__init__(batch_size=BATCH_SIZE, data=data) 22 | 23 | def __iter__(self): 24 | return self.data 25 | -------------------------------------------------------------------------------- /src/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .negative_mi import NegativeMI 2 | from .infograph_losses import LocalGlobalLoss, AdjLoss 3 | from .simple_losses import cosine_similarity 4 | -------------------------------------------------------------------------------- /src/losses/infograph_losses.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Par of code are adapted from cortex_DIM 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import math 9 | from typing import Optional 10 | 11 | 12 | class LocalGlobalLoss(torch.nn.Module): 13 | r"""Negative Mutual Information. 14 | Estimate the negative mutual information loss for the specified neural similarity function (sim_function) and the 15 | inputs x, y, x_ind, y_ind. (x, y) is sampled from the joint distribution (x, y)~p(x, y). x_ind and y_ind are 16 | independently sampled from their marginal distributions x_ind~p(x), y_ind~p(y). 17 | Reference: Deep Graph Infomax. 18 | 19 | Args: 20 | in_channels (Optional[int]): input channels to the neural mutual information estimator. 21 | 22 | sim_function (Optional[torch.nn.Module]): the neural similarity measuring function. 23 | The default is the bilinear similarity function used by DGI. 24 | """ 25 | def __init__(self, ): 26 | super().__init__() 27 | 28 | def forward(self, l_enc, g_enc, batch, measure): 29 | ''' 30 | Args: 31 | l: Local feature map. 32 | g: Global features. 33 | measure: Type of f-divergence. For use with mode `fd` 34 | mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`. 35 | Returns: 36 | torch.Tensor: Loss. 37 | ''' 38 | num_graphs = g_enc.shape[0] 39 | num_nodes = l_enc.shape[0] 40 | 41 | pos_mask = torch.zeros((num_nodes, num_graphs)).cuda() 42 | neg_mask = torch.ones((num_nodes, num_graphs)).cuda() 43 | for nodeidx, graphidx in enumerate(batch): 44 | pos_mask[nodeidx][graphidx] = 1. 45 | neg_mask[nodeidx][graphidx] = 0. 46 | 47 | res = torch.mm(l_enc, g_enc.t()) 48 | 49 | E_pos = get_positive_expectation(res * pos_mask, measure, average=False).sum() 50 | E_pos = E_pos / num_nodes 51 | E_neg = get_negative_expectation(res * neg_mask, measure, average=False).sum() 52 | E_neg = E_neg / (num_nodes * (num_graphs - 1)) 53 | 54 | return E_neg - E_pos 55 | 56 | 57 | class AdjLoss(torch.nn.Module): 58 | def __init__(self,): 59 | super().__init__() 60 | self.loss = nn.BCELoss() 61 | 62 | def forward(self, l_enc, edge_index): 63 | num_nodes = l_enc.shape[0] 64 | 65 | adj = torch.zeros((num_nodes, num_nodes)).cuda() 66 | mask = torch.eye(num_nodes).cuda() 67 | for node1, node2 in zip(edge_index[0], edge_index[1]): 68 | adj[node1.item()][node2.item()] = 1. 69 | adj[node2.item()][node1.item()] = 1. 70 | 71 | res = torch.sigmoid((torch.mm(l_enc, l_enc.t()))) 72 | res = (1-mask) * res 73 | # print(res.shape, adj.shape) 74 | # input() 75 | 76 | return self.loss(res, adj) 77 | 78 | 79 | def log_sum_exp(x, axis=None): 80 | """Log sum exp function 81 | 82 | Args: 83 | x: Input. 84 | axis: Axis over which to perform sum. 85 | 86 | Returns: 87 | torch.Tensor: log sum exp 88 | 89 | """ 90 | x_max = torch.max(x, axis)[0] 91 | y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max 92 | return y 93 | 94 | 95 | def raise_measure_error(measure): 96 | supported_measures = ['GAN', 'JSD', 'X2', 'KL', 'RKL', 'DV', 'H2', 'W1'] 97 | raise NotImplementedError( 98 | 'Measure `{}` not supported. Supported: {}'.format(measure, 99 | supported_measures)) 100 | 101 | def get_positive_expectation(p_samples, measure, average=True): 102 | """Computes the positive part of a divergence / difference. 103 | 104 | Args: 105 | p_samples: Positive samples. 106 | measure: Measure to compute for. 107 | average: Average the result over samples. 108 | 109 | Returns: 110 | torch.Tensor 111 | 112 | """ 113 | log_2 = math.log(2.) 114 | 115 | if measure == 'GAN': 116 | Ep = - F.softplus(-p_samples) 117 | elif measure == 'JSD': 118 | Ep = log_2 - F.softplus(- p_samples) 119 | elif measure == 'X2': 120 | Ep = p_samples ** 2 121 | elif measure == 'KL': 122 | Ep = p_samples + 1. 123 | elif measure == 'RKL': 124 | Ep = -torch.exp(-p_samples) 125 | elif measure == 'DV': 126 | Ep = p_samples 127 | elif measure == 'H2': 128 | Ep = 1. - torch.exp(-p_samples) 129 | elif measure == 'W1': 130 | Ep = p_samples 131 | else: 132 | raise_measure_error(measure) 133 | 134 | if average: 135 | return Ep.mean() 136 | else: 137 | return Ep 138 | 139 | 140 | def get_negative_expectation(q_samples, measure, average=True): 141 | """Computes the negative part of a divergence / difference. 142 | 143 | Args: 144 | q_samples: Negative samples. 145 | measure: Measure to compute for. 146 | average: Average the result over samples. 147 | 148 | Returns: 149 | torch.Tensor 150 | 151 | """ 152 | log_2 = math.log(2.) 153 | 154 | if measure == 'GAN': 155 | Eq = F.softplus(-q_samples) + q_samples 156 | elif measure == 'JSD': 157 | Eq = F.softplus(-q_samples) + q_samples - log_2 158 | elif measure == 'X2': 159 | Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2) 160 | elif measure == 'KL': 161 | Eq = torch.exp(q_samples) 162 | elif measure == 'RKL': 163 | Eq = q_samples - 1. 164 | elif measure == 'DV': 165 | Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0)) 166 | elif measure == 'H2': 167 | Eq = torch.exp(q_samples) - 1. 168 | elif measure == 'W1': 169 | Eq = q_samples 170 | else: 171 | raise_measure_error(measure) 172 | 173 | if average: 174 | return Eq.mean() 175 | else: 176 | return Eq 177 | -------------------------------------------------------------------------------- /src/losses/negative_mi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from typing import Optional 4 | from torch_geometric.typing import Tensor 5 | 6 | from src.similarity_functions import Bilinear 7 | 8 | 9 | class NegativeMI(torch.nn.Module): 10 | r"""Negative Mutual Information. 11 | Estimate the negative mutual information loss for the specified neural similarity function (sim_function) and the 12 | inputs x, y, x_ind, y_ind. (x, y) is sampled from the joint distribution (x, y)~p(x, y). x_ind and y_ind are 13 | independently sampled from their marginal distributions x_ind~p(x), y_ind~p(y). 14 | Reference: Deep Graph Infomax. 15 | 16 | Args: 17 | in_channels (Optional[int]): input channels to the neural mutual information estimator. 18 | 19 | sim_func (Optional[torch.nn.Module]): the neural similarity measuring function. 20 | The default is the bilinear similarity function used by DGI. 21 | """ 22 | def __init__(self, 23 | in_channels: Optional[int] = None, 24 | sim_func: Optional[torch.nn.Module] = None): 25 | super().__init__() 26 | self.loss_func = torch.nn.BCEWithLogitsLoss() 27 | 28 | self.sim = sim_func 29 | if self.sim is None: 30 | assert in_channels is not None, \ 31 | "If use the default bilinear discriminator, then hidden_channels must be set." 32 | self.sim = Bilinear(in_channels=in_channels) 33 | 34 | def forward(self, x: Tensor, y: Tensor, x_ind: Tensor, y_ind: Tensor): 35 | r""" 36 | Args: 37 | (x, y): sampled from the joint distribution p(x, y). [batch_size, hidden_channels] 38 | x_ind: sampled from p(x). [batch_size, hidden_channels] 39 | y_ind: sampled from p(y). [batch_size, hidden_channels] 40 | 41 | Returns: 42 | The loss of the mutual information estimation. 43 | """ 44 | logits_pos = self.sim(x=x, y=y) 45 | logits_neg = self.sim(x=x_ind, y=y_ind) 46 | logits = torch.cat([logits_pos, logits_neg], -1) 47 | 48 | label_pos = torch.ones(logits_pos.shape[0]) 49 | label_neg = torch.zeros(logits_neg.shape[0]) 50 | labels = torch.stack((label_pos, label_neg), -1).to(x.device) 51 | return self.loss_func(logits, labels) 52 | -------------------------------------------------------------------------------- /src/losses/sce_loss.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class SCELoss(torch.nn.Module): 7 | def __init__(self, ): 8 | super().__init__() 9 | 10 | def forward(self, x: Tensor, y: Tensor, alpha=3): 11 | x = F.normalize(x, p=2, dim=-1) 12 | y = F.normalize(y, p=2, dim=-1) 13 | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha).mean() 14 | return loss 15 | -------------------------------------------------------------------------------- /src/losses/sig_loss.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class SIGLoss(torch.nn.Module): 7 | def __init__(self, ): 8 | super().__init__() 9 | 10 | def forward(self, x: Tensor, y: Tensor): 11 | x = F.normalize(x, p=2, dim=-1) 12 | y = F.normalize(y, p=2, dim=-1) 13 | 14 | loss = (x * y).sum(1) 15 | loss = torch.sigmoid(-loss) 16 | loss = loss.mean() 17 | return loss 18 | 19 | -------------------------------------------------------------------------------- /src/losses/simple_losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def cosine_similarity(x, y): 5 | x = F.normalize(x, dim=-1, p=2) 6 | y = F.normalize(y, dim=-1, p=2) 7 | return 2 - 2 * (x * y).sum(dim=-1) 8 | -------------------------------------------------------------------------------- /src/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseMethod 2 | 3 | from .dgi import DGI, DGIEncoder 4 | from .graphcl import GraphCL, GraphCLEncoder 5 | from .infograph import InfoGraph 6 | from .bgrl import BGRL, BGRLEncoder 7 | from .afgrl import AFGRL, AFGRLEncoder 8 | from .merit import GCN, Merit 9 | from .sugrl import SUGRL, SugrlGCN, SugrlMLP 10 | from .mvgrl import MVGRL, MVGRLEncoder 11 | 12 | from .hdmi import HDMI, HDMIEncoder 13 | from .heco import HeCo, Sc_encoder, Mp_encoder, HeCoDBLPTransform 14 | 15 | __all__ = [ 16 | "BaseMethod", 17 | "DGI", 18 | "DGIEncoder", 19 | "GraphCL", 20 | "GraphCLEncoder", 21 | ] 22 | -------------------------------------------------------------------------------- /src/methods/adgcl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torch 6 | 7 | from .base import Model 8 | from nn.utils import AvgReadout, DiscriminatorMVGRL 9 | 10 | from torch_geometric.typing import Tensor, Adj 11 | 12 | 13 | class SUGRL(Model): 14 | r"""The full model to train the encoder. 15 | 16 | Args: 17 | encoder (torch.nn.Module): the encoder to be trained. 18 | discriminator (torch.nn.Module): the discriminator for contrastive learning. 19 | """ 20 | def __init__(self, encoder: torch.nn.Module): 21 | super().__init__(encoder=encoder) 22 | self.encoder_1 = encoder[0] 23 | self.encoder_2 = encoder[1] 24 | # self.discriminator = discriminator 25 | self.read = AvgReadout() 26 | self.sigmoid = torch.nn.Sigmoid() 27 | 28 | def forward(self, x: Tensor, x_neg: Tensor, adj: Adj, diff: Adj, is_sparse: bool = True, 29 | msk: Tensor = None, samp_bias1: Tensor = None, samp_bias2: Tensor = None): 30 | h_1 = self.encoder_1(x, adj, is_sparse) 31 | c_1 = self.read(h_1, msk) 32 | c_1 = self.sigmoid(c_1) 33 | 34 | h_2 = self.encoder_2(x, diff, is_sparse) 35 | c_2 = self.read(h_2, msk) 36 | c_2 = self.sigmoid(c_2) 37 | 38 | h_3 = self.encoder_1(x_neg, adj, is_sparse) 39 | 40 | h_4 = self.encoder_2(x_neg, diff, is_sparse) 41 | 42 | ret = self.discriminator(c_1, c_2, h_1, h_2, h_3, h_4, samp_bias1, samp_bias2) 43 | return ret, h_1, h_2 44 | 45 | def get_embs(self, x: Tensor, adj: Adj, diff: Adj, is_sparse: bool = True, msk: Tensor = None): 46 | h_a = self.MLP(seq_a) 47 | if self.sparse: 48 | h_p = torch.spmm(adj, h_a) 49 | else: 50 | h_p = torch.mm(adj, h_a) 51 | return h_a.detach(), h_p.detach() 52 | 53 | def get_embs_numpy(self, x: Tensor, adj: Adj, diff: Adj, is_sparse: bool = True, msk: Tensor = None): 54 | embs = self.get_embs(x=x, adj=adj, diff=diff, is_sparse=is_sparse) 55 | return embs.cpu().to_numpy() 56 | 57 | def make_mlplayers(in_channel, cfg, batch_norm=False, out_layer =None): 58 | layers = [] 59 | in_channels = in_channel 60 | layer_num = len(cfg) 61 | for i, v in enumerate(cfg): 62 | out_channels = v 63 | mlp = nn.Linear(in_channels, out_channels) 64 | if batch_norm: 65 | layers += [mlp, nn.BatchNorm1d(out_channels, affine=False), nn.ReLU()] 66 | elif i != (layer_num-1): 67 | layers += [mlp, nn.ReLU()] 68 | else: 69 | layers += [mlp] 70 | in_channels = out_channels 71 | if out_layer != None: 72 | mlp = nn.Linear(in_channels, out_layer) 73 | layers += [mlp] 74 | return nn.Sequential(*layers) 75 | 76 | class SUGRL_Fast(nn.Module): 77 | def __init__(self, n_in ,cfg = None, dropout = 0.2): 78 | super(SUGRL_Fast, self).__init__() 79 | self.MLP = make_mlplayers(n_in, cfg) 80 | self.act = nn.ReLU() 81 | self.dropout = dropout 82 | self.A = None 83 | self.sparse = True 84 | for m in self.modules(): 85 | self.weights_init(m) 86 | 87 | def weights_init(self, m): 88 | if isinstance(m, nn.Linear): 89 | torch.nn.init.xavier_uniform_(m.weight.data) 90 | if m.bias is not None: 91 | m.bias.data.fill_(0.0) 92 | 93 | def forward(self, seq_a, adj=None): 94 | if self.A is None: 95 | self.A = adj 96 | seq_a = F.dropout(seq_a, self.dropout, training=self.training) 97 | 98 | h_a = self.MLP(seq_a) 99 | h_p_0 = F.dropout(h_a, 0.2, training=self.training) 100 | if self.sparse: 101 | h_p = torch.spmm(adj, h_p_0) 102 | else: 103 | h_p = torch.mm(adj, h_p_0) 104 | return h_a, h_p 105 | 106 | def embed(self, seq_a , adj=None ): 107 | h_a = self.MLP(seq_a) 108 | if self.sparse: 109 | h_p = torch.spmm(adj, h_a) 110 | else: 111 | h_p = torch.mm(adj, h_a) 112 | return h_a.detach(), h_p.detach() 113 | -------------------------------------------------------------------------------- /src/methods/afgrl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | # from augment import DataAugmentation, AugNegDGI, AugPosDGI 5 | from src.loader import Loader, FullLoader 6 | from .base import BaseMethod 7 | from torch_geometric.nn.models import GCN 8 | from torch_geometric.typing import * 9 | from torch_geometric.nn import GCNConv 10 | import copy 11 | from .utils import EMA, update_moving_average 12 | import torch.nn.functional as F 13 | from src.loader import AugmentDataLoader 14 | import torch.nn as nn 15 | 16 | 17 | class AFGRL(BaseMethod): 18 | r"""The full model to train the encoder. 19 | 20 | Args: 21 | encoder (torch.nn.Module): the encoder to be trained. 22 | discriminator (torch.nn.Module): the discriminator for contrastive learning. 23 | """ 24 | def __init__(self, student_encoder: torch.nn.Module, teacher_encoder: torch.nn.Module, data_augment=None, adj_ori = None, topk=8): 25 | super().__init__(encoder=student_encoder, data_augment=data_augment, loss_function=loss_fn) 26 | self.encoder = student_encoder 27 | self.teacher_encoder = teacher_encoder 28 | set_requires_grad(self.teacher_encoder, False) 29 | rep_dim = self.encoder.hidden_channels 30 | pred_hid = rep_dim*2 31 | self.student_predictor = nn.Sequential(nn.Linear(rep_dim, pred_hid), nn.PReLU(), nn.Linear(pred_hid, rep_dim)) 32 | self.student_predictor.apply(init_weights) 33 | self.data_augment = data_augment 34 | self.topk = topk 35 | self.adj_ori = adj_ori 36 | 37 | def forward(self, batch): 38 | student = self.encoder(batch, batch.edge_index) 39 | pred = self.student_predictor(student) 40 | 41 | with torch.no_grad(): 42 | teacher = self.teacher_encoder(batch, batch.edge_index) 43 | 44 | adj_search = self.adj_ori 45 | student, teacher, pred = torch.squeeze(student), torch.squeeze(teacher), torch.squeeze(pred) 46 | ind, k = self.data_augment(adj_search, F.normalize(student, dim=-1, p=2), F.normalize(teacher, dim=-1, p=2), self.topk) 47 | 48 | loss1 = loss_fn(pred[ind[0]], teacher[ind[1]].detach()) 49 | loss2 = loss_fn(pred[ind[1]], teacher[ind[0]].detach()) 50 | loss = loss1 + loss2 51 | 52 | return loss.mean() 53 | 54 | def apply_data_augment_offline(self, dataloader): 55 | batch_list = [] 56 | for i, batch in enumerate(dataloader): 57 | batch = batch.to(self._device) 58 | batch_list.append(batch) 59 | new_loader = AugmentDataLoader(batch_list=batch_list) 60 | return new_loader 61 | 62 | # def get_embs(self, x, edge_index): 63 | # return self.encoder(x, edge_index) 64 | 65 | def get_embs(self, data): 66 | return self.encoder(data, data.edge_index).detach() 67 | 68 | 69 | def loss_fn(x, y): 70 | x = F.normalize(x, dim=-1, p=2) 71 | y = F.normalize(y, dim=-1, p=2) 72 | return 2 - 2 * (x * y).sum(dim=-1) 73 | 74 | 75 | def init_weights(m): 76 | if type(m) == nn.Linear: 77 | torch.nn.init.xavier_uniform_(m.weight) 78 | m.bias.data.fill_(0.01) 79 | 80 | def set_requires_grad(model, val): 81 | for p in model.parameters(): 82 | p.requires_grad = val 83 | 84 | 85 | class AFGRLEncoder(nn.Module): 86 | 87 | def __init__(self, in_channel, hidden_channels, **kwargs): 88 | super().__init__() 89 | self.hidden_channels = hidden_channels[-1] 90 | self.conv1 = GCNConv(in_channel, hidden_channels[0]) 91 | self.bn1 = nn.BatchNorm1d(hidden_channels[0], momentum = 0.01) 92 | self.prelu1 = nn.PReLU() 93 | self.num_layer = len(hidden_channels) 94 | self.conv_list, self.bn_list, self.act_list = nn.ModuleList([]), nn.ModuleList([]), nn.ModuleList([]) 95 | for i in range(self.num_layer-1): 96 | self.conv_list.append(GCNConv(hidden_channels[i],hidden_channels[i+1])) 97 | self.bn_list.append(nn.BatchNorm1d(hidden_channels[i+1], momentum = 0.01)) 98 | self.act_list.append(nn.PReLU()) 99 | 100 | def forward(self, data, edge_index, edge_weight=None): 101 | x = data.x 102 | x = self.conv1(x, edge_index, edge_weight=edge_weight) 103 | x = self.prelu1(self.bn1(x)) 104 | for i in range(self.num_layer-1): 105 | x = self.conv_list[i](x, edge_index, edge_weight=edge_weight) 106 | x = self.act_list[i](self.bn_list[i](x)) 107 | return x -------------------------------------------------------------------------------- /src/methods/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy 3 | import torch 4 | 5 | from typing import Union, Callable, Optional, Any 6 | from src.typing import OptAugment, Tensor 7 | 8 | ENCODER_NAME = "encoder.ckpt" 9 | MODEL_NAME = "model.ckpt" 10 | 11 | 12 | class BaseMethod(torch.nn.Module): 13 | r"""Base class for self-supervised learning methods. 14 | 15 | Args: 16 | encoder (torch.nn.Module): the encoder to be trained. 17 | loss_function (Callable): loss function. 18 | data_augment (OptAugment): data augment to be used. 19 | emb_augment (OptAugment): embedding augment to be used. 20 | """ 21 | 22 | # def __init__(self, 23 | # encoder: torch.nn.Module, 24 | # loss_function: Union[Callable, torch.nn.Module], 25 | # data_augment: OptAugment = None, 26 | # emb_augment: OptAugment = None): 27 | # super().__init__() 28 | 29 | # self.encoder = encoder 30 | # self.loss_function = loss_function 31 | # self.data_augment = data_augment 32 | # self.emb_augment = emb_augment 33 | 34 | def __init__(self, 35 | encoder: torch.nn.Module, 36 | device: str="cuda", 37 | save_root: str="./", 38 | *args, **kwargs): 39 | super().__init__() 40 | 41 | self.encoder = encoder 42 | self._device = device # device: cuda or cpu 43 | self.save_root = save_root # record the latest path used by save() or load() 44 | 45 | def __repr__(self) -> str: 46 | return f'{self.__class__.__name__}()' 47 | 48 | def forward(self, *args, **kwargs) -> Tensor: 49 | r"""Perform the forward pass.""" 50 | raise NotImplementedError 51 | 52 | def apply_data_augment(self, *args, **kwargs) -> Any: 53 | r"""Apply online data augmentation. 54 | These augmentations are used online within methods.""" 55 | raise NotImplementedError 56 | 57 | def apply_data_augment_offline(self, *args, **kwargs): 58 | r"""Apply offline data augmentation. 59 | These augmentations are used offline in the trainer.""" 60 | return None 61 | 62 | def apply_emb_augment(self, *args, **kwargs) -> Any: 63 | r"""Apply online embedding augmentation.""" 64 | raise NotImplementedError 65 | 66 | def get_loss(self, *args, **kwargs) -> Tensor: 67 | r"""Get loss.""" 68 | raise NotImplementedError 69 | 70 | def save(self, path: Optional[str] = None) -> None: 71 | r"""Save the parameters of the entire model to the specified path.""" 72 | if path is None: 73 | path = self.__class__.__name__ + ".ckpt" 74 | path = os.path.join(self.save_root, path) 75 | torch.save(self.state_dict(), path) 76 | 77 | def load(self, path: Optional[str] = None) -> None: 78 | r"""Load the parameters from the specified path.""" 79 | if path is None: 80 | raise FileNotFoundError 81 | state_dict = torch.load(path) 82 | self.model.load_state_dict(state_dict) 83 | 84 | @property 85 | def device(self): 86 | return self._device 87 | 88 | @device.setter 89 | def device(self, value): 90 | self._device = value 91 | self.to(self.device) 92 | -------------------------------------------------------------------------------- /src/methods/bgrl.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from .base import BaseMethod 4 | import numpy as np 5 | # from src.augment import DataAugmentation, AugNegDGI, AugPosDGI 6 | from src.loader import Loader, FullLoader 7 | from .base import BaseMethod 8 | from .utils import EMA, update_moving_average 9 | from torch_geometric.nn.models import GCN 10 | from torch_geometric.typing import * 11 | import torch.nn.functional as F 12 | from src.loader import AugmentDataLoader 13 | import torch.nn as nn 14 | from torch_geometric.nn import GCNConv 15 | 16 | def loss_fn(x, y): 17 | x = F.normalize(x, dim=-1, p=2) 18 | y = F.normalize(y, dim=-1, p=2) 19 | return 2 - 2 * (x * y).sum(dim=-1) 20 | 21 | 22 | def init_weights(m): 23 | if type(m) == nn.Linear: 24 | torch.nn.init.xavier_uniform_(m.weight) 25 | m.bias.data.fill_(0.01) 26 | 27 | def set_requires_grad(model, val): 28 | for p in model.parameters(): 29 | p.requires_grad = val 30 | 31 | 32 | class BGRL(BaseMethod): 33 | r"""The full model to train the encoder. 34 | 35 | Args: 36 | encoder (torch.nn.Module): the encoder to be trained. 37 | discriminator (torch.nn.Module): the discriminator for contrastive learning. 38 | """ 39 | def __init__(self, student_encoder: torch.nn.Module, teacher_encoder: torch.nn.Module, data_augment = None, pred_dim=None): 40 | super().__init__(encoder=student_encoder, data_augment=data_augment, loss_function=loss_fn) 41 | self.encoder = student_encoder 42 | self.teacher_encoder = teacher_encoder 43 | self.data_augment = data_augment 44 | self.loss_function = loss_fn 45 | set_requires_grad(self.teacher_encoder, False) 46 | rep_dim = self.encoder.hidden_channels 47 | if pred_dim==None: 48 | pred_dim = rep_dim*2 49 | # self.student_predictor = nn.Sequential(nn.Linear(rep_dim, pred_dim), nn.PReLU(), nn.Linear(pred_dim, rep_dim)) 50 | self.student_predictor = nn.Sequential(nn.Linear(rep_dim, rep_dim)) 51 | self.student_predictor.apply(init_weights) 52 | 53 | # def forward(self, x1: Tensor, x2: Tensor, adj1: Adj, adj2: Adj, is_sparse: bool = True): 54 | def forward(self, batch): 55 | batch_1, batch_2 = batch 56 | h_1 = self.encoder(batch_1, batch_1.edge_index) 57 | h_2 = self.encoder(batch_2, batch_2.edge_index) 58 | 59 | h_1_pred = self.student_predictor(h_1) 60 | h_2_pred = self.student_predictor(h_2) 61 | 62 | with torch.no_grad(): 63 | v_1 = self.teacher_encoder(batch_1, batch_1.edge_index) 64 | v_2 = self.teacher_encoder(batch_2, batch_2.edge_index) 65 | loss1 = self.loss_function(h_1_pred, v_2.detach()) 66 | loss2 = self.loss_function(h_2_pred, v_1.detach()) 67 | 68 | loss = loss1 + loss2 69 | return loss.mean() 70 | 71 | def apply_data_augment_offline(self, dataloader): 72 | batch_list = [] 73 | for i, batch in enumerate(dataloader): 74 | batch = batch.to(self._device) 75 | batch_aug_1 = self.data_augment["augment_1"](batch) 76 | batch_aug_2 = self.data_augment["augment_2"](batch) 77 | batch_list.append((batch_aug_1, batch_aug_2)) 78 | new_loader = AugmentDataLoader(batch_list=batch_list) 79 | return new_loader 80 | 81 | def get_embs(self, data): 82 | return self.encoder(data, data.edge_index).detach() 83 | 84 | 85 | class BGRLEncoder(nn.Module): 86 | 87 | def __init__(self, in_channel, hidden_channels, **kwargs): 88 | super().__init__() 89 | self.hidden_channels = hidden_channels[-1] 90 | self.conv1 = GCNConv(in_channel, hidden_channels[0]) 91 | self.bn1 = nn.BatchNorm1d(hidden_channels[0], momentum = 0.01) 92 | self.prelu1 = nn.PReLU() 93 | self.num_layer = len(hidden_channels) 94 | self.conv_list, self.bn_list, self.act_list = nn.ModuleList([]), nn.ModuleList([]), nn.ModuleList([]) 95 | for i in range(self.num_layer-1): 96 | self.conv_list.append(GCNConv(hidden_channels[i],hidden_channels[i+1])) 97 | self.bn_list.append(nn.BatchNorm1d(hidden_channels[i+1], momentum = 0.01)) 98 | self.act_list.append(nn.PReLU()) 99 | 100 | def forward(self, data, edge_index, edge_weight=None): 101 | x = data.x 102 | x = self.conv1(x, edge_index, edge_weight=edge_weight) 103 | x = self.prelu1(self.bn1(x)) 104 | for i in range(self.num_layer-1): 105 | x = self.conv_list[i](x, edge_index, edge_weight=edge_weight) 106 | x = self.act_list[i](self.bn_list[i](x)) 107 | return x 108 | 109 | 110 | class BGRLEncoder_old(torch.nn.Module): 111 | def __init__(self, 112 | in_channel: int, 113 | hidden_channels: int = 512, 114 | act: torch.nn = torch.nn.PReLU(), 115 | num_layers=1): 116 | super(BGRLEncoder_old, self).__init__() 117 | self.hidden_channels=hidden_channels 118 | self.gcn = GCN(in_channels=in_channel, hidden_channels=hidden_channels, num_layers=num_layers, act=act) 119 | self.act = act 120 | for m in self.modules(): 121 | self._weights_init(m) 122 | 123 | def _weights_init(self, m): 124 | if isinstance(m, torch.nn.Linear): 125 | torch.nn.init.xavier_uniform_(m.weight.data) 126 | if m.bias is not None: 127 | m.bias.data.fill_(0.0) 128 | 129 | def forward(self, batch, edge_index, is_sparse=True): 130 | edge_weight = batch.edge_weight if "edge_weight" in batch else None 131 | return self.act(self.gcn(x=batch.x, edge_index=edge_index, edge_weight=edge_weight)) 132 | -------------------------------------------------------------------------------- /src/methods/dgi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn.models import GCN 3 | 4 | from .base import BaseMethod 5 | from src.augment import ShuffleNode, SumEmb 6 | from src.losses import NegativeMI 7 | from typing import Callable 8 | 9 | 10 | class DGI(BaseMethod): 11 | r""" Deep Graph Infomax (DGI). 12 | 13 | Args: 14 | encoder (Optional[torch.nn.Module]): the encoder to be trained. 15 | hidden_channels (int): output dimension of the encoder. 16 | readout (str): "avg" or "max": generate global embeddings. 17 | (default: "avg") 18 | """ 19 | def __init__(self, 20 | encoder: torch.nn.Module, 21 | hidden_channels: int, 22 | readout: str="avg", 23 | readout_act: Callable=torch.nn.Sigmoid()) -> None: 24 | super().__init__(encoder=encoder) 25 | 26 | self.readout = SumEmb(readout, readout_act) 27 | self.corrupt = ShuffleNode() 28 | self.loss_func = NegativeMI(in_channels=hidden_channels) 29 | 30 | def apply_data_augment(self, batch): 31 | batch = batch.to(self.device) 32 | batch2 = self.corrupt(batch).to(self.device) 33 | return batch, batch2 34 | 35 | def apply_emb_augment(self, h_pos): 36 | s = self.readout(h_pos, keepdim=True) 37 | return s 38 | 39 | def get_loss(self, h_pos, h_neg, s): 40 | s = s.expand_as(h_pos) 41 | loss = self.loss_func(x=s, y=h_pos, x_ind=s, y_ind=h_neg) 42 | return loss 43 | 44 | def forward(self, batch): 45 | # 1. data augmentation 46 | batch, batch2 = self.apply_data_augment(batch) 47 | 48 | # 2. get embeddings 49 | h_pos = self.encoder(batch) 50 | h_neg = self.encoder(batch2) 51 | 52 | # 3. emb augmentation 53 | s = self.apply_emb_augment(h_pos) 54 | 55 | # 4. get loss 56 | loss = self.get_loss(h_pos, h_neg, s) 57 | return loss 58 | 59 | def get_embs(self, data): 60 | return self.encoder(data).detach() 61 | 62 | 63 | class DGIEncoder(torch.nn.Module): 64 | def __init__(self, in_channels, hidden_channels=512, num_layers=1, act=torch.nn.PReLU()): 65 | super().__init__() 66 | self.gcn = GCN(in_channels=in_channels, hidden_channels=hidden_channels, num_layers=num_layers, act=act) 67 | self.act = act 68 | 69 | def forward(self, batch): 70 | edge_weight = batch.edge_weight if "edge_weight" in batch else None 71 | return self.act(self.gcn(x=batch.x, edge_index=batch.edge_index, edge_weight=edge_weight)) -------------------------------------------------------------------------------- /src/methods/dgi_bk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base import BaseMethod, ContrastiveMethod 4 | from .utils import AvgReadout 5 | from src.augment import ShuffleNode, Echo, AugmentorDict 6 | from src.losses import NegativeMI 7 | 8 | from typing import Optional, Callable, Union 9 | from src.typing import AugmentType 10 | 11 | 12 | class DGI(BaseMethod): 13 | r""" Deep Graph Infomax (DGI). 14 | 15 | Args: 16 | encoder (Optional[torch.nn.Module]): the encoder to be trained. 17 | hidden_channels (int): output dimension of the encoder. 18 | readout (Union[Callable, torch.nn.Module]): the readout function to obtain the summary emb. of the entire graph. 19 | (default: AvgReadout()) 20 | corruption (OptAugment): data augmentation/corruption to generate negative node pairs. 21 | (default: ShuffleNode()) 22 | loss_function (Optional[torch.nn.Module]): the loss function. If None, then use the NegativeMI loss. 23 | """ 24 | def __init__(self, 25 | encoder: torch.nn.Module, 26 | hidden_channels: int, 27 | readout: Union[Callable, torch.nn.Module] = AvgReadout(), 28 | corruption: AugmentType = ShuffleNode(), 29 | loss_function: Optional[torch.nn.Module] = None) -> None: 30 | loss_function = loss_function if loss_function else NegativeMI(hidden_channels) 31 | super().__init__(encoder=encoder, data_augment=corruption, loss_function=loss_function) 32 | 33 | self.readout = readout 34 | self.sigmoid = torch.nn.Sigmoid() 35 | 36 | def forward(self, batch): 37 | batch = batch.to(self.device) 38 | batch2 = self.data_augment(batch).to(self._device) 39 | 40 | h_pos = self.encoder(batch) 41 | h_neg = self.encoder(batch2) 42 | 43 | s = self.readout(h_pos, keepdim=True) 44 | s = self.sigmoid(s) 45 | 46 | s = s.expand_as(h_pos) 47 | # pos_pairs = torch.stack([s, h_pos], -1) 48 | # neg_pairs = torch.stack([s, h_neg], -1) 49 | 50 | loss = self.loss_function(x=s, y=h_pos, x_ind=s, y_ind=h_neg) 51 | return loss 52 | -------------------------------------------------------------------------------- /src/methods/gca.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from torch_geometric.nn import GCNConv 8 | 9 | from .base import BaseMethod 10 | 11 | 12 | class GCA_Encoder(torch.nn.Module): 13 | def __init__(self, in_channels: int, out_channels: int, activation, base_model=GCNConv, k: int = 2, skip=False): 14 | super(GCA_Encoder, self).__init__() 15 | self.base_model = base_model 16 | 17 | assert k >= 2 18 | self.k = k 19 | self.skip = skip 20 | if not self.skip: 21 | self.conv = [base_model(in_channels, 2 * out_channels).jittable()] 22 | for _ in range(1, k - 1): 23 | self.conv.append(base_model(2 * out_channels, 2 * out_channels)) 24 | self.conv.append(base_model(2 * out_channels, out_channels)) 25 | self.conv = nn.ModuleList(self.conv) 26 | 27 | self.activation = activation 28 | else: 29 | self.fc_skip = nn.Linear(in_channels, out_channels) 30 | self.conv = [base_model(in_channels, out_channels)] 31 | for _ in range(1, k): 32 | self.conv.append(base_model(out_channels, out_channels)) 33 | self.conv = nn.ModuleList(self.conv) 34 | 35 | self.activation = activation 36 | 37 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor): 38 | if not self.skip: 39 | for i in range(self.k): 40 | x = self.activation(self.conv[i](x, edge_index)) 41 | return x 42 | else: 43 | h = self.activation(self.conv[0](x, edge_index)) 44 | hs = [self.fc_skip(x), h] 45 | for i in range(1, self.k): 46 | u = sum(hs) 47 | hs.append(self.activation(self.conv[i](u, edge_index))) 48 | return hs[-1] 49 | 50 | 51 | # class GRACE(torch.nn.Module): 52 | class GRACE(BaseMethod): 53 | def __init__(self, 54 | # encoder: GCA_Encoder, 55 | encoder: torch.nn.Module, 56 | loss_function: None, 57 | num_hidden: int, 58 | num_proj_hidden: int, 59 | tau: float = 0.5): 60 | 61 | super(GRACE, self).__init__(encoder=encoder, loss_function=loss_function) 62 | # self.encoder: GCA_Encoder = encoder 63 | self.encoder = encoder 64 | self.tau: float = tau 65 | 66 | self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden) 67 | self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden) 68 | 69 | self.num_hidden = num_hidden 70 | 71 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: 72 | return self.encoder(x, edge_index) 73 | 74 | def projection(self, z: torch.Tensor) -> torch.Tensor: 75 | z = F.elu(self.fc1(z)) 76 | return self.fc2(z) 77 | 78 | def sim(self, z1: torch.Tensor, z2: torch.Tensor): 79 | z1 = F.normalize(z1) 80 | z2 = F.normalize(z2) 81 | return torch.mm(z1, z2.t()) 82 | 83 | def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor): 84 | f = lambda x: torch.exp(x / self.tau) 85 | refl_sim = f(self.sim(z1, z1)) 86 | between_sim = f(self.sim(z1, z2)) 87 | 88 | return -torch.log(between_sim.diag() / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag())) 89 | 90 | def batched_semi_loss(self, z1: torch.Tensor, z2: torch.Tensor, batch_size: int): 91 | # Space complexity: O(BN) (semi_loss: O(N^2)) 92 | device = z1.device 93 | num_nodes = z1.size(0) 94 | num_batches = (num_nodes - 1) // batch_size + 1 95 | f = lambda x: torch.exp(x / self.tau) 96 | indices = torch.arange(0, num_nodes).to(device) 97 | losses = [] 98 | 99 | for i in range(num_batches): 100 | mask = indices[i * batch_size:(i + 1) * batch_size] 101 | refl_sim = f(self.sim(z1[mask], z1)) # [B, N] 102 | between_sim = f(self.sim(z1[mask], z2)) # [B, N] 103 | 104 | losses.append(-torch.log(between_sim[:, i * batch_size:(i + 1) * batch_size].diag() 105 | / (refl_sim.sum(1) + between_sim.sum(1) 106 | - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag()))) 107 | 108 | return torch.cat(losses) 109 | 110 | def loss(self, z1: torch.Tensor, z2: torch.Tensor, mean: bool = True, batch_size: Optional[int] = None): 111 | h1 = self.projection(z1) 112 | h2 = self.projection(z2) 113 | 114 | if batch_size is None: 115 | l1 = self.semi_loss(h1, h2) 116 | l2 = self.semi_loss(h2, h1) 117 | else: 118 | l1 = self.batched_semi_loss(h1, h2, batch_size) 119 | l2 = self.batched_semi_loss(h2, h1, batch_size) 120 | 121 | ret = (l1 + l2) * 0.5 122 | ret = ret.mean() if mean else ret.sum() 123 | 124 | return ret 125 | 126 | def get_embs(self, data): 127 | return self.encoder(data.x, data.edge_index).detach() 128 | 129 | 130 | class LogReg(nn.Module): 131 | def __init__(self, ft_in, nb_classes): 132 | super(LogReg, self).__init__() 133 | self.fc = nn.Linear(ft_in, nb_classes) 134 | 135 | for m in self.modules(): 136 | self.weights_init(m) 137 | 138 | def weights_init(self, m): 139 | if isinstance(m, nn.Linear): 140 | torch.nn.init.xavier_uniform_(m.weight.data) 141 | if m.bias is not None: 142 | m.bias.data.fill_(0.0) 143 | 144 | def forward(self, seq): 145 | ret = self.fc(seq) 146 | return ret 147 | -------------------------------------------------------------------------------- /src/methods/graphcl.py: -------------------------------------------------------------------------------- 1 | from .base import BaseMethod 2 | from torch_geometric.typing import * 3 | from src.augment import RandomMask, ShuffleNode 4 | from typing import Optional, Callable, Union 5 | from src.typing import AugmentType 6 | from .utils import AvgReadout 7 | from src.losses import NegativeMI 8 | from src.loader import AugmentDataLoader 9 | from torch_geometric.nn.models import GCN 10 | from torch_geometric.nn import GCNConv 11 | import torch 12 | 13 | 14 | class GraphCL(BaseMethod): 15 | r""" 16 | TODO: add descriptions 17 | """ 18 | def __init__(self, 19 | encoder: torch.nn.Module, 20 | hidden_channels: int, 21 | readout: Union[Callable, torch.nn.Module] = AvgReadout(), 22 | corruption: AugmentType = RandomMask(), 23 | loss_function: Optional[torch.nn.Module] = None) -> None: 24 | super().__init__(encoder=encoder, data_augment=corruption, loss_function=loss_function) 25 | 26 | self.readout = readout 27 | self.sigmoid = torch.nn.Sigmoid() 28 | self.data_augment = corruption 29 | self.loss_function = loss_function if loss_function else NegativeMI(hidden_channels) 30 | 31 | def forward(self, batch): 32 | pos_batch, neg_batch, neg_batch2 = batch 33 | h_pos = self.encoder(pos_batch, pos_batch.adj_t) 34 | if self.augment_type == 'edge': 35 | # h_neg1 = self.encoder(pos_batch, neg_batch.adj_t.to(pos_batch.batch.device)) 36 | # h_neg2 = self.encoder(pos_batch, neg_batch2.adj_t.to(pos_batch.batch.device)) 37 | h_neg1 = self.encoder(pos_batch, neg_batch.edge_index) 38 | h_neg2 = self.encoder(pos_batch, neg_batch2.edge_index) 39 | elif self.augment_type == 'mask': 40 | # h_neg1 = self.encoder(neg_batch, pos_batch.adj_t) 41 | # h_neg2 = self.encoder(neg_batch2, pos_batch.adj_t) 42 | h_neg1 = self.encoder(neg_batch, pos_batch.edge_index) 43 | h_neg2 = self.encoder(neg_batch2, pos_batch.edge_index) 44 | elif self.augment_type == 'node' or self.augment_type == 'subgraph': 45 | # h_neg1 = self.encoder(neg_batch, neg_batch.adj_t.to(pos_batch.batch.device)) 46 | # h_neg2 = self.encoder(neg_batch2, neg_batch2.adj_t.to(pos_batch.batch.device)) 47 | h_neg1 = self.encoder(neg_batch, neg_batch.edge_index) 48 | h_neg2 = self.encoder(neg_batch2, neg_batch2.edge_index) 49 | else: 50 | assert False 51 | loss = self.get_loss(h_pos, h_neg1, h_neg2, pos_batch) 52 | return loss 53 | 54 | def apply_data_augment_offline(self, dataloader): 55 | batch_list = [] 56 | for i, batch in enumerate(dataloader): 57 | batch_aug = self.data_augment(batch) 58 | batch_aug2 = self.data_augment(batch) 59 | batch_list.append((batch, batch_aug, batch_aug2)) 60 | new_loader = AugmentDataLoader(batch_list=batch_list) 61 | return new_loader 62 | 63 | def apply_data_augment(self, batch): 64 | raise NotImplementedError 65 | 66 | def apply_emb_augment(self, h_pos): 67 | raise NotImplementedError 68 | 69 | def get_loss(self, h_pos, h_neg1, h_neg2, pos_batch): 70 | s1 = self.readout(h_neg1, keepdim=True) 71 | s1 = self.sigmoid(s1) 72 | s2 = self.readout(h_neg2, keepdim=True) 73 | s2 = self.sigmoid(s2) 74 | s1 = s1.expand_as(h_pos) 75 | s2 = s2.expand_as(h_pos) 76 | augmentation = ShuffleNode() 77 | neg_batch3 = augmentation(pos_batch).to(self._device) 78 | h_neg = self.encoder(neg_batch3, pos_batch.adj_t) 79 | 80 | loss1 = self.loss_function(x=s1, y=h_pos, x_ind=s1, y_ind=h_neg) 81 | loss2 = self.loss_function(x=s2, y=h_pos, x_ind=s2, y_ind=h_neg) 82 | return loss1 + loss2 83 | 84 | def get_embs(self, data): 85 | return self.encoder(data, data.adj_t).detach() 86 | 87 | 88 | class GraphCLEncoder(torch.nn.Module): 89 | def __init__(self, 90 | in_channels: int, 91 | hidden_channels: int = 512, 92 | act: torch.nn = torch.nn.PReLU(), 93 | num_layers=1, 94 | bias=True): 95 | super(GraphCLEncoder, self).__init__() 96 | self.dim_out = hidden_channels 97 | self.act = act 98 | 99 | if bias: 100 | self.bias = torch.nn.Parameter(torch.FloatTensor(hidden_channels)) 101 | self.bias.data.fill_(0.0) 102 | else: 103 | self.register_parameter('bias', None) 104 | 105 | for m in self.modules(): 106 | self._weights_init(m) 107 | self.fc = torch.nn.Linear(in_channels, hidden_channels, bias=False) 108 | self.gcn = GCNConv(in_channels=in_channels, out_channels=hidden_channels, bias=True) 109 | 110 | def _weights_init(self, m): 111 | if isinstance(m, torch.nn.Linear): 112 | torch.nn.init.xavier_uniform_(m.weight.data) 113 | if m.bias is not None: 114 | m.bias.data.fill_(0.0) 115 | 116 | def forward(self, batch, edge_index, is_sparse=True): 117 | x = batch.x 118 | adj = batch.adj_t.to_dense().to(x.device) 119 | seq_fts = self.fc(x) 120 | if is_sparse: 121 | out = torch.mm(adj, torch.squeeze(seq_fts, 0)) 122 | else: 123 | out = torch.bmm(adj, seq_fts) 124 | if self.bias is not None: 125 | out += self.bias 126 | 127 | return self.act(out) -------------------------------------------------------------------------------- /src/methods/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class EMA: 5 | def __init__(self, beta, epochs): 6 | super().__init__() 7 | self.beta = beta 8 | self.step = 0 9 | self.total_steps = epochs 10 | 11 | def update_average(self, old, new): 12 | if old is None: 13 | return new 14 | beta = 1 - (1 - self.beta) * (np.cos(np.pi * self.step / self.total_steps) + 1) / 2.0 15 | self.step += 1 16 | return old * beta + (1 - beta) * new 17 | 18 | 19 | 20 | def update_moving_average(ema_updater, ma_model, current_model): 21 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 22 | old_weight, up_weight = ma_params.data, current_params.data 23 | ma_params.data = ema_updater.update_average(old_weight, up_weight) -------------------------------------------------------------------------------- /src/methods/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .ema import EMA, update_moving_average 2 | from .readout import AvgReadout 3 | -------------------------------------------------------------------------------- /src/methods/utils/ema.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO: merge EMA class with update_moving_average() function. 3 | """ 4 | 5 | import numpy as np 6 | 7 | 8 | class EMA: 9 | def __init__(self, beta, epochs): 10 | super().__init__() 11 | self.beta = beta 12 | self.step = 0 13 | self.total_steps = epochs 14 | 15 | def update_average(self, old, new): 16 | if old is None: 17 | return new 18 | beta = 1 - (1 - self.beta) * (np.cos(np.pi * self.step / self.total_steps) + 1) / 2.0 19 | self.step += 1 20 | return old * beta + (1 - beta) * new 21 | 22 | 23 | def update_moving_average(ema_updater, ma_model, current_model): 24 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 25 | old_weight, up_weight = ma_params.data, current_params.data 26 | ma_params.data = ema_updater.update_average(old_weight, up_weight) -------------------------------------------------------------------------------- /src/methods/utils/readout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.typing import Tensor 3 | 4 | """ 5 | To be removed. 6 | """ 7 | 8 | class AvgReadout(object): 9 | def __init__(self) -> None: 10 | return 11 | 12 | def __call__(self, x: Tensor, keepdim: bool = False): 13 | return torch.mean(x, -2, keepdim=keepdim) 14 | -------------------------------------------------------------------------------- /src/similarity_functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .simple_functions import * 2 | from .bilinear import Bilinear 3 | 4 | __all__ = [ 5 | "dot", 6 | "cosine", 7 | "distance", 8 | "Bilinear", 9 | ] -------------------------------------------------------------------------------- /src/similarity_functions/bilinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Bilinear(nn.Module): 6 | def __init__(self, in_channels: int=512): 7 | super().__init__() 8 | self.bilinear = nn.Bilinear(in_channels, in_channels, 1) 9 | 10 | def forward(self, x, y): 11 | """ 12 | Args: 13 | x: [N, *, dim] 14 | y: [N, *, dim] 15 | * means all but the last dimension should be the same. 16 | 17 | Returns: 18 | similarity scores of (x, y): [N, *, 1] 19 | """ 20 | scores = self.bilinear(x, y) 21 | return scores 22 | -------------------------------------------------------------------------------- /src/similarity_functions/simple_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def dot(x, y): 6 | """ 7 | Args: 8 | x: [batch, n1, dim] 9 | y: [batch, n2, dim] 10 | 11 | Returns: 12 | [batch, n1, n2] 13 | """ 14 | assert len(x.shape) == 3 and len(x.shape) == 3, \ 15 | "The shape of x and y should be " 16 | return torch.bmm(x, torch.transpose(y, 1, 2)) 17 | 18 | 19 | def cosine(x, y): 20 | """ 21 | Args: 22 | x: [batch, n1, dim] 23 | y: [batch, n2, dim] 24 | 25 | Returns: 26 | [batch, n1, n2] 27 | """ 28 | x = F.normalize(x, dim=-1) 29 | y = F.normalize(y, dim=-1) 30 | return dot(x, y) 31 | 32 | 33 | def distance(x, y, p=2): 34 | """ 35 | Args: 36 | x: [batch, n1, dim] 37 | y: [batch, n2, dim] 38 | p: the order 39 | 40 | Returns: 41 | [batch, n1, n2] 42 | """ 43 | x = x.contiguous() 44 | y = y.contiguous() 45 | return -torch.cdist(x, y, p=p) 46 | -------------------------------------------------------------------------------- /src/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseTrainer 2 | from .simple_trainer import SimpleTrainer 3 | from .non_contrast_trainer import NonContrastTrainer 4 | -------------------------------------------------------------------------------- /src/trainer/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Any, Union 3 | 4 | import torch 5 | from torch_geometric.loader import DataLoader 6 | from src.methods import BaseMethod 7 | 8 | 9 | METHOD_NAME = "method.ckpt" 10 | 11 | 12 | class BaseTrainer(object): 13 | def __init__(self, 14 | method: BaseMethod, 15 | data_loader: DataLoader, 16 | save_root: str = "./ckpt", 17 | device: Union[str, int] = "cpu") -> None: 18 | r"""Base class for self-supervised learning methods. 19 | 20 | Args: 21 | method (torch.nn.Module): the entire method, including encoders and other components (e.g. discriminators). 22 | data_loader (Loader): 23 | save_root (str): the root to save the method/encoder. 24 | device (str): set the device. 25 | """ 26 | self.method = method 27 | self.data_loader = data_loader 28 | 29 | self.save_root = save_root 30 | self.device = device 31 | self.method.device = self.device 32 | 33 | if not os.path.exists(self.save_root): 34 | os.makedirs(self.save_root) 35 | 36 | def train(self, *args, **kwargs) -> Any: 37 | r"""Train the encoder.""" 38 | raise NotImplementedError 39 | 40 | def save(self, path: Optional[str] = None) -> None: 41 | r"""Save the parameters of the method.""" 42 | if path is None: 43 | path = os.path.join(self.save_root, METHOD_NAME) 44 | self.method.save(path) 45 | 46 | def load(self, path: Optional[str] = None) -> None: 47 | r"""Load the parameters of the method.""" 48 | if path is None: 49 | path = os.path.join(self.save_root, METHOD_NAME) 50 | self.method.load(path) 51 | -------------------------------------------------------------------------------- /src/trainer/gca_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch_geometric.loader import DataLoader 4 | from src.methods import BaseMethod 5 | from .base import BaseTrainer 6 | from .utils import EarlyStopper 7 | from typing import Union 8 | 9 | from src.augment.gca_augments import drop_edge_weighted, drop_feature, drop_feature_weighted_2 10 | from src.augment.gca_augments import degree_drop_weights, pr_drop_weights, evc_drop_weights 11 | from src.augment.gca_augments import feature_drop_weights, feature_drop_weights_dense 12 | from src.augment.gca_augments import compute_pr, eigenvector_centrality 13 | 14 | from torch_geometric.utils import dropout_adj, degree, to_undirected 15 | 16 | 17 | class GCATrainer(BaseTrainer): 18 | r""" 19 | TODO: 1. Add descriptions. 20 | 2. Do we need to support more arguments? 21 | """ 22 | def __init__(self, 23 | method: BaseMethod, 24 | data_loader: DataLoader, 25 | lr: float = 0.001, 26 | weight_decay: float = 0.0, 27 | n_epochs: int = 5000, 28 | patience: int = 50, 29 | drop_scheme: str = 'degree', 30 | dataset_name: str = 'WikiCS', 31 | device: Union[str, int] = "cuda:1", 32 | save_root: str = "./ckpt"): 33 | super().__init__(method=method, 34 | data_loader=data_loader, 35 | save_root=save_root, 36 | device=device) 37 | self.device = device 38 | 39 | self.drop_scheme = drop_scheme 40 | 41 | self.dataset_name = dataset_name 42 | 43 | self.optimizer = torch.optim.Adam(self.method.parameters(), lr, weight_decay=weight_decay) 44 | 45 | self.n_epochs = n_epochs 46 | self.patience = patience 47 | 48 | self.early_stopper = EarlyStopper(patience=self.patience) 49 | 50 | def train(self): 51 | self.method = self.method.to(self.device) 52 | # self.data_loader = self.data_loader.to(self.device) 53 | 54 | 55 | for epoch in range(self.n_epochs): 56 | start_time = time.time() 57 | 58 | self.method.train() 59 | self.optimizer.zero_grad() 60 | 61 | for data in self.data_loader: 62 | # ---> 2. augment 63 | 64 | # ------> 2.1 drop_weights 65 | data = data.to(self.device) 66 | 67 | if self.drop_scheme == 'degree': 68 | drop_weights = degree_drop_weights(data.edge_index).to(self.device) 69 | elif self.drop_scheme == 'pr': 70 | drop_weights = pr_drop_weights(data.edge_index, aggr='sink', k=200).to(self.device) 71 | elif self.drop_scheme == 'evc': 72 | drop_weights = evc_drop_weights(data).to(self.device) 73 | else: 74 | drop_weights = None 75 | 76 | # ------> 2.2 feature_weights 77 | if self.drop_scheme == 'degree': 78 | edge_index_ = to_undirected(data.edge_index) 79 | node_deg = degree(edge_index_[1]) 80 | if self.dataset_name == 'WikiCS': 81 | feature_weights = feature_drop_weights_dense(data.x, node_c=node_deg).to(self.device) 82 | else: 83 | feature_weights = feature_drop_weights(data.x, node_c=node_deg).to(self.device) 84 | elif self.drop_scheme == 'pr': 85 | node_pr = compute_pr(data.edge_index) 86 | if self.dataset_name == 'WikiCS': 87 | feature_weights = feature_drop_weights_dense(data.x, node_c=node_pr).to(self.device) 88 | else: 89 | feature_weights = feature_drop_weights(data.x, node_c=node_pr).to(self.device) 90 | elif self.drop_scheme == 'evc': 91 | node_evc = eigenvector_centrality(data) 92 | if self.dataset_name == 'WikiCS': 93 | feature_weights = feature_drop_weights_dense(data.x, node_c=node_evc).to(self.device) 94 | else: 95 | feature_weights = feature_drop_weights(data.x, node_c=node_evc).to(self.device) 96 | else: 97 | feature_weights = torch.ones((data.x.size(1),)).to(self.device) 98 | 99 | 100 | def drop_edge(idx: int): 101 | # global drop_weights 102 | 103 | if idx == 1: 104 | drop_edge_rate = 0.3 105 | else: 106 | drop_edge_rate = 0.4 107 | 108 | if self.drop_scheme == 'uniform': 109 | return dropout_adj(data.edge_index, p=drop_edge_rate)[0] 110 | elif self.drop_scheme in ['degree', 'evc', 'pr']: 111 | return drop_edge_weighted(data.edge_index, drop_weights, p=drop_edge_rate, threshold=0.7) 112 | else: 113 | raise Exception(f'undefined drop scheme: {self.drop_scheme}') 114 | 115 | drop_feature_rate_1 = 0.1 116 | drop_feature_rate_2 = 0.0 117 | 118 | edge_index_1 = drop_edge(1) 119 | edge_index_2 = drop_edge(2) 120 | x_1 = drop_feature(data.x, drop_feature_rate_1) 121 | x_2 = drop_feature(data.x, drop_feature_rate_2) 122 | 123 | if self.drop_scheme in ['pr', 'degree', 'evc']: 124 | x_1 = drop_feature_weighted_2(data.x, feature_weights, drop_feature_rate_1) 125 | x_2 = drop_feature_weighted_2(data.x, feature_weights, drop_feature_rate_2) 126 | 127 | z1 = self.method(x_1, edge_index_1) 128 | z2 = self.method(x_2, edge_index_2) 129 | 130 | loss = self.method.loss(z1, z2, batch_size=1024) 131 | loss.backward() 132 | self.optimizer.step() 133 | 134 | # print loss 135 | end_time = time.time() 136 | print("Epoch {}: loss: {:.4f}, time: {:.4f}s".format(epoch, loss, end_time - start_time)) 137 | 138 | self.early_stopper.update(loss) # update the status 139 | if self.early_stopper.save: 140 | self.save() 141 | if self.early_stopper.stop: 142 | return 143 | 144 | -------------------------------------------------------------------------------- /src/trainer/non_contrast_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch_geometric.loader import DataLoader 4 | from src.methods import BaseMethod 5 | from .base import BaseTrainer 6 | from .utils import EarlyStopper 7 | from typing import Union 8 | import numpy as np 9 | from src.methods.utils import EMA, update_moving_average 10 | from src.evaluation import LogisticRegression 11 | from tqdm import tqdm 12 | 13 | class NonContrastTrainer(BaseTrainer): 14 | r""" 15 | TODO: 1. Add descriptions. 16 | 2. Do we need to support more arguments? 17 | """ 18 | def __init__(self, 19 | method: BaseMethod, 20 | data_loader: DataLoader, 21 | lr: float = 0.001, 22 | weight_decay: float = 0.0, 23 | n_epochs: int = 10000, 24 | patience: int = 50, 25 | device: Union[str, int] = "cuda:0", 26 | use_ema: bool = False, 27 | moving_average_decay: float = 0.9, 28 | save_root: str = "./ckpt", 29 | dataset=None): 30 | super().__init__(method=method, 31 | data_loader=data_loader, 32 | save_root=save_root, 33 | device=device) 34 | # if config: 35 | # self.optimizer = torch.optim.Adam(self.method.parameters(), lr, weight_decay=config.optim.weight_decay) 36 | 37 | self.optimizer = torch.optim.AdamW(self.method.parameters(), lr, weight_decay=weight_decay) 38 | self.dataset = dataset 39 | self.n_epochs = n_epochs 40 | self.patience = patience 41 | self.device = device 42 | # scheduler = lambda epoch: epoch / 1000 if epoch < 1000 \ 43 | # else ( 1 + np.cos((epoch-1000) * np.pi / (n_epochs - 1000))) * 0.5 44 | # self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda = scheduler) 45 | self.use_ema = use_ema 46 | if self.use_ema: 47 | self.ema_updater = EMA(moving_average_decay, n_epochs) 48 | self.early_stopper = EarlyStopper(patience=self.patience) 49 | 50 | def train(self): 51 | self.method = self.method.to(self.device) 52 | new_loader = self.method.apply_data_augment_offline(self.data_loader) 53 | if new_loader != None: 54 | self.data_loader = new_loader 55 | for epoch in tqdm(range(self.n_epochs)): 56 | start_time = time.time() 57 | 58 | for data in self.data_loader: 59 | self.method.train() 60 | self.optimizer.zero_grad() 61 | 62 | loss = self.method(data) 63 | 64 | loss.backward() 65 | self.optimizer.step() 66 | # self.scheduler.step() 67 | if self.use_ema: 68 | update_moving_average(self.ema_updater, self.method.teacher_encoder, self.method.encoder) 69 | 70 | end_time = time.time() 71 | info = "Epoch {}: loss: {:.4f}, time: {:.4f}s".format(epoch, loss.detach().cpu().numpy(), end_time-start_time) 72 | 73 | if (epoch+1)%20==0: 74 | print(info) 75 | self.method.eval() 76 | data_pyg = self.dataset.data.to(self.method.device) 77 | embs = self.method.get_embs(data_pyg).detach() 78 | lg = LogisticRegression(lr=0.01, weight_decay=0, max_iter=100, n_run=5, device=self.device) 79 | lg(embs=embs, dataset=data_pyg) 80 | self.early_stopper.update(loss) # update the status 81 | if self.early_stopper.save: 82 | self.save() 83 | if self.early_stopper.stop: 84 | return 85 | -------------------------------------------------------------------------------- /src/trainer/simple_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch_geometric.loader import DataLoader 4 | from src.methods import BaseMethod 5 | from .base import BaseTrainer 6 | from .utils import EarlyStopper 7 | from typing import Union 8 | import numpy as np 9 | from src.methods.utils import EMA, update_moving_average 10 | from src.evaluation import LogisticRegression 11 | from tqdm import tqdm 12 | 13 | class SimpleTrainer(BaseTrainer): 14 | r""" 15 | TODO: 1. Add descriptions. 16 | 2. Do we need to support more arguments? 17 | """ 18 | def __init__(self, 19 | method: BaseMethod, 20 | data_loader: DataLoader, 21 | lr: float = 0.001, 22 | weight_decay: float = 0.0, 23 | n_epochs: int = 10000, 24 | patience: int = 50, 25 | device: Union[str, int] = "cuda:0", 26 | save_root: str = "./ckpt", 27 | dataset=None): 28 | super().__init__(method=method, 29 | data_loader=data_loader, 30 | save_root=save_root, 31 | device=device) 32 | # if config: 33 | # self.optimizer = torch.optim.Adam(self.method.parameters(), lr, weight_decay=config.optim.weight_decay) 34 | 35 | self.optimizer = torch.optim.AdamW(self.method.parameters(), lr, weight_decay=weight_decay) 36 | self.dataset = dataset 37 | self.n_epochs = n_epochs 38 | self.patience = patience 39 | self.device = device 40 | # scheduler = lambda epoch: epoch / 1000 if epoch < 1000 \ 41 | # else ( 1 + np.cos((epoch-1000) * np.pi / (n_epochs - 1000))) * 0.5 42 | # self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda = scheduler) 43 | self.early_stopper = EarlyStopper(patience=self.patience) 44 | 45 | def train(self): 46 | self.method = self.method.to(self.device) 47 | new_loader = self.method.apply_data_augment_offline(self.data_loader) 48 | if new_loader != None: 49 | self.data_loader = new_loader 50 | for epoch in range(self.n_epochs): 51 | start_time = time.time() 52 | 53 | for data in self.data_loader: 54 | self.method.train() 55 | self.optimizer.zero_grad() 56 | 57 | data = self.push_batch_to_device(data) 58 | loss = self.method(data) 59 | 60 | 61 | loss.backward() 62 | self.optimizer.step() 63 | # self.scheduler.step() 64 | 65 | end_time = time.time() 66 | info = "Epoch {}: loss: {:.4f}, time: {:.4f}s".format(epoch, loss.detach().cpu().numpy(), end_time-start_time) 67 | print(info) 68 | 69 | # # ------------------ Evaluator ------------------- 70 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 71 | # data_pyg = self.data_loader.dataset.data.to(device) 72 | # y, embs = self.method.get_embs(self.data_loader) 73 | # data_pyg.x = embs 74 | # from src.evaluation import LogisticRegression 75 | # lg = LogisticRegression(lr=0.001, weight_decay=0, max_iter=100, n_run=1, device=device) 76 | # lg(embs=embs, dataset=data_pyg) 77 | 78 | 79 | self.early_stopper.update(loss) # update the status 80 | if self.early_stopper.save: 81 | self.save() 82 | if self.early_stopper.stop: 83 | return 84 | 85 | # push data to device 86 | def push_batch_to_device(self, batch): 87 | if type(batch) is tuple: 88 | f = lambda x: tuple(x_.to(self.device) for x_ in batch) 89 | return f(batch) 90 | else: 91 | return batch.to(self.device) 92 | 93 | def check_dataloader(self, dataloader): 94 | assert hasattr(dataloader, 'x'), 'The dataset does not have attributes x.' 95 | # assert hasattr(dataloader, 'train_mask'), 'T' 96 | # return 0 97 | -------------------------------------------------------------------------------- /src/trainer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .early_stop import EarlyStopper 2 | -------------------------------------------------------------------------------- /src/trainer/utils/early_stop.py: -------------------------------------------------------------------------------- 1 | class EarlyStopper: 2 | r""" Early stopping. This class has two status 'stop' and 'save'. Use the update function to update the status. 3 | stop: Whether stop the training. When the maximum patience is reached, the stop=True; else, stop=False. 4 | save: Whether save the model. When the current_value is better than the historical best value, then save=True. 5 | 6 | Args: 7 | patience (int): The maximum tolerance. 8 | (default: 20) 9 | type (str): Stop at the maximum ('max') or minimum ('min') value. 10 | (default: 'max') 11 | """ 12 | def __init__(self, patience: int = 20, type: str = "min") -> None: 13 | self.patience = patience 14 | assert type in {'max', 'min'} 15 | self.type = type 16 | 17 | self.stop = False 18 | self.save = False 19 | 20 | self._best_value = None 21 | self._n_wait = 0 22 | 23 | def update(self, current_value) -> None: 24 | r"""Update the status `save` and `stop`.""" 25 | if self._best_value is None: # the first call of the method. 26 | self._best_value = current_value 27 | self._n_wait = 0 28 | return 29 | 30 | if self.type == "min" and current_value < self._best_value \ 31 | or self.type == "max" and current_value > self._best_value: 32 | self._n_wait = 0 33 | self._best_value = current_value 34 | self.save = True 35 | else: 36 | self._n_wait += 1 37 | self.save = False 38 | 39 | if self._n_wait == self.patience: 40 | print('Early stopped!') 41 | self.stop = True 42 | -------------------------------------------------------------------------------- /src/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.transforms import * 2 | 3 | from .normalize_features import NormalizeFeatures 4 | from .gcn_norm import GCNNorm 5 | from .edge2adj import Edge2Adj 6 | -------------------------------------------------------------------------------- /src/transforms/base.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable 2 | 3 | from torch_geometric.transforms import BaseTransform 4 | 5 | 6 | class TransformList(BaseTransform): 7 | r"""A list of transform functions.""" 8 | def __init__(self, transform_list: List[Callable]): 9 | self.transform_list = transform_list 10 | 11 | def __call__(self, data): 12 | for transform in self.transform_list: 13 | data = transform(data) 14 | return data 15 | -------------------------------------------------------------------------------- /src/transforms/edge2adj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_sparse import SparseTensor 3 | 4 | from torch_geometric.data.datapipes import functional_transform 5 | from torch_geometric.transforms import BaseTransform 6 | 7 | from typing import Optional, Callable 8 | 9 | 10 | @functional_transform('edge_to_adj_ssl') 11 | class Edge2Adj(BaseTransform): 12 | r"""Convert edge to adjacency matrix. 13 | 14 | Args: 15 | norm (Optional[Callable]): Normalization of the adjacency matrix. 16 | (default: None) 17 | None: Generate 0/1 adjacency matrix from edge_index & edge_attr/edge_weight. 18 | """ 19 | def __init__(self, norm: Optional[Callable] = None): 20 | self.norm = norm 21 | 22 | def __call__(self, data): 23 | assert 'edge_index' in data 24 | 25 | edge_weight = data.edge_attr 26 | if 'edge_weight' in data: 27 | edge_weight = data.edge_weight 28 | 29 | if edge_weight is None: 30 | adj_t = torch.sparse_coo_tensor(data.edge_index, torch.ones_like(data.edge_index[0]), 31 | [data.num_nodes, data.num_nodes], dtype=torch.float) 32 | else: 33 | adj_t = torch.sparse_coo_tensor(data.edge_index, data.edge_weight, 34 | [data.num_nodes, data.num_nodes], dtype=torch.float) 35 | data.adj_t = SparseTensor.from_torch_sparse_coo_tensor(adj_t) 36 | if self.norm is not None: 37 | data = self.norm(data) 38 | return data 39 | -------------------------------------------------------------------------------- /src/transforms/gcn_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_sparse import SparseTensor, fill_diag, mul, set_diag 3 | from torch_sparse import sum as sparsesum 4 | from torch_scatter import scatter_add 5 | 6 | from torch_geometric.data.datapipes import functional_transform 7 | from torch_geometric.transforms import BaseTransform 8 | from torch_geometric.utils.num_nodes import maybe_num_nodes 9 | from torch_geometric.utils import add_remaining_self_loops 10 | 11 | # based on torch_geometric.nn.conv.gcn_conv & torch_geometric.transforms.gcn_norm.py 12 | # from torch_geometric.transforms.gcn_norm import GCNNorm 13 | 14 | 15 | @functional_transform('gcn_norm_ssl') 16 | class GCNNorm(BaseTransform): 17 | r"""Applies the GCN normalization from the `"Semi-supervised Classification 18 | with Graph Convolutional Networks" `_ 19 | paper (functional name: :obj:`gcn_norm`). 20 | 21 | .. math:: 22 | \mathbf{\hat{A}} = \mathbf{\hat{D}}^{-1/2} (\mathbf{A} + \mathbf{I}) 23 | \mathbf{\hat{D}}^{-1/2} 24 | 25 | where :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij} + 1`. 26 | """ 27 | def __init__(self, add_self_loops: int = 1): 28 | self.add_self_loops = add_self_loops 29 | 30 | def __call__(self, data): 31 | assert 'edge_index' in data or 'adj_t' in data 32 | 33 | if 'edge_index' in data: 34 | edge_weight = data.edge_attr 35 | if 'edge_weight' in data: 36 | edge_weight = data.edge_weight 37 | data.edge_index, data.edge_weight = gcn_norm( 38 | data.edge_index, edge_weight, data.num_nodes, 39 | add_self_loops=self.add_self_loops) 40 | if 'adj_t' in data: # It is possible that the data has both edge_index and adj_t 41 | data.adj_t = gcn_norm(data.adj_t, add_self_loops=self.add_self_loops) 42 | 43 | return data 44 | 45 | 46 | def gcn_norm(edge_index, edge_weight=None, num_nodes=None, add_self_loops=1., flow="source_to_target", dtype=None): 47 | r"""Directly use add_self_loop to indicate the number of self_loops.""" 48 | 49 | if isinstance(edge_index, SparseTensor): 50 | assert flow in ["source_to_target"] 51 | adj_t = edge_index 52 | if not adj_t.has_value(): 53 | adj_t = adj_t.fill_value(1., dtype=dtype) 54 | # adj_t = fill_diag(adj_t, add_self_loops) this seems incorrect 55 | diag_val = adj_t.get_diag() + add_self_loops 56 | adj_t = set_diag(adj_t, diag_val) 57 | deg = sparsesum(adj_t, dim=1) 58 | deg_inv_sqrt = deg.pow_(-0.5) 59 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) 60 | adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1)) 61 | adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) 62 | return adj_t 63 | 64 | else: 65 | assert flow in ["source_to_target", "target_to_source"] 66 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 67 | 68 | if edge_weight is None: 69 | edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) 70 | 71 | edge_index, tmp_edge_weight = add_remaining_self_loops(edge_index, edge_weight, add_self_loops, num_nodes) 72 | assert tmp_edge_weight is not None 73 | edge_weight = tmp_edge_weight 74 | 75 | row, col = edge_index[0], edge_index[1] 76 | idx = col if flow == "source_to_target" else row 77 | deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes) 78 | deg_inv_sqrt = deg.pow_(-0.5) 79 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) 80 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 81 | -------------------------------------------------------------------------------- /src/transforms/normalize_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data, HeteroData 4 | from torch_geometric.data.datapipes import functional_transform 5 | from torch_geometric.transforms import BaseTransform 6 | 7 | from typing import List, Union, Optional, Any 8 | from torch_geometric.typing import Tensor 9 | 10 | # based on torch_geometric.transforms.normalize_features.py 11 | # from torch_geometric.transforms.normalize_features import NormalizeFeatures 12 | 13 | 14 | @functional_transform('normalize_features_ssl') 15 | class NormalizeFeatures(BaseTransform): 16 | r"""Row-normalizes the attributes given in :obj:`attrs` to sum-up to one 17 | (functional name: :obj:`normalize_features`). 18 | 19 | Args: 20 | attrs (List[str]): The names of attributes to normalize. 21 | (default: :obj:`["x"]`) 22 | ord (Optional[str, int, float]): The order of normalization. 23 | (default: "sum1") 24 | sum1: x = x - min(x), x / sum(x) 25 | other supported normalization: 26 | https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm 27 | """ 28 | 29 | ord = "sum1" 30 | 31 | def __init__(self, attrs: List[str] = ["x"], ord: Optional[Any] = None): 32 | self.attrs = attrs 33 | if ord is not None: 34 | self.ord = ord 35 | 36 | def __call__(self, data: Union[Data, HeteroData]): 37 | for store in data.stores: 38 | for key, value in store.items(*self.attrs): 39 | if self.ord == "sum1": 40 | store[key] = sum1(value) 41 | else: 42 | denominator = torch.linalg.norm(value, ord=self.ord, dim=-1, keepdim=True) 43 | value = value.div_(denominator) 44 | value[torch.isinf(value)] = 0. 45 | store[key] = value 46 | return data 47 | 48 | def __repr__(self) -> str: 49 | return f'{self.__class__.__name__}()' 50 | 51 | 52 | def sum1(value: Tensor): 53 | value = value - value.min() 54 | value.div_(value.sum(dim=-1, keepdim=True).clamp_(min=1.)) 55 | return value 56 | -------------------------------------------------------------------------------- /src/typing.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.typing import * 2 | from src.augment import Augmentor, AugmentorList, AugmentorDict 3 | 4 | AugmentType = Union[Augmentor, AugmentorList, AugmentorDict] 5 | OptAugment = Optional[AugmentType] 6 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .bilinear_sim import BilinearSim 2 | from .add_adj import add_adj_t 3 | -------------------------------------------------------------------------------- /src/utils/add_adj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.typing import SparseTensor 3 | 4 | 5 | def add_adj_t(dataset): 6 | batch = dataset.data 7 | if not hasattr(batch, 'adj_t'): 8 | value = torch.ones_like(batch.edge_index[0], dtype=torch.float) 9 | batch.adj_t = SparseTensor( 10 | row=batch.edge_index[1], 11 | col=batch.edge_index[0], 12 | value=value, 13 | sparse_sizes=(batch.x.shape[0], batch.x.shape[0]), 14 | is_sorted=True, 15 | trust_data=True, 16 | ) 17 | return dataset -------------------------------------------------------------------------------- /src/utils/bilinear_sim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.typing import Tensor 3 | 4 | 5 | class BilinearSim(torch.nn.Module): 6 | r"""Bilinear similarity between two tensors.""" 7 | def __init__(self, in_channels: int = 512): 8 | super().__init__() 9 | self.sim = torch.nn.Bilinear(in_channels, in_channels, 1) 10 | 11 | def forward(self, x: Tensor, y: Tensor): 12 | r""" 13 | Args: 14 | x: [batch_size, dim] 15 | y: [batch_size, dim] 16 | 17 | Returns: 18 | similarity scores of (x, y): [batch_size, 1] 19 | """ 20 | logits = self.sim(x, y) 21 | return logits 22 | -------------------------------------------------------------------------------- /src/utils/create_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch_geometric.datasets import Planetoid, Coauthor, Amazon, WikiCS 4 | 5 | def create_masks(data, data_name=None): 6 | """ 7 | Splits data into training, validation, and test splits in a stratified manner if 8 | it is not already splitted. Each split is associated with a mask vector, which 9 | specifies the indices for that split. The data will be modified in-place 10 | :param data: Data object 11 | :return: The modified data 12 | """ 13 | if not hasattr(data, "val_mask"): 14 | 15 | data.train_mask = data.dev_mask = data.test_mask = None 16 | 17 | for i in range(20): 18 | labels = data.y.numpy() 19 | dev_size = int(labels.shape[0] * 0.1) 20 | test_size = int(labels.shape[0] * 0.8) 21 | 22 | perm = np.random.permutation(labels.shape[0]) 23 | test_index = perm[:test_size] 24 | dev_index = perm[test_size:test_size + dev_size] 25 | 26 | data_index = np.arange(labels.shape[0]) 27 | test_mask = torch.tensor(np.in1d(data_index, test_index), dtype=torch.bool) 28 | dev_mask = torch.tensor(np.in1d(data_index, dev_index), dtype=torch.bool) 29 | train_mask = ~(dev_mask + test_mask) 30 | test_mask = test_mask.reshape(1, -1) 31 | dev_mask = dev_mask.reshape(1, -1) 32 | train_mask = train_mask.reshape(1, -1) 33 | 34 | if not hasattr(data, "train_mask"): 35 | data.train_mask = train_mask 36 | data.val_mask = dev_mask 37 | data.test_mask = test_mask 38 | else: 39 | data.train_mask = torch.cat((data.train_mask, train_mask), dim=0) 40 | data.val_mask = torch.cat((data.val_mask, dev_mask), dim=0) 41 | data.test_mask = torch.cat((data.test_mask, test_mask), dim=0) 42 | 43 | else: # in the case of WikiCS 44 | data.train_mask = data.train_mask.T 45 | data.val_mask = data.val_mask.T 46 | 47 | return data 48 | import os 49 | import os.path as osp 50 | def create_dirs(dirs): 51 | for dir_tree in dirs: 52 | sub_dirs = dir_tree.split("/") 53 | path = "" 54 | for sub_dir in sub_dirs: 55 | path = osp.join(path, sub_dir) 56 | os.makedirs(path, exist_ok=True) 57 | 58 | 59 | def decide_config(root, name): 60 | """ 61 | Create a configuration to download datasets 62 | :param root: A path to a root directory where data will be stored 63 | :param name: The name of the dataset to be downloaded 64 | :return: A modified root dir, the name of the dataset class, and parameters associated to the class 65 | """ 66 | name = name.lower() 67 | if name == 'cora' or name == 'citeseer' or name == "pubmed": 68 | root = osp.join(root, "pyg", "planetoid") 69 | params = {"kwargs": {"root": root, "name": name}, 70 | "name": name, "class": Planetoid, "src": "pyg"} 71 | elif name == "computers": 72 | name = "Computers" 73 | root = osp.join(root, "pyg") 74 | params = {"kwargs": {"root": root, "name": name}, 75 | "name": name, "class": Amazon, "src": "pyg"} 76 | elif name == "photo": 77 | name = "Photo" 78 | root = osp.join(root, "pyg") 79 | params = {"kwargs": {"root": root, "name": name}, 80 | "name": name, "class": Amazon, "src": "pyg"} 81 | elif name == "cs" : 82 | name = "CS" 83 | root = osp.join(root, "pyg") 84 | params = {"kwargs": {"root": root, "name": name}, 85 | "name": name, "class": Coauthor, "src": "pyg"} 86 | elif name == "physics": 87 | name = "Physics" 88 | root = osp.join(root, "pyg") 89 | params = {"kwargs": {"root": root, "name": name}, 90 | "name": name, "class": Coauthor, "src": "pyg"} 91 | elif name == "wikics": 92 | name = "WikiCS" 93 | root = osp.join(root, "pyg") 94 | params = {"kwargs": {"root": root}, 95 | "name": name, "class": WikiCS, "src": "pyg"} 96 | else: 97 | raise Exception( 98 | f"Unknown dataset name {name}, name has to be one of the following 'cora', 'citeseer', 'pubmed', 'photo', 'computers', 'cs', 'physics' 'wikics'") 99 | return params --------------------------------------------------------------------------------