├── GNN ├── __init__.py ├── model │ ├── __init__.py │ ├── utils.py │ ├── GNN_arg.py │ ├── Dataloader.py │ └── GNN_library.py ├── readme.md ├── GNN_Link_AUC.py ├── GNN_Node.py ├── GNN_Link_MRR.py ├── GNN_Link_HitsK.py ├── GNN_Link_MRR_loader.py └── GNN_Link_HitsK_loader.py ├── assets ├── logo.png ├── dataset.png ├── example.png └── feature.png ├── run_gnn_node.sh ├── run_gnn_link.sh ├── license ├── run_emb.sh ├── .gitignore ├── requirements.txt ├── data_preprocess ├── generate_llm_emb.py └── dataset.ipynb └── README.md /GNN/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /GNN/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhuofeng-Li/TEG-Benchmark/HEAD/assets/logo.png -------------------------------------------------------------------------------- /assets/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhuofeng-Li/TEG-Benchmark/HEAD/assets/dataset.png -------------------------------------------------------------------------------- /assets/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhuofeng-Li/TEG-Benchmark/HEAD/assets/example.png -------------------------------------------------------------------------------- /assets/feature.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhuofeng-Li/TEG-Benchmark/HEAD/assets/feature.png -------------------------------------------------------------------------------- /run_gnn_node.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python GNN/GNN_Node.py \ 4 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \ 5 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \ 6 | --graph_path Dataset/goodreads_children/processed/children.pkl 7 | 8 | # python GNN/GNN_Node.py \ 9 | # --use_PLM_node Dataset/twitter/emb/twitter_bert_base_uncased_512_cls_node.pt \ 10 | # --use_PLM_edge Dataset/twitter/emb/twitter_bert_base_uncased_512_cls_edge.pt \ 11 | # --graph_path Dataset/twitter/processed/twitter.pkl 12 | 13 | 14 | -------------------------------------------------------------------------------- /run_gnn_link.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define datasets 4 | datasets=( 5 | # "goodreads_children" 6 | # "twitter" 7 | # "goodreads_history" 8 | "goodreads_crime" 9 | ) 10 | 11 | # Define models to run 12 | models=( 13 | "GNN_Link_HitsK_loader.py" 14 | # "GNN_Link_MRR_loader.py" 15 | # "GNN_Link_AUC.py" 16 | ) 17 | 18 | # Loop through datasets and models 19 | for dataset in "${datasets[@]}"; do 20 | for model in "${models[@]}"; do 21 | echo "Running $model on $dataset dataset" 22 | python GNN/$model \ 23 | --use_PLM_node Dataset/$dataset/emb/${dataset##*_}_bert_base_uncased_512_cls_node.pt \ 24 | --use_PLM_edge Dataset/$dataset/emb/${dataset##*_}_bert_base_uncased_512_cls_edge.pt \ 25 | --path Dataset/$dataset/LinkPrediction/ \ 26 | --graph_path Dataset/$dataset/processed/${dataset##*_}.pkl \ 27 | --batch_size 1024 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 TAG-Benchmark Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /GNN/model/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from ogb.linkproppred import * 4 | from ogb.nodeproppred import * 5 | 6 | from dgl.data import CitationGraphDataset 7 | 8 | 9 | def load_graph(name): 10 | cite_graphs = ["cora", "citeseer", "pubmed"] 11 | 12 | if name in cite_graphs: 13 | dataset = CitationGraphDataset(name) 14 | graph = dataset[0] 15 | 16 | nodes = graph.nodes() 17 | y = graph.ndata["label"] 18 | train_mask = graph.ndata["train_mask"] 19 | val_mask = graph.ndata["test_mask"] 20 | 21 | nodes_train, y_train = nodes[train_mask], y[train_mask] 22 | nodes_val, y_val = nodes[val_mask], y[val_mask] 23 | eval_set = [(nodes_train, y_train), (nodes_val, y_val)] 24 | 25 | elif name.startswith("ogbn"): 26 | 27 | dataset = DglNodePropPredDataset(name) 28 | graph, y = dataset[0] 29 | split_nodes = dataset.get_idx_split() 30 | nodes = graph.nodes() 31 | 32 | train_idx = split_nodes["train"] 33 | val_idx = split_nodes["valid"] 34 | 35 | nodes_train, y_train = nodes[train_idx], y[train_idx] 36 | nodes_val, y_val = nodes[val_idx], y[val_idx] 37 | eval_set = [(nodes_train, y_train), (nodes_val, y_val)] 38 | 39 | else: 40 | raise ValueError("Dataset name error!") 41 | 42 | return graph, eval_set 43 | 44 | 45 | def parse_arguments(): 46 | """ 47 | Parse arguments 48 | """ 49 | parser = argparse.ArgumentParser(description="Node2vec") 50 | parser.add_argument("--dataset", type=str, default="cora") 51 | # 'train' for training node2vec model, 'time' for testing speed of random walk 52 | parser.add_argument("--task", type=str, default="train") 53 | parser.add_argument("--runs", type=int, default=10) 54 | parser.add_argument("--device", type=str, default="cpu") 55 | parser.add_argument("--embedding_dim", type=int, default=128) 56 | parser.add_argument("--walk_length", type=int, default=50) 57 | parser.add_argument("--p", type=float, default=0.25) 58 | parser.add_argument("--q", type=float, default=4.0) 59 | parser.add_argument("--num_walks", type=int, default=10) 60 | parser.add_argument("--epochs", type=int, default=100) 61 | parser.add_argument("--batch_size", type=int, default=128) 62 | 63 | args = parser.parse_args() 64 | 65 | return args -------------------------------------------------------------------------------- /run_emb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #!/bin/bash 3 | 4 | # # Define an array of dataset configurations 5 | # datasets=( 6 | # "goodreads_children:children" 7 | # "goodreads_history:history" 8 | # "goodreads_comics:comics" 9 | # "goodreads_crime:crime" 10 | # # "twitter:twitter" 11 | # # "amazon_movie:movie" 12 | # # "amazon_baby:baby" 13 | # # "reddit:reddit" 14 | # ) 15 | 16 | # # Loop through each dataset 17 | # for dataset in "${datasets[@]}"; do 18 | # # Split the dataset string into folder and name 19 | # IFS=":" read -r folder name <<< "$dataset" 20 | 21 | # echo "Processing dataset: $folder" 22 | 23 | # python data_preprocess/generate_llm_emb.py \ 24 | # --pkl_file "Dataset/$folder/processed/${name}.pkl" \ 25 | # --path "Dataset/$folder/emb" \ 26 | # --name "$name" \ 27 | # --model_name meta-llama/Meta-Llama-3-8B \ 28 | # --max_length 2048 \ 29 | # --batch_size 25 30 | # done 31 | 32 | 33 | # python data_preprocess/generate_llm_emb.py \ 34 | # --pkl_file Dataset/arxiv/processed/arxiv.pkl \ 35 | # --path Dataset/arxiv/emb \ 36 | # --name arxiv \ 37 | # --model_name bert-base-uncased \ 38 | 39 | # python data_preprocess/generate_llm_emb.py \ 40 | # --pkl_file Dataset/arxiv/processed/arxiv.pkl \ 41 | # --path Dataset/arxiv/emb \ 42 | # --name arxiv \ 43 | # --model_name bert-large-uncased \ 44 | # --batch_size 250 45 | 46 | # python data_preprocess/generate_llm_emb.py \ 47 | # --pkl_file Dataset/twitter/processed/twitter.pkl \ 48 | # --path Dataset/twitter/emb \ 49 | # --name tweets \ 50 | # --model_name bert-base-uncased \ 51 | 52 | python data_preprocess/generate_llm_emb.py \ 53 | --pkl_file Dataset/twitter/processed/twitter.pkl \ 54 | --path Dataset/twitter/emb \ 55 | --name tweets \ 56 | --model_name bert-base-uncased \ 57 | 58 | python data_preprocess/generate_llm_emb.py \ 59 | --pkl_file Dataset/twitter/processed/twitter.pkl \ 60 | --path Dataset/twitter/emb \ 61 | --name tweets \ 62 | --model_name bert-large-uncased \ 63 | --batch_size 250 64 | 65 | # python data_preprocess/generate_llm_emb.py \ 66 | # --pkl_file Dataset/amazon_movie/processed/movie.pkl \ 67 | # --path Dataset/amazon_movie/emb \ 68 | # --name movie \ 69 | # --model_name bert-base-uncased \ 70 | 71 | 72 | # python data_preprocess/generate_llm_emb.py \ 73 | # --pkl_file Dataset/amazon_baby/processed/baby.pkl \ 74 | # --path Dataset/amazon_baby/emb \ 75 | # --name baby \ 76 | # --model_name bert-base-uncased \ 77 | 78 | # python data_preprocess/generate_llm_emb.py \ 79 | # --pkl_file Dataset/reddit/processed/reddit.pkl \ 80 | # --path Dataset/reddit/emb \ 81 | # --name reddit \ 82 | # --model_name bert-large-uncased \ 83 | -------------------------------------------------------------------------------- /GNN/readme.md: -------------------------------------------------------------------------------- 1 | ## Edge-aware GNN Link Prediction 2 | 3 | ### MRR Metric 4 | You can utilize following codes to do full batch training on MRR link prediction. 5 | ``` 6 | python GNN/GNN_Link_MRR.py \ 7 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \ 8 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \ 9 | --path Dataset/goodreads_children/LinkPrediction/ \ 10 | --graph_path Dataset/goodreads_children/processed/children.pkl 11 | ``` 12 | 13 | ### MRR Metric using neighbor sampling 14 | You can utilize following codes to handle large-scale graph MRR link prediction. 15 | 16 | ``` 17 | python GNN/GNN_Link_MRR_loader.py \ 18 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \ 19 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \ 20 | --path Dataset/goodreads_children/LinkPrediction/ \ 21 | --graph_path Dataset/goodreads_children/processed/children.pkl 22 | ``` 23 | 24 | ### HitsK Metric 25 | You can utilize following codes to do full batch training on HitsK link prediction. 26 | ``` 27 | python GNN/GNN_Link_HitsK.py \ 28 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \ 29 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \ 30 | --path Dataset/goodreads_children/LinkPrediction/ \ 31 | --graph_path Dataset/goodreads_children/processed/children.pkl 32 | ``` 33 | 34 | ### HitsK Metric using neighbor sampling 35 | You can utilize following codes to handle large-scale graph HitsK link prediction. 36 | 37 | ``` 38 | python GNN/GNN_Link_HitsK_loader.py \ 39 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \ 40 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \ 41 | --path Dataset/goodreads_children/LinkPrediction/ \ 42 | --graph_path Dataset/goodreads_children/processed/children.pkl 43 | ``` 44 | 45 | ### AUC Metric using neighbor sampling 46 | You can utilize following codes to handle large-scale graph AUC link prediction. 47 | 48 | ``` 49 | python GNN/GNN_Link_AUC.py \ 50 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \ 51 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \ 52 | --graph_path Dataset/goodreads_children/processed/children.pkl 53 | ``` 54 | 55 | ## Edge-aware GNN Node Classification 56 | 57 | ``` 58 | python GNN/GNN_Node.py \ 59 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \ 60 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \ 61 | --graph_path Dataset/goodreads_children/processed/children.pkl 62 | ``` 63 | 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | GNN/__pycache__/ 4 | GNN/wandb/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | venv/ 15 | ENV/ 16 | env.bak/ 17 | venv.bak/ 18 | bin/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Jupyter Notebook 57 | .ipynb_checkpoints 58 | 59 | # IPython 60 | profile_default/ 61 | ipython_config.py 62 | 63 | # pyenv 64 | .python-version 65 | 66 | # Environments 67 | .env 68 | .venv 69 | env/ 70 | venv/ 71 | ENV/ 72 | env.bak/ 73 | venv.bak/ 74 | 75 | # Spyder project settings 76 | .spyderproject 77 | .spyproject 78 | 79 | # Rope project settings 80 | .ropeproject 81 | 82 | # mkdocs documentation 83 | /site 84 | 85 | # mypy 86 | .mypy_cache/ 87 | 88 | # Pyre type checker 89 | .pyre/ 90 | 91 | # pytype static type analyzer 92 | .pytype/ 93 | 94 | # Cython debug symbols 95 | cython_debug/ 96 | 97 | # Django stuff: 98 | *.log 99 | local_settings.py 100 | db.sqlite3 101 | db.sqlite3-journal 102 | 103 | # Flask stuff: 104 | instance/ 105 | .webassets-cache 106 | 107 | # Scrapy stuff: 108 | .scrapy 109 | 110 | # Sphinx documentation 111 | docs/_build/ 112 | 113 | # PyBuilder 114 | target/ 115 | 116 | # VS Code settings 117 | .vscode/ 118 | .vscode/settings.json 119 | 120 | # Python-specific user settings 121 | python_settings.json 122 | 123 | # Data science 124 | *.csv 125 | *.tsv 126 | *.dat 127 | *.np[yz] 128 | 129 | # Notebook checkpoints 130 | .ipynb_checkpoints 131 | 132 | # VS Code 133 | .vscode/ 134 | 135 | # Visual Studio Code 136 | .vscode/ 137 | 138 | # PyCharm 139 | .idea/ 140 | *.iml 141 | 142 | # Project-specific 143 | *.pkl 144 | *.npy 145 | *.pt 146 | *.pkl 147 | nx_graph.pkl 148 | 149 | # Jupyter 150 | .ipynb_checkpoints 151 | 152 | # Temporary files 153 | *.tmp 154 | *.temp 155 | *.swp 156 | *.swo 157 | *.bak 158 | *~ 159 | 160 | example/linkproppred/*/*_dataset/ 161 | example/nodeproppred/*/*_dataset/ 162 | todo.txt 163 | TAG_Benchmark/example/nodeproppred/citation/node_classification_log/tmp.ipynb 164 | example/linkproppred/*/*/*_dataset/ 165 | example/nodeproppred/*/*/*_dataset/ 166 | Dataset 167 | GNN/GNN_Node.py 168 | wandb/ 169 | tmp.ipynb 170 | tmp_link.sh -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.0.1 2 | aiohappyeyeballs==2.4.3 3 | aiohttp==3.10.10 4 | aiosignal==1.3.1 5 | annotated-types==0.7.0 6 | anyio==4.6.2.post1 7 | argon2-cffi==23.1.0 8 | argon2-cffi-bindings==21.2.0 9 | arrow==1.3.0 10 | asttokens==2.4.1 11 | async-lru==2.0.4 12 | async-timeout==4.0.3 13 | attrs==24.2.0 14 | Automat==20.2.0 15 | babel==2.16.0 16 | bcrypt==3.2.0 17 | beautifulsoup4==4.12.3 18 | bleach==6.1.0 19 | blinker==1.4 20 | certifi==2020.6.20 21 | cffi==1.17.1 22 | chardet==4.0.0 23 | charset-normalizer==3.4.0 24 | click==8.0.3 25 | cloud-init==24.3.1 26 | colorama==0.4.4 27 | comm==0.2.2 28 | command-not-found==0.3 29 | configobj==5.0.6 30 | constantly==15.1.0 31 | contourpy==1.3.0 32 | cryptography==3.4.8 33 | cycler==0.12.1 34 | datasets==3.0.1 35 | dbus-python==1.2.18 36 | debugpy==1.8.7 37 | decorator==5.1.1 38 | defusedxml==0.7.1 39 | dgl==2.4.0+cu124 40 | dill==0.3.8 41 | distro==1.7.0 42 | distro-info==1.1+ubuntu0.2 43 | docker-pycreds==0.4.0 44 | exceptiongroup==1.2.2 45 | executing==2.1.0 46 | fastjsonschema==2.20.0 47 | filelock==3.16.1 48 | fonttools==4.54.1 49 | fqdn==1.5.1 50 | frozenlist==1.4.1 51 | fsspec==2024.6.1 52 | gitdb==4.0.11 53 | GitPython==3.1.43 54 | h11==0.14.0 55 | httpcore==1.0.6 56 | httplib2==0.20.2 57 | httpx==0.27.2 58 | huggingface-hub==0.26.0 59 | hyperlink==21.0.0 60 | idna==3.3 61 | importlib-metadata==4.6.4 62 | incremental==21.3.0 63 | ipykernel==6.29.5 64 | ipython==8.28.0 65 | ipywidgets==8.1.5 66 | isoduration==20.11.0 67 | jedi==0.19.1 68 | jeepney==0.7.1 69 | Jinja2==3.0.3 70 | joblib==1.4.2 71 | json5==0.9.25 72 | jsonpatch==1.32 73 | jsonpointer==2.0 74 | jsonschema==4.23.0 75 | jsonschema-specifications==2024.10.1 76 | jupyter==1.1.1 77 | jupyter-console==6.6.3 78 | jupyter-events==0.10.0 79 | jupyter-lsp==2.2.5 80 | jupyter_client==8.6.3 81 | jupyter_core==5.7.2 82 | jupyter_server==2.14.2 83 | jupyter_server_terminals==0.5.3 84 | jupyterlab==4.2.5 85 | jupyterlab_pygments==0.3.0 86 | jupyterlab_server==2.27.3 87 | jupyterlab_widgets==3.0.13 88 | keyring==23.5.0 89 | kiwisolver==1.4.7 90 | launchpadlib==1.10.16 91 | lazr.restfulclient==0.14.4 92 | lazr.uri==1.0.6 93 | littleutils==0.2.4 94 | MarkupSafe==2.0.1 95 | matplotlib==3.9.2 96 | matplotlib-inline==0.1.7 97 | mistune==3.0.2 98 | more-itertools==8.10.0 99 | mpmath==1.3.0 100 | multidict==6.1.0 101 | multiprocess==0.70.16 102 | nbclient==0.10.0 103 | nbconvert==7.16.4 104 | nbformat==5.10.4 105 | nest-asyncio==1.6.0 106 | netifaces==0.11.0 107 | networkx==3.4.1 108 | notebook==7.2.2 109 | notebook_shim==0.2.4 110 | numpy==2.1.2 111 | nvidia-cublas-cu12==12.1.3.1 112 | nvidia-cuda-cupti-cu12==12.1.105 113 | nvidia-cuda-nvrtc-cu12==12.1.105 114 | nvidia-cuda-runtime-cu12==12.1.105 115 | nvidia-cudnn-cu12==9.1.0.70 116 | nvidia-cufft-cu12==11.0.2.54 117 | nvidia-curand-cu12==10.3.2.106 118 | nvidia-cusolver-cu12==11.4.5.107 119 | nvidia-cusparse-cu12==12.1.0.106 120 | nvidia-nccl-cu12==2.20.5 121 | nvidia-nvjitlink-cu12==12.4.127 122 | nvidia-nvtx-cu12==12.1.105 123 | oauthlib==3.2.0 124 | ogb==1.3.6 125 | outdated==0.2.2 126 | overrides==7.7.0 127 | packaging==24.1 128 | pandas==2.2.3 129 | pandocfilters==1.5.1 130 | parso==0.8.4 131 | pexpect==4.8.0 132 | pillow==11.0.0 133 | platformdirs==4.3.6 134 | prometheus_client==0.21.0 135 | prompt_toolkit==3.0.48 136 | propcache==0.2.0 137 | protobuf==5.28.2 138 | psutil==6.1.0 139 | ptyprocess==0.7.0 140 | pure_eval==0.2.3 141 | pyarrow==17.0.0 142 | pyasn1==0.4.8 143 | pyasn1-modules==0.2.1 144 | pycparser==2.22 145 | pydantic==2.9.2 146 | pydantic_core==2.23.4 147 | pyg-lib==0.4.0+pt24cu124 148 | Pygments==2.18.0 149 | PyGObject==3.42.1 150 | PyHamcrest==2.0.2 151 | PyJWT==2.3.0 152 | pyOpenSSL==21.0.0 153 | pyparsing==2.4.7 154 | pyrsistent==0.18.1 155 | pyserial==3.5 156 | python-apt==2.4.0+ubuntu4 157 | python-dateutil==2.9.0.post0 158 | python-debian==0.1.43+ubuntu1.1 159 | python-json-logger==2.0.7 160 | python-magic==0.4.24 161 | pytz==2022.1 162 | PyYAML==5.4.1 163 | pyzmq==26.2.0 164 | referencing==0.35.1 165 | regex==2024.9.11 166 | requests==2.32.3 167 | rfc3339-validator==0.1.4 168 | rfc3986-validator==0.1.1 169 | rpds-py==0.20.0 170 | safetensors==0.4.5 171 | scikit-learn==1.5.2 172 | scipy==1.14.1 173 | screen-resolution-extra==0.0.0 174 | SecretStorage==3.3.1 175 | Send2Trash==1.8.3 176 | sentry-sdk==2.17.0 177 | service-identity==18.1.0 178 | setproctitle==1.3.3 179 | six==1.16.0 180 | smmap==5.0.1 181 | sniffio==1.3.1 182 | sos==4.5.6 183 | soupsieve==2.6 184 | ssh-import-id==5.11 185 | stack-data==0.6.3 186 | sympy==1.13.1 187 | systemd-python==234 188 | terminado==0.18.1 189 | threadpoolctl==3.5.0 190 | tinycss2==1.3.0 191 | tokenizers==0.20.1 192 | tomli==2.0.2 193 | torch==2.4.0 194 | torch-geometric==2.6.1 195 | torch_cluster==1.6.3+pt24cu124 196 | torch_scatter==2.1.2+pt24cu124 197 | torch_sparse==0.6.18+pt24cu124 198 | torch_spline_conv==1.2.2+pt24cu124 199 | tornado==6.4.1 200 | tqdm==4.66.5 201 | traitlets==5.14.3 202 | transformers==4.45.2 203 | triton==3.0.0 204 | Twisted==22.1.0 205 | types-python-dateutil==2.9.0.20241003 206 | typing_extensions==4.12.2 207 | tzdata==2024.2 208 | ubuntu-pro-client==8001 209 | ufw==0.36.1 210 | unattended-upgrades==0.1 211 | uri-template==1.3.0 212 | urllib3==2.2.3 213 | wadllib==1.3.6 214 | wandb==0.18.5 215 | wcwidth==0.2.13 216 | webcolors==24.8.0 217 | webencodings==0.5.1 218 | websocket-client==1.8.0 219 | widgetsnbextension==4.0.13 220 | xkit==0.0.0 221 | xxhash==3.5.0 222 | yarl==1.15.4 223 | zipp==1.0.0 224 | zope.interface==5.4.0 225 | -------------------------------------------------------------------------------- /data_preprocess/generate_llm_emb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import torch.nn.functional as F 7 | import os 8 | from sklearn.decomposition import PCA 9 | from transformers import ( 10 | AutoTokenizer, 11 | AutoModel, 12 | TrainingArguments, 13 | PreTrainedModel, 14 | Trainer, 15 | ) 16 | from transformers.modeling_outputs import TokenClassifierOutput 17 | from datasets import Dataset 18 | 19 | 20 | class CLSEmbInfModel(PreTrainedModel): 21 | def __init__(self, model): 22 | super().__init__(model.config) 23 | self.encoder = model 24 | 25 | @torch.no_grad() 26 | def forward(self, input_ids, attention_mask): 27 | # Extract outputs from the model3 28 | outputs = self.encoder(input_ids, attention_mask, output_hidden_states=True) 29 | node_cls_emb = outputs.last_hidden_state[:, 0, :] # Last layer 30 | return TokenClassifierOutput(logits=node_cls_emb) 31 | 32 | 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser( 36 | description="Process node and edge text data and save the overall representation as .pt files." 37 | ) 38 | parser.add_argument( 39 | "--pkl_file", 40 | default="twitter/processed/twitter.pkl", 41 | type=str, 42 | help="Path to the Textual-Edge Graph .pkl file", 43 | ) 44 | parser.add_argument( 45 | "--model_name", 46 | type=str, 47 | default="bert-base-uncased", 48 | help="Name or path of the Huggingface model", 49 | ) 50 | parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased") 51 | parser.add_argument( 52 | "--name", type=str, default="twitter", help="Prefix name for the NPY file" 53 | ) 54 | parser.add_argument( 55 | "--path", type=str, default="twitter/emb", help="Path to the .pt File" 56 | ) 57 | parser.add_argument( 58 | "--pretrain_path", type=str, default=None, help="Path to the NPY File" 59 | ) 60 | parser.add_argument( 61 | "--max_length", 62 | type=int, 63 | default=512, 64 | help="Maximum length of the text for language models", 65 | ) 66 | parser.add_argument( 67 | "--batch_size", 68 | type=int, 69 | default=500, 70 | help="Number of batch size for inference", 71 | ) 72 | parser.add_argument("--fp16", type=bool, default=True, help="if fp16") 73 | parser.add_argument( 74 | "--cls", 75 | action="store_true", 76 | default=True, 77 | help="whether use cls token to represent the whole text", 78 | ) 79 | parser.add_argument( 80 | "--nomask", 81 | action="store_true", 82 | help="whether do not use mask to claculate the mean pooling", 83 | ) 84 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for computation") 85 | 86 | 87 | parser.add_argument("--norm", type=bool, default=False, help="nomic use True") 88 | 89 | # 解析命令行参数 90 | args = parser.parse_args() 91 | pkl_file = args.pkl_file 92 | model_name = args.model_name 93 | name = args.name 94 | max_length = args.max_length 95 | batch_size = args.batch_size 96 | 97 | tokenizer_name = args.tokenizer_name 98 | 99 | root_dir = os.path.dirname(os.path.abspath(__file__)) 100 | base_dir = os.path.dirname(root_dir.rstrip("/")) 101 | Feature_path = os.path.join(base_dir, args.path) 102 | cache_path = f"{Feature_path}cache/" 103 | 104 | if not os.path.exists(Feature_path): 105 | os.makedirs(Feature_path) 106 | if not os.path.exists(cache_path): 107 | os.makedirs(cache_path) 108 | print("Embedding Model:", model_name) 109 | 110 | if args.pretrain_path is not None: 111 | output_file = os.path.join( 112 | Feature_path 113 | , name 114 | + "_" 115 | + model_name.split("/")[-1].replace("-", "_") 116 | + "_" 117 | + str(max_length) 118 | + "_" 119 | + "Tuned" 120 | ) 121 | else: 122 | output_file = os.path.join( 123 | Feature_path 124 | , name 125 | + "_" 126 | + model_name.split("/")[-1].replace("-", "_") 127 | + "_" 128 | + str(max_length) 129 | ) 130 | 131 | print("output_file:", output_file) 132 | 133 | pkl_file = os.path.join(base_dir, args.pkl_file) 134 | with open(pkl_file, "rb") as f: 135 | data = pickle.load(f) 136 | text_data = data.text_nodes + data.text_edges 137 | 138 | # Load tokenizer and Model 139 | tokenizer = AutoTokenizer.from_pretrained( 140 | model_name, use_fast=True, trust_remote_code=True 141 | ) 142 | 143 | if tokenizer.pad_token is None: 144 | tokenizer.pad_token = tokenizer.eos_token 145 | 146 | dataset = Dataset.from_dict({"text": text_data}) 147 | 148 | def tokenize_function(examples): 149 | return tokenizer(examples["text"], padding=True, truncation=True, max_length=max_length, return_tensors="pt") 150 | 151 | dataset = dataset.map(tokenize_function, batched=True, batch_size=batch_size) 152 | 153 | if args.pretrain_path is not None: 154 | model = AutoModel.from_pretrained(f"{args.pretrain_path}", attn_implementation="eager") 155 | print("Loading model from the path: {}".format(args.pretrain_path)) 156 | else: 157 | model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="eager") 158 | 159 | model = model.to(args.device) 160 | CLS_Feateres_Extractor = CLSEmbInfModel(model) 161 | CLS_Feateres_Extractor.eval() 162 | 163 | inference_args = TrainingArguments( 164 | output_dir=cache_path, 165 | do_train=False, 166 | do_predict=True, 167 | per_device_eval_batch_size=batch_size, 168 | dataloader_drop_last=False, 169 | dataloader_num_workers=12, 170 | fp16_full_eval=args.fp16, 171 | ) 172 | 173 | # CLS representatoin 174 | if args.cls: 175 | if not os.path.exists(output_file + "_cls_node.pt") or not os.path.exists(output_file + "_cls_edge.pt"): 176 | trainer = Trainer(model=CLS_Feateres_Extractor, args=inference_args) 177 | cls_emb = trainer.predict(dataset) 178 | node_cls_emb = torch.from_numpy(cls_emb.predictions[: len(data.text_nodes)]) 179 | edge_cls_emb = torch.from_numpy(cls_emb.predictions[len(data.text_nodes) :]) 180 | torch.save(node_cls_emb, output_file + "_cls_node.pt") 181 | torch.save(edge_cls_emb, output_file + "_cls_edge.pt") 182 | print("Existing saved to the {}".format(output_file)) 183 | 184 | else: 185 | print("Existing saved CLS") 186 | else: 187 | raise ValueError("others are not defined") 188 | 189 | 190 | if __name__ == "__main__": 191 | main() 192 | -------------------------------------------------------------------------------- /GNN/model/GNN_arg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import wandb 4 | 5 | 6 | def args_init(): 7 | argparser = argparse.ArgumentParser( 8 | "GNN(GCN,GIN,GraphSAGE,GAT) on OGBN-Arxiv/Amazon-Books/and so on. Node Classification tasks", 9 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 10 | ) 11 | argparser.add_argument( 12 | "--cpu", 13 | action="store_true", 14 | help="CPU mode. This option overrides --gpu.", 15 | ) 16 | argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.") 17 | argparser.add_argument( 18 | "--n-runs", type=int, default=3, help="running times" 19 | ) 20 | argparser.add_argument( 21 | "--n-epochs", type=int, default=1000, help="number of epochs" 22 | ) 23 | argparser.add_argument( 24 | "--lr", type=float, default=0.005, help="learning rate" 25 | ) 26 | argparser.add_argument( 27 | "--n-layers", type=int, default=3, help="number of layers" 28 | ) 29 | argparser.add_argument( 30 | "--n-hidden", type=int, default=256, help="number of hidden units" 31 | ) 32 | argparser.add_argument( 33 | "--input-drop", type=float, default=0.1, help="input drop rate" 34 | ) 35 | argparser.add_argument( 36 | "--learning-eps", type=bool, default=True, 37 | help="If True, learn epsilon to distinguish center nodes from neighbors;" 38 | "If False, aggregate neighbors and center nodes altogether." 39 | ) 40 | argparser.add_argument("--wd", type=float, default=0, help="weight decay") 41 | argparser.add_argument( 42 | "--num-mlp-layers", type=int, default=2, help="number of mlp layers" 43 | ) 44 | argparser.add_argument( 45 | "--neighbor-pooling-type", type=str, default='mean', help="how to aggregate neighbors (sum, mean, or max)" 46 | ) 47 | # ! GAT 48 | argparser.add_argument( 49 | "--no-attn-dst", type=bool, default=True, help="Don't use attn_dst." 50 | ) 51 | argparser.add_argument( 52 | "--n-heads", type=int, default=3, help="number of heads" 53 | ) 54 | argparser.add_argument( 55 | "--attn-drop", type=float, default=0.0, help="attention drop rate" 56 | ) 57 | argparser.add_argument( 58 | "--edge-drop", type=float, default=0.0, help="edge drop rate" 59 | ) 60 | # ! SAGE 61 | argparser.add_argument("--aggregator-type", type=str, default="mean", 62 | help="Aggregator type: mean/gcn/pool/lstm") 63 | # ! JKNET 64 | argparser.add_argument( 65 | "--mode", type=str, default='cat', help="the mode of aggregate the feature, 'cat', 'lstm'" 66 | ) 67 | #! Monet 68 | argparser.add_argument( 69 | "--pseudo-dim",type=int,default=2,help="Pseudo coordinate dimensions in GMMConv, 2 and 3", 70 | ) 71 | argparser.add_argument( 72 | "--n-kernels",type=int,default=3,help="Number of kernels in GMMConv layer", 73 | ) 74 | #! APPNP 75 | argparser.add_argument( 76 | "--alpha", type=float, default=0.1, help="Teleport Probability" 77 | ) 78 | argparser.add_argument( 79 | "--k", type=int, default=10, help="Number of propagation steps" 80 | ) 81 | argparser.add_argument( 82 | "--hidden_sizes", type=int, nargs="+", default=[64], help="hidden unit sizes for appnp", 83 | ) 84 | # ! default 85 | argparser.add_argument( 86 | "--log-every", type=int, default=20, help="log every LOG_EVERY epochs" 87 | ) 88 | argparser.add_argument( 89 | "--eval_steps", type=int, default=5, help="eval in every epochs" 90 | ) 91 | argparser.add_argument( 92 | "--use_PLM", type=str, default=None, help="Use LM embedding as feature" 93 | ) 94 | argparser.add_argument( 95 | "--model_name", type=str, default='RevGAT', help="Which GNN be implemented" 96 | ) 97 | argparser.add_argument( 98 | "--dropout", type=float, default=0.5, help="dropout rate" 99 | ) 100 | argparser.add_argument( 101 | "--data_name", type=str, default='ogbn-arxiv', help="The datasets to be implemented." 102 | ) 103 | argparser.add_argument( 104 | "--metric", type=str, default='acc', help="The datasets to be implemented." 105 | ) 106 | # ! Split datasets 107 | argparser.add_argument( 108 | "--train_ratio", type=float, default=0.6, help="training ratio" 109 | ) 110 | argparser.add_argument( 111 | "--val_ratio", type=float, default=0.2, help="training ratio" 112 | ) 113 | return argparser 114 | 115 | class Logger(object): 116 | def __init__(self, runs, info=None): 117 | self.info = info 118 | self.results = [[] for _ in range(runs)] 119 | 120 | def add_result(self, run, result): 121 | assert len(result) == 3 122 | assert run >= 0 and run < len(self.results) 123 | self.results[run].append(result) 124 | 125 | def print_statistics(self, run=None, key='Hits@10'): 126 | if run is not None: 127 | result = 100 * torch.tensor(self.results[run]) 128 | argmax = result[:, 1].argmax().item() 129 | print(f'Run {run + 1:02d}:') 130 | print(f'Highest Train: {result[:, 0].max():.2f}') 131 | print(f'Highest Valid: {result[:, 1].max():.2f}') 132 | print(f' Final Train: {result[argmax, 0]:.2f}') 133 | print(f' Final Test: {result[argmax, 2]:.2f}') 134 | else: 135 | result = 100 * torch.tensor(self.results) 136 | 137 | best_results = [] 138 | for r in result: 139 | train1 = r[:, 0].max().item() 140 | valid = r[:, 1].max().item() 141 | train2 = r[r[:, 1].argmax(), 0].item() 142 | test = r[r[:, 1].argmax(), 2].item() 143 | best_results.append((train1, valid, train2, test)) 144 | 145 | best_result = torch.tensor(best_results) 146 | 147 | print(f'All runs:') 148 | r = best_result[:, 0] 149 | print(f'{key} Highest Train: {r.mean():.2f} ± {r.std():.2f}') 150 | wandb.log({f'{key} Highest Train Acc': float(f'{r.mean():.2f}'), 151 | f'{key} Highest Train Std': float(f'{r.std():.2f}')}) 152 | r = best_result[:, 1] 153 | print(f'{key} Highest Valid: {r.mean():.2f} ± {r.std():.2f}') 154 | wandb.log({f'{key} Highest Valid Acc': float(f'{r.mean():.2f}'), 155 | f'{key} Highest Valid Std': float(f'{r.std():.2f}')}) 156 | r = best_result[:, 2] 157 | print(f'{key} Final Train: {r.mean():.2f} ± {r.std():.2f}') 158 | wandb.log( 159 | {f'{key} Final Train Acc': float(f'{r.mean():.2f}'), f'{key} Final Train Std': float(f'{r.std():.2f}')}) 160 | r = best_result[:, 3] 161 | print(f'{key} Final Test: {r.mean():.2f} ± {r.std():.2f}') 162 | wandb.log( 163 | {f'{key} Final Test Acc': float(f'{r.mean():.2f}'), f'{key} Final Test Std': float(f'{r.std():.2f}')}) 164 | 165 | 166 | -------------------------------------------------------------------------------- /data_preprocess/dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/lizhuofeng/.local/lib/python3.10/site-packages/huggingface_hub/hf_api.py:9620: UserWarning: Warnings while validating metadata in README.md:\n", 13 | "- empty or missing yaml metadata in repo card\n", 14 | " warnings.warn(f\"Warnings while validating metadata in README.md:\\n{message}\")\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "from huggingface_hub import HfApi\n", 20 | "api = HfApi()\n", 21 | "\n", 22 | "api.upload_folder(\n", 23 | " folder_path=\"../Dataset\",\n", 24 | " repo_id=\"ZhuofengLi/TEG-Datasets\",\n", 25 | " repo_type=\"dataset\",\n", 26 | "\tignore_patterns=[\"**/logs/*.txt\", \"**/.ipynb_checkpoints\", \"**/__pycache__\", \"**/.cache\", \"**/embcache\", \"**/LinkPrediction\"]\n", 27 | ")" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 1, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "application/vnd.jupyter.widget-view+json": { 38 | "model_id": "c484c0011f2a48d39490aa13140c778d", 39 | "version_major": 2, 40 | "version_minor": 0 41 | }, 42 | "text/plain": [ 43 | "Fetching 85 files: 0%| | 0/85 [00:00= args.threshold, 1, 0) 225 | f1 = f1_score(ground_truth, y_label) 226 | print(f"F1 score: {f1:.4f}") 227 | 228 | # AUC 229 | auc = roc_auc_score(ground_truth, preds) 230 | print(f"Validation AUC: {auc:.4f}") 231 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | ## Why TEGs instead of TAGs? 🤔 6 | 7 | Textual-Edge Graphs (TEGs) incorporate textual content on **both nodes and edges**, unlike Text-Attributed Graphs (TAGs) featuring textual information **only at the nodes**. **Edge texts are crucial for understanding document meanings and semantic relationships.** For instance, as shown below, to understand the knowledge "Planck endorsed the uncertainty and probabilistic nature of quantum mechanics," **citation edge (Book D - Paper E) text information is essential**. This reveals the comprehensive connections and influences among scholarly works, enabling a deeper analysis of document semantics and knowledge networks. 8 | 9 | ![alt text](assets/example.png) 10 | 11 | ## Overview 📚 12 | 13 | Textual-Edge Graphs Datasets and Benchmark (TEG-DB) is a comprehensive and diverse collection of benchmark textual-edge datasets featuring rich textual descriptions on nodes and edges, data loaders, and performance benchmarks for various baseline models, including pre-trained language models (PLMs), graph neural networks (GNNs), and their combinations. This repository aims to facilitate research in the domain of textual-edge graphs by providing standardized data formats and easy-to-use tools for model evaluation and comparison. 14 | 15 | ## Features ✨ 16 | 17 | + **Unified Data Representation:** All TEG datasets are represented in a unified format. This standardization allows for easy extension of new datasets into our benchmark. 18 | + **Highly Efficient Pipeline:** TEG-Benchmark is highly integrated with PyTorch Geometric (PyG), leveraging its powerful tools and functionalities. Therefore, its code is concise. Specifically, for each paradigm, we provide a small `.py` file with a summary of all relevant models and a `.ssh` file to run all baselines in one click. 19 | + **Comprehensive Benchmark and Analysis:** We conduct extensive benchmark experiments and perform a comprehensive analysis of TEG-based methods, delving deep into various aspects such as the impact of different models, the effect of embeddings generated by Pre-trained Language Models (PLMs) of various scales, and the influence of different domain datasets. The statistics of our TEG datasets are as follows: 20 | 21 |

22 | 23 |

24 | 25 | ## Datasets 📊 26 | 27 | Please click [Huggingface TEG-Benchmark](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets) to find the TEG datasets we upload! 28 | 29 | We have constructed **10 comprehensive and representative TEG datasets (we will continue to expand)**. These datasets cover domains including Book Recommendation, E-commerce, Academic, and Social networks. They vary in size, ranging from small to large. Each dataset contains rich raw text data on both nodes and edges, providing a diverse range of information for analysis and modeling purposes. 30 | 31 | | Dataset | Nodes | Edges | Graph Domain | Node Classification | Link Prediction | Details | 32 | |--------------------|-----------|------------|---------------------|---------------------|------------------|-------------------------------------| 33 | | Goodreads-History | 540,796 | 2,066,193 | Book Recommendation | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/goodreads_history/goodreads.md) | 34 | | Goodreads-Crime | 422,642 | 1,849,236 | Book Recommendation | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/goodreads_crime/goodreads.md) | 35 | | Goodreads-Children | 216,613 | 734,640 | Book Recommendation | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/goodreads_children/goodreads.md) | 36 | | Goodreads-Comics | 148,658 | 542,338 | Book Recommendation | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/goodreads_comics/goodreads.md) | 37 | | Amazon-Movie | 174,012 | 1,697,533 | E-commerce | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/amazon_movie/movie.md) | 38 | | Amazon-Apps | 100,480 | 752,937 | E-commerce | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/amazon_apps/apps.md) | 39 | | Amazon-Baby | 186,790 | 1,241,083 | E-commerce | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/amazon_baby/baby.md) | 40 | | Reddit | 512,776 | 256,388 | Social Networks | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/reddit/reddit.md) | 41 | | Twitter | 60,785 | 74,666 | Social Networks | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/twitter/twitter.md) | 42 | | Citation | 169,343 | 1,166,243 | Academic | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/arxiv/arxiv.md) | 43 | 44 | 45 | 46 | **TEG-DB is an ongoing effort, and we are planning to increase our coverage in the future.** 47 | 48 | ## Our Experiments 🔬 49 | 50 | Please check the experimental results and analysis from our [paper](https://arxiv.org/abs/2406.10310). 51 | 52 | ## Package Usage ⚙️ 53 | ### Requirements 📋 54 | You can quickly install the corresponding dependencies, 55 | 56 | ```bash 57 | pip install -r requirements.txt 58 | ``` 59 | 60 | ### More details about packages 📦 61 | 62 | The `GNN` folder is intended to house all GNN methods for link prediction and node classification (Please click [here](GNN/readme.md) to see more training details). The `data_preprocess` folder contains various scripts for generating TEG data. For instance, `data_preprocess/generate_llm_emb.py` generates text embeddings from LLM for the TEG in a high speed, while `data_preprocess/dataset.ipynb` facilitates uploading and downloading the [TEG datasets](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets) we provide from Hugging Face. 63 | 64 | ### Datasets setup 🛠️ 65 | 66 | You can go to the [Huggingface TEG-Benchmark](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets) to find the datasets we upload! In each dataset folder, you can find a `.py` (downloads raw data) and a `.ipynb` file (processes raw data to TEG) in `raw` folder, `.pt` files (node and edge text embeddings we extract from the PLM) in `emb` folder, and a `.pkl` file processed from raw data (final Textual-edge graph file). 67 | 68 | **Download all datasets:** 69 | ```python 70 | from huggingface_hub import snapshot_download 71 | 72 | snapshot_download(repo_id="ZhuofengLi/TEG-Datasets", repo_type="dataset", local_dir="./Dataset") 73 | ``` 74 | 75 | ### Running Experiments 🚀 76 | **GNN for Link Preidction:** 77 | ```bash 78 | bash run_gnn_link.sh 79 | ``` 80 | 81 | **GNN for Node Classification:** 82 | ```bash 83 | bash run_gnn_node.sh 84 | ``` 85 | 86 | Please click [here](GNN/readme.md) to see more details. 87 | 88 | ## Create Your Own Dataset and Model 🛠️ 89 | You can follow the steps below: 90 | + Align your dataset format with ours. Please check the dataset [README](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets) to find more details. 91 | + Generate your raw text embedding through [`generate_llm_emb.py`](https://github.com/Zhuofeng-Li/TEG-Benchmark/blob/main/data_preprocess/generate_llm_emb.py) with high speed. 92 | + Add your model to [`GNN/model/GNN_library.py`](https://github.com/Zhuofeng-Li/TEG-Benchmark/blob/main/GNN/model/GNN_library.py). 93 | + Update training codes' `args` and `gen_model()` function. 94 | 95 | ## Acknowledgements 96 | We would like to express our gratitude to the following projects and organizations for their contributions: 97 | 98 | + [semantic scholar](https://www.semanticscholar.org/) 99 | + [ogb](https://github.com/snap-stanford/ogb) 100 | 101 | 102 | 103 | ## Star and Cite ⭐ 104 | 105 | Please star our repo 🌟 and cite our [paper](https://arxiv.org/abs/2406.10310) if you feel useful. Feel free to [email](mailto:zhuofengli12345@gmail.com) us (zhuofengli12345@gmail.com) if you have any questions. 106 | 107 | ``` 108 | @misc{li2024tegdb, 109 | title={TEG-DB: A Comprehensive Dataset and Benchmark of Textual-Edge Graphs}, 110 | author={Zhuofeng Li and Zixing Gou and Xiangnan Zhang and Zhongyuan Liu and Sirui Li and Yuntong Hu and Chen Ling and Zheng Zhang and Liang Zhao}, 111 | year={2024}, 112 | eprint={2406.10310}, 113 | archivePrefix={arXiv}, 114 | primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'} 115 | } 116 | ``` 117 | -------------------------------------------------------------------------------- /GNN/GNN_Node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | from torch_sparse import SparseTensor 4 | import tqdm 5 | from sklearn.metrics import roc_auc_score 6 | from torch_geometric import seed_everything 7 | from torch_geometric.loader import NeighborLoader 8 | from model.GNN_library import GAT, GINE, GeneralGNN, GraphTransformer 9 | from torch.nn import Linear 10 | from sklearn.metrics import f1_score 11 | from sklearn.metrics import accuracy_score 12 | import argparse 13 | from torch_geometric.data import Data 14 | from sklearn.preprocessing import MultiLabelBinarizer 15 | 16 | 17 | class Classifier(torch.nn.Module): 18 | def __init__(self, hidden_channels, out_channels): 19 | super().__init__() 20 | self.lin1 = Linear(hidden_channels, hidden_channels) 21 | self.lin2 = Linear(hidden_channels, out_channels) 22 | 23 | def forward(self, x): 24 | x = self.lin1(x).relu() 25 | x = self.lin2(x) 26 | return x 27 | 28 | 29 | def gen_model(args, x, edge_feature): 30 | if args.gnn_model == "GAT": 31 | model = GAT( 32 | x.size(1), 33 | edge_feature.size(1), 34 | args.hidden_channels, 35 | args.hidden_channels, 36 | args.num_layers, 37 | args.heads, 38 | args.dropout, 39 | ) 40 | elif args.gnn_model == "GraphTransformer": 41 | model = GraphTransformer( 42 | x.size(1), 43 | edge_feature.size(1), 44 | args.hidden_channels, 45 | args.hidden_channels, 46 | args.num_layers, 47 | args.dropout, 48 | ) 49 | elif args.gnn_model == "GINE": 50 | model = GINE( 51 | x.size(1), 52 | edge_feature.size(1), 53 | args.hidden_channels, 54 | args.hidden_channels, 55 | args.num_layers, 56 | args.dropout, 57 | ) 58 | elif args.gnn_model == "GeneralGNN": 59 | model = GeneralGNN( 60 | x.size(1), 61 | edge_feature.size(1), 62 | args.hidden_channels, 63 | args.hidden_channels, 64 | args.num_layers, 65 | args.dropout, 66 | ) 67 | else: 68 | raise ValueError("Not implemented") 69 | return model 70 | 71 | 72 | def train(model, predictor, train_loader, optimizer, criterion): 73 | model.train() 74 | predictor.train() 75 | 76 | total_loss = total_examples = 0 77 | for batch in tqdm.tqdm(train_loader): 78 | optimizer.zero_grad() 79 | batch = batch.to(device) 80 | h = model(batch.x, batch.adj_t)[: batch.batch_size] 81 | pred = predictor(h) 82 | loss = criterion(pred, batch.y[: batch.batch_size].float()) 83 | loss.backward() 84 | optimizer.step() 85 | total_loss += float(loss) * pred.numel() 86 | total_examples += pred.numel() 87 | return total_examples, total_loss 88 | 89 | 90 | def process_data(args, device, data): # TODO: process all data here 91 | num_nodes = len(data.text_nodes) 92 | data.num_nodes = num_nodes 93 | 94 | product_indices = torch.tensor( 95 | [ 96 | i for i, label in enumerate(data.text_node_labels) if label != -1 97 | ] # TODO: check the label in each dataset 98 | ).long() 99 | product_labels = [label for label in data.text_node_labels if label != -1] 100 | mlb = MultiLabelBinarizer() 101 | product_binary_labels = mlb.fit_transform(product_labels) 102 | y = torch.zeros(num_nodes, product_binary_labels.shape[1]).float() 103 | y[product_indices] = torch.tensor(product_binary_labels).float() 104 | y = y.to(device) # TODO: check the label 105 | 106 | train_ratio = 1 - args.test_ratio - args.val_ratio 107 | val_ratio = args.val_ratio 108 | 109 | num_products = product_indices.shape[0] 110 | train_idx = product_indices[: int(num_products * train_ratio)] 111 | val_idx = product_indices[ 112 | int(num_products * train_ratio) : int(num_products * (train_ratio + val_ratio)) 113 | ] 114 | test_idx = product_indices[ 115 | int(product_indices.shape[0] * (train_ratio + val_ratio)) : 116 | ] 117 | 118 | x = torch.load(args.use_PLM_node).squeeze().float() 119 | edge_feature = torch.load(args.use_PLM_edge).squeeze().float() 120 | 121 | edge_index = data.edge_index 122 | adj_t = SparseTensor.from_edge_index( 123 | edge_index, edge_feature, sparse_sizes=(data.num_nodes, data.num_nodes) 124 | ).t() 125 | adj_t = adj_t.to_symmetric() 126 | node_split = {"train": train_idx, "val": val_idx, "test": test_idx} 127 | return node_split, x, edge_feature, adj_t, y 128 | 129 | 130 | def test(model, predictor, test_loader, args): 131 | print("Validation begins") 132 | with torch.no_grad(): 133 | preds = [] 134 | ground_truths = [] 135 | for batch in tqdm.tqdm(test_loader): 136 | batch = batch.to(device) 137 | h = model(batch.x, batch.adj_t)[: batch.batch_size] 138 | pred = predictor(h) 139 | ground_truth = batch.y[: batch.batch_size] 140 | preds.append(pred) 141 | ground_truths.append(ground_truth) 142 | 143 | preds = torch.cat(preds, dim=0).cpu().numpy() 144 | ground_truths = torch.cat(ground_truths, dim=0).cpu().numpy() 145 | 146 | y_label = (preds > args.threshold).astype(int) 147 | f1 = f1_score(ground_truths, y_label, average="weighted") 148 | print(f"F1 score: {f1:.4f}") 149 | 150 | data_type = args.graph_path.split("/")[-1] 151 | if data_type not in [ 152 | "twitter.pkl", 153 | "reddit.pkl", 154 | "citation.pkl", 155 | ]: # TODO: check metric in 156 | auc = roc_auc_score(ground_truths, preds, average="micro") 157 | print(f"Validation AUC: {auc:.4f}") 158 | else: 159 | accuracy = accuracy_score(ground_truths, y_label) 160 | print(f"Validation Accuracy: {accuracy:.4f}") 161 | 162 | 163 | if __name__ == "__main__": 164 | # TODO: use wandb as log 165 | seed_everything(66) 166 | 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("--device", type=int, default=0) 169 | parser.add_argument( 170 | "--gnn_model", 171 | "-gm", 172 | type=str, 173 | default="GAT", 174 | help="Model type for HeteroGNN, options are GraphTransformer, GINE, GraphSAGE, GeneralConv, MLP, EdgeConv,RevGAT", 175 | ) 176 | parser.add_argument( 177 | "--use_PLM_node", 178 | type=str, 179 | default="data/CSTAG/Photo/Feature/children_gpt_node.pt", 180 | help="Use LM embedding as node feature", 181 | ) 182 | parser.add_argument( 183 | "--use_PLM_edge", 184 | type=str, 185 | default="data/CSTAG/Photo/Feature/children_gpt_edge.pt", 186 | help="Use LM embedding as edge feature", 187 | ) 188 | parser.add_argument( 189 | "--graph_path", 190 | type=str, 191 | default="../../TEG-Datasets/goodreads_children/processed/children.pkl", # "data/CSTAG/Photo/children.pkl", 192 | help="Path to load the graph", 193 | ) 194 | parser.add_argument("--eval_steps", type=int, default=1) 195 | parser.add_argument("--hidden_channels", type=int, default=256) 196 | parser.add_argument("--heads", type=int, default=4) 197 | parser.add_argument("--num_layers", type=int, default=2) 198 | parser.add_argument("--dropout", type=float, default=0.0) 199 | parser.add_argument("--epochs", type=int, default=100) 200 | parser.add_argument("--threshold", type=float, default=0.5) 201 | parser.add_argument("--test_ratio", type=float, default=0.1) 202 | parser.add_argument("--val_ratio", type=float, default=0.1) 203 | parser.add_argument("--batch_size", type=int, default=1024) 204 | args = parser.parse_args() 205 | 206 | device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu" 207 | device = torch.device(device) 208 | 209 | with open(f"{args.graph_path}", "rb") as f: 210 | data = pickle.load(f) 211 | 212 | node_split, x, edge_feature, adj_t, y = process_data(args, device, data) 213 | data = Data(x=x, adj_t=adj_t, y=y) 214 | 215 | train_loader = NeighborLoader( 216 | data, 217 | input_nodes=node_split["train"], 218 | num_neighbors=[10, 10], 219 | batch_size=args.batch_size, 220 | shuffle=True, 221 | ) 222 | val_loader = NeighborLoader( 223 | data, 224 | input_nodes=node_split["val"], 225 | num_neighbors=[10, 10], 226 | batch_size=args.batch_size, 227 | shuffle=False, 228 | ) 229 | test_loader = NeighborLoader( 230 | data, 231 | input_nodes=node_split["test"], 232 | num_neighbors=[10, 10], 233 | batch_size=args.batch_size, 234 | shuffle=False, 235 | ) # TODO: move `num_neighbors` to args 236 | 237 | model = gen_model(args, x, edge_feature) 238 | model = model.to(device) 239 | predictor = Classifier(args.hidden_channels, y.shape[1]).to(device) 240 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 241 | criterion = torch.nn.CrossEntropyLoss() 242 | 243 | for epoch in range(1, 1 + args.epochs): 244 | total_examples, total_loss = train( 245 | model, predictor, train_loader, optimizer, criterion 246 | ) 247 | print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}") 248 | 249 | # validation 250 | if epoch % args.eval_steps == 0: 251 | test(model, predictor, test_loader, args) 252 | -------------------------------------------------------------------------------- /GNN/GNN_Link_MRR.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | 8 | from torch_sparse import SparseTensor 9 | 10 | from model.GNN_library import GAT, GINE, GeneralGNN, GraphTransformer 11 | from model.Dataloader import Evaluator, split_edge_mrr 12 | from model.GNN_arg import Logger 13 | import wandb 14 | import os 15 | 16 | 17 | class LinkPredictor(torch.nn.Module): 18 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 19 | super(LinkPredictor, self).__init__() 20 | 21 | self.lins = torch.nn.ModuleList() 22 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 23 | for _ in range(num_layers - 2): 24 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 25 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels)) 26 | 27 | self.dropout = dropout 28 | 29 | def reset_parameters(self): 30 | for lin in self.lins: 31 | lin.reset_parameters() 32 | 33 | def forward(self, x_i, x_j): 34 | x = x_i * x_j 35 | for lin in self.lins[:-1]: 36 | x = lin(x) 37 | x = F.relu(x) 38 | x = F.dropout(x, p=self.dropout, training=self.training) 39 | x = self.lins[-1](x) 40 | return torch.sigmoid(x) 41 | 42 | 43 | def train(model, predictor, x, adj_t, edge_split, optimizer, batch_size): 44 | model.train() 45 | predictor.train() 46 | 47 | source_edge = edge_split["train"]["source_node"].to(x.device) 48 | target_edge = edge_split["train"]["target_node"].to(x.device) 49 | 50 | total_loss = total_examples = 0 51 | for perm in DataLoader(range(source_edge.size(0)), batch_size, shuffle=True): 52 | optimizer.zero_grad() 53 | 54 | h = model(x, adj_t) 55 | src, dst = source_edge[perm], target_edge[perm] 56 | pos_out = predictor(h[src], h[dst]) 57 | pos_loss = -torch.log(pos_out + 1e-15).mean() 58 | 59 | # Just do some trivial random sampling. 60 | dst_neg = torch.randint( 61 | 0, x.size(0), src.size(), dtype=torch.long, device=h.device 62 | ) 63 | neg_out = predictor(h[src], h[dst_neg]) 64 | neg_loss = -torch.log(1 - neg_out + 1e-15).mean() 65 | 66 | loss = pos_loss + neg_loss 67 | loss.backward() 68 | optimizer.step() 69 | 70 | num_examples = pos_out.size(0) 71 | total_loss += loss.item() * num_examples 72 | total_examples += num_examples 73 | 74 | return total_loss / total_examples 75 | 76 | 77 | @torch.no_grad() 78 | def test(model, predictor, x, adj_t, edge_split, evaluator, batch_size, neg_len): 79 | model.eval() 80 | predictor.eval() 81 | 82 | h = model(x, adj_t) 83 | 84 | def test_split(split, neg_len): 85 | source = edge_split[split]["source_node"].to(h.device) 86 | target = edge_split[split]["target_node"].to(h.device) 87 | target_neg = edge_split[split]["target_node_neg"].to(h.device) 88 | 89 | pos_preds = [] 90 | for perm in DataLoader(range(source.size(0)), batch_size): 91 | src, dst = source[perm], target[perm] 92 | pos_preds += [predictor(h[src], h[dst]).squeeze().cpu()] 93 | pos_pred = torch.cat(pos_preds, dim=0) 94 | 95 | neg_preds = [] 96 | source = source.view(-1, 1).repeat(1, neg_len).view(-1) 97 | target_neg = target_neg.view(-1) 98 | for perm in DataLoader(range(source.size(0)), batch_size): 99 | src, dst_neg = source[perm], target_neg[perm] 100 | neg_preds += [predictor(h[src], h[dst_neg]).squeeze().cpu()] 101 | neg_pred = torch.cat(neg_preds, dim=0).view(-1, neg_len) 102 | 103 | return ( 104 | evaluator.eval( 105 | { 106 | "y_pred_pos": pos_pred, 107 | "y_pred_neg": neg_pred, 108 | } 109 | )["mrr_list"] 110 | .mean() 111 | .item() 112 | ) 113 | 114 | train_mrr = test_split("eval_train", neg_len) 115 | valid_mrr = test_split("valid", neg_len) 116 | test_mrr = test_split("test", neg_len) 117 | 118 | return train_mrr, valid_mrr, test_mrr 119 | 120 | 121 | def main(): 122 | parser = argparse.ArgumentParser(description="Link-Prediction PLM/TCL") 123 | parser.add_argument("--device", type=int, default=0) 124 | parser.add_argument("--log_steps", type=int, default=5) 125 | parser.add_argument("--use_node_embedding", action="store_true") 126 | parser.add_argument("--num_layers", type=int, default=2) 127 | parser.add_argument("--hidden_channels", type=int, default=256) 128 | parser.add_argument("--dropout", type=float, default=0.0) 129 | parser.add_argument("--batch_size", type=int, default=64 * 1024) 130 | parser.add_argument("--lr", type=float, default=0.0005) 131 | parser.add_argument("--epochs", type=int, default=100) 132 | parser.add_argument("--gnn_model", type=str, help="GNN Model", default="GeneralGNN") 133 | parser.add_argument("--heads", type=int, default=4) 134 | parser.add_argument("--eval_steps", type=int, default=1) 135 | parser.add_argument("--runs", type=int, default=5) 136 | parser.add_argument("--test_ratio", type=float, default=0.08) 137 | parser.add_argument("--val_ratio", type=float, default=0.02) 138 | parser.add_argument("--neg_len", type=int, default=50) 139 | parser.add_argument( 140 | "--use_PLM_node", 141 | "-pn", 142 | type=str, 143 | default="data/CSTAG/Photo/Feature/children_gpt_node.pt", 144 | help="Use LM embedding as node feature", 145 | ) 146 | parser.add_argument( 147 | "--use_PLM_edge", 148 | "-pe", 149 | type=str, 150 | default="data/CSTAG/Photo/Feature/children_gpt_edge.pt", 151 | help="Use LM embedding as edge feature", 152 | ) 153 | parser.add_argument( 154 | "--path", 155 | type=str, 156 | default="data/CSTAG/Photo/LinkPrediction/", 157 | help="Path to save splitting", 158 | ) 159 | parser.add_argument( 160 | "--graph_path", 161 | "-gp", 162 | type=str, 163 | default="data/CSTAG/Photo/children.pkl", 164 | help="Path to load the graph", 165 | ) 166 | 167 | args = parser.parse_args() 168 | 169 | wandb.config = args 170 | wandb.init(config=args, reinit=True) 171 | print(args) 172 | 173 | if not os.path.exists(f"{args.path}{args.neg_len}/"): 174 | os.makedirs(f"{args.path}{args.neg_len}/") 175 | 176 | device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu" 177 | device = torch.device(device) 178 | 179 | with open(f"{args.graph_path}", "rb") as file: 180 | graph = pickle.load(file) 181 | 182 | graph.num_nodes = len(graph.text_nodes) 183 | 184 | edge_split = split_edge_mrr( 185 | graph, 186 | test_ratio=args.test_ratio, 187 | val_ratio=args.val_ratio, 188 | path=args.path, 189 | neg_len=args.neg_len, 190 | ) 191 | 192 | x = torch.load(args.use_PLM_node).float().to(device) 193 | edge_feature = torch.load(args.use_PLM_edge)[ 194 | edge_split["train"]["train_edge_feature_index"] 195 | ].float() 196 | 197 | x = x.to(device) 198 | 199 | torch.manual_seed(42) 200 | idx = torch.randperm(edge_split["train"]["source_node"].numel())[ 201 | : len(edge_split["valid"]["source_node"]) 202 | ] 203 | edge_split["eval_train"] = { 204 | "source_node": edge_split["train"]["source_node"][idx], 205 | "target_node": edge_split["train"]["target_node"][idx], 206 | "target_node_neg": edge_split["valid"]["target_node_neg"], 207 | } 208 | 209 | 210 | edge_index = edge_split["train"]["edge"].t() 211 | adj_t = SparseTensor.from_edge_index(edge_index, edge_feature, sparse_sizes=(graph.num_nodes, graph.num_nodes)).t() 212 | adj_t = adj_t.to_symmetric().to(device) 213 | 214 | if args.gnn_model == "GAT": # TODO: use `gen_model` func to simplify the code 215 | model = GAT( 216 | x.size(1), 217 | edge_feature.size(1), 218 | args.hidden_channels, 219 | args.hidden_channels, 220 | args.num_layers, 221 | args.heads, 222 | args.dropout, 223 | ).to(device) 224 | elif args.gnn_model == "GraphTransformer": 225 | model = GraphTransformer( 226 | x.size(1), 227 | edge_feature.size(1), 228 | args.hidden_channels, 229 | args.hidden_channels, 230 | args.num_layers, 231 | args.dropout, 232 | ).to(device) 233 | elif args.gnn_model == "GINE": 234 | model = GINE( 235 | x.size(1), 236 | edge_feature.size(1), 237 | args.hidden_channels, 238 | args.hidden_channels, 239 | args.num_layers, 240 | args.dropout, 241 | ).to(device) 242 | elif args.gnn_model == "GeneralGNN": 243 | model = GeneralGNN( 244 | x.size(1), 245 | edge_feature.size(1), 246 | args.hidden_channels, 247 | args.hidden_channels, 248 | args.num_layers, 249 | args.dropout, 250 | ).to(device) 251 | else: 252 | raise ValueError("Not implemented") 253 | 254 | predictor = LinkPredictor( 255 | args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout 256 | ).to(device) 257 | 258 | evaluator = Evaluator(name="DBLP") 259 | logger = Logger(args.runs, args) 260 | 261 | for run in range(args.runs): 262 | model.reset_parameters() 263 | predictor.reset_parameters() 264 | optimizer = torch.optim.Adam( 265 | list(model.parameters()) + list(predictor.parameters()), lr=args.lr 266 | ) 267 | 268 | for epoch in range(1, 1 + args.epochs): 269 | loss = train( 270 | model, predictor, x, adj_t, edge_split, optimizer, args.batch_size 271 | ) 272 | wandb.log({"Loss": loss}) 273 | if epoch % args.eval_steps == 0: 274 | result = test( 275 | model, 276 | predictor, 277 | x, 278 | adj_t, 279 | edge_split, 280 | evaluator, 281 | args.batch_size, 282 | args.neg_len, 283 | ) 284 | logger.add_result(run, result) 285 | 286 | if epoch % args.log_steps == 0: 287 | train_mrr, valid_mrr, test_mrr = result 288 | print( 289 | f"Run: {run + 1:02d}, " 290 | f"Epoch: {epoch:02d}, " 291 | f"Loss: {loss:.4f}, " 292 | f"Train: {train_mrr:.4f}, " 293 | f"Valid: {valid_mrr:.4f}, " 294 | f"Test: {test_mrr:.4f}" 295 | ) 296 | 297 | logger.print_statistics(run) 298 | 299 | logger.print_statistics(key="mrr") 300 | 301 | 302 | if __name__ == "__main__": 303 | main() 304 | -------------------------------------------------------------------------------- /GNN/GNN_Link_HitsK.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | 8 | from torch_sparse import SparseTensor 9 | 10 | from model.GNN_library import GAT, GINE, GeneralGNN, GraphTransformer 11 | from model.Dataloader import Evaluator, split_edge 12 | from model.GNN_arg import Logger 13 | import wandb 14 | import os 15 | 16 | 17 | class LinkPredictor(torch.nn.Module): 18 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 19 | super(LinkPredictor, self).__init__() 20 | 21 | self.lins = torch.nn.ModuleList() 22 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 23 | for _ in range(num_layers - 2): 24 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 25 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels)) 26 | 27 | self.dropout = dropout 28 | 29 | def reset_parameters(self): 30 | for lin in self.lins: 31 | lin.reset_parameters() 32 | 33 | def forward(self, x_i, x_j): 34 | x = x_i * x_j 35 | for lin in self.lins[:-1]: 36 | x = lin(x) 37 | x = F.relu(x) 38 | x = F.dropout(x, p=self.dropout, training=self.training) 39 | x = self.lins[-1](x) 40 | return torch.sigmoid(x) 41 | 42 | 43 | def train(model, predictor, x, adj_t, edge_split, optimizer, batch_size): 44 | model.train() 45 | predictor.train() 46 | 47 | pos_train_edge = edge_split["train"]["edge"].to(x.device) 48 | 49 | total_loss = total_examples = 0 50 | for perm in DataLoader(range(pos_train_edge.size(0)), batch_size, shuffle=True): 51 | optimizer.zero_grad() 52 | 53 | h = model(x, adj_t) 54 | 55 | edge = pos_train_edge[perm].t() 56 | 57 | pos_out = predictor(h[edge[0]], h[edge[1]]) 58 | pos_loss = -torch.log(pos_out + 1e-15).mean() 59 | 60 | # Just do some trivial random sampling. 61 | edge = torch.randint( 62 | 0, x.size(0), edge.size(), dtype=torch.long, device=h.device 63 | ) 64 | neg_out = predictor(h[edge[0]], h[edge[1]]) 65 | neg_loss = -torch.log(1 - neg_out + 1e-15).mean() 66 | 67 | loss = pos_loss + neg_loss 68 | loss.backward() 69 | 70 | optimizer.step() 71 | 72 | num_examples = pos_out.size(0) 73 | total_loss += loss.item() * num_examples 74 | total_examples += num_examples 75 | 76 | return total_loss / total_examples 77 | 78 | 79 | @torch.no_grad() 80 | def test(model, predictor, x, adj_t, edge_split, evaluator, batch_size): 81 | model.eval() 82 | predictor.eval() 83 | 84 | h = model(x, adj_t) 85 | 86 | pos_train_edge = edge_split["train"]["edge"].to(h.device) 87 | pos_valid_edge = edge_split["valid"]["edge"].to(h.device) 88 | neg_valid_edge = edge_split["valid"]["edge_neg"].to(h.device) 89 | pos_test_edge = edge_split["test"]["edge"].to(h.device) 90 | neg_test_edge = edge_split["test"]["edge_neg"].to(h.device) 91 | 92 | pos_train_preds = [] 93 | for perm in DataLoader(range(pos_train_edge.size(0)), batch_size): 94 | edge = pos_train_edge[perm].t() 95 | pos_train_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 96 | pos_train_pred = torch.cat(pos_train_preds, dim=0) 97 | 98 | pos_valid_preds = [] 99 | for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size): 100 | edge = pos_valid_edge[perm].t() 101 | pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 102 | pos_valid_pred = torch.cat(pos_valid_preds, dim=0) 103 | 104 | neg_valid_preds = [] 105 | for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size): 106 | edge = neg_valid_edge[perm].t() 107 | neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 108 | neg_valid_pred = torch.cat(neg_valid_preds, dim=0) 109 | 110 | h = model(x, adj_t) 111 | 112 | pos_test_preds = [] 113 | for perm in DataLoader(range(pos_test_edge.size(0)), batch_size): 114 | edge = pos_test_edge[perm].t() 115 | pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 116 | pos_test_pred = torch.cat(pos_test_preds, dim=0) 117 | 118 | neg_test_preds = [] 119 | for perm in DataLoader(range(neg_test_edge.size(0)), batch_size): 120 | edge = neg_test_edge[perm].t() 121 | neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 122 | neg_test_pred = torch.cat(neg_test_preds, dim=0) 123 | 124 | results = {} 125 | for K in [10, 50, 100]: 126 | evaluator.K = K 127 | train_hits = evaluator.eval( 128 | { 129 | "y_pred_pos": pos_train_pred, 130 | "y_pred_neg": neg_valid_pred, 131 | } 132 | )[f"hits@{K}"] 133 | valid_hits = evaluator.eval( 134 | { 135 | "y_pred_pos": pos_valid_pred, 136 | "y_pred_neg": neg_valid_pred, 137 | } 138 | )[f"hits@{K}"] 139 | test_hits = evaluator.eval( 140 | { 141 | "y_pred_pos": pos_test_pred, 142 | "y_pred_neg": neg_test_pred, 143 | } 144 | )[f"hits@{K}"] 145 | 146 | results[f"Hits@{K}"] = (train_hits, valid_hits, test_hits) 147 | 148 | return results 149 | 150 | 151 | def main(): 152 | parser = argparse.ArgumentParser(description="Link-Prediction PLM/TCL") 153 | parser.add_argument("--device", type=int, default=0) 154 | parser.add_argument("--log_steps", type=int, default=1) 155 | parser.add_argument("--use_node_embedding", action="store_true") 156 | parser.add_argument("--num_layers", type=int, default=3) 157 | parser.add_argument("--hidden_channels", type=int, default=256) 158 | parser.add_argument("--dropout", type=float, default=0.0) 159 | parser.add_argument("--batch_size", type=int, default=64 * 1024) 160 | parser.add_argument("--lr", type=float, default=0.001) 161 | parser.add_argument("--epochs", type=int, default=10) 162 | parser.add_argument( 163 | "--gnn_model", type=str, help="GNN Model", default="GraphTransformer" 164 | ) 165 | parser.add_argument("--heads", type=int, default=4) 166 | parser.add_argument("--eval_steps", type=int, default=1) 167 | parser.add_argument("--runs", type=int, default=5) 168 | parser.add_argument("--test_ratio", type=float, default=0.08) 169 | parser.add_argument("--val_ratio", type=float, default=0.02) 170 | parser.add_argument("--neg_len", type=str, default="10000") 171 | parser.add_argument( 172 | "--use_PLM_node", 173 | type=str, 174 | default="data/CSTAG/Photo/Feature/children_gpt_node.pt", 175 | help="Use LM embedding as node feature", 176 | ) 177 | parser.add_argument( 178 | "--use_PLM_edge", 179 | type=str, 180 | default="data/CSTAG/Photo/Feature/children_gpt_edge.pt", 181 | help="Use LM embedding as edge feature", 182 | ) 183 | parser.add_argument( 184 | "--path", 185 | type=str, 186 | default="data/CSTAG/Photo/LinkPrediction/", 187 | help="Path to save splitting", 188 | ) 189 | parser.add_argument( 190 | "--graph_path", 191 | type=str, 192 | default="data/CSTAG/Photo/children.pkl", 193 | help="Path to load the graph", 194 | ) 195 | args = parser.parse_args() 196 | wandb.config = args 197 | wandb.init(config=args, reinit=True) 198 | print(args) 199 | 200 | if not os.path.exists(f"{args.path}{args.neg_len}/"): 201 | os.makedirs(f"{args.path}{args.neg_len}/") 202 | 203 | device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu" 204 | device = torch.device(device) 205 | 206 | with open(f"{args.graph_path}", "rb") as file: 207 | graph = pickle.load(file) 208 | 209 | graph.num_nodes = len(graph.text_nodes) 210 | edge_split = split_edge( 211 | graph, test_ratio=args.test_ratio, val_ratio=args.val_ratio, path=args.path, neg_len=args.neg_len 212 | ) 213 | 214 | x = torch.load(args.use_PLM_node).float().to(device) 215 | edge_feature = torch.load(args.use_PLM_edge)[ 216 | edge_split["train"]["train_edge_feature_index"] 217 | ].float() 218 | 219 | edge_index = edge_split["train"]["edge"].t() 220 | adj_t = SparseTensor.from_edge_index(edge_index, edge_feature, sparse_sizes=(graph.num_nodes, graph.num_nodes)).t() 221 | adj_t = adj_t.to_symmetric().to(device) 222 | if args.gnn_model == "GAT": 223 | model = GAT( 224 | x.size(1), 225 | edge_feature.size(1), 226 | args.hidden_channels, 227 | args.hidden_channels, 228 | args.num_layers, 229 | args.heads, 230 | args.dropout, 231 | ).to(device) 232 | elif args.gnn_model == "GraphTransformer": 233 | model = GraphTransformer( 234 | x.size(1), 235 | edge_feature.size(1), 236 | args.hidden_channels, 237 | args.hidden_channels, 238 | args.num_layers, 239 | args.dropout, 240 | ).to(device) 241 | elif args.gnn_model == "GINE": 242 | model = GINE( 243 | x.size(1), 244 | edge_feature.size(1), 245 | args.hidden_channels, 246 | args.hidden_channels, 247 | args.num_layers, 248 | args.dropout, 249 | ).to(device) 250 | elif args.gnn_model == "GeneralGNN": 251 | model = GeneralGNN( 252 | x.size(1), 253 | edge_feature.size(1), 254 | args.hidden_channels, 255 | args.hidden_channels, 256 | args.num_layers, 257 | args.dropout, 258 | ).to(device) 259 | else: 260 | raise ValueError("Not implemented") 261 | 262 | predictor = LinkPredictor( 263 | args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout 264 | ).to(device) 265 | 266 | evaluator = Evaluator(name="History") 267 | loggers = { 268 | "Hits@10": Logger(args.runs, args), 269 | "Hits@50": Logger(args.runs, args), 270 | "Hits@100": Logger(args.runs, args), 271 | } 272 | 273 | for run in range(args.runs): 274 | model.reset_parameters() 275 | predictor.reset_parameters() 276 | optimizer = torch.optim.Adam( 277 | list(model.parameters()) + list(predictor.parameters()), lr=args.lr 278 | ) 279 | 280 | for epoch in range(1, 1 + args.epochs): 281 | loss = train( 282 | model, predictor, x, adj_t, edge_split, optimizer, args.batch_size 283 | ) 284 | 285 | if epoch % args.eval_steps == 0: 286 | results = test( 287 | model, predictor, x, adj_t, edge_split, evaluator, args.batch_size 288 | ) 289 | for key, result in results.items(): 290 | loggers[key].add_result(run, result) 291 | 292 | if epoch % args.log_steps == 0: 293 | for key, result in results.items(): 294 | train_hits, valid_hits, test_hits = result 295 | print(key) 296 | print( 297 | f"Run: {run + 1:02d}, " 298 | f"Epoch: {epoch:02d}, " 299 | f"Loss: {loss:.4f}, " 300 | f"Train: {100 * train_hits:.2f}%, " 301 | f"Valid: {100 * valid_hits:.2f}%, " 302 | f"Test: {100 * test_hits:.2f}%" 303 | ) 304 | print("---") 305 | 306 | for key in loggers.keys(): 307 | print(key) 308 | loggers[key].print_statistics(run) 309 | 310 | for key in loggers.keys(): 311 | print(key) 312 | loggers[key].print_statistics(key=key) 313 | 314 | 315 | if __name__ == "__main__": 316 | main() 317 | -------------------------------------------------------------------------------- /GNN/GNN_Link_MRR_loader.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data 2 | from torch_geometric.loader import LinkNeighborLoader 3 | import argparse 4 | import pickle 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from torch_sparse import SparseTensor 10 | 11 | from model.GNN_library import GAT, GINE, GeneralGNN, GraphTransformer 12 | from model.Dataloader import Evaluator, split_edge_mrr 13 | from model.GNN_arg import Logger 14 | import wandb 15 | import os 16 | 17 | 18 | class LinkPredictor(torch.nn.Module): 19 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 20 | super(LinkPredictor, self).__init__() 21 | 22 | self.lins = torch.nn.ModuleList() 23 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 24 | for _ in range(num_layers - 2): 25 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 26 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels)) 27 | 28 | self.dropout = dropout 29 | 30 | def reset_parameters(self): 31 | for lin in self.lins: 32 | lin.reset_parameters() 33 | 34 | def forward(self, x_i, x_j): 35 | x = x_i * x_j 36 | for lin in self.lins[:-1]: 37 | x = lin(x) 38 | x = F.relu(x) 39 | x = F.dropout(x, p=self.dropout, training=self.training) 40 | x = self.lins[-1](x) 41 | return torch.sigmoid(x) 42 | 43 | 44 | def gen_loader(args, edge_split, x, edge_index, adj_t): 45 | train_data = Data(x=x, adj_t=adj_t) 46 | train_loader = LinkNeighborLoader( 47 | data=train_data, 48 | num_neighbors=[20, 10], 49 | edge_label_index=(edge_index), 50 | edge_label=torch.ones(edge_index.shape[1]), 51 | batch_size=args.batch_size, 52 | neg_sampling_ratio=0.0, 53 | shuffle=True, 54 | ) 55 | 56 | val_edge_label_index = edge_split["valid"]["edge"].t() 57 | val_loader = LinkNeighborLoader( 58 | data=train_data, 59 | num_neighbors=[20, 10], 60 | edge_label_index=(val_edge_label_index), 61 | edge_label=torch.ones(val_edge_label_index.shape[1]), 62 | batch_size=args.batch_size, 63 | neg_sampling_ratio=0.0, 64 | shuffle=False, 65 | ) 66 | 67 | test_edge_label_index = edge_split["test"]["edge"].t() 68 | test_loader = LinkNeighborLoader( 69 | data=train_data, 70 | num_neighbors=[20, 10], 71 | edge_label_index=(test_edge_label_index), 72 | edge_label=torch.ones(test_edge_label_index.shape[1]), 73 | batch_size=args.batch_size, 74 | neg_sampling_ratio=0.0, 75 | shuffle=False, 76 | ) 77 | 78 | return train_loader, val_loader, test_loader 79 | 80 | 81 | def gen_model(args, x, edge_feature): 82 | if args.gnn_model == "GAT": 83 | model = GAT( 84 | x.size(1), 85 | edge_feature.size(1), 86 | args.hidden_channels, 87 | args.hidden_channels, 88 | args.num_layers, 89 | args.heads, 90 | args.dropout, 91 | ) 92 | elif args.gnn_model == "GraphTransformer": 93 | model = GraphTransformer( 94 | x.size(1), 95 | edge_feature.size(1), 96 | args.hidden_channels, 97 | args.hidden_channels, 98 | args.num_layers, 99 | args.dropout, 100 | ) 101 | elif args.gnn_model == "GINE": 102 | model = GINE( 103 | x.size(1), 104 | edge_feature.size(1), 105 | args.hidden_channels, 106 | args.hidden_channels, 107 | args.num_layers, 108 | args.dropout, 109 | ) 110 | elif args.gnn_model == "GeneralGNN": 111 | model = GeneralGNN( 112 | x.size(1), 113 | edge_feature.size(1), 114 | args.hidden_channels, 115 | args.hidden_channels, 116 | args.num_layers, 117 | args.dropout, 118 | ) 119 | else: 120 | raise ValueError("Not implemented") 121 | return model 122 | 123 | 124 | def train(model, predictor, train_loader, optimizer, device): 125 | model.train() 126 | predictor.train() 127 | 128 | total_loss = total_examples = 0 129 | for batch in train_loader: 130 | optimizer.zero_grad() 131 | batch = batch.to(device) 132 | h = model(batch.x, batch.adj_t) 133 | 134 | src = batch.edge_label_index.t()[:, 0] 135 | dst = batch.edge_label_index.t()[:, 1] 136 | 137 | pos_out = predictor(h[src], h[dst]) 138 | pos_loss = -torch.log(pos_out + 1e-15).mean() 139 | 140 | # Just do some trivial random sampling. 141 | dst_neg = torch.randint( 142 | 0, h.size(0), src.size(), dtype=torch.long, device=h.device 143 | ) 144 | neg_out = predictor(h[src], h[dst_neg]) 145 | neg_loss = -torch.log(1 - neg_out + 1e-15).mean() 146 | 147 | loss = pos_loss + neg_loss 148 | loss.backward() 149 | optimizer.step() 150 | 151 | num_examples = pos_out.size(0) 152 | total_loss += loss.item() * num_examples 153 | total_examples += num_examples 154 | 155 | return total_loss / total_examples 156 | 157 | 158 | @torch.no_grad() 159 | def test(model, predictor, dataloaders, evaluator, neg_len, device): 160 | model.eval() 161 | predictor.eval() 162 | 163 | def test_split(dataloader, neg_len, device): 164 | pos_preds = [] 165 | for batch in dataloader: 166 | batch = batch.to(device) 167 | h = model(batch.x, batch.adj_t) 168 | src = batch.edge_label_index.t()[:, 0] 169 | dst = batch.edge_label_index.t()[:, 1] 170 | pos_preds += [predictor(h[src], h[dst]).squeeze().cpu()] 171 | pos_pred = torch.cat(pos_preds, dim=0) 172 | 173 | neg_preds = [] 174 | 175 | for batch in dataloader: 176 | batch = batch.to(device) 177 | h = model(batch.x, batch.adj_t) 178 | src = batch.edge_label_index.t()[:, 0] 179 | dst_neg = torch.randint( 180 | 0, h.size(0), [len(src), int(neg_len)], dtype=torch.long 181 | ).view(-1) 182 | src = src.view(-1, 1).repeat(1, neg_len).view(-1) 183 | neg_preds += [predictor(h[src], h[dst_neg]).squeeze().cpu()] 184 | neg_pred = torch.cat(neg_preds, dim=0).view(-1, neg_len) 185 | 186 | return ( 187 | evaluator.eval( 188 | { 189 | "y_pred_pos": pos_pred, 190 | "y_pred_neg": neg_pred, 191 | } 192 | )["mrr_list"] 193 | .mean() 194 | .item() 195 | ) 196 | 197 | train_mrr = test_split(dataloaders["train"], neg_len, device) 198 | valid_mrr = test_split(dataloaders["valid"], neg_len, device) 199 | test_mrr = test_split(dataloaders["test"], neg_len, device) 200 | 201 | return train_mrr, valid_mrr, test_mrr 202 | 203 | 204 | def main(): 205 | parser = argparse.ArgumentParser(description="Link-Prediction PLM/TCL") 206 | parser.add_argument("--device", type=int, default=0) 207 | parser.add_argument("--log_steps", type=int, default=1) # TODO: update latter 208 | parser.add_argument("--use_node_embedding", action="store_true") 209 | parser.add_argument("--num_layers", type=int, default=2) 210 | parser.add_argument("--hidden_channels", type=int, default=256) 211 | parser.add_argument("--dropout", type=float, default=0.0) 212 | parser.add_argument("--batch_size", type=int, default=64 * 1024) 213 | parser.add_argument("--lr", type=float, default=0.0005) 214 | parser.add_argument("--epochs", type=int, default=100) 215 | parser.add_argument("--gnn_model", type=str, help="GNN Model", default="GeneralGNN") 216 | parser.add_argument("--heads", type=int, default=4) 217 | parser.add_argument("--eval_steps", type=int, default=1) 218 | parser.add_argument("--runs", type=int, default=5) 219 | parser.add_argument("--test_ratio", type=float, default=0.08) 220 | parser.add_argument("--val_ratio", type=float, default=0.02) 221 | parser.add_argument("--neg_len", type=int, default=50) 222 | parser.add_argument( 223 | "--use_PLM_node", 224 | "-pn", 225 | type=str, 226 | default="data/CSTAG/Photo/Feature/children_gpt_node.pt", 227 | help="Use LM embedding as node feature", 228 | ) 229 | parser.add_argument( 230 | "--use_PLM_edge", 231 | "-pe", 232 | type=str, 233 | default="data/CSTAG/Photo/Feature/children_gpt_edge.pt", 234 | help="Use LM embedding as edge feature", 235 | ) 236 | parser.add_argument( 237 | "--path", 238 | type=str, 239 | default="data/CSTAG/Photo/LinkPrediction/", 240 | help="Path to save splitting", 241 | ) 242 | parser.add_argument( 243 | "--graph_path", 244 | "-gp", 245 | type=str, 246 | default="data/CSTAG/Photo/children.pkl", 247 | help="Path to load the graph", 248 | ) 249 | 250 | args = parser.parse_args() 251 | 252 | wandb.config = args 253 | wandb.init(config=args, reinit=True) 254 | print(args) 255 | 256 | if not os.path.exists(f"{args.path}{args.neg_len}/"): 257 | os.makedirs(f"{args.path}{args.neg_len}/") 258 | 259 | device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu" 260 | device = torch.device(device) 261 | 262 | with open(f"{args.graph_path}", "rb") as file: 263 | graph = pickle.load(file) 264 | 265 | graph.num_nodes = len(graph.text_nodes) 266 | 267 | edge_split = split_edge_mrr( 268 | graph, 269 | test_ratio=args.test_ratio, 270 | val_ratio=args.val_ratio, 271 | path=args.path, 272 | neg_len=args.neg_len, 273 | ) 274 | 275 | x = torch.load(args.use_PLM_node).float() 276 | edge_feature = torch.load(args.use_PLM_edge)[ 277 | edge_split["train"]["train_edge_feature_index"] 278 | ].float() 279 | 280 | edge_index = edge_split["train"]["edge"].t() 281 | adj_t = SparseTensor.from_edge_index( 282 | edge_index, edge_feature, sparse_sizes=(graph.num_nodes, graph.num_nodes) 283 | ).t() 284 | adj_t = adj_t.to_symmetric() 285 | 286 | train_loader, val_loader, test_loader = gen_loader( 287 | args, edge_split, x, edge_index, adj_t 288 | ) 289 | dataloaders = {"train": train_loader, "valid": val_loader, "test": test_loader} 290 | 291 | model = gen_model(args, x, edge_feature).to(device) 292 | 293 | predictor = LinkPredictor( 294 | args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout 295 | ).to(device) 296 | 297 | evaluator = Evaluator(name="DBLP") 298 | logger = Logger(args.runs, args) 299 | 300 | for run in range(args.runs): 301 | model.reset_parameters() 302 | predictor.reset_parameters() 303 | optimizer = torch.optim.Adam( 304 | list(model.parameters()) + list(predictor.parameters()), lr=args.lr 305 | ) 306 | 307 | for epoch in range(1, 1 + args.epochs): 308 | loss = train(model, predictor, train_loader, optimizer, device) 309 | wandb.log({"Loss": loss}) 310 | if epoch % args.eval_steps == 0: 311 | result = test( 312 | model, predictor, dataloaders, evaluator, args.neg_len, device 313 | ) 314 | logger.add_result(run, result) 315 | 316 | if epoch % args.log_steps == 0: 317 | train_mrr, valid_mrr, test_mrr = result 318 | print( 319 | f"Run: {run + 1:02d}, " 320 | f"Epoch: {epoch:02d}, " 321 | f"Loss: {loss:.4f}, " 322 | f"Train: {train_mrr:.4f}, " 323 | f"Valid: {valid_mrr:.4f}, " 324 | f"Test: {test_mrr:.4f}" 325 | ) 326 | 327 | logger.print_statistics(run) 328 | 329 | logger.print_statistics(key="mrr") 330 | 331 | 332 | if __name__ == "__main__": 333 | main() 334 | -------------------------------------------------------------------------------- /GNN/GNN_Link_HitsK_loader.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from torch_geometric.data import Data 3 | from torch_geometric.loader import LinkNeighborLoader 4 | import argparse 5 | import pickle 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from torch_sparse import SparseTensor 11 | 12 | from model.GNN_library import GAT, GINE, GeneralGNN, GraphTransformer 13 | from model.Dataloader import Evaluator, split_edge 14 | from model.GNN_arg import Logger 15 | import wandb 16 | import os 17 | 18 | 19 | class LinkPredictor(torch.nn.Module): 20 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 21 | super(LinkPredictor, self).__init__() 22 | 23 | self.lins = torch.nn.ModuleList() 24 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 25 | for _ in range(num_layers - 2): 26 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 27 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels)) 28 | 29 | self.dropout = dropout 30 | 31 | def reset_parameters(self): 32 | for lin in self.lins: 33 | lin.reset_parameters() 34 | 35 | def forward(self, x_i, x_j): 36 | x = x_i * x_j 37 | for lin in self.lins[:-1]: 38 | x = lin(x) 39 | x = F.relu(x) 40 | x = F.dropout(x, p=self.dropout, training=self.training) 41 | x = self.lins[-1](x) 42 | return torch.sigmoid(x) 43 | 44 | 45 | def train(model, predictor, train_loader, optimizer, device): 46 | model.train() 47 | predictor.train() 48 | 49 | total_loss = total_examples = 0 50 | for batch in tqdm.tqdm(train_loader): 51 | optimizer.zero_grad() 52 | batch = batch.to(device) 53 | h = model(batch.x, batch.adj_t) 54 | 55 | src = batch.edge_label_index.t()[:, 0] 56 | dst = batch.edge_label_index.t()[:, 1] 57 | 58 | pos_out = predictor(h[src], h[dst]) 59 | pos_loss = -torch.log(pos_out + 1e-15).mean() 60 | 61 | # Just do some trivial random sampling. 62 | dst_neg = torch.randint( 63 | 0, h.size(0), src.size(), dtype=torch.long, device=h.device 64 | ) 65 | neg_out = predictor(h[src], h[dst_neg]) 66 | neg_loss = -torch.log(1 - neg_out + 1e-15).mean() 67 | 68 | loss = pos_loss + neg_loss 69 | loss.backward() 70 | optimizer.step() 71 | 72 | num_examples = pos_out.size(0) 73 | total_loss += loss.item() * num_examples 74 | total_examples += num_examples 75 | 76 | return total_loss / total_examples 77 | 78 | 79 | @torch.no_grad() 80 | def test(model, predictor, dataloaders, evaluator, device): 81 | model.eval() 82 | predictor.eval() 83 | ( 84 | train_loader, 85 | valid_loader, 86 | valid_negative_loader, 87 | test_loader, 88 | test_negative_loader, 89 | ) = ( 90 | dataloaders["train"], 91 | dataloaders["valid"], 92 | dataloaders["valid_negative"], 93 | dataloaders["test"], 94 | dataloaders["test_negative"], 95 | ) 96 | 97 | def test_split(dataloader, neg_dataloader, device): 98 | pos_preds = [] 99 | for batch in dataloader: 100 | batch = batch.to(device) 101 | h = model(batch.x, batch.adj_t) 102 | src = batch.edge_label_index.t()[:, 0] 103 | dst = batch.edge_label_index.t()[:, 1] 104 | pos_preds += [predictor(h[src], h[dst]).squeeze().cpu()] 105 | pos_pred = torch.cat(pos_preds, dim=0) 106 | 107 | neg_preds = [] 108 | for batch in neg_dataloader: 109 | batch = batch.to(device) 110 | h = model(batch.x, batch.adj_t) 111 | src = batch.edge_label_index.t()[:, 0] 112 | dst = batch.edge_label_index.t()[:, 1] 113 | neg_preds += [predictor(h[src], h[dst]).squeeze().cpu()] 114 | neg_pred = torch.cat(neg_preds, dim=0) 115 | return pos_pred, neg_pred 116 | 117 | pos_train_pred, neg_valid_pred = test_split( 118 | train_loader, valid_negative_loader, device 119 | ) 120 | pos_valid_pred, neg_valid_pred = test_split( 121 | valid_loader, valid_negative_loader, device 122 | ) 123 | pos_test_pred, neg_test_pred = test_split(test_loader, test_negative_loader, device) 124 | 125 | results = {} 126 | for K in [10, 50, 100]: 127 | evaluator.K = K 128 | train_hits = evaluator.eval( 129 | { 130 | "y_pred_pos": pos_train_pred, 131 | "y_pred_neg": neg_valid_pred, 132 | } 133 | )[f"hits@{K}"] 134 | valid_hits = evaluator.eval( 135 | { 136 | "y_pred_pos": pos_valid_pred, 137 | "y_pred_neg": neg_valid_pred, 138 | } 139 | )[f"hits@{K}"] 140 | test_hits = evaluator.eval( 141 | { 142 | "y_pred_pos": pos_test_pred, 143 | "y_pred_neg": neg_test_pred, 144 | } 145 | )[f"hits@{K}"] 146 | 147 | results[f"Hits@{K}"] = (train_hits, valid_hits, test_hits) 148 | 149 | return results 150 | 151 | def gen_model(args, x, edge_feature): 152 | if args.gnn_model == "GAT": 153 | model = GAT( 154 | x.size(1), 155 | edge_feature.size(1), 156 | args.hidden_channels, 157 | args.hidden_channels, 158 | args.num_layers, 159 | args.heads, 160 | args.dropout, 161 | ) 162 | elif args.gnn_model == "GraphTransformer": 163 | model = GraphTransformer( 164 | x.size(1), 165 | edge_feature.size(1), 166 | args.hidden_channels, 167 | args.hidden_channels, 168 | args.num_layers, 169 | args.dropout, 170 | ) 171 | elif args.gnn_model == "GINE": 172 | model = GINE( 173 | x.size(1), 174 | edge_feature.size(1), 175 | args.hidden_channels, 176 | args.hidden_channels, 177 | args.num_layers, 178 | args.dropout, 179 | ) 180 | elif args.gnn_model == "GeneralGNN": 181 | model = GeneralGNN( 182 | x.size(1), 183 | edge_feature.size(1), 184 | args.hidden_channels, 185 | args.hidden_channels, 186 | args.num_layers, 187 | args.dropout, 188 | ) 189 | else: 190 | raise ValueError("Not implemented") 191 | return model 192 | 193 | 194 | def gen_loader(args, edge_split, x, edge_index, adj_t): 195 | train_data = Data(x=x, adj_t=adj_t) 196 | train_loader = LinkNeighborLoader( 197 | data=train_data, 198 | num_neighbors=[20, 10], 199 | edge_label_index=(edge_index), 200 | edge_label=torch.ones(edge_index.shape[1]), 201 | batch_size=args.batch_size, 202 | neg_sampling_ratio=0.0, 203 | shuffle=True, 204 | ) 205 | 206 | val_edge_label_index = edge_split["valid"]["edge"].t() 207 | val_loader = LinkNeighborLoader( 208 | data=train_data, 209 | num_neighbors=[20, 10], 210 | edge_label_index=(val_edge_label_index), 211 | edge_label=torch.ones(val_edge_label_index.shape[1]), 212 | batch_size=args.batch_size, 213 | neg_sampling_ratio=0.0, 214 | shuffle=False, 215 | ) 216 | 217 | val_negative_edge_label_index = edge_split["valid"]["edge_neg"].t() 218 | val_negative_loader = LinkNeighborLoader( 219 | data=train_data, 220 | num_neighbors=[20, 10], 221 | edge_label_index=(val_negative_edge_label_index), 222 | edge_label=torch.zeros(val_negative_edge_label_index.shape[1]), 223 | batch_size=args.batch_size, 224 | neg_sampling_ratio=0.0, 225 | shuffle=False, 226 | ) 227 | 228 | test_edge_label_index = edge_split["test"]["edge"].t() 229 | test_loader = LinkNeighborLoader( 230 | data=train_data, 231 | num_neighbors=[20, 10], 232 | edge_label_index=(test_edge_label_index), 233 | edge_label=torch.ones(test_edge_label_index.shape[1]), 234 | batch_size=args.batch_size, 235 | neg_sampling_ratio=0.0, 236 | shuffle=False, 237 | ) 238 | 239 | test_negative_edge_label_index = edge_split["test"]["edge_neg"].t() 240 | test_negative_loader = LinkNeighborLoader( 241 | data=train_data, 242 | num_neighbors=[20, 10], 243 | edge_label_index=(test_negative_edge_label_index), 244 | edge_label=torch.zeros(test_negative_edge_label_index.shape[1]), 245 | batch_size=args.batch_size, 246 | neg_sampling_ratio=0.0, 247 | shuffle=False, 248 | ) 249 | 250 | return ( 251 | train_loader, 252 | val_loader, 253 | val_negative_loader, 254 | test_loader, 255 | test_negative_loader, 256 | ) 257 | 258 | def main(): 259 | parser = argparse.ArgumentParser(description="Link-Prediction PLM/TCL") 260 | parser.add_argument("--device", type=int, default=0) 261 | parser.add_argument("--log_steps", type=int, default=1) 262 | parser.add_argument("--use_node_embedding", action="store_true") 263 | parser.add_argument("--num_layers", type=int, default=3) 264 | parser.add_argument("--hidden_channels", type=int, default=256) 265 | parser.add_argument("--dropout", type=float, default=0.0) 266 | parser.add_argument("--batch_size", type=int, default=64 * 1024) 267 | parser.add_argument("--lr", type=float, default=0.001) 268 | parser.add_argument("--epochs", type=int, default=10) 269 | parser.add_argument( 270 | "--gnn_model", type=str, help="GNN Model", default="GraphTransformer" 271 | ) 272 | parser.add_argument("--heads", type=int, default=4) 273 | parser.add_argument("--eval_steps", type=int, default=1) 274 | parser.add_argument("--runs", type=int, default=1) 275 | parser.add_argument("--test_ratio", type=float, default=0.08) 276 | parser.add_argument("--val_ratio", type=float, default=0.02) 277 | parser.add_argument("--neg_len", type=str, default="5000") 278 | parser.add_argument( 279 | "--use_PLM_node", 280 | type=str, 281 | default="data/CSTAG/Photo/Feature/children_gpt_node.pt", 282 | help="Use LM embedding as node feature", 283 | ) 284 | parser.add_argument( 285 | "--use_PLM_edge", 286 | type=str, 287 | default="data/CSTAG/Photo/Feature/children_gpt_edge.pt", 288 | help="Use LM embedding as edge feature", 289 | ) 290 | parser.add_argument( 291 | "--path", 292 | type=str, 293 | default="data/CSTAG/Photo/LinkPrediction/", 294 | help="Path to save splitting", 295 | ) 296 | parser.add_argument( 297 | "--graph_path", 298 | type=str, 299 | default="data/CSTAG/Photo/children.pkl", 300 | help="Path to load the graph", 301 | ) 302 | args = parser.parse_args() 303 | wandb.config = args 304 | wandb.init(config=args, reinit=True) 305 | print(args) 306 | 307 | if not os.path.exists(f"{args.path}{args.neg_len}/"): 308 | os.makedirs(f"{args.path}{args.neg_len}/") 309 | 310 | device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu" 311 | device = torch.device(device) 312 | 313 | with open(f"{args.graph_path}", "rb") as file: 314 | graph = pickle.load(file) 315 | 316 | graph.num_nodes = len(graph.text_nodes) 317 | edge_split = split_edge( 318 | graph, 319 | test_ratio=args.test_ratio, 320 | val_ratio=args.val_ratio, 321 | path=args.path, 322 | neg_len=args.neg_len, 323 | ) 324 | 325 | x = torch.load(args.use_PLM_node).float() 326 | edge_feature = torch.load(args.use_PLM_edge)[ 327 | edge_split["train"]["train_edge_feature_index"] 328 | ].float() 329 | 330 | edge_index = edge_split["train"]["edge"].t() 331 | adj_t = SparseTensor.from_edge_index( 332 | edge_index, edge_feature, sparse_sizes=(graph.num_nodes, graph.num_nodes) 333 | ).t() 334 | adj_t = adj_t.to_symmetric() 335 | 336 | train_loader, val_loader, val_negative_loader, test_loader, test_negative_loader = ( 337 | gen_loader(args, edge_split, x, edge_index, adj_t) 338 | ) 339 | dataloaders = { 340 | "train": train_loader, 341 | "valid": val_loader, 342 | "valid_negative": val_negative_loader, 343 | "test": test_loader, 344 | "test_negative": test_negative_loader, 345 | } 346 | 347 | model = gen_model(args, x, edge_feature).to(device) 348 | 349 | predictor = LinkPredictor( 350 | args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout 351 | ).to(device) 352 | 353 | evaluator = Evaluator(name="History") 354 | loggers = { 355 | "Hits@10": Logger(args.runs, args), 356 | "Hits@50": Logger(args.runs, args), 357 | "Hits@100": Logger(args.runs, args), 358 | } 359 | 360 | for run in range(args.runs): 361 | model.reset_parameters() 362 | predictor.reset_parameters() 363 | optimizer = torch.optim.Adam( 364 | list(model.parameters()) + list(predictor.parameters()), lr=args.lr 365 | ) 366 | 367 | for epoch in range(1, 1 + args.epochs): 368 | loss = train(model, predictor, train_loader, optimizer, device) 369 | 370 | if epoch % args.eval_steps == 0: 371 | results = test( 372 | model, predictor, dataloaders, evaluator, device 373 | ) 374 | for key, result in results.items(): 375 | loggers[key].add_result(run, result) 376 | 377 | if epoch % args.log_steps == 0: 378 | for key, result in results.items(): 379 | train_hits, valid_hits, test_hits = result 380 | print(key) 381 | print( 382 | f"Run: {run + 1:02d}, " 383 | f"Epoch: {epoch:02d}, " 384 | f"Loss: {loss:.4f}, " 385 | f"Train: {100 * train_hits:.2f}%, " 386 | f"Valid: {100 * valid_hits:.2f}%, " 387 | f"Test: {100 * test_hits:.2f}%" 388 | ) 389 | print("---") 390 | 391 | for key in loggers.keys(): 392 | print(key) 393 | loggers[key].print_statistics(run) 394 | 395 | for key in loggers.keys(): 396 | print(key) 397 | loggers[key].print_statistics(key=key) 398 | 399 | 400 | 401 | if __name__ == "__main__": 402 | main() 403 | -------------------------------------------------------------------------------- /GNN/model/Dataloader.py: -------------------------------------------------------------------------------- 1 | from ogb.nodeproppred import DglNodePropPredDataset 2 | import dgl 3 | import numpy as np 4 | import torch as th 5 | from sklearn.model_selection import train_test_split 6 | import scipy.sparse as sp 7 | import os 8 | import random 9 | import pickle 10 | 11 | 12 | def split_graph(nodes_num, train_ratio, val_ratio): 13 | np.random.seed(42) 14 | indices = np.random.permutation(nodes_num) 15 | train_size = int(nodes_num * train_ratio) 16 | val_size = int(nodes_num * val_ratio) 17 | 18 | train_ids = indices[:train_size] 19 | val_ids = indices[train_size : train_size + val_size] 20 | test_ids = indices[train_size + val_size :] 21 | 22 | return train_ids, val_ids, test_ids 23 | 24 | 25 | # def split_time(g, train_year=2016, val_year=2017): 26 | # year = list(np.array(g.ndata['year'])) 27 | # indices = np.arange(g.num_nodes()) 28 | # # 1999-2014 train 29 | # train_ids = indices[:year.index(train_year)] 30 | # val_ids = indices[year.index(train_year):year.index(val_year)] 31 | # test_ids = indices[year.index(val_year):] 32 | # 33 | # return train_ids, val_ids, test_ids 34 | 35 | 36 | def split_time(g, train_year=2016, val_year=2017): 37 | np.random.seed(42) 38 | year = list(np.array(g.ndata["year"])) 39 | indices = np.arange(g.num_nodes()) 40 | # 1999-2014 train 41 | # Filter out nodes with label -1 42 | valid_indices = [i for i in indices if g.ndata["label"][i] != -1] 43 | 44 | # Filter out valid indices based on years 45 | train_ids = [i for i in valid_indices if year[i] < train_year] 46 | val_ids = [i for i in valid_indices if year[i] >= train_year and year[i] < val_year] 47 | test_ids = [i for i in valid_indices if year[i] >= val_year] 48 | 49 | train_length = len(train_ids) 50 | val_length = len(val_ids) 51 | test_length = len(test_ids) 52 | 53 | print("Train set length:", train_length) 54 | print("Validation set length:", val_length) 55 | print("Test set length:", test_length) 56 | 57 | return train_ids, val_ids, test_ids 58 | 59 | 60 | def load_data(name, train_ratio=0.6, val_ratio=0.2): 61 | if name == "ogbn-arxiv": 62 | data = DglNodePropPredDataset(name=name) 63 | splitted_idx = data.get_idx_split() 64 | train_idx, val_idx, test_idx = ( 65 | splitted_idx["train"], 66 | splitted_idx["valid"], 67 | splitted_idx["test"], 68 | ) 69 | graph, labels = data[0] 70 | labels = labels[:, 0] 71 | elif name == "Children": 72 | graph = dgl.load_graphs("data/CSTAG/Children/Children.pt")[0][0] 73 | labels = graph.ndata["label"] 74 | train_idx, val_idx, test_idx = split_graph(graph.num_nodes(), 0.6, 0.2) 75 | train_idx = th.tensor(train_idx) 76 | val_idx = th.tensor(val_idx) 77 | test_idx = th.tensor(test_idx) 78 | elif name == "History": 79 | graph = dgl.load_graphs("data/CSTAG/History/History.pt")[0][0] 80 | labels = graph.ndata["label"] 81 | train_idx, val_idx, test_idx = split_graph(graph.num_nodes(), 0.6, 0.2) 82 | train_idx = th.tensor(train_idx) 83 | val_idx = th.tensor(val_idx) 84 | test_idx = th.tensor(test_idx) 85 | elif name == "Fitness": 86 | graph = dgl.load_graphs("data/CSTAG/Fitness/Fitness.pt")[0][0] 87 | labels = graph.ndata["label"] 88 | train_idx, val_idx, test_idx = split_graph( 89 | graph.num_nodes(), train_ratio, val_ratio 90 | ) 91 | train_idx = th.tensor(train_idx) 92 | val_idx = th.tensor(val_idx) 93 | test_idx = th.tensor(test_idx) 94 | elif name == "Photo": 95 | graph = dgl.load_graphs("data/CSTAG/Photo/Photo.pt")[0][0] 96 | labels = graph.ndata["label"] 97 | train_idx, val_idx, test_idx = split_time(graph, 2015, 2016) 98 | train_idx = th.tensor(train_idx) 99 | val_idx = th.tensor(val_idx) 100 | test_idx = th.tensor(test_idx) 101 | elif name == "Computers": 102 | graph = dgl.load_graphs("data/CSTAG/Computers/Computers.pt")[0][0] 103 | labels = graph.ndata["label"] 104 | train_idx, val_idx, test_idx = split_time(graph, 2017, 2018) 105 | train_idx = th.tensor(train_idx) 106 | val_idx = th.tensor(val_idx) 107 | test_idx = th.tensor(test_idx) 108 | elif name == "webkb-cornell": 109 | graph = dgl.load_graphs("data/webkb/Cornell/Cornell.pt")[0][0] 110 | labels = graph.ndata["label"] 111 | train_idx, val_idx, test_idx = split_graph( 112 | graph.num_nodes(), train_ratio, val_ratio 113 | ) 114 | train_idx = th.tensor(train_idx) 115 | val_idx = th.tensor(val_idx) 116 | test_idx = th.tensor(test_idx) 117 | elif name == "webkb-texas": 118 | graph = dgl.load_graphs("data/webkb/Texas/Texas.pt")[0][0] 119 | labels = graph.ndata["label"] 120 | train_idx, val_idx, test_idx = split_graph( 121 | graph.num_nodes(), train_ratio, val_ratio 122 | ) 123 | train_idx = th.tensor(train_idx) 124 | val_idx = th.tensor(val_idx) 125 | test_idx = th.tensor(test_idx) 126 | elif name == "webkb-washington": 127 | graph = dgl.load_graphs("data/webkb/Washington/Washington.pt")[0][0] 128 | labels = graph.ndata["label"] 129 | train_idx, val_idx, test_idx = split_graph( 130 | graph.num_nodes(), train_ratio, val_ratio 131 | ) 132 | train_idx = th.tensor(train_idx) 133 | val_idx = th.tensor(val_idx) 134 | test_idx = th.tensor(test_idx) 135 | elif name == "webkb-wisconsin": 136 | graph = dgl.load_graphs("data/webkb/Wisconsin/Wisconsin.pt")[0][0] 137 | labels = graph.ndata["label"] 138 | train_idx, val_idx, test_idx = split_graph( 139 | graph.num_nodes(), train_ratio, val_ratio 140 | ) 141 | train_idx = th.tensor(train_idx) 142 | val_idx = th.tensor(val_idx) 143 | test_idx = th.tensor(test_idx) 144 | else: 145 | raise ValueError("Not implemetned") 146 | return graph, labels, train_idx, val_idx, test_idx 147 | 148 | 149 | def from_dgl(g): 150 | import dgl 151 | 152 | from torch_geometric.data import Data, HeteroData 153 | 154 | if not isinstance(g, dgl.DGLGraph): 155 | raise ValueError(f"Invalid data type (got '{type(g)}')") 156 | 157 | if g.is_homogeneous: 158 | data = Data() 159 | data.edge_index = th.stack(g.edges(), dim=0) 160 | 161 | for attr, value in g.ndata.items(): 162 | data[attr] = value 163 | for attr, value in g.edata.items(): 164 | data[attr] = value 165 | 166 | return data 167 | 168 | data = HeteroData() # TODO: Delete this 169 | 170 | for node_type in g.ntypes: 171 | for attr, value in g.nodes[node_type].data.items(): 172 | data[node_type][attr] = value 173 | 174 | for edge_type in g.canonical_etypes: 175 | row, col = g.edges(form="uv", etype=edge_type) 176 | data[edge_type].edge_index = th.stack([row, col], dim=0) 177 | for attr, value in g.edge_attr_schemes(edge_type).items(): 178 | data[edge_type][attr] = value 179 | 180 | return data 181 | 182 | 183 | def split_edge( 184 | dgl_graph, 185 | test_ratio=0.2, 186 | val_ratio=0.1, 187 | random_seed=42, 188 | neg_len="1000", 189 | path=None, 190 | way="random", 191 | ): 192 | if os.path.exists(os.path.join(path, f"{neg_len}/edge_split_hitsk.pt")): 193 | edge_split = th.load(os.path.join(path, f"{neg_len}/edge_split_hitsk.pt")) 194 | 195 | else: 196 | np.random.seed(random_seed) 197 | th.manual_seed(random_seed) 198 | 199 | graph = dgl_graph 200 | 201 | eids = np.arange(graph.num_edges) 202 | eids = np.random.permutation(eids) 203 | 204 | u, v = graph.edge_index 205 | 206 | test_size = int(len(eids) * test_ratio) 207 | val_size = int(len(eids) * val_ratio) 208 | 209 | test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]] 210 | val_pos_u, val_pos_v = ( 211 | u[eids[test_size : test_size + val_size]], 212 | v[eids[test_size : test_size + val_size]], 213 | ) 214 | train_pos_u, train_pos_v = ( 215 | u[eids[test_size + val_size :]], 216 | v[eids[test_size + val_size :]], 217 | ) 218 | 219 | train_edge_index = th.stack((train_pos_u, train_pos_v), dim=1) 220 | val_edge_index = th.stack((val_pos_u, val_pos_v), dim=1) 221 | test_edge_index = th.stack((test_pos_u, test_pos_v), dim=1) 222 | 223 | valid_neg_edge_index = th.randint( 224 | 0, graph.num_nodes, [int(neg_len), 2], dtype=th.long 225 | ) # TODO: filter some edges in graph 226 | test_neg_edge_index = th.randint( 227 | 0, graph.num_nodes, [int(neg_len), 2], dtype=th.long 228 | ) 229 | 230 | edge_split = { 231 | "train": { 232 | "edge": train_edge_index, 233 | "train_edge_feature_index": eids[test_size + val_size :], 234 | }, 235 | "valid": {"edge": val_edge_index, "edge_neg": valid_neg_edge_index}, 236 | "test": {"edge": test_edge_index, "edge_neg": test_neg_edge_index}, 237 | } 238 | 239 | th.save(edge_split, os.path.join(path, f"{neg_len}/edge_split_hitsk.pt")) 240 | 241 | return edge_split 242 | 243 | 244 | def split_edge_mrr( 245 | dgl_graph, 246 | test_ratio=0.2, 247 | val_ratio=0.1, 248 | random_seed=42, 249 | neg_len="1000", 250 | path=None, 251 | way="random", 252 | ): 253 | if os.path.exists(os.path.join(path, f"{neg_len}/edge_split_mrr.pt")): 254 | edge_split = th.load(os.path.join(path, f"{neg_len}/edge_split_mrr.pt")) 255 | 256 | else: 257 | np.random.seed(random_seed) 258 | th.manual_seed(random_seed) 259 | 260 | graph = dgl_graph 261 | 262 | eids = np.arange(graph.num_edges) 263 | eids = np.random.permutation(eids) 264 | 265 | u, v = graph.edge_index 266 | 267 | test_size = int(len(eids) * test_ratio) 268 | val_size = int(len(eids) * val_ratio) 269 | 270 | test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]] 271 | val_pos_u, val_pos_v = ( 272 | u[eids[test_size : test_size + val_size]], 273 | v[eids[test_size : test_size + val_size]], 274 | ) 275 | train_pos_u, train_pos_v = ( 276 | u[eids[test_size + val_size :]], 277 | v[eids[test_size + val_size :]], 278 | ) 279 | 280 | valid_target_neg = th.randint( 281 | 0, graph.num_nodes, [len(val_pos_u), int(neg_len)], dtype=th.long 282 | ) # TODO: filter some edges in graph 283 | test_target_neg = th.randint( 284 | 0, graph.num_nodes, [len(test_pos_u), int(neg_len)], dtype=th.long 285 | ) # TODO: filter some edges in graph 286 | 287 | edge_split = { 288 | "train": { 289 | "source_node": train_pos_u, 290 | "target_node": train_pos_v, 291 | "edge": th.stack((train_pos_u, train_pos_v), dim=1), 292 | "train_edge_feature_index": eids[test_size + val_size :], 293 | }, 294 | "valid": { 295 | "source_node": val_pos_u, 296 | "target_node": val_pos_v, 297 | "target_node_neg": valid_target_neg, 298 | "edge": th.stack((val_pos_u, val_pos_v), dim=1), 299 | }, 300 | "test": { 301 | "source_node": test_pos_u, 302 | "target_node": test_pos_v, 303 | "target_node_neg": test_target_neg, 304 | "edge": th.stack((test_pos_u, test_pos_v), dim=1), 305 | }, 306 | } 307 | 308 | th.save(edge_split, os.path.join(path, f"{neg_len}/edge_split_mrr.pt")) 309 | 310 | return edge_split 311 | 312 | 313 | # TODO: Consider time in split 314 | # def split_edge_MMR(dgl_graph, time=2015, random_seed=42, neg_len=1000, path=None): 315 | # if os.path.exists(os.path.join(path, 'edge_split_mrr.pt')): 316 | # edge_split = th.load(os.path.join(path, 'edge_split_mrr.pt')) 317 | # else: 318 | 319 | # np.random.seed(random_seed) 320 | # th.manual_seed(random_seed) 321 | 322 | # graph = from_dgl(dgl_graph) 323 | 324 | # year = list(np.array(graph.year)) 325 | # indices = np.arange(graph.num_nodes) 326 | # all_source = graph.edge_index[0] 327 | # all_target = graph.edge_index[1] 328 | # import random 329 | # val_list = [] 330 | # test_list = [] 331 | # for i in range(year.index(time), graph.num_nodes): 332 | # indices = th.where(all_source == i)[0] 333 | # if len(indices) >= 2: 334 | # _list = random.sample(list(indices), k=2) 335 | # val_list.append(_list[0]) # TOOD: check edge 336 | # test_list.append(_list[1]) 337 | 338 | # val_source = all_source[val_list] 339 | # val_target = all_target[val_list] 340 | 341 | # test_source = all_source[test_list] 342 | # test_target = all_target[test_list] 343 | 344 | # all_index = list(range(0, len(graph.edge_index[0]))) 345 | 346 | # val_list = np.array(val_list).tolist() 347 | # test_list = np.array(test_list).tolist() 348 | 349 | # train_idx = set(all_index) - set(val_list) - set(test_list) 350 | 351 | # tra_source = all_source[list(train_idx)] 352 | # tra_target = all_target[list(train_idx)] 353 | 354 | # val_target_neg = th.randint(low=0, high=year.index(time), size=(len(val_source), neg_len)) 355 | # test_target_neg = th.randint(low=0, high=year.index(time), size=(len(test_source), neg_len)) 356 | 357 | # # ! 创建dict类型存法 358 | # edge_split = {'train': {'source_node': tra_source, 'target_node': tra_target}, 359 | # 'valid': {'source_node': val_source, 'target_node': val_target, 360 | # 'target_node_neg': val_target_neg}, 361 | # 'test': {'source_node': test_source, 'target_node': test_target, 362 | # 'target_node_neg': test_target_neg}} 363 | 364 | # th.save(edge_split, os.path.join(path, 'edge_split_mrr.pt')) 365 | # # ! 保存子图 366 | # train_g = dgl.remove_edges(dgl_graph, val_list + test_list) 367 | # dgl.save_graphs(os.path.join(path, 'train_G.pt'), train_g) 368 | 369 | # return edge_split, train_g 370 | 371 | 372 | class Evaluator: 373 | def __init__(self, name): 374 | self.name = name 375 | meta_info = { 376 | "History": {"name": "History", "eval_metric": "hits@50"}, 377 | "DBLP": {"name": "DBLP", "eval_metric": "mrr"}, 378 | } 379 | 380 | self.eval_metric = meta_info[self.name]["eval_metric"] 381 | 382 | if "hits@" in self.eval_metric: 383 | ### Hits@K 384 | 385 | self.K = int(self.eval_metric.split("@")[1]) 386 | 387 | def _parse_and_check_input(self, input_dict): 388 | if "hits@" in self.eval_metric: 389 | if not "y_pred_pos" in input_dict: 390 | raise RuntimeError("Missing key of y_pred_pos") 391 | if not "y_pred_neg" in input_dict: 392 | raise RuntimeError("Missing key of y_pred_neg") 393 | 394 | y_pred_pos, y_pred_neg = input_dict["y_pred_pos"], input_dict["y_pred_neg"] 395 | 396 | """ 397 | y_pred_pos: numpy ndarray or torch tensor of shape (num_edge, ) 398 | y_pred_neg: numpy ndarray or torch tensor of shape (num_edge, ) 399 | """ 400 | 401 | # convert y_pred_pos, y_pred_neg into either torch tensor or both numpy array 402 | # type_info stores information whether torch or numpy is used 403 | 404 | type_info = None 405 | 406 | # check the raw tyep of y_pred_pos 407 | if not ( 408 | isinstance(y_pred_pos, np.ndarray) 409 | or (th is not None and isinstance(y_pred_pos, th.Tensor)) 410 | ): 411 | raise ValueError( 412 | "y_pred_pos needs to be either numpy ndarray or th tensor" 413 | ) 414 | 415 | # check the raw type of y_pred_neg 416 | if not ( 417 | isinstance(y_pred_neg, np.ndarray) 418 | or (th is not None and isinstance(y_pred_neg, th.Tensor)) 419 | ): 420 | raise ValueError( 421 | "y_pred_neg needs to be either numpy ndarray or th tensor" 422 | ) 423 | 424 | # if either y_pred_pos or y_pred_neg is th tensor, use th tensor 425 | if th is not None and ( 426 | isinstance(y_pred_pos, th.Tensor) or isinstance(y_pred_neg, th.Tensor) 427 | ): 428 | # converting to th.Tensor to numpy on cpu 429 | if isinstance(y_pred_pos, np.ndarray): 430 | y_pred_pos = th.from_numpy(y_pred_pos) 431 | 432 | if isinstance(y_pred_neg, np.ndarray): 433 | y_pred_neg = th.from_numpy(y_pred_neg) 434 | 435 | # put both y_pred_pos and y_pred_neg on the same device 436 | y_pred_pos = y_pred_pos.to(y_pred_neg.device) 437 | 438 | type_info = "torch" 439 | 440 | else: 441 | # both y_pred_pos and y_pred_neg are numpy ndarray 442 | 443 | type_info = "numpy" 444 | 445 | if not y_pred_pos.ndim == 1: 446 | raise RuntimeError( 447 | "y_pred_pos must to 1-dim arrray, {}-dim array given".format( 448 | y_pred_pos.ndim 449 | ) 450 | ) 451 | 452 | if not y_pred_neg.ndim == 1: 453 | raise RuntimeError( 454 | "y_pred_neg must to 1-dim arrray, {}-dim array given".format( 455 | y_pred_neg.ndim 456 | ) 457 | ) 458 | 459 | return y_pred_pos, y_pred_neg, type_info 460 | 461 | elif "mrr" == self.eval_metric: 462 | 463 | if not "y_pred_pos" in input_dict: 464 | raise RuntimeError("Missing key of y_pred_pos") 465 | if not "y_pred_neg" in input_dict: 466 | raise RuntimeError("Missing key of y_pred_neg") 467 | 468 | y_pred_pos, y_pred_neg = input_dict["y_pred_pos"], input_dict["y_pred_neg"] 469 | 470 | """ 471 | y_pred_pos: numpy ndarray or torch tensor of shape (num_edge, ) 472 | y_pred_neg: numpy ndarray or torch tensor of shape (num_edge, num_node_negative) 473 | """ 474 | 475 | # convert y_pred_pos, y_pred_neg into either torch tensor or both numpy array 476 | # type_info stores information whether torch or numpy is used 477 | 478 | type_info = None 479 | 480 | # check the raw tyep of y_pred_pos 481 | if not ( 482 | isinstance(y_pred_pos, np.ndarray) 483 | or (th is not None and isinstance(y_pred_pos, th.Tensor)) 484 | ): 485 | raise ValueError( 486 | "y_pred_pos needs to be either numpy ndarray or th tensor" 487 | ) 488 | 489 | # check the raw type of y_pred_neg 490 | if not ( 491 | isinstance(y_pred_neg, np.ndarray) 492 | or (th is not None and isinstance(y_pred_neg, th.Tensor)) 493 | ): 494 | raise ValueError( 495 | "y_pred_neg needs to be either numpy ndarray or th tensor" 496 | ) 497 | 498 | # if either y_pred_pos or y_pred_neg is th tensor, use th tensor 499 | if th is not None and ( 500 | isinstance(y_pred_pos, th.Tensor) or isinstance(y_pred_neg, th.Tensor) 501 | ): 502 | # converting to th.Tensor to numpy on cpu 503 | if isinstance(y_pred_pos, np.ndarray): 504 | y_pred_pos = th.from_numpy(y_pred_pos) 505 | 506 | if isinstance(y_pred_neg, np.ndarray): 507 | y_pred_neg = th.from_numpy(y_pred_neg) 508 | 509 | # put both y_pred_pos and y_pred_neg on the same device 510 | y_pred_pos = y_pred_pos.to(y_pred_neg.device) 511 | 512 | type_info = "torch" 513 | 514 | else: 515 | # both y_pred_pos and y_pred_neg are numpy ndarray 516 | 517 | type_info = "numpy" 518 | 519 | if not y_pred_pos.ndim == 1: 520 | raise RuntimeError( 521 | "y_pred_pos must to 1-dim arrray, {}-dim array given".format( 522 | y_pred_pos.ndim 523 | ) 524 | ) 525 | 526 | if not y_pred_neg.ndim == 2: 527 | raise RuntimeError( 528 | "y_pred_neg must to 2-dim arrray, {}-dim array given".format( 529 | y_pred_neg.ndim 530 | ) 531 | ) 532 | 533 | return y_pred_pos, y_pred_neg, type_info 534 | 535 | else: 536 | raise ValueError("Undefined eval metric %s" % (self.eval_metric)) 537 | 538 | def eval(self, input_dict): 539 | 540 | if "hits@" in self.eval_metric: 541 | y_pred_pos, y_pred_neg, type_info = self._parse_and_check_input(input_dict) 542 | return self._eval_hits(y_pred_pos, y_pred_neg, type_info) 543 | elif self.eval_metric == "mrr": 544 | y_pred_pos, y_pred_neg, type_info = self._parse_and_check_input(input_dict) 545 | return self._eval_mrr(y_pred_pos, y_pred_neg, type_info) 546 | 547 | else: 548 | raise ValueError("Undefined eval metric %s" % (self.eval_metric)) 549 | 550 | def _eval_hits(self, y_pred_pos, y_pred_neg, type_info): 551 | """ 552 | compute Hits@K 553 | For each positive target node, the negative target nodes are the same. 554 | 555 | y_pred_neg is an array. 556 | rank y_pred_pos[i] against y_pred_neg for each i 557 | """ 558 | 559 | if len(y_pred_neg) < self.K: 560 | return {"hits@{}".format(self.K): 1.0} 561 | 562 | if type_info == "torch": 563 | kth_score_in_negative_edges = th.topk(y_pred_neg, self.K)[0][-1] 564 | hitsK = float(th.sum(y_pred_pos > kth_score_in_negative_edges).cpu()) / len( 565 | y_pred_pos 566 | ) 567 | 568 | # type_info is numpy 569 | else: 570 | kth_score_in_negative_edges = np.sort(y_pred_neg)[-self.K] 571 | hitsK = float(np.sum(y_pred_pos > kth_score_in_negative_edges)) / len( 572 | y_pred_pos 573 | ) 574 | 575 | return {"hits@{}".format(self.K): hitsK} 576 | 577 | def _eval_mrr(self, y_pred_pos, y_pred_neg, type_info): 578 | """ 579 | compute mrr 580 | y_pred_neg is an array with shape (batch size, num_entities_neg). 581 | y_pred_pos is an array with shape (batch size, ) 582 | """ 583 | 584 | if type_info == "torch": 585 | y_pred = th.cat([y_pred_pos.view(-1, 1), y_pred_neg], dim=1) 586 | argsort = th.argsort(y_pred, dim=1, descending=True) 587 | ranking_list = th.nonzero(argsort == 0, as_tuple=False) 588 | ranking_list = ranking_list[:, 1] + 1 589 | # hits1_list = (ranking_list <= 1).to(th.float) 590 | # hits3_list = (ranking_list <= 3).to(th.float) 591 | # hits10_list = (ranking_list <= 10).to(th.float) 592 | mrr_list = 1.0 / ranking_list.to(th.float) 593 | 594 | return { 595 | # 'hits@1_list': hits1_list, 596 | # 'hits@3_list': hits3_list, 597 | # 'hits@10_list': hits10_list, 598 | "mrr_list": mrr_list 599 | } 600 | 601 | else: 602 | y_pred = np.concatenate([y_pred_pos.reshape(-1, 1), y_pred_neg], axis=1) 603 | argsort = np.argsort(-y_pred, axis=1) 604 | ranking_list = (argsort == 0).nonzero() 605 | ranking_list = ranking_list[1] + 1 606 | # hits1_list = (ranking_list <= 1).astype(np.float32) 607 | # hits3_list = (ranking_list <= 3).astype(np.float32) 608 | # hits10_list = (ranking_list <= 10).astype(np.float32) 609 | mrr_list = 1.0 / ranking_list.astype(np.float32) 610 | 611 | return { 612 | # 'hits@1_list': hits1_list, 613 | # 'hits@3_list': hits3_list, 614 | # 'hits@10_list': hits10_list, 615 | "mrr_list": mrr_list 616 | } 617 | -------------------------------------------------------------------------------- /GNN/model/GNN_library.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import ( 3 | GCNConv, 4 | SAGEConv, 5 | TransformerConv, 6 | GATConv, 7 | GINEConv, 8 | GeneralConv, 9 | ) 10 | from torch.nn import Linear 11 | import torch.nn.functional as F 12 | import torch 13 | import torch.nn as nn 14 | 15 | import dgl.nn.pytorch as dglnn 16 | from dgl import function as fn 17 | from dgl.ops import edge_softmax 18 | from dgl.utils import expand_as_pair 19 | import torch.nn.functional as F 20 | from dgl.sampling import node2vec_random_walk 21 | from torch.utils.data import DataLoader 22 | from typing import Callable, Optional, Union 23 | 24 | import torch 25 | from torch import Tensor 26 | from torch.nn import Linear 27 | from torch_geometric.nn.conv import MessagePassing 28 | from torch_geometric.nn.dense.linear import Linear 29 | from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size 30 | 31 | 32 | class ElementWiseLinear(nn.Module): 33 | def __init__(self, size, weight=True, bias=True, inplace=False): 34 | super().__init__() 35 | if weight: 36 | self.weight = nn.Parameter(torch.Tensor(size)) 37 | else: 38 | self.weight = None 39 | if bias: 40 | self.bias = nn.Parameter(torch.Tensor(size)) 41 | else: 42 | self.bias = None 43 | self.inplace = inplace 44 | 45 | self.reset_parameters() 46 | 47 | def reset_parameters(self): 48 | if self.weight is not None: 49 | nn.init.ones_(self.weight) 50 | if self.bias is not None: 51 | nn.init.zeros_(self.bias) 52 | 53 | def forward(self, x): 54 | if self.inplace: 55 | if self.weight is not None: 56 | x.mul_(self.weight) 57 | if self.bias is not None: 58 | x.add_(self.bias) 59 | else: 60 | if self.weight is not None: 61 | x = x * self.weight 62 | if self.bias is not None: 63 | x = x + self.bias 64 | return x 65 | 66 | 67 | class APPNP(nn.Module): 68 | def __init__( 69 | self, 70 | in_feats, 71 | n_layers, 72 | n_hidden, 73 | n_classes, 74 | activation, 75 | input_drop, 76 | edge_drop, 77 | alpha, 78 | k, 79 | ): 80 | super(APPNP, self).__init__() 81 | self.n_layers = n_layers 82 | self.n_hidden = n_hidden 83 | self.layers = nn.ModuleList() 84 | # input layer 85 | for i in range(n_layers): 86 | in_hidden = n_hidden if i > 0 else in_feats 87 | out_hidden = n_hidden if i < n_layers - 1 else n_classes 88 | self.layers.append(nn.Linear(in_hidden, out_hidden)) 89 | 90 | self.activation = activation 91 | 92 | self.input_drop = nn.Dropout(input_drop) 93 | 94 | self.propagate = dglnn.APPNPConv(k, alpha, edge_drop) 95 | self.reset_parameters() 96 | 97 | def reset_parameters(self): 98 | for layer in self.layers: 99 | layer.reset_parameters() 100 | 101 | def forward(self, graph, features): 102 | # prediction step 103 | h = features 104 | h = self.input_drop(h) 105 | h = self.activation(self.layers[0](h)) 106 | for layer in self.layers[1:-1]: 107 | h = self.activation(layer(h)) 108 | h = self.layers[-1](self.input_drop(h)) 109 | # propagation step 110 | h = self.propagate(graph, h) 111 | return h 112 | 113 | 114 | class GraphSAGE(nn.Module): 115 | def __init__( 116 | self, 117 | in_feats, 118 | n_hidden, 119 | n_classes, 120 | n_layers, 121 | activation, 122 | dropout, 123 | aggregator_type, 124 | input_drop=0.0, 125 | ): 126 | super(GraphSAGE, self).__init__() 127 | self.n_layers = n_layers 128 | self.n_hidden = n_hidden 129 | self.n_classes = n_classes 130 | 131 | self.layers = nn.ModuleList() 132 | self.norms = nn.ModuleList() 133 | 134 | # input layer 135 | for i in range(n_layers): 136 | in_hidden = n_hidden if i > 0 else in_feats 137 | out_hidden = n_hidden if i < n_layers - 1 else n_classes 138 | 139 | self.layers.append(dglnn.SAGEConv(in_hidden, out_hidden, aggregator_type)) 140 | if i < n_layers - 1: 141 | self.norms.append(nn.BatchNorm1d(out_hidden)) 142 | 143 | self.input_drop = nn.Dropout(input_drop) 144 | self.dropout = nn.Dropout(dropout) 145 | self.activation = activation 146 | 147 | def forward(self, graph, feat): 148 | h = feat 149 | h = self.input_drop(h) 150 | 151 | for l, layer in enumerate(self.layers): 152 | conv = layer(graph, h) 153 | h = conv 154 | 155 | if l != len(self.layers) - 1: 156 | h = self.norms[l](h) 157 | h = self.activation(h) 158 | h = self.dropout(h) 159 | 160 | return h 161 | 162 | 163 | class GCN(torch.nn.Module): 164 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 165 | super(GCN, self).__init__() 166 | 167 | self.convs = torch.nn.ModuleList() 168 | self.convs.append(GCNConv(in_channels, hidden_channels, cached=True)) 169 | for _ in range(num_layers - 2): 170 | self.convs.append(GCNConv(hidden_channels, hidden_channels, cached=True)) 171 | self.convs.append(GCNConv(hidden_channels, out_channels, cached=True)) 172 | 173 | self.dropout = dropout 174 | 175 | def reset_parameters(self): 176 | for conv in self.convs: 177 | conv.reset_parameters() 178 | 179 | def forward(self, x, adj_t): 180 | for conv in self.convs[:-1]: 181 | x = conv(x, adj_t) 182 | x = F.relu(x) 183 | x = F.dropout(x, p=self.dropout, training=self.training) 184 | x = self.convs[-1](x, adj_t) 185 | return x 186 | 187 | 188 | class GCN(torch.nn.Module): 189 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 190 | super(GCN, self).__init__() 191 | 192 | self.convs = torch.nn.ModuleList() 193 | self.convs.append(GCNConv(in_channels, hidden_channels, cached=True)) 194 | for _ in range(num_layers - 2): 195 | self.convs.append(GCNConv(hidden_channels, hidden_channels, cached=True)) 196 | self.convs.append(GCNConv(hidden_channels, out_channels, cached=True)) 197 | 198 | self.dropout = dropout 199 | 200 | def reset_parameters(self): 201 | for conv in self.convs: 202 | conv.reset_parameters() 203 | 204 | def forward(self, x, adj_t): 205 | for conv in self.convs[:-1]: 206 | x = conv(x, adj_t) 207 | x = F.relu(x) 208 | x = F.dropout(x, p=self.dropout, training=self.training) 209 | x = self.convs[-1](x, adj_t) 210 | return x 211 | 212 | 213 | class GAT(nn.Module): 214 | def __init__( 215 | self, 216 | in_feats, 217 | n_classes, 218 | n_hidden, 219 | n_layers, 220 | n_heads, 221 | activation, 222 | dropout=0.0, 223 | input_drop=0.0, 224 | attn_drop=0.0, 225 | edge_drop=0.0, 226 | use_attn_dst=True, 227 | use_symmetric_norm=False, 228 | residual=False, 229 | ): 230 | super().__init__() 231 | self.in_feats = in_feats 232 | self.n_hidden = n_hidden 233 | self.n_classes = n_classes 234 | self.n_layers = n_layers 235 | self.num_heads = n_heads 236 | 237 | self.convs = nn.ModuleList() 238 | self.norms = nn.ModuleList() 239 | 240 | for i in range(n_layers): 241 | in_hidden = n_heads * n_hidden if i > 0 else in_feats 242 | out_hidden = n_hidden if i < n_layers - 1 else n_classes 243 | num_heads = n_heads if i < n_layers - 1 else 1 244 | out_channels = n_heads 245 | 246 | self.convs.append( 247 | GATConv( 248 | in_hidden, 249 | out_hidden, 250 | num_heads=num_heads, 251 | attn_drop=attn_drop, 252 | edge_drop=edge_drop, 253 | use_attn_dst=use_attn_dst, 254 | use_symmetric_norm=use_symmetric_norm, 255 | residual=residual, 256 | ) 257 | ) 258 | 259 | if i < n_layers - 1: 260 | self.norms.append(nn.BatchNorm1d(out_channels * out_hidden)) 261 | 262 | self.bias_last = ElementWiseLinear( 263 | n_classes, weight=False, bias=True, inplace=True 264 | ) 265 | 266 | self.input_drop = nn.Dropout(input_drop) 267 | self.dropout = nn.Dropout(dropout) 268 | self.activation = activation 269 | 270 | def forward(self, graph, feat): 271 | h = feat 272 | h = self.input_drop(h) 273 | 274 | for i in range(self.n_layers): 275 | conv = self.convs[i](graph, h) 276 | 277 | h = conv 278 | 279 | if i < self.n_layers - 1: 280 | h = h.flatten(1) 281 | h = self.norms[i](h) 282 | h = self.activation(h, inplace=True) 283 | h = self.dropout(h) 284 | 285 | h = h.mean(1) 286 | h = self.bias_last(h) 287 | 288 | return h 289 | 290 | 291 | class ApplyNodeFunc(nn.Module): 292 | """Update the node feature hv with MLP, BN and ReLU.""" 293 | 294 | def __init__(self, mlp): 295 | super(ApplyNodeFunc, self).__init__() 296 | self.mlp = mlp 297 | self.bn = nn.BatchNorm1d(self.mlp.output_dim) 298 | 299 | def forward(self, h): 300 | h = self.mlp(h) 301 | h = self.bn(h) 302 | h = F.relu(h) 303 | return h 304 | 305 | 306 | class MLP(nn.Module): 307 | """MLP with linear output""" 308 | 309 | def __init__( 310 | self, num_layers, input_dim, hidden_dim, output_dim, input_drop=0.0, dropout=0.2 311 | ): 312 | """MLP layers construction 313 | 314 | Paramters 315 | --------- 316 | num_layers: int 317 | The number of linear layers 318 | input_dim: int 319 | The dimensionality of input features 320 | hidden_dim: int 321 | The dimensionality of hidden units at ALL layers 322 | output_dim: int 323 | The number of classes for prediction 324 | 325 | """ 326 | super(MLP, self).__init__() 327 | self.linear_or_not = True # default is linear model 328 | self.num_layers = num_layers 329 | self.output_dim = output_dim 330 | self.input_drop = nn.Dropout(input_drop) 331 | self.dropout = nn.Dropout(dropout) 332 | 333 | if num_layers < 1: 334 | raise ValueError("number of layers should be positive!") 335 | elif num_layers == 1: 336 | # Linear model 337 | self.linear = nn.Linear(input_dim, output_dim) 338 | else: 339 | # Multi-layer model 340 | self.linear_or_not = False 341 | self.linears = torch.nn.ModuleList() 342 | self.batch_norms = torch.nn.ModuleList() 343 | 344 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 345 | for layer in range(num_layers - 2): 346 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 347 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 348 | 349 | for layer in range(num_layers - 1): 350 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) 351 | 352 | def forward(self, x): 353 | if self.linear_or_not: 354 | # If linear model 355 | return self.linear(x) 356 | else: 357 | # If MLP 358 | x = self.input_drop(x) 359 | h = x 360 | for i in range(self.num_layers - 1): 361 | h = F.relu(self.batch_norms[i](self.linears[i](h))) 362 | h = self.dropout(h) 363 | return self.linears[-1](h) 364 | 365 | 366 | class GIN(nn.Module): 367 | """GIN model""" 368 | 369 | def __init__( 370 | self, 371 | in_feats, 372 | n_hidden, 373 | n_classes, 374 | n_layers, 375 | num_mlp_layers, 376 | input_dropout, 377 | learn_eps, 378 | neighbor_pooling_type, 379 | ): 380 | """model parameters setting 381 | 382 | Paramters 383 | --------- 384 | num_layers: int 385 | The number of linear layers in the neural network 386 | num_mlp_layers: int 387 | The number of linear layers in mlps 388 | input_dim: int 389 | The dimensionality of input features 390 | hidden_dim: int 391 | The dimensionality of hidden units at ALL layers 392 | output_dim: int 393 | The number of classes for prediction 394 | final_dropout: float 395 | dropout ratio on the final linear layer 396 | learn_eps: boolean 397 | If True, learn epsilon to distinguish center nodes from neighbors 398 | If False, aggregate neighbors and center nodes altogether. 399 | neighbor_pooling_type: str 400 | how to aggregate neighbors (sum, mean, or max) 401 | graph_pooling_type: str 402 | how to aggregate entire nodes in a graph (sum, mean or max) 403 | 404 | """ 405 | super(GIN, self).__init__() 406 | self.n_layers = n_layers 407 | self.n_hidden = n_hidden 408 | self.n_classes = n_classes 409 | self.learn_eps = learn_eps 410 | 411 | # List of MLPs 412 | self.ginlayers = torch.nn.ModuleList() 413 | self.batch_norms = torch.nn.ModuleList() 414 | 415 | for layer in range(self.n_layers - 1): 416 | if layer == 0: 417 | mlp = MLP(num_mlp_layers, in_feats, n_hidden, n_hidden) 418 | else: 419 | mlp = MLP(num_mlp_layers, n_hidden, n_hidden, n_hidden) 420 | 421 | self.ginlayers.append( 422 | dglnn.GINConv( 423 | ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps 424 | ) 425 | ) 426 | self.batch_norms.append(nn.BatchNorm1d(n_hidden)) 427 | 428 | # Linear function for graph poolings of output of each layer 429 | # which maps the output of different layers into a prediction score 430 | self.linears_prediction = torch.nn.ModuleList() 431 | 432 | for layer in range(n_layers): 433 | if layer == 0: 434 | self.linears_prediction.append(nn.Linear(in_feats, n_classes)) 435 | else: 436 | self.linears_prediction.append(nn.Linear(n_hidden, n_classes)) 437 | 438 | self.input_drop = nn.Dropout(input_dropout) 439 | 440 | def forward(self, g, h): 441 | h = self.input_drop(h) 442 | # list of hidden representation at each layer (including input) 443 | hidden_rep = [] 444 | 445 | for i in range(self.n_layers - 1): 446 | h = self.ginlayers[i](g, h) 447 | h = self.batch_norms[i](h) 448 | h = F.relu(h) 449 | hidden_rep.append(h) 450 | 451 | z = torch.cat(hidden_rep, dim=1) 452 | 453 | return z 454 | 455 | 456 | class JKNet(nn.Module): 457 | def __init__( 458 | self, in_feats, n_hidden, n_classes, n_layers=1, mode="cat", dropout=0.0 459 | ): 460 | super(JKNet, self).__init__() 461 | 462 | self.mode = mode 463 | self.dropout = nn.Dropout(dropout) 464 | self.layers = nn.ModuleList() 465 | self.layers.append(dglnn.GraphConv(in_feats, n_hidden, activation=F.relu)) 466 | for _ in range(n_layers): 467 | self.layers.append(dglnn.GraphConv(n_hidden, n_hidden, activation=F.relu)) 468 | 469 | if self.mode == "lstm": 470 | self.jump = dglnn.JumpingKnowledge(mode, n_hidden, n_layers) 471 | else: 472 | self.jump = dglnn.JumpingKnowledge(mode) 473 | 474 | if self.mode == "cat": 475 | n_hidden = n_hidden * (n_layers + 1) 476 | 477 | self.output = nn.Linear(n_hidden, n_classes) 478 | self.reset_params() 479 | 480 | def reset_params(self): 481 | self.output.reset_parameters() 482 | for layers in self.layers: 483 | layers.reset_parameters() 484 | if self.mode == "lstm": 485 | self.lstm.reset_parameters() 486 | self.attn.reset_parameters() 487 | 488 | def forward(self, g, feats): 489 | feat_lst = [] 490 | for layer in self.layers: 491 | feats = self.dropout(layer(g, feats)) 492 | feat_lst.append(feats) 493 | 494 | if self.mode == "cat": 495 | out = torch.cat(feat_lst, dim=-1) 496 | elif self.mode == "max": 497 | out = torch.stack(feat_lst, dim=-1).max(dim=-1)[0] 498 | else: 499 | # lstm 500 | x = torch.stack(feat_lst, dim=1) 501 | alpha, _ = self.lstm(x) 502 | alpha = self.attn(alpha).squeeze(-1) 503 | alpha = torch.softmax(alpha, dim=-1).unsqueeze(-1) 504 | out = (x * alpha).sum(dim=1) 505 | 506 | g.ndata["h"] = out 507 | g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) 508 | 509 | return self.output(g.ndata["h"]) 510 | 511 | 512 | class MoNet(nn.Module): 513 | def __init__( 514 | self, 515 | in_feats, 516 | n_hidden, 517 | n_classes, 518 | n_layers, 519 | dim, 520 | n_kernels, 521 | input_drop, 522 | dropout, 523 | ): 524 | super(MoNet, self).__init__() 525 | self.input_drop = nn.Dropout(input_drop) 526 | self.layers = nn.ModuleList() 527 | self.pseudo_proj = nn.ModuleList() 528 | 529 | # Input layer 530 | self.layers.append(dglnn.GMMConv(in_feats, n_hidden, dim, n_kernels)) 531 | self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh())) 532 | 533 | # Hidden layer 534 | for _ in range(n_layers - 1): 535 | self.layers.append(dglnn.GMMConv(n_hidden, n_hidden, dim, n_kernels)) 536 | self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh())) 537 | 538 | # Output layer 539 | self.layers.append(dglnn.GMMConv(n_hidden, n_classes, dim, n_kernels)) 540 | self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh())) 541 | self.dropout = nn.Dropout(dropout) 542 | 543 | def forward(self, feat, pseudo, graph): 544 | h = feat 545 | h = self.input_drop(h) 546 | for i in range(len(self.layers)): 547 | if i != 0: 548 | h = self.dropout(h) 549 | h = self.layers[i](graph, h, self.pseudo_proj[i](pseudo)) 550 | return h 551 | 552 | 553 | class Node2vec(nn.Module): 554 | """Node2vec model from paper node2vec: Scalable Feature Learning for Networks 555 | Attributes 556 | ---------- 557 | g: DGLGraph 558 | The graph. 559 | embedding_dim: int 560 | Dimension of node embedding. 561 | walk_length: int 562 | Length of each trace. 563 | p: float 564 | Likelihood of immediately revisiting a node in the walk. Same notation as in the paper. 565 | q: float 566 | Control parameter to interpolate between breadth-first strategy and depth-first strategy. 567 | Same notation as in the paper. 568 | num_walks: int 569 | Number of random walks for each node. Default: 10. 570 | window_size: int 571 | Maximum distance between the center node and predicted node. Default: 5. 572 | num_negatives: int 573 | The number of negative samples for each positive sample. Default: 5. 574 | use_sparse: bool 575 | If set to True, use PyTorch's sparse embedding and optimizer. Default: ``True``. 576 | weight_name : str, optional 577 | The name of the edge feature tensor on the graph storing the (unnormalized) 578 | probabilities associated with each edge for choosing the next node. 579 | The feature tensor must be non-negative and the sum of the probabilities 580 | must be positive for the outbound edges of all nodes (although they don't have 581 | to sum up to one). The result will be undefined otherwise. 582 | If omitted, DGL assumes that the neighbors are picked uniformly. 583 | """ 584 | 585 | def __init__( 586 | self, 587 | g, 588 | embedding_dim, 589 | walk_length, 590 | p, 591 | q, 592 | num_walks=10, 593 | window_size=5, 594 | num_negatives=5, 595 | use_sparse=True, 596 | weight_name=None, 597 | ): 598 | super(Node2vec, self).__init__() 599 | 600 | assert walk_length >= window_size 601 | 602 | self.g = g 603 | self.embedding_dim = embedding_dim 604 | self.walk_length = walk_length 605 | self.p = p 606 | self.q = q 607 | self.num_walks = num_walks 608 | self.window_size = window_size 609 | self.num_negatives = num_negatives 610 | self.N = self.g.num_nodes() 611 | if weight_name is not None: 612 | self.prob = weight_name 613 | else: 614 | self.prob = None 615 | 616 | self.embedding = nn.Embedding(self.N, embedding_dim, sparse=use_sparse) 617 | 618 | def reset_parameters(self): 619 | self.embedding.reset_parameters() 620 | 621 | def sample(self, batch): 622 | """ 623 | Generate positive and negative samples. 624 | Positive samples are generated from random walk 625 | Negative samples are generated from random sampling 626 | """ 627 | if not isinstance(batch, torch.Tensor): 628 | batch = torch.tensor(batch) 629 | 630 | batch = batch.repeat(self.num_walks) 631 | # positive 632 | pos_traces = node2vec_random_walk( 633 | self.g, batch, self.p, self.q, self.walk_length, self.prob 634 | ) 635 | pos_traces = pos_traces.unfold(1, self.window_size, 1) # rolling window 636 | pos_traces = pos_traces.contiguous().view(-1, self.window_size) 637 | 638 | # negative 639 | neg_batch = batch.repeat(self.num_negatives) 640 | neg_traces = torch.randint(self.N, (neg_batch.size(0), self.walk_length)) 641 | neg_traces = torch.cat([neg_batch.view(-1, 1), neg_traces], dim=-1) 642 | neg_traces = neg_traces.unfold(1, self.window_size, 1) # rolling window 643 | neg_traces = neg_traces.contiguous().view(-1, self.window_size) 644 | 645 | return pos_traces, neg_traces 646 | 647 | def forward(self, nodes=None): 648 | """ 649 | Returns the embeddings of the input nodes 650 | Parameters 651 | ---------- 652 | nodes: Tensor, optional 653 | Input nodes, if set `None`, will return all the node embedding. 654 | Returns 655 | ------- 656 | Tensor 657 | Node embedding 658 | """ 659 | emb = self.embedding.weight 660 | if nodes is None: 661 | return emb 662 | else: 663 | return emb[nodes] 664 | 665 | def loss(self, pos_trace, neg_trace): 666 | """ 667 | Computes the loss given positive and negative random walks. 668 | Parameters 669 | ---------- 670 | pos_trace: Tensor 671 | positive random walk trace 672 | neg_trace: Tensor 673 | negative random walk trace 674 | """ 675 | e = 1e-15 676 | 677 | # Positive 678 | pos_start, pos_rest = ( 679 | pos_trace[:, 0], 680 | pos_trace[:, 1:].contiguous(), 681 | ) # start node and following trace 682 | w_start = self.embedding(pos_start).unsqueeze(dim=1) 683 | w_rest = self.embedding(pos_rest) 684 | pos_out = (w_start * w_rest).sum(dim=-1).view(-1) 685 | 686 | # Negative 687 | neg_start, neg_rest = neg_trace[:, 0], neg_trace[:, 1:].contiguous() 688 | 689 | w_start = self.embedding(neg_start).unsqueeze(dim=1) 690 | w_rest = self.embedding(neg_rest) 691 | neg_out = (w_start * w_rest).sum(dim=-1).view(-1) 692 | 693 | # compute loss 694 | pos_loss = -torch.log(torch.sigmoid(pos_out) + e).mean() 695 | neg_loss = -torch.log(1 - torch.sigmoid(neg_out) + e).mean() 696 | 697 | return pos_loss + neg_loss 698 | 699 | def loader(self, batch_size): 700 | """ 701 | Parameters 702 | ---------- 703 | batch_size: int 704 | batch size 705 | Returns 706 | ------- 707 | DataLoader 708 | Node2vec training data loader 709 | """ 710 | return DataLoader( 711 | torch.arange(self.N), 712 | batch_size=batch_size, 713 | shuffle=True, 714 | collate_fn=self.sample, 715 | ) 716 | 717 | 718 | class Node2vecModel(object): 719 | """ 720 | Wrapper of the ``Node2Vec`` class with a ``train`` method. 721 | Attributes 722 | ---------- 723 | g: DGLGraph 724 | The graph. 725 | embedding_dim: int 726 | Dimension of node embedding. 727 | walk_length: int 728 | Length of each trace. 729 | p: float 730 | Likelihood of immediately revisiting a node in the walk. 731 | q: float 732 | Control parameter to interpolate between breadth-first strategy and depth-first strategy. 733 | num_walks: int 734 | Number of random walks for each node. Default: 10. 735 | window_size: int 736 | Maximum distance between the center node and predicted node. Default: 5. 737 | num_negatives: int 738 | The number of negative samples for each positive sample. Default: 5. 739 | use_sparse: bool 740 | If set to True, uses PyTorch's sparse embedding and optimizer. Default: ``True``. 741 | weight_name : str, optional 742 | The name of the edge feature tensor on the graph storing the (unnormalized) 743 | probabilities associated with each edge for choosing the next node. 744 | The feature tensor must be non-negative and the sum of the probabilities 745 | must be positive for the outbound edges of all nodes (although they don't have 746 | to sum up to one). The result will be undefined otherwise. 747 | If omitted, DGL assumes that the neighbors are picked uniformly. Default: ``None``. 748 | eval_set: list of tuples (Tensor, Tensor) 749 | [(nodes_train,y_train),(nodes_val,y_val)] 750 | If omitted, model will not be evaluated. Default: ``None``. 751 | eval_steps: int 752 | Interval steps of evaluation. 753 | if set <= 0, model will not be evaluated. Default: ``None``. 754 | device: str 755 | device, default 'cpu'. 756 | """ 757 | 758 | def __init__( 759 | self, 760 | g, 761 | embedding_dim, 762 | walk_length, 763 | p=1.0, 764 | q=1.0, 765 | num_walks=1, 766 | window_size=5, 767 | num_negatives=5, 768 | use_sparse=True, 769 | weight_name=None, 770 | eval_set=None, 771 | eval_steps=-1, 772 | device="cpu", 773 | ): 774 | 775 | self.model = Node2vec( 776 | g, 777 | embedding_dim, 778 | walk_length, 779 | p, 780 | q, 781 | num_walks, 782 | window_size, 783 | num_negatives, 784 | use_sparse, 785 | weight_name, 786 | ) 787 | self.g = g 788 | self.use_sparse = use_sparse 789 | self.eval_steps = eval_steps 790 | self.eval_set = eval_set 791 | 792 | if device == "cpu": 793 | self.device = device 794 | else: 795 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 796 | 797 | def _train_step(self, model, loader, optimizer, device): 798 | model.train() 799 | total_loss = 0 800 | for pos_traces, neg_traces in loader: 801 | pos_traces, neg_traces = pos_traces.to(device), neg_traces.to(device) 802 | optimizer.zero_grad() 803 | loss = model.loss(pos_traces, neg_traces) 804 | loss.backward() 805 | optimizer.step() 806 | total_loss += loss.item() 807 | return total_loss / len(loader) 808 | 809 | @torch.no_grad() 810 | def _evaluate_step(self): 811 | nodes_train, y_train = self.eval_set[0] 812 | nodes_val, y_val = self.eval_set[1] 813 | 814 | acc = self.model.evaluate(nodes_train, y_train, nodes_val, y_val) 815 | return acc 816 | 817 | def train(self, epochs, batch_size, learning_rate=0.01): 818 | """ 819 | Parameters 820 | ---------- 821 | epochs: int 822 | num of train epoch 823 | batch_size: int 824 | batch size 825 | learning_rate: float 826 | learning rate. Default 0.01. 827 | """ 828 | 829 | self.model = self.model.to(self.device) 830 | loader = self.model.loader(batch_size) 831 | if self.use_sparse: 832 | optimizer = torch.optim.SparseAdam( 833 | list(self.model.parameters()), lr=learning_rate 834 | ) 835 | else: 836 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) 837 | for i in range(epochs): 838 | loss = self._train_step(self.model, loader, optimizer, self.device) 839 | if self.eval_steps > 0: 840 | if epochs % self.eval_steps == 0: 841 | acc = self._evaluate_step() 842 | print( 843 | "Epoch: {}, Train Loss: {:.4f}, Val Acc: {:.4f}".format( 844 | i, loss, acc 845 | ) 846 | ) 847 | 848 | def embedding(self, nodes=None): 849 | """ 850 | Returns the embeddings of the input nodes 851 | Parameters 852 | ---------- 853 | nodes: Tensor, optional 854 | Input nodes, if set `None`, will return all the node embedding. 855 | Returns 856 | ------- 857 | Tensor 858 | Node embedding. 859 | """ 860 | 861 | return self.model(nodes) 862 | 863 | 864 | class GAT(torch.nn.Module): 865 | def __init__( 866 | self, 867 | in_channels, 868 | edge_dim, 869 | hidden_channels, 870 | out_channels, 871 | num_layers, 872 | heads, 873 | dropout, 874 | ): 875 | super().__init__() 876 | self.num_layers = num_layers 877 | self.dropout = dropout 878 | 879 | self.convs = torch.nn.ModuleList() 880 | self.convs.append( 881 | GATConv( 882 | in_channels, 883 | hidden_channels, 884 | heads, 885 | edge_dim=edge_dim, 886 | add_self_loops=False, 887 | ) 888 | ) 889 | for _ in range(num_layers - 2): 890 | self.convs.append( 891 | GATConv( 892 | heads * hidden_channels, 893 | hidden_channels, 894 | heads, 895 | edge_dim=edge_dim, 896 | add_self_loops=False, 897 | ) 898 | ) 899 | self.convs.append( 900 | GATConv( 901 | heads * hidden_channels, 902 | out_channels, 903 | heads, 904 | concat=False, 905 | edge_dim=edge_dim, 906 | add_self_loops=False, 907 | ) 908 | ) 909 | 910 | def reset_parameters(self): 911 | for conv in self.convs: 912 | conv.reset_parameters() 913 | 914 | def forward(self, x, edge_index): 915 | for conv in self.convs[:-1]: 916 | x = conv(x, edge_index) 917 | x = F.elu(x) 918 | x = F.dropout(x, p=self.dropout, training=self.training) 919 | x = self.convs[-1](x, edge_index) 920 | return x 921 | 922 | 923 | class GraphTransformer(torch.nn.Module): 924 | def __init__( 925 | self, in_channels, edge_dim, hidden_channels, out_channels, num_layers, dropout 926 | ): 927 | super().__init__() 928 | self.num_layers = num_layers 929 | self.dropout = dropout 930 | 931 | self.convs = torch.nn.ModuleList() 932 | self.convs.append( 933 | TransformerConv(in_channels, hidden_channels, edge_dim=edge_dim) 934 | ) 935 | for _ in range(num_layers - 2): 936 | self.convs.append( 937 | TransformerConv(hidden_channels, hidden_channels, edge_dim=edge_dim) 938 | ) 939 | self.convs.append( 940 | TransformerConv(hidden_channels, out_channels, edge_dim=edge_dim) 941 | ) 942 | 943 | def reset_parameters(self): 944 | for conv in self.convs: 945 | conv.reset_parameters() 946 | 947 | def forward(self, x, edge_index): 948 | for conv in self.convs[:-1]: 949 | x = conv(x, edge_index) 950 | x = F.relu(x) 951 | x = F.dropout(x, p=self.dropout, training=self.training) 952 | x = self.convs[-1](x, edge_index) 953 | return x 954 | 955 | 956 | class GINE(torch.nn.Module): 957 | def __init__( 958 | self, in_channels, edge_dim, hidden_channels, out_channels, num_layers, dropout 959 | ): 960 | super().__init__() 961 | self.num_layers = num_layers 962 | self.dropout = dropout 963 | 964 | self.convs = torch.nn.ModuleList() 965 | self.convs.append( 966 | GINEConv(Linear(in_channels, hidden_channels), edge_dim=edge_dim) 967 | ) 968 | for _ in range(num_layers - 2): 969 | self.convs.append( 970 | GINEConv(Linear(hidden_channels, hidden_channels), edge_dim=edge_dim) 971 | ) 972 | self.convs.append( 973 | GINEConv(Linear(hidden_channels, out_channels), edge_dim=edge_dim) 974 | ) 975 | 976 | def reset_parameters(self): 977 | for conv in self.convs: 978 | conv.reset_parameters() 979 | 980 | def forward(self, x, edge_index): 981 | for conv in self.convs[:-1]: 982 | x = conv(x, edge_index) 983 | x = F.relu(x) 984 | x = F.dropout(x, p=self.dropout, training=self.training) 985 | x = self.convs[-1](x, edge_index) 986 | return x 987 | 988 | 989 | class EdgeConvConv(MessagePassing): 990 | def __init__( 991 | self, 992 | nn: Callable, 993 | eps: float = 0.0, 994 | train_eps: bool = False, 995 | edge_dim: Optional[int] = None, 996 | **kwargs, 997 | ): 998 | kwargs.setdefault("aggr", "add") 999 | super().__init__(**kwargs) 1000 | self.nn = nn 1001 | self.initial_eps = eps 1002 | if train_eps: 1003 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 1004 | else: 1005 | self.register_buffer("eps", torch.Tensor([eps])) 1006 | if edge_dim is not None: 1007 | if hasattr(self.nn, "in_features"): 1008 | in_channels = self.nn.in_features 1009 | else: 1010 | in_channels = self.nn.in_channels 1011 | self.lin = Linear(edge_dim, in_channels) 1012 | else: 1013 | self.lin = None 1014 | 1015 | def forward( 1016 | self, 1017 | x: Union[Tensor, OptPairTensor], 1018 | edge_index: Adj, 1019 | edge_attr: OptTensor = None, 1020 | size: Size = None, 1021 | ) -> Tensor: 1022 | """""" 1023 | if isinstance(x, Tensor): 1024 | x: OptPairTensor = (x, x) 1025 | 1026 | # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) 1027 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) 1028 | 1029 | return out 1030 | 1031 | def message(self, x_i: Tensor, x_j: Tensor, edge_attr: Tensor) -> Tensor: 1032 | 1033 | temp = torch.cat([x_i, x_j, edge_attr], dim=1) 1034 | 1035 | return self.nn(temp) 1036 | 1037 | def __repr__(self) -> str: 1038 | return f"{self.__class__.__name__}(nn={self.nn})" 1039 | 1040 | 1041 | class EdgeConv(torch.nn.Module): 1042 | def __init__( 1043 | self, in_channels, edge_dim, hidden_channels, out_channels, num_layers, dropout 1044 | ): 1045 | super().__init__() 1046 | self.num_layers = num_layers 1047 | self.dropout = dropout 1048 | 1049 | self.convs = torch.nn.ModuleList() 1050 | self.convs.append( 1051 | EdgeConvConv( 1052 | Linear(2 * in_channels + edge_dim, hidden_channels), edge_dim=edge_dim 1053 | ) 1054 | ) 1055 | for _ in range(num_layers - 2): 1056 | self.convs.append( 1057 | EdgeConvConv( 1058 | Linear(2 * hidden_channels + edge_dim, hidden_channels), 1059 | edge_dim=edge_dim, 1060 | ) 1061 | ) 1062 | self.convs.append( 1063 | EdgeConvConv( 1064 | Linear(2 * hidden_channels + edge_dim, out_channels), edge_dim=edge_dim 1065 | ) 1066 | ) 1067 | 1068 | def reset_parameters(self): 1069 | for conv in self.convs: 1070 | conv.reset_parameters() 1071 | 1072 | def forward(self, x, edge_index): 1073 | for conv in self.convs[:-1]: 1074 | x = conv(x, edge_index) 1075 | x = F.relu(x) 1076 | x = F.dropout(x, p=self.dropout, training=self.training) 1077 | x = self.convs[-1](x, edge_index) 1078 | return x 1079 | 1080 | 1081 | class GeneralGNN(torch.nn.Module): 1082 | def __init__( 1083 | self, in_channels, edge_dim, hidden_channels, out_channels, num_layers, dropout 1084 | ): 1085 | super().__init__() 1086 | self.num_layers = num_layers 1087 | self.dropout = dropout 1088 | 1089 | self.convs = torch.nn.ModuleList() 1090 | self.convs.append( 1091 | GeneralConv(in_channels, hidden_channels, in_edge_channels=edge_dim) 1092 | ) 1093 | for _ in range(num_layers - 2): 1094 | self.convs.append( 1095 | GeneralConv(hidden_channels, hidden_channels, in_edge_channels=edge_dim) 1096 | ) 1097 | self.convs.append( 1098 | GeneralConv(hidden_channels, out_channels, in_edge_channels=edge_dim) 1099 | ) 1100 | 1101 | def reset_parameters(self): 1102 | for conv in self.convs: 1103 | conv.reset_parameters() 1104 | 1105 | def forward(self, x, edge_index): 1106 | for conv in self.convs[:-1]: 1107 | x = conv(x, edge_index) 1108 | x = F.relu(x) 1109 | x = F.dropout(x, p=self.dropout, training=self.training) 1110 | x = self.convs[-1](x, edge_index) 1111 | return x 1112 | --------------------------------------------------------------------------------