├── GNN
├── __init__.py
├── model
│ ├── __init__.py
│ ├── utils.py
│ ├── GNN_arg.py
│ ├── Dataloader.py
│ └── GNN_library.py
├── readme.md
├── GNN_Link_AUC.py
├── GNN_Node.py
├── GNN_Link_MRR.py
├── GNN_Link_HitsK.py
├── GNN_Link_MRR_loader.py
└── GNN_Link_HitsK_loader.py
├── assets
├── logo.png
├── dataset.png
├── example.png
└── feature.png
├── run_gnn_node.sh
├── run_gnn_link.sh
├── license
├── run_emb.sh
├── .gitignore
├── requirements.txt
├── data_preprocess
├── generate_llm_emb.py
└── dataset.ipynb
└── README.md
/GNN/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/GNN/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhuofeng-Li/TEG-Benchmark/HEAD/assets/logo.png
--------------------------------------------------------------------------------
/assets/dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhuofeng-Li/TEG-Benchmark/HEAD/assets/dataset.png
--------------------------------------------------------------------------------
/assets/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhuofeng-Li/TEG-Benchmark/HEAD/assets/example.png
--------------------------------------------------------------------------------
/assets/feature.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhuofeng-Li/TEG-Benchmark/HEAD/assets/feature.png
--------------------------------------------------------------------------------
/run_gnn_node.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python GNN/GNN_Node.py \
4 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \
5 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \
6 | --graph_path Dataset/goodreads_children/processed/children.pkl
7 |
8 | # python GNN/GNN_Node.py \
9 | # --use_PLM_node Dataset/twitter/emb/twitter_bert_base_uncased_512_cls_node.pt \
10 | # --use_PLM_edge Dataset/twitter/emb/twitter_bert_base_uncased_512_cls_edge.pt \
11 | # --graph_path Dataset/twitter/processed/twitter.pkl
12 |
13 |
14 |
--------------------------------------------------------------------------------
/run_gnn_link.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define datasets
4 | datasets=(
5 | # "goodreads_children"
6 | # "twitter"
7 | # "goodreads_history"
8 | "goodreads_crime"
9 | )
10 |
11 | # Define models to run
12 | models=(
13 | "GNN_Link_HitsK_loader.py"
14 | # "GNN_Link_MRR_loader.py"
15 | # "GNN_Link_AUC.py"
16 | )
17 |
18 | # Loop through datasets and models
19 | for dataset in "${datasets[@]}"; do
20 | for model in "${models[@]}"; do
21 | echo "Running $model on $dataset dataset"
22 | python GNN/$model \
23 | --use_PLM_node Dataset/$dataset/emb/${dataset##*_}_bert_base_uncased_512_cls_node.pt \
24 | --use_PLM_edge Dataset/$dataset/emb/${dataset##*_}_bert_base_uncased_512_cls_edge.pt \
25 | --path Dataset/$dataset/LinkPrediction/ \
26 | --graph_path Dataset/$dataset/processed/${dataset##*_}.pkl \
27 | --batch_size 1024
28 | done
29 | done
30 |
--------------------------------------------------------------------------------
/license:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 TAG-Benchmark Team
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/GNN/model/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from ogb.linkproppred import *
4 | from ogb.nodeproppred import *
5 |
6 | from dgl.data import CitationGraphDataset
7 |
8 |
9 | def load_graph(name):
10 | cite_graphs = ["cora", "citeseer", "pubmed"]
11 |
12 | if name in cite_graphs:
13 | dataset = CitationGraphDataset(name)
14 | graph = dataset[0]
15 |
16 | nodes = graph.nodes()
17 | y = graph.ndata["label"]
18 | train_mask = graph.ndata["train_mask"]
19 | val_mask = graph.ndata["test_mask"]
20 |
21 | nodes_train, y_train = nodes[train_mask], y[train_mask]
22 | nodes_val, y_val = nodes[val_mask], y[val_mask]
23 | eval_set = [(nodes_train, y_train), (nodes_val, y_val)]
24 |
25 | elif name.startswith("ogbn"):
26 |
27 | dataset = DglNodePropPredDataset(name)
28 | graph, y = dataset[0]
29 | split_nodes = dataset.get_idx_split()
30 | nodes = graph.nodes()
31 |
32 | train_idx = split_nodes["train"]
33 | val_idx = split_nodes["valid"]
34 |
35 | nodes_train, y_train = nodes[train_idx], y[train_idx]
36 | nodes_val, y_val = nodes[val_idx], y[val_idx]
37 | eval_set = [(nodes_train, y_train), (nodes_val, y_val)]
38 |
39 | else:
40 | raise ValueError("Dataset name error!")
41 |
42 | return graph, eval_set
43 |
44 |
45 | def parse_arguments():
46 | """
47 | Parse arguments
48 | """
49 | parser = argparse.ArgumentParser(description="Node2vec")
50 | parser.add_argument("--dataset", type=str, default="cora")
51 | # 'train' for training node2vec model, 'time' for testing speed of random walk
52 | parser.add_argument("--task", type=str, default="train")
53 | parser.add_argument("--runs", type=int, default=10)
54 | parser.add_argument("--device", type=str, default="cpu")
55 | parser.add_argument("--embedding_dim", type=int, default=128)
56 | parser.add_argument("--walk_length", type=int, default=50)
57 | parser.add_argument("--p", type=float, default=0.25)
58 | parser.add_argument("--q", type=float, default=4.0)
59 | parser.add_argument("--num_walks", type=int, default=10)
60 | parser.add_argument("--epochs", type=int, default=100)
61 | parser.add_argument("--batch_size", type=int, default=128)
62 |
63 | args = parser.parse_args()
64 |
65 | return args
--------------------------------------------------------------------------------
/run_emb.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #!/bin/bash
3 |
4 | # # Define an array of dataset configurations
5 | # datasets=(
6 | # "goodreads_children:children"
7 | # "goodreads_history:history"
8 | # "goodreads_comics:comics"
9 | # "goodreads_crime:crime"
10 | # # "twitter:twitter"
11 | # # "amazon_movie:movie"
12 | # # "amazon_baby:baby"
13 | # # "reddit:reddit"
14 | # )
15 |
16 | # # Loop through each dataset
17 | # for dataset in "${datasets[@]}"; do
18 | # # Split the dataset string into folder and name
19 | # IFS=":" read -r folder name <<< "$dataset"
20 |
21 | # echo "Processing dataset: $folder"
22 |
23 | # python data_preprocess/generate_llm_emb.py \
24 | # --pkl_file "Dataset/$folder/processed/${name}.pkl" \
25 | # --path "Dataset/$folder/emb" \
26 | # --name "$name" \
27 | # --model_name meta-llama/Meta-Llama-3-8B \
28 | # --max_length 2048 \
29 | # --batch_size 25
30 | # done
31 |
32 |
33 | # python data_preprocess/generate_llm_emb.py \
34 | # --pkl_file Dataset/arxiv/processed/arxiv.pkl \
35 | # --path Dataset/arxiv/emb \
36 | # --name arxiv \
37 | # --model_name bert-base-uncased \
38 |
39 | # python data_preprocess/generate_llm_emb.py \
40 | # --pkl_file Dataset/arxiv/processed/arxiv.pkl \
41 | # --path Dataset/arxiv/emb \
42 | # --name arxiv \
43 | # --model_name bert-large-uncased \
44 | # --batch_size 250
45 |
46 | # python data_preprocess/generate_llm_emb.py \
47 | # --pkl_file Dataset/twitter/processed/twitter.pkl \
48 | # --path Dataset/twitter/emb \
49 | # --name tweets \
50 | # --model_name bert-base-uncased \
51 |
52 | python data_preprocess/generate_llm_emb.py \
53 | --pkl_file Dataset/twitter/processed/twitter.pkl \
54 | --path Dataset/twitter/emb \
55 | --name tweets \
56 | --model_name bert-base-uncased \
57 |
58 | python data_preprocess/generate_llm_emb.py \
59 | --pkl_file Dataset/twitter/processed/twitter.pkl \
60 | --path Dataset/twitter/emb \
61 | --name tweets \
62 | --model_name bert-large-uncased \
63 | --batch_size 250
64 |
65 | # python data_preprocess/generate_llm_emb.py \
66 | # --pkl_file Dataset/amazon_movie/processed/movie.pkl \
67 | # --path Dataset/amazon_movie/emb \
68 | # --name movie \
69 | # --model_name bert-base-uncased \
70 |
71 |
72 | # python data_preprocess/generate_llm_emb.py \
73 | # --pkl_file Dataset/amazon_baby/processed/baby.pkl \
74 | # --path Dataset/amazon_baby/emb \
75 | # --name baby \
76 | # --model_name bert-base-uncased \
77 |
78 | # python data_preprocess/generate_llm_emb.py \
79 | # --pkl_file Dataset/reddit/processed/reddit.pkl \
80 | # --path Dataset/reddit/emb \
81 | # --name reddit \
82 | # --model_name bert-large-uncased \
83 |
--------------------------------------------------------------------------------
/GNN/readme.md:
--------------------------------------------------------------------------------
1 | ## Edge-aware GNN Link Prediction
2 |
3 | ### MRR Metric
4 | You can utilize following codes to do full batch training on MRR link prediction.
5 | ```
6 | python GNN/GNN_Link_MRR.py \
7 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \
8 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \
9 | --path Dataset/goodreads_children/LinkPrediction/ \
10 | --graph_path Dataset/goodreads_children/processed/children.pkl
11 | ```
12 |
13 | ### MRR Metric using neighbor sampling
14 | You can utilize following codes to handle large-scale graph MRR link prediction.
15 |
16 | ```
17 | python GNN/GNN_Link_MRR_loader.py \
18 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \
19 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \
20 | --path Dataset/goodreads_children/LinkPrediction/ \
21 | --graph_path Dataset/goodreads_children/processed/children.pkl
22 | ```
23 |
24 | ### HitsK Metric
25 | You can utilize following codes to do full batch training on HitsK link prediction.
26 | ```
27 | python GNN/GNN_Link_HitsK.py \
28 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \
29 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \
30 | --path Dataset/goodreads_children/LinkPrediction/ \
31 | --graph_path Dataset/goodreads_children/processed/children.pkl
32 | ```
33 |
34 | ### HitsK Metric using neighbor sampling
35 | You can utilize following codes to handle large-scale graph HitsK link prediction.
36 |
37 | ```
38 | python GNN/GNN_Link_HitsK_loader.py \
39 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \
40 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \
41 | --path Dataset/goodreads_children/LinkPrediction/ \
42 | --graph_path Dataset/goodreads_children/processed/children.pkl
43 | ```
44 |
45 | ### AUC Metric using neighbor sampling
46 | You can utilize following codes to handle large-scale graph AUC link prediction.
47 |
48 | ```
49 | python GNN/GNN_Link_AUC.py \
50 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \
51 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \
52 | --graph_path Dataset/goodreads_children/processed/children.pkl
53 | ```
54 |
55 | ## Edge-aware GNN Node Classification
56 |
57 | ```
58 | python GNN/GNN_Node.py \
59 | --use_PLM_node Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_node.pt \
60 | --use_PLM_edge Dataset/goodreads_children/emb/children_bert_base_uncased_512_cls_edge.pt \
61 | --graph_path Dataset/goodreads_children/processed/children.pkl
62 | ```
63 |
64 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | GNN/__pycache__/
4 | GNN/wandb/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | env/
14 | venv/
15 | ENV/
16 | env.bak/
17 | venv.bak/
18 | bin/
19 | build/
20 | develop-eggs/
21 | dist/
22 | eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Jupyter Notebook
57 | .ipynb_checkpoints
58 |
59 | # IPython
60 | profile_default/
61 | ipython_config.py
62 |
63 | # pyenv
64 | .python-version
65 |
66 | # Environments
67 | .env
68 | .venv
69 | env/
70 | venv/
71 | ENV/
72 | env.bak/
73 | venv.bak/
74 |
75 | # Spyder project settings
76 | .spyderproject
77 | .spyproject
78 |
79 | # Rope project settings
80 | .ropeproject
81 |
82 | # mkdocs documentation
83 | /site
84 |
85 | # mypy
86 | .mypy_cache/
87 |
88 | # Pyre type checker
89 | .pyre/
90 |
91 | # pytype static type analyzer
92 | .pytype/
93 |
94 | # Cython debug symbols
95 | cython_debug/
96 |
97 | # Django stuff:
98 | *.log
99 | local_settings.py
100 | db.sqlite3
101 | db.sqlite3-journal
102 |
103 | # Flask stuff:
104 | instance/
105 | .webassets-cache
106 |
107 | # Scrapy stuff:
108 | .scrapy
109 |
110 | # Sphinx documentation
111 | docs/_build/
112 |
113 | # PyBuilder
114 | target/
115 |
116 | # VS Code settings
117 | .vscode/
118 | .vscode/settings.json
119 |
120 | # Python-specific user settings
121 | python_settings.json
122 |
123 | # Data science
124 | *.csv
125 | *.tsv
126 | *.dat
127 | *.np[yz]
128 |
129 | # Notebook checkpoints
130 | .ipynb_checkpoints
131 |
132 | # VS Code
133 | .vscode/
134 |
135 | # Visual Studio Code
136 | .vscode/
137 |
138 | # PyCharm
139 | .idea/
140 | *.iml
141 |
142 | # Project-specific
143 | *.pkl
144 | *.npy
145 | *.pt
146 | *.pkl
147 | nx_graph.pkl
148 |
149 | # Jupyter
150 | .ipynb_checkpoints
151 |
152 | # Temporary files
153 | *.tmp
154 | *.temp
155 | *.swp
156 | *.swo
157 | *.bak
158 | *~
159 |
160 | example/linkproppred/*/*_dataset/
161 | example/nodeproppred/*/*_dataset/
162 | todo.txt
163 | TAG_Benchmark/example/nodeproppred/citation/node_classification_log/tmp.ipynb
164 | example/linkproppred/*/*/*_dataset/
165 | example/nodeproppred/*/*/*_dataset/
166 | Dataset
167 | GNN/GNN_Node.py
168 | wandb/
169 | tmp.ipynb
170 | tmp_link.sh
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.0.1
2 | aiohappyeyeballs==2.4.3
3 | aiohttp==3.10.10
4 | aiosignal==1.3.1
5 | annotated-types==0.7.0
6 | anyio==4.6.2.post1
7 | argon2-cffi==23.1.0
8 | argon2-cffi-bindings==21.2.0
9 | arrow==1.3.0
10 | asttokens==2.4.1
11 | async-lru==2.0.4
12 | async-timeout==4.0.3
13 | attrs==24.2.0
14 | Automat==20.2.0
15 | babel==2.16.0
16 | bcrypt==3.2.0
17 | beautifulsoup4==4.12.3
18 | bleach==6.1.0
19 | blinker==1.4
20 | certifi==2020.6.20
21 | cffi==1.17.1
22 | chardet==4.0.0
23 | charset-normalizer==3.4.0
24 | click==8.0.3
25 | cloud-init==24.3.1
26 | colorama==0.4.4
27 | comm==0.2.2
28 | command-not-found==0.3
29 | configobj==5.0.6
30 | constantly==15.1.0
31 | contourpy==1.3.0
32 | cryptography==3.4.8
33 | cycler==0.12.1
34 | datasets==3.0.1
35 | dbus-python==1.2.18
36 | debugpy==1.8.7
37 | decorator==5.1.1
38 | defusedxml==0.7.1
39 | dgl==2.4.0+cu124
40 | dill==0.3.8
41 | distro==1.7.0
42 | distro-info==1.1+ubuntu0.2
43 | docker-pycreds==0.4.0
44 | exceptiongroup==1.2.2
45 | executing==2.1.0
46 | fastjsonschema==2.20.0
47 | filelock==3.16.1
48 | fonttools==4.54.1
49 | fqdn==1.5.1
50 | frozenlist==1.4.1
51 | fsspec==2024.6.1
52 | gitdb==4.0.11
53 | GitPython==3.1.43
54 | h11==0.14.0
55 | httpcore==1.0.6
56 | httplib2==0.20.2
57 | httpx==0.27.2
58 | huggingface-hub==0.26.0
59 | hyperlink==21.0.0
60 | idna==3.3
61 | importlib-metadata==4.6.4
62 | incremental==21.3.0
63 | ipykernel==6.29.5
64 | ipython==8.28.0
65 | ipywidgets==8.1.5
66 | isoduration==20.11.0
67 | jedi==0.19.1
68 | jeepney==0.7.1
69 | Jinja2==3.0.3
70 | joblib==1.4.2
71 | json5==0.9.25
72 | jsonpatch==1.32
73 | jsonpointer==2.0
74 | jsonschema==4.23.0
75 | jsonschema-specifications==2024.10.1
76 | jupyter==1.1.1
77 | jupyter-console==6.6.3
78 | jupyter-events==0.10.0
79 | jupyter-lsp==2.2.5
80 | jupyter_client==8.6.3
81 | jupyter_core==5.7.2
82 | jupyter_server==2.14.2
83 | jupyter_server_terminals==0.5.3
84 | jupyterlab==4.2.5
85 | jupyterlab_pygments==0.3.0
86 | jupyterlab_server==2.27.3
87 | jupyterlab_widgets==3.0.13
88 | keyring==23.5.0
89 | kiwisolver==1.4.7
90 | launchpadlib==1.10.16
91 | lazr.restfulclient==0.14.4
92 | lazr.uri==1.0.6
93 | littleutils==0.2.4
94 | MarkupSafe==2.0.1
95 | matplotlib==3.9.2
96 | matplotlib-inline==0.1.7
97 | mistune==3.0.2
98 | more-itertools==8.10.0
99 | mpmath==1.3.0
100 | multidict==6.1.0
101 | multiprocess==0.70.16
102 | nbclient==0.10.0
103 | nbconvert==7.16.4
104 | nbformat==5.10.4
105 | nest-asyncio==1.6.0
106 | netifaces==0.11.0
107 | networkx==3.4.1
108 | notebook==7.2.2
109 | notebook_shim==0.2.4
110 | numpy==2.1.2
111 | nvidia-cublas-cu12==12.1.3.1
112 | nvidia-cuda-cupti-cu12==12.1.105
113 | nvidia-cuda-nvrtc-cu12==12.1.105
114 | nvidia-cuda-runtime-cu12==12.1.105
115 | nvidia-cudnn-cu12==9.1.0.70
116 | nvidia-cufft-cu12==11.0.2.54
117 | nvidia-curand-cu12==10.3.2.106
118 | nvidia-cusolver-cu12==11.4.5.107
119 | nvidia-cusparse-cu12==12.1.0.106
120 | nvidia-nccl-cu12==2.20.5
121 | nvidia-nvjitlink-cu12==12.4.127
122 | nvidia-nvtx-cu12==12.1.105
123 | oauthlib==3.2.0
124 | ogb==1.3.6
125 | outdated==0.2.2
126 | overrides==7.7.0
127 | packaging==24.1
128 | pandas==2.2.3
129 | pandocfilters==1.5.1
130 | parso==0.8.4
131 | pexpect==4.8.0
132 | pillow==11.0.0
133 | platformdirs==4.3.6
134 | prometheus_client==0.21.0
135 | prompt_toolkit==3.0.48
136 | propcache==0.2.0
137 | protobuf==5.28.2
138 | psutil==6.1.0
139 | ptyprocess==0.7.0
140 | pure_eval==0.2.3
141 | pyarrow==17.0.0
142 | pyasn1==0.4.8
143 | pyasn1-modules==0.2.1
144 | pycparser==2.22
145 | pydantic==2.9.2
146 | pydantic_core==2.23.4
147 | pyg-lib==0.4.0+pt24cu124
148 | Pygments==2.18.0
149 | PyGObject==3.42.1
150 | PyHamcrest==2.0.2
151 | PyJWT==2.3.0
152 | pyOpenSSL==21.0.0
153 | pyparsing==2.4.7
154 | pyrsistent==0.18.1
155 | pyserial==3.5
156 | python-apt==2.4.0+ubuntu4
157 | python-dateutil==2.9.0.post0
158 | python-debian==0.1.43+ubuntu1.1
159 | python-json-logger==2.0.7
160 | python-magic==0.4.24
161 | pytz==2022.1
162 | PyYAML==5.4.1
163 | pyzmq==26.2.0
164 | referencing==0.35.1
165 | regex==2024.9.11
166 | requests==2.32.3
167 | rfc3339-validator==0.1.4
168 | rfc3986-validator==0.1.1
169 | rpds-py==0.20.0
170 | safetensors==0.4.5
171 | scikit-learn==1.5.2
172 | scipy==1.14.1
173 | screen-resolution-extra==0.0.0
174 | SecretStorage==3.3.1
175 | Send2Trash==1.8.3
176 | sentry-sdk==2.17.0
177 | service-identity==18.1.0
178 | setproctitle==1.3.3
179 | six==1.16.0
180 | smmap==5.0.1
181 | sniffio==1.3.1
182 | sos==4.5.6
183 | soupsieve==2.6
184 | ssh-import-id==5.11
185 | stack-data==0.6.3
186 | sympy==1.13.1
187 | systemd-python==234
188 | terminado==0.18.1
189 | threadpoolctl==3.5.0
190 | tinycss2==1.3.0
191 | tokenizers==0.20.1
192 | tomli==2.0.2
193 | torch==2.4.0
194 | torch-geometric==2.6.1
195 | torch_cluster==1.6.3+pt24cu124
196 | torch_scatter==2.1.2+pt24cu124
197 | torch_sparse==0.6.18+pt24cu124
198 | torch_spline_conv==1.2.2+pt24cu124
199 | tornado==6.4.1
200 | tqdm==4.66.5
201 | traitlets==5.14.3
202 | transformers==4.45.2
203 | triton==3.0.0
204 | Twisted==22.1.0
205 | types-python-dateutil==2.9.0.20241003
206 | typing_extensions==4.12.2
207 | tzdata==2024.2
208 | ubuntu-pro-client==8001
209 | ufw==0.36.1
210 | unattended-upgrades==0.1
211 | uri-template==1.3.0
212 | urllib3==2.2.3
213 | wadllib==1.3.6
214 | wandb==0.18.5
215 | wcwidth==0.2.13
216 | webcolors==24.8.0
217 | webencodings==0.5.1
218 | websocket-client==1.8.0
219 | widgetsnbextension==4.0.13
220 | xkit==0.0.0
221 | xxhash==3.5.0
222 | yarl==1.15.4
223 | zipp==1.0.0
224 | zope.interface==5.4.0
225 |
--------------------------------------------------------------------------------
/data_preprocess/generate_llm_emb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | import torch.nn.functional as F
7 | import os
8 | from sklearn.decomposition import PCA
9 | from transformers import (
10 | AutoTokenizer,
11 | AutoModel,
12 | TrainingArguments,
13 | PreTrainedModel,
14 | Trainer,
15 | )
16 | from transformers.modeling_outputs import TokenClassifierOutput
17 | from datasets import Dataset
18 |
19 |
20 | class CLSEmbInfModel(PreTrainedModel):
21 | def __init__(self, model):
22 | super().__init__(model.config)
23 | self.encoder = model
24 |
25 | @torch.no_grad()
26 | def forward(self, input_ids, attention_mask):
27 | # Extract outputs from the model3
28 | outputs = self.encoder(input_ids, attention_mask, output_hidden_states=True)
29 | node_cls_emb = outputs.last_hidden_state[:, 0, :] # Last layer
30 | return TokenClassifierOutput(logits=node_cls_emb)
31 |
32 |
33 |
34 | def main():
35 | parser = argparse.ArgumentParser(
36 | description="Process node and edge text data and save the overall representation as .pt files."
37 | )
38 | parser.add_argument(
39 | "--pkl_file",
40 | default="twitter/processed/twitter.pkl",
41 | type=str,
42 | help="Path to the Textual-Edge Graph .pkl file",
43 | )
44 | parser.add_argument(
45 | "--model_name",
46 | type=str,
47 | default="bert-base-uncased",
48 | help="Name or path of the Huggingface model",
49 | )
50 | parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased")
51 | parser.add_argument(
52 | "--name", type=str, default="twitter", help="Prefix name for the NPY file"
53 | )
54 | parser.add_argument(
55 | "--path", type=str, default="twitter/emb", help="Path to the .pt File"
56 | )
57 | parser.add_argument(
58 | "--pretrain_path", type=str, default=None, help="Path to the NPY File"
59 | )
60 | parser.add_argument(
61 | "--max_length",
62 | type=int,
63 | default=512,
64 | help="Maximum length of the text for language models",
65 | )
66 | parser.add_argument(
67 | "--batch_size",
68 | type=int,
69 | default=500,
70 | help="Number of batch size for inference",
71 | )
72 | parser.add_argument("--fp16", type=bool, default=True, help="if fp16")
73 | parser.add_argument(
74 | "--cls",
75 | action="store_true",
76 | default=True,
77 | help="whether use cls token to represent the whole text",
78 | )
79 | parser.add_argument(
80 | "--nomask",
81 | action="store_true",
82 | help="whether do not use mask to claculate the mean pooling",
83 | )
84 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for computation")
85 |
86 |
87 | parser.add_argument("--norm", type=bool, default=False, help="nomic use True")
88 |
89 | # 解析命令行参数
90 | args = parser.parse_args()
91 | pkl_file = args.pkl_file
92 | model_name = args.model_name
93 | name = args.name
94 | max_length = args.max_length
95 | batch_size = args.batch_size
96 |
97 | tokenizer_name = args.tokenizer_name
98 |
99 | root_dir = os.path.dirname(os.path.abspath(__file__))
100 | base_dir = os.path.dirname(root_dir.rstrip("/"))
101 | Feature_path = os.path.join(base_dir, args.path)
102 | cache_path = f"{Feature_path}cache/"
103 |
104 | if not os.path.exists(Feature_path):
105 | os.makedirs(Feature_path)
106 | if not os.path.exists(cache_path):
107 | os.makedirs(cache_path)
108 | print("Embedding Model:", model_name)
109 |
110 | if args.pretrain_path is not None:
111 | output_file = os.path.join(
112 | Feature_path
113 | , name
114 | + "_"
115 | + model_name.split("/")[-1].replace("-", "_")
116 | + "_"
117 | + str(max_length)
118 | + "_"
119 | + "Tuned"
120 | )
121 | else:
122 | output_file = os.path.join(
123 | Feature_path
124 | , name
125 | + "_"
126 | + model_name.split("/")[-1].replace("-", "_")
127 | + "_"
128 | + str(max_length)
129 | )
130 |
131 | print("output_file:", output_file)
132 |
133 | pkl_file = os.path.join(base_dir, args.pkl_file)
134 | with open(pkl_file, "rb") as f:
135 | data = pickle.load(f)
136 | text_data = data.text_nodes + data.text_edges
137 |
138 | # Load tokenizer and Model
139 | tokenizer = AutoTokenizer.from_pretrained(
140 | model_name, use_fast=True, trust_remote_code=True
141 | )
142 |
143 | if tokenizer.pad_token is None:
144 | tokenizer.pad_token = tokenizer.eos_token
145 |
146 | dataset = Dataset.from_dict({"text": text_data})
147 |
148 | def tokenize_function(examples):
149 | return tokenizer(examples["text"], padding=True, truncation=True, max_length=max_length, return_tensors="pt")
150 |
151 | dataset = dataset.map(tokenize_function, batched=True, batch_size=batch_size)
152 |
153 | if args.pretrain_path is not None:
154 | model = AutoModel.from_pretrained(f"{args.pretrain_path}", attn_implementation="eager")
155 | print("Loading model from the path: {}".format(args.pretrain_path))
156 | else:
157 | model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="eager")
158 |
159 | model = model.to(args.device)
160 | CLS_Feateres_Extractor = CLSEmbInfModel(model)
161 | CLS_Feateres_Extractor.eval()
162 |
163 | inference_args = TrainingArguments(
164 | output_dir=cache_path,
165 | do_train=False,
166 | do_predict=True,
167 | per_device_eval_batch_size=batch_size,
168 | dataloader_drop_last=False,
169 | dataloader_num_workers=12,
170 | fp16_full_eval=args.fp16,
171 | )
172 |
173 | # CLS representatoin
174 | if args.cls:
175 | if not os.path.exists(output_file + "_cls_node.pt") or not os.path.exists(output_file + "_cls_edge.pt"):
176 | trainer = Trainer(model=CLS_Feateres_Extractor, args=inference_args)
177 | cls_emb = trainer.predict(dataset)
178 | node_cls_emb = torch.from_numpy(cls_emb.predictions[: len(data.text_nodes)])
179 | edge_cls_emb = torch.from_numpy(cls_emb.predictions[len(data.text_nodes) :])
180 | torch.save(node_cls_emb, output_file + "_cls_node.pt")
181 | torch.save(edge_cls_emb, output_file + "_cls_edge.pt")
182 | print("Existing saved to the {}".format(output_file))
183 |
184 | else:
185 | print("Existing saved CLS")
186 | else:
187 | raise ValueError("others are not defined")
188 |
189 |
190 | if __name__ == "__main__":
191 | main()
192 |
--------------------------------------------------------------------------------
/GNN/model/GNN_arg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import wandb
4 |
5 |
6 | def args_init():
7 | argparser = argparse.ArgumentParser(
8 | "GNN(GCN,GIN,GraphSAGE,GAT) on OGBN-Arxiv/Amazon-Books/and so on. Node Classification tasks",
9 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
10 | )
11 | argparser.add_argument(
12 | "--cpu",
13 | action="store_true",
14 | help="CPU mode. This option overrides --gpu.",
15 | )
16 | argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
17 | argparser.add_argument(
18 | "--n-runs", type=int, default=3, help="running times"
19 | )
20 | argparser.add_argument(
21 | "--n-epochs", type=int, default=1000, help="number of epochs"
22 | )
23 | argparser.add_argument(
24 | "--lr", type=float, default=0.005, help="learning rate"
25 | )
26 | argparser.add_argument(
27 | "--n-layers", type=int, default=3, help="number of layers"
28 | )
29 | argparser.add_argument(
30 | "--n-hidden", type=int, default=256, help="number of hidden units"
31 | )
32 | argparser.add_argument(
33 | "--input-drop", type=float, default=0.1, help="input drop rate"
34 | )
35 | argparser.add_argument(
36 | "--learning-eps", type=bool, default=True,
37 | help="If True, learn epsilon to distinguish center nodes from neighbors;"
38 | "If False, aggregate neighbors and center nodes altogether."
39 | )
40 | argparser.add_argument("--wd", type=float, default=0, help="weight decay")
41 | argparser.add_argument(
42 | "--num-mlp-layers", type=int, default=2, help="number of mlp layers"
43 | )
44 | argparser.add_argument(
45 | "--neighbor-pooling-type", type=str, default='mean', help="how to aggregate neighbors (sum, mean, or max)"
46 | )
47 | # ! GAT
48 | argparser.add_argument(
49 | "--no-attn-dst", type=bool, default=True, help="Don't use attn_dst."
50 | )
51 | argparser.add_argument(
52 | "--n-heads", type=int, default=3, help="number of heads"
53 | )
54 | argparser.add_argument(
55 | "--attn-drop", type=float, default=0.0, help="attention drop rate"
56 | )
57 | argparser.add_argument(
58 | "--edge-drop", type=float, default=0.0, help="edge drop rate"
59 | )
60 | # ! SAGE
61 | argparser.add_argument("--aggregator-type", type=str, default="mean",
62 | help="Aggregator type: mean/gcn/pool/lstm")
63 | # ! JKNET
64 | argparser.add_argument(
65 | "--mode", type=str, default='cat', help="the mode of aggregate the feature, 'cat', 'lstm'"
66 | )
67 | #! Monet
68 | argparser.add_argument(
69 | "--pseudo-dim",type=int,default=2,help="Pseudo coordinate dimensions in GMMConv, 2 and 3",
70 | )
71 | argparser.add_argument(
72 | "--n-kernels",type=int,default=3,help="Number of kernels in GMMConv layer",
73 | )
74 | #! APPNP
75 | argparser.add_argument(
76 | "--alpha", type=float, default=0.1, help="Teleport Probability"
77 | )
78 | argparser.add_argument(
79 | "--k", type=int, default=10, help="Number of propagation steps"
80 | )
81 | argparser.add_argument(
82 | "--hidden_sizes", type=int, nargs="+", default=[64], help="hidden unit sizes for appnp",
83 | )
84 | # ! default
85 | argparser.add_argument(
86 | "--log-every", type=int, default=20, help="log every LOG_EVERY epochs"
87 | )
88 | argparser.add_argument(
89 | "--eval_steps", type=int, default=5, help="eval in every epochs"
90 | )
91 | argparser.add_argument(
92 | "--use_PLM", type=str, default=None, help="Use LM embedding as feature"
93 | )
94 | argparser.add_argument(
95 | "--model_name", type=str, default='RevGAT', help="Which GNN be implemented"
96 | )
97 | argparser.add_argument(
98 | "--dropout", type=float, default=0.5, help="dropout rate"
99 | )
100 | argparser.add_argument(
101 | "--data_name", type=str, default='ogbn-arxiv', help="The datasets to be implemented."
102 | )
103 | argparser.add_argument(
104 | "--metric", type=str, default='acc', help="The datasets to be implemented."
105 | )
106 | # ! Split datasets
107 | argparser.add_argument(
108 | "--train_ratio", type=float, default=0.6, help="training ratio"
109 | )
110 | argparser.add_argument(
111 | "--val_ratio", type=float, default=0.2, help="training ratio"
112 | )
113 | return argparser
114 |
115 | class Logger(object):
116 | def __init__(self, runs, info=None):
117 | self.info = info
118 | self.results = [[] for _ in range(runs)]
119 |
120 | def add_result(self, run, result):
121 | assert len(result) == 3
122 | assert run >= 0 and run < len(self.results)
123 | self.results[run].append(result)
124 |
125 | def print_statistics(self, run=None, key='Hits@10'):
126 | if run is not None:
127 | result = 100 * torch.tensor(self.results[run])
128 | argmax = result[:, 1].argmax().item()
129 | print(f'Run {run + 1:02d}:')
130 | print(f'Highest Train: {result[:, 0].max():.2f}')
131 | print(f'Highest Valid: {result[:, 1].max():.2f}')
132 | print(f' Final Train: {result[argmax, 0]:.2f}')
133 | print(f' Final Test: {result[argmax, 2]:.2f}')
134 | else:
135 | result = 100 * torch.tensor(self.results)
136 |
137 | best_results = []
138 | for r in result:
139 | train1 = r[:, 0].max().item()
140 | valid = r[:, 1].max().item()
141 | train2 = r[r[:, 1].argmax(), 0].item()
142 | test = r[r[:, 1].argmax(), 2].item()
143 | best_results.append((train1, valid, train2, test))
144 |
145 | best_result = torch.tensor(best_results)
146 |
147 | print(f'All runs:')
148 | r = best_result[:, 0]
149 | print(f'{key} Highest Train: {r.mean():.2f} ± {r.std():.2f}')
150 | wandb.log({f'{key} Highest Train Acc': float(f'{r.mean():.2f}'),
151 | f'{key} Highest Train Std': float(f'{r.std():.2f}')})
152 | r = best_result[:, 1]
153 | print(f'{key} Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
154 | wandb.log({f'{key} Highest Valid Acc': float(f'{r.mean():.2f}'),
155 | f'{key} Highest Valid Std': float(f'{r.std():.2f}')})
156 | r = best_result[:, 2]
157 | print(f'{key} Final Train: {r.mean():.2f} ± {r.std():.2f}')
158 | wandb.log(
159 | {f'{key} Final Train Acc': float(f'{r.mean():.2f}'), f'{key} Final Train Std': float(f'{r.std():.2f}')})
160 | r = best_result[:, 3]
161 | print(f'{key} Final Test: {r.mean():.2f} ± {r.std():.2f}')
162 | wandb.log(
163 | {f'{key} Final Test Acc': float(f'{r.mean():.2f}'), f'{key} Final Test Std': float(f'{r.std():.2f}')})
164 |
165 |
166 |
--------------------------------------------------------------------------------
/data_preprocess/dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stderr",
10 | "output_type": "stream",
11 | "text": [
12 | "/home/lizhuofeng/.local/lib/python3.10/site-packages/huggingface_hub/hf_api.py:9620: UserWarning: Warnings while validating metadata in README.md:\n",
13 | "- empty or missing yaml metadata in repo card\n",
14 | " warnings.warn(f\"Warnings while validating metadata in README.md:\\n{message}\")\n"
15 | ]
16 | }
17 | ],
18 | "source": [
19 | "from huggingface_hub import HfApi\n",
20 | "api = HfApi()\n",
21 | "\n",
22 | "api.upload_folder(\n",
23 | " folder_path=\"../Dataset\",\n",
24 | " repo_id=\"ZhuofengLi/TEG-Datasets\",\n",
25 | " repo_type=\"dataset\",\n",
26 | "\tignore_patterns=[\"**/logs/*.txt\", \"**/.ipynb_checkpoints\", \"**/__pycache__\", \"**/.cache\", \"**/embcache\", \"**/LinkPrediction\"]\n",
27 | ")"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 1,
33 | "metadata": {},
34 | "outputs": [
35 | {
36 | "data": {
37 | "application/vnd.jupyter.widget-view+json": {
38 | "model_id": "c484c0011f2a48d39490aa13140c778d",
39 | "version_major": 2,
40 | "version_minor": 0
41 | },
42 | "text/plain": [
43 | "Fetching 85 files: 0%| | 0/85 [00:00, ?it/s]"
44 | ]
45 | },
46 | "metadata": {},
47 | "output_type": "display_data"
48 | },
49 | {
50 | "data": {
51 | "application/vnd.jupyter.widget-view+json": {
52 | "model_id": "d6ad2a46c1f84e5cb773b83ad81afec0",
53 | "version_major": 2,
54 | "version_minor": 0
55 | },
56 | "text/plain": [
57 | ".gitattributes: 0%| | 0.00/4.28k [00:00, ?B/s]"
58 | ]
59 | },
60 | "metadata": {},
61 | "output_type": "display_data"
62 | },
63 | {
64 | "data": {
65 | "application/vnd.jupyter.widget-view+json": {
66 | "model_id": "849865916d95499f8aadd8f654a7f9b5",
67 | "version_major": 2,
68 | "version_minor": 0
69 | },
70 | "text/plain": [
71 | "output_requests_apps_edge.jsonl: 0%| | 0.00/342M [00:00, ?B/s]"
72 | ]
73 | },
74 | "metadata": {},
75 | "output_type": "display_data"
76 | },
77 | {
78 | "data": {
79 | "application/vnd.jupyter.widget-view+json": {
80 | "model_id": "50574cf3f1e34595ba081c4aa6938963",
81 | "version_major": 2,
82 | "version_minor": 0
83 | },
84 | "text/plain": [
85 | "arxiv/arxiv.md: 0%| | 0.00/1.19k [00:00, ?B/s]"
86 | ]
87 | },
88 | "metadata": {},
89 | "output_type": "display_data"
90 | },
91 | {
92 | "data": {
93 | "application/vnd.jupyter.widget-view+json": {
94 | "model_id": "db779c2f324f4e6ea4b6e5e6487d9b65",
95 | "version_major": 2,
96 | "version_minor": 0
97 | },
98 | "text/plain": [
99 | "(…)put_requests_children_edge_results.jsonl: 0%| | 0.00/869k [00:00, ?B/s]"
100 | ]
101 | },
102 | "metadata": {},
103 | "output_type": "display_data"
104 | },
105 | {
106 | "data": {
107 | "application/vnd.jupyter.widget-view+json": {
108 | "model_id": "4ccfe3fbed9f456d8aedc6035e94180b",
109 | "version_major": 2,
110 | "version_minor": 0
111 | },
112 | "text/plain": [
113 | "arxiv.pkl: 0%| | 0.00/389M [00:00, ?B/s]"
114 | ]
115 | },
116 | "metadata": {},
117 | "output_type": "display_data"
118 | },
119 | {
120 | "data": {
121 | "application/vnd.jupyter.widget-view+json": {
122 | "model_id": "aadcc1bbeeff4f81a5112e17eb72abe2",
123 | "version_major": 2,
124 | "version_minor": 0
125 | },
126 | "text/plain": [
127 | "output_requests_children_edge.jsonl: 0%| | 0.00/369M [00:00, ?B/s]"
128 | ]
129 | },
130 | "metadata": {},
131 | "output_type": "display_data"
132 | },
133 | {
134 | "data": {
135 | "application/vnd.jupyter.widget-view+json": {
136 | "model_id": "b67034ac11c148848c0cd7ab713aa677",
137 | "version_major": 2,
138 | "version_minor": 0
139 | },
140 | "text/plain": [
141 | "reddit/reddit.md: 0%| | 0.00/1.03k [00:00, ?B/s]"
142 | ]
143 | },
144 | "metadata": {},
145 | "output_type": "display_data"
146 | },
147 | {
148 | "data": {
149 | "application/vnd.jupyter.widget-view+json": {
150 | "model_id": "329cfdc7c7eb4cf891a747d33cb0f013",
151 | "version_major": 2,
152 | "version_minor": 0
153 | },
154 | "text/plain": [
155 | "twitter_openai_edge.pt: 0%| | 0.00/76.5M [00:00, ?B/s]"
156 | ]
157 | },
158 | "metadata": {},
159 | "output_type": "display_data"
160 | },
161 | {
162 | "data": {
163 | "application/vnd.jupyter.widget-view+json": {
164 | "model_id": "941dc4d8e96f46fd8f23883c4cf4f13b",
165 | "version_major": 2,
166 | "version_minor": 0
167 | },
168 | "text/plain": [
169 | "twitter_openai_node.pt: 0%| | 0.00/62.2M [00:00, ?B/s]"
170 | ]
171 | },
172 | "metadata": {},
173 | "output_type": "display_data"
174 | },
175 | {
176 | "data": {
177 | "application/vnd.jupyter.widget-view+json": {
178 | "model_id": "35ee8f5f8b444ae7b2ff65011a2da878",
179 | "version_major": 2,
180 | "version_minor": 0
181 | },
182 | "text/plain": [
183 | "twitter/twitter.md: 0%| | 0.00/634 [00:00, ?B/s]"
184 | ]
185 | },
186 | "metadata": {},
187 | "output_type": "display_data"
188 | },
189 | {
190 | "data": {
191 | "text/plain": [
192 | "'/home/lizhuofeng/TEG-Benchmark/Dataset'"
193 | ]
194 | },
195 | "execution_count": 1,
196 | "metadata": {},
197 | "output_type": "execute_result"
198 | }
199 | ],
200 | "source": [
201 | "from huggingface_hub import snapshot_download\n",
202 | "\n",
203 | "snapshot_download(repo_id=\"ZhuofengLi/TEG-Datasets\", repo_type=\"dataset\", local_dir=\"../Dataset\")"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "metadata": {},
210 | "outputs": [],
211 | "source": []
212 | }
213 | ],
214 | "metadata": {
215 | "kernelspec": {
216 | "display_name": "Python 3",
217 | "language": "python",
218 | "name": "python3"
219 | },
220 | "language_info": {
221 | "codemirror_mode": {
222 | "name": "ipython",
223 | "version": 3
224 | },
225 | "file_extension": ".py",
226 | "mimetype": "text/x-python",
227 | "name": "python",
228 | "nbconvert_exporter": "python",
229 | "pygments_lexer": "ipython3",
230 | "version": "3.10.12"
231 | }
232 | },
233 | "nbformat": 4,
234 | "nbformat_minor": 2
235 | }
236 |
--------------------------------------------------------------------------------
/GNN/GNN_Link_AUC.py:
--------------------------------------------------------------------------------
1 | # sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))
2 |
3 | import numpy as np
4 | import torch
5 | import pickle
6 | import torch.nn.functional as F
7 | import torch_geometric.transforms as T
8 | from torch_sparse import SparseTensor
9 | import tqdm
10 | from sklearn.metrics import roc_auc_score
11 | from torch_geometric import seed_everything
12 | from torch_geometric.loader import LinkNeighborLoader
13 | from model.GNN_library import GAT, GINE, GeneralGNN, GraphTransformer
14 | from torch.nn import Linear
15 | from sklearn.metrics import f1_score
16 | import argparse
17 | from torch_geometric.data import Data
18 |
19 |
20 | class Classifier(torch.nn.Module):
21 | def __init__(self, hidden_channels):
22 | super().__init__()
23 | self.lin1 = Linear(2 * hidden_channels, hidden_channels)
24 | self.lin2 = Linear(hidden_channels, 1)
25 |
26 | def forward(self, x, edge_label_index):
27 | # Convert node embeddings to edge-level representations:
28 | edge_feat_src = x[edge_label_index[0]]
29 | edge_feat_dst = x[edge_label_index[1]]
30 |
31 | z = torch.cat([edge_feat_src, edge_feat_dst], dim=-1)
32 | z = self.lin1(z).relu()
33 | z = self.lin2(z)
34 | return z.view(-1)
35 |
36 |
37 | def gen_model(args, x, edge_feature):
38 | if args.gnn_model == "GAT":
39 | model = GAT(
40 | x.size(1),
41 | edge_feature.size(1),
42 | args.hidden_channels,
43 | args.hidden_channels,
44 | args.num_layers,
45 | args.heads,
46 | args.dropout,
47 | )
48 | elif args.gnn_model == "GraphTransformer":
49 | model = GraphTransformer(
50 | x.size(1),
51 | edge_feature.size(1),
52 | args.hidden_channels,
53 | args.hidden_channels,
54 | args.num_layers,
55 | args.dropout,
56 | )
57 | elif args.gnn_model == "GINE":
58 | model = GINE(
59 | x.size(1),
60 | edge_feature.size(1),
61 | args.hidden_channels,
62 | args.hidden_channels,
63 | args.num_layers,
64 | args.dropout,
65 | )
66 | elif args.gnn_model == "GeneralGNN":
67 | model = GeneralGNN(
68 | x.size(1),
69 | edge_feature.size(1),
70 | args.hidden_channels,
71 | args.hidden_channels,
72 | args.num_layers,
73 | args.dropout,
74 | )
75 | else:
76 | raise ValueError("Not implemented")
77 | return model
78 |
79 | if __name__ == "__main__":
80 | seed_everything(66)
81 |
82 | parser = argparse.ArgumentParser()
83 | parser.add_argument("--device", type=int, default=0)
84 | parser.add_argument(
85 | "--gnn_model",
86 | "-gm",
87 | type=str,
88 | default="GAT",
89 | help="Model type for HeteroGNN, options are GraphTransformer, GINE, GraphSAGE, GeneralConv, MLP, EdgeConv,RevGAT",
90 | )
91 | parser.add_argument(
92 | "--use_PLM_node",
93 | type=str,
94 | default="data/CSTAG/Photo/Feature/children_gpt_node.pt",
95 | help="Use LM embedding as node feature",
96 | )
97 | parser.add_argument(
98 | "--use_PLM_edge",
99 | type=str,
100 | default="data/CSTAG/Photo/Feature/children_gpt_edge.pt",
101 | help="Use LM embedding as edge feature",
102 | )
103 | parser.add_argument(
104 | "--graph_path",
105 | type=str,
106 | default="data/CSTAG/Photo/children.pkl",
107 | help="Path to load the graph",
108 | )
109 | parser.add_argument("--eval_steps", type=int, default=1)
110 | parser.add_argument("--hidden_channels", type=int, default=256)
111 | parser.add_argument("--heads", type=int, default=4)
112 | parser.add_argument("--num_layers", type=int, default=2)
113 | parser.add_argument("--dropout", type=float, default=0.0)
114 | parser.add_argument("--epochs", type=int, default=100)
115 | parser.add_argument("--threshold", type=float, default=0.5)
116 | parser.add_argument("--test_ratio", type=float, default=0.1)
117 | parser.add_argument("--val_ratio", type=float, default=0.1)
118 | parser.add_argument("--batch_size", type=int, default=1024)
119 | args = parser.parse_args()
120 |
121 | device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
122 | device = torch.device(device)
123 |
124 | with open(f"{args.graph_path}", "rb") as f:
125 | data = pickle.load(f)
126 |
127 | num_nodes = len(data.text_nodes)
128 | num_edges = len(data.text_edges)
129 |
130 | x = torch.load(args.use_PLM_node).squeeze().float()
131 | edge_feature = torch.load(args.use_PLM_edge).squeeze().float()
132 | data = Data(x=x, edge_index=data.edge_index, edge_attr=edge_feature)
133 |
134 | print(data)
135 |
136 | train_data, val_data, test_data = T.RandomLinkSplit(
137 | num_val=args.val_ratio,
138 | num_test=args.test_ratio,
139 | disjoint_train_ratio=0.3,
140 | neg_sampling_ratio=1.0,
141 | )(data)
142 |
143 | # Perform a link-level split into training, validation, and test edges:
144 | edge_label_index = train_data.edge_label_index
145 | edge_label = train_data.edge_label
146 | train_loader = LinkNeighborLoader(
147 | data=train_data,
148 | num_neighbors=[20, 10],
149 | edge_label_index=(edge_label_index),
150 | edge_label=edge_label,
151 | batch_size=args.batch_size,
152 | shuffle=True,
153 | )
154 |
155 | edge_label_index = val_data.edge_label_index
156 | edge_label = val_data.edge_label
157 | val_loader = LinkNeighborLoader(
158 | data=val_data,
159 | num_neighbors=[20, 10],
160 | edge_label_index=(edge_label_index),
161 | edge_label=edge_label,
162 | batch_size=args.batch_size,
163 | shuffle=False,
164 | )
165 |
166 | edge_label_index = test_data.edge_label_index
167 | edge_label = test_data.edge_label
168 | test_loader = LinkNeighborLoader(
169 | data=test_data,
170 | num_neighbors=[20, 10],
171 | edge_label_index=(edge_label_index),
172 | edge_label=edge_label,
173 | batch_size=args.batch_size,
174 | shuffle=False,
175 | )
176 |
177 | model = gen_model(args, x, edge_feature).to(device)
178 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
179 | print(device)
180 |
181 | model = model.to(device)
182 | predictor = Classifier(args.hidden_channels).to(device)
183 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
184 |
185 | for epoch in range(1, 1 + args.epochs):
186 | total_loss = total_examples = 0
187 | for sampled_data in tqdm.tqdm(train_loader):
188 | optimizer.zero_grad()
189 | sampled_data = sampled_data.to(device)
190 | adj_t = SparseTensor.from_edge_index(
191 | sampled_data.edge_index, sampled_data.edge_attr
192 | ).t()
193 | adj_t = adj_t.to_symmetric()
194 | x = model(sampled_data.x, adj_t)
195 | pred = predictor(x, sampled_data.edge_label_index)
196 | ground_truth = sampled_data.edge_label
197 | loss = F.binary_cross_entropy_with_logits(pred, ground_truth)
198 | loss.backward()
199 | optimizer.step()
200 | total_loss += float(loss) * pred.numel()
201 | total_examples += pred.numel()
202 | print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")
203 |
204 | # validation
205 | if epoch % args.eval_steps == 0 and epoch != 0:
206 | print("Validation begins")
207 | with torch.no_grad():
208 | preds = []
209 | ground_truths = []
210 | for sampled_data in tqdm.tqdm(test_loader): # TODO: use val_loader
211 | sampled_data = sampled_data.to(device)
212 | adj_t = SparseTensor.from_edge_index(
213 | sampled_data.edge_index, sampled_data.edge_attr
214 | ).t()
215 | adj_t = adj_t.to_symmetric()
216 | x = model(sampled_data.x, adj_t)
217 | pred = predictor(x, sampled_data.edge_label_index)
218 | preds.append(pred)
219 | ground_truths.append(sampled_data.edge_label)
220 |
221 | preds = torch.cat(preds, dim=0).cpu().numpy()
222 | ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
223 |
224 | y_label = np.where(preds >= args.threshold, 1, 0)
225 | f1 = f1_score(ground_truth, y_label)
226 | print(f"F1 score: {f1:.4f}")
227 |
228 | # AUC
229 | auc = roc_auc_score(ground_truth, preds)
230 | print(f"Validation AUC: {auc:.4f}")
231 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## Why TEGs instead of TAGs? 🤔
6 |
7 | Textual-Edge Graphs (TEGs) incorporate textual content on **both nodes and edges**, unlike Text-Attributed Graphs (TAGs) featuring textual information **only at the nodes**. **Edge texts are crucial for understanding document meanings and semantic relationships.** For instance, as shown below, to understand the knowledge "Planck endorsed the uncertainty and probabilistic nature of quantum mechanics," **citation edge (Book D - Paper E) text information is essential**. This reveals the comprehensive connections and influences among scholarly works, enabling a deeper analysis of document semantics and knowledge networks.
8 |
9 | 
10 |
11 | ## Overview 📚
12 |
13 | Textual-Edge Graphs Datasets and Benchmark (TEG-DB) is a comprehensive and diverse collection of benchmark textual-edge datasets featuring rich textual descriptions on nodes and edges, data loaders, and performance benchmarks for various baseline models, including pre-trained language models (PLMs), graph neural networks (GNNs), and their combinations. This repository aims to facilitate research in the domain of textual-edge graphs by providing standardized data formats and easy-to-use tools for model evaluation and comparison.
14 |
15 | ## Features ✨
16 |
17 | + **Unified Data Representation:** All TEG datasets are represented in a unified format. This standardization allows for easy extension of new datasets into our benchmark.
18 | + **Highly Efficient Pipeline:** TEG-Benchmark is highly integrated with PyTorch Geometric (PyG), leveraging its powerful tools and functionalities. Therefore, its code is concise. Specifically, for each paradigm, we provide a small `.py` file with a summary of all relevant models and a `.ssh` file to run all baselines in one click.
19 | + **Comprehensive Benchmark and Analysis:** We conduct extensive benchmark experiments and perform a comprehensive analysis of TEG-based methods, delving deep into various aspects such as the impact of different models, the effect of embeddings generated by Pre-trained Language Models (PLMs) of various scales, and the influence of different domain datasets. The statistics of our TEG datasets are as follows:
20 |
21 |
22 |
23 |
24 |
25 | ## Datasets 📊
26 |
27 | Please click [Huggingface TEG-Benchmark](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets) to find the TEG datasets we upload!
28 |
29 | We have constructed **10 comprehensive and representative TEG datasets (we will continue to expand)**. These datasets cover domains including Book Recommendation, E-commerce, Academic, and Social networks. They vary in size, ranging from small to large. Each dataset contains rich raw text data on both nodes and edges, providing a diverse range of information for analysis and modeling purposes.
30 |
31 | | Dataset | Nodes | Edges | Graph Domain | Node Classification | Link Prediction | Details |
32 | |--------------------|-----------|------------|---------------------|---------------------|------------------|-------------------------------------|
33 | | Goodreads-History | 540,796 | 2,066,193 | Book Recommendation | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/goodreads_history/goodreads.md) |
34 | | Goodreads-Crime | 422,642 | 1,849,236 | Book Recommendation | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/goodreads_crime/goodreads.md) |
35 | | Goodreads-Children | 216,613 | 734,640 | Book Recommendation | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/goodreads_children/goodreads.md) |
36 | | Goodreads-Comics | 148,658 | 542,338 | Book Recommendation | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/goodreads_comics/goodreads.md) |
37 | | Amazon-Movie | 174,012 | 1,697,533 | E-commerce | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/amazon_movie/movie.md) |
38 | | Amazon-Apps | 100,480 | 752,937 | E-commerce | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/amazon_apps/apps.md) |
39 | | Amazon-Baby | 186,790 | 1,241,083 | E-commerce | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/amazon_baby/baby.md) |
40 | | Reddit | 512,776 | 256,388 | Social Networks | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/reddit/reddit.md) |
41 | | Twitter | 60,785 | 74,666 | Social Networks | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/twitter/twitter.md) |
42 | | Citation | 169,343 | 1,166,243 | Academic | ✅ | ✅ | [Link](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets/blob/main/arxiv/arxiv.md) |
43 |
44 |
45 |
46 | **TEG-DB is an ongoing effort, and we are planning to increase our coverage in the future.**
47 |
48 | ## Our Experiments 🔬
49 |
50 | Please check the experimental results and analysis from our [paper](https://arxiv.org/abs/2406.10310).
51 |
52 | ## Package Usage ⚙️
53 | ### Requirements 📋
54 | You can quickly install the corresponding dependencies,
55 |
56 | ```bash
57 | pip install -r requirements.txt
58 | ```
59 |
60 | ### More details about packages 📦
61 |
62 | The `GNN` folder is intended to house all GNN methods for link prediction and node classification (Please click [here](GNN/readme.md) to see more training details). The `data_preprocess` folder contains various scripts for generating TEG data. For instance, `data_preprocess/generate_llm_emb.py` generates text embeddings from LLM for the TEG in a high speed, while `data_preprocess/dataset.ipynb` facilitates uploading and downloading the [TEG datasets](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets) we provide from Hugging Face.
63 |
64 | ### Datasets setup 🛠️
65 |
66 | You can go to the [Huggingface TEG-Benchmark](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets) to find the datasets we upload! In each dataset folder, you can find a `.py` (downloads raw data) and a `.ipynb` file (processes raw data to TEG) in `raw` folder, `.pt` files (node and edge text embeddings we extract from the PLM) in `emb` folder, and a `.pkl` file processed from raw data (final Textual-edge graph file).
67 |
68 | **Download all datasets:**
69 | ```python
70 | from huggingface_hub import snapshot_download
71 |
72 | snapshot_download(repo_id="ZhuofengLi/TEG-Datasets", repo_type="dataset", local_dir="./Dataset")
73 | ```
74 |
75 | ### Running Experiments 🚀
76 | **GNN for Link Preidction:**
77 | ```bash
78 | bash run_gnn_link.sh
79 | ```
80 |
81 | **GNN for Node Classification:**
82 | ```bash
83 | bash run_gnn_node.sh
84 | ```
85 |
86 | Please click [here](GNN/readme.md) to see more details.
87 |
88 | ## Create Your Own Dataset and Model 🛠️
89 | You can follow the steps below:
90 | + Align your dataset format with ours. Please check the dataset [README](https://huggingface.co/datasets/ZhuofengLi/TEG-Datasets) to find more details.
91 | + Generate your raw text embedding through [`generate_llm_emb.py`](https://github.com/Zhuofeng-Li/TEG-Benchmark/blob/main/data_preprocess/generate_llm_emb.py) with high speed.
92 | + Add your model to [`GNN/model/GNN_library.py`](https://github.com/Zhuofeng-Li/TEG-Benchmark/blob/main/GNN/model/GNN_library.py).
93 | + Update training codes' `args` and `gen_model()` function.
94 |
95 | ## Acknowledgements
96 | We would like to express our gratitude to the following projects and organizations for their contributions:
97 |
98 | + [semantic scholar](https://www.semanticscholar.org/)
99 | + [ogb](https://github.com/snap-stanford/ogb)
100 |
101 |
102 |
103 | ## Star and Cite ⭐
104 |
105 | Please star our repo 🌟 and cite our [paper](https://arxiv.org/abs/2406.10310) if you feel useful. Feel free to [email](mailto:zhuofengli12345@gmail.com) us (zhuofengli12345@gmail.com) if you have any questions.
106 |
107 | ```
108 | @misc{li2024tegdb,
109 | title={TEG-DB: A Comprehensive Dataset and Benchmark of Textual-Edge Graphs},
110 | author={Zhuofeng Li and Zixing Gou and Xiangnan Zhang and Zhongyuan Liu and Sirui Li and Yuntong Hu and Chen Ling and Zheng Zhang and Liang Zhao},
111 | year={2024},
112 | eprint={2406.10310},
113 | archivePrefix={arXiv},
114 | primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
115 | }
116 | ```
117 |
--------------------------------------------------------------------------------
/GNN/GNN_Node.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pickle
3 | from torch_sparse import SparseTensor
4 | import tqdm
5 | from sklearn.metrics import roc_auc_score
6 | from torch_geometric import seed_everything
7 | from torch_geometric.loader import NeighborLoader
8 | from model.GNN_library import GAT, GINE, GeneralGNN, GraphTransformer
9 | from torch.nn import Linear
10 | from sklearn.metrics import f1_score
11 | from sklearn.metrics import accuracy_score
12 | import argparse
13 | from torch_geometric.data import Data
14 | from sklearn.preprocessing import MultiLabelBinarizer
15 |
16 |
17 | class Classifier(torch.nn.Module):
18 | def __init__(self, hidden_channels, out_channels):
19 | super().__init__()
20 | self.lin1 = Linear(hidden_channels, hidden_channels)
21 | self.lin2 = Linear(hidden_channels, out_channels)
22 |
23 | def forward(self, x):
24 | x = self.lin1(x).relu()
25 | x = self.lin2(x)
26 | return x
27 |
28 |
29 | def gen_model(args, x, edge_feature):
30 | if args.gnn_model == "GAT":
31 | model = GAT(
32 | x.size(1),
33 | edge_feature.size(1),
34 | args.hidden_channels,
35 | args.hidden_channels,
36 | args.num_layers,
37 | args.heads,
38 | args.dropout,
39 | )
40 | elif args.gnn_model == "GraphTransformer":
41 | model = GraphTransformer(
42 | x.size(1),
43 | edge_feature.size(1),
44 | args.hidden_channels,
45 | args.hidden_channels,
46 | args.num_layers,
47 | args.dropout,
48 | )
49 | elif args.gnn_model == "GINE":
50 | model = GINE(
51 | x.size(1),
52 | edge_feature.size(1),
53 | args.hidden_channels,
54 | args.hidden_channels,
55 | args.num_layers,
56 | args.dropout,
57 | )
58 | elif args.gnn_model == "GeneralGNN":
59 | model = GeneralGNN(
60 | x.size(1),
61 | edge_feature.size(1),
62 | args.hidden_channels,
63 | args.hidden_channels,
64 | args.num_layers,
65 | args.dropout,
66 | )
67 | else:
68 | raise ValueError("Not implemented")
69 | return model
70 |
71 |
72 | def train(model, predictor, train_loader, optimizer, criterion):
73 | model.train()
74 | predictor.train()
75 |
76 | total_loss = total_examples = 0
77 | for batch in tqdm.tqdm(train_loader):
78 | optimizer.zero_grad()
79 | batch = batch.to(device)
80 | h = model(batch.x, batch.adj_t)[: batch.batch_size]
81 | pred = predictor(h)
82 | loss = criterion(pred, batch.y[: batch.batch_size].float())
83 | loss.backward()
84 | optimizer.step()
85 | total_loss += float(loss) * pred.numel()
86 | total_examples += pred.numel()
87 | return total_examples, total_loss
88 |
89 |
90 | def process_data(args, device, data): # TODO: process all data here
91 | num_nodes = len(data.text_nodes)
92 | data.num_nodes = num_nodes
93 |
94 | product_indices = torch.tensor(
95 | [
96 | i for i, label in enumerate(data.text_node_labels) if label != -1
97 | ] # TODO: check the label in each dataset
98 | ).long()
99 | product_labels = [label for label in data.text_node_labels if label != -1]
100 | mlb = MultiLabelBinarizer()
101 | product_binary_labels = mlb.fit_transform(product_labels)
102 | y = torch.zeros(num_nodes, product_binary_labels.shape[1]).float()
103 | y[product_indices] = torch.tensor(product_binary_labels).float()
104 | y = y.to(device) # TODO: check the label
105 |
106 | train_ratio = 1 - args.test_ratio - args.val_ratio
107 | val_ratio = args.val_ratio
108 |
109 | num_products = product_indices.shape[0]
110 | train_idx = product_indices[: int(num_products * train_ratio)]
111 | val_idx = product_indices[
112 | int(num_products * train_ratio) : int(num_products * (train_ratio + val_ratio))
113 | ]
114 | test_idx = product_indices[
115 | int(product_indices.shape[0] * (train_ratio + val_ratio)) :
116 | ]
117 |
118 | x = torch.load(args.use_PLM_node).squeeze().float()
119 | edge_feature = torch.load(args.use_PLM_edge).squeeze().float()
120 |
121 | edge_index = data.edge_index
122 | adj_t = SparseTensor.from_edge_index(
123 | edge_index, edge_feature, sparse_sizes=(data.num_nodes, data.num_nodes)
124 | ).t()
125 | adj_t = adj_t.to_symmetric()
126 | node_split = {"train": train_idx, "val": val_idx, "test": test_idx}
127 | return node_split, x, edge_feature, adj_t, y
128 |
129 |
130 | def test(model, predictor, test_loader, args):
131 | print("Validation begins")
132 | with torch.no_grad():
133 | preds = []
134 | ground_truths = []
135 | for batch in tqdm.tqdm(test_loader):
136 | batch = batch.to(device)
137 | h = model(batch.x, batch.adj_t)[: batch.batch_size]
138 | pred = predictor(h)
139 | ground_truth = batch.y[: batch.batch_size]
140 | preds.append(pred)
141 | ground_truths.append(ground_truth)
142 |
143 | preds = torch.cat(preds, dim=0).cpu().numpy()
144 | ground_truths = torch.cat(ground_truths, dim=0).cpu().numpy()
145 |
146 | y_label = (preds > args.threshold).astype(int)
147 | f1 = f1_score(ground_truths, y_label, average="weighted")
148 | print(f"F1 score: {f1:.4f}")
149 |
150 | data_type = args.graph_path.split("/")[-1]
151 | if data_type not in [
152 | "twitter.pkl",
153 | "reddit.pkl",
154 | "citation.pkl",
155 | ]: # TODO: check metric in
156 | auc = roc_auc_score(ground_truths, preds, average="micro")
157 | print(f"Validation AUC: {auc:.4f}")
158 | else:
159 | accuracy = accuracy_score(ground_truths, y_label)
160 | print(f"Validation Accuracy: {accuracy:.4f}")
161 |
162 |
163 | if __name__ == "__main__":
164 | # TODO: use wandb as log
165 | seed_everything(66)
166 |
167 | parser = argparse.ArgumentParser()
168 | parser.add_argument("--device", type=int, default=0)
169 | parser.add_argument(
170 | "--gnn_model",
171 | "-gm",
172 | type=str,
173 | default="GAT",
174 | help="Model type for HeteroGNN, options are GraphTransformer, GINE, GraphSAGE, GeneralConv, MLP, EdgeConv,RevGAT",
175 | )
176 | parser.add_argument(
177 | "--use_PLM_node",
178 | type=str,
179 | default="data/CSTAG/Photo/Feature/children_gpt_node.pt",
180 | help="Use LM embedding as node feature",
181 | )
182 | parser.add_argument(
183 | "--use_PLM_edge",
184 | type=str,
185 | default="data/CSTAG/Photo/Feature/children_gpt_edge.pt",
186 | help="Use LM embedding as edge feature",
187 | )
188 | parser.add_argument(
189 | "--graph_path",
190 | type=str,
191 | default="../../TEG-Datasets/goodreads_children/processed/children.pkl", # "data/CSTAG/Photo/children.pkl",
192 | help="Path to load the graph",
193 | )
194 | parser.add_argument("--eval_steps", type=int, default=1)
195 | parser.add_argument("--hidden_channels", type=int, default=256)
196 | parser.add_argument("--heads", type=int, default=4)
197 | parser.add_argument("--num_layers", type=int, default=2)
198 | parser.add_argument("--dropout", type=float, default=0.0)
199 | parser.add_argument("--epochs", type=int, default=100)
200 | parser.add_argument("--threshold", type=float, default=0.5)
201 | parser.add_argument("--test_ratio", type=float, default=0.1)
202 | parser.add_argument("--val_ratio", type=float, default=0.1)
203 | parser.add_argument("--batch_size", type=int, default=1024)
204 | args = parser.parse_args()
205 |
206 | device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
207 | device = torch.device(device)
208 |
209 | with open(f"{args.graph_path}", "rb") as f:
210 | data = pickle.load(f)
211 |
212 | node_split, x, edge_feature, adj_t, y = process_data(args, device, data)
213 | data = Data(x=x, adj_t=adj_t, y=y)
214 |
215 | train_loader = NeighborLoader(
216 | data,
217 | input_nodes=node_split["train"],
218 | num_neighbors=[10, 10],
219 | batch_size=args.batch_size,
220 | shuffle=True,
221 | )
222 | val_loader = NeighborLoader(
223 | data,
224 | input_nodes=node_split["val"],
225 | num_neighbors=[10, 10],
226 | batch_size=args.batch_size,
227 | shuffle=False,
228 | )
229 | test_loader = NeighborLoader(
230 | data,
231 | input_nodes=node_split["test"],
232 | num_neighbors=[10, 10],
233 | batch_size=args.batch_size,
234 | shuffle=False,
235 | ) # TODO: move `num_neighbors` to args
236 |
237 | model = gen_model(args, x, edge_feature)
238 | model = model.to(device)
239 | predictor = Classifier(args.hidden_channels, y.shape[1]).to(device)
240 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
241 | criterion = torch.nn.CrossEntropyLoss()
242 |
243 | for epoch in range(1, 1 + args.epochs):
244 | total_examples, total_loss = train(
245 | model, predictor, train_loader, optimizer, criterion
246 | )
247 | print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")
248 |
249 | # validation
250 | if epoch % args.eval_steps == 0:
251 | test(model, predictor, test_loader, args)
252 |
--------------------------------------------------------------------------------
/GNN/GNN_Link_MRR.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.utils.data import DataLoader
7 |
8 | from torch_sparse import SparseTensor
9 |
10 | from model.GNN_library import GAT, GINE, GeneralGNN, GraphTransformer
11 | from model.Dataloader import Evaluator, split_edge_mrr
12 | from model.GNN_arg import Logger
13 | import wandb
14 | import os
15 |
16 |
17 | class LinkPredictor(torch.nn.Module):
18 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
19 | super(LinkPredictor, self).__init__()
20 |
21 | self.lins = torch.nn.ModuleList()
22 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
23 | for _ in range(num_layers - 2):
24 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
25 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels))
26 |
27 | self.dropout = dropout
28 |
29 | def reset_parameters(self):
30 | for lin in self.lins:
31 | lin.reset_parameters()
32 |
33 | def forward(self, x_i, x_j):
34 | x = x_i * x_j
35 | for lin in self.lins[:-1]:
36 | x = lin(x)
37 | x = F.relu(x)
38 | x = F.dropout(x, p=self.dropout, training=self.training)
39 | x = self.lins[-1](x)
40 | return torch.sigmoid(x)
41 |
42 |
43 | def train(model, predictor, x, adj_t, edge_split, optimizer, batch_size):
44 | model.train()
45 | predictor.train()
46 |
47 | source_edge = edge_split["train"]["source_node"].to(x.device)
48 | target_edge = edge_split["train"]["target_node"].to(x.device)
49 |
50 | total_loss = total_examples = 0
51 | for perm in DataLoader(range(source_edge.size(0)), batch_size, shuffle=True):
52 | optimizer.zero_grad()
53 |
54 | h = model(x, adj_t)
55 | src, dst = source_edge[perm], target_edge[perm]
56 | pos_out = predictor(h[src], h[dst])
57 | pos_loss = -torch.log(pos_out + 1e-15).mean()
58 |
59 | # Just do some trivial random sampling.
60 | dst_neg = torch.randint(
61 | 0, x.size(0), src.size(), dtype=torch.long, device=h.device
62 | )
63 | neg_out = predictor(h[src], h[dst_neg])
64 | neg_loss = -torch.log(1 - neg_out + 1e-15).mean()
65 |
66 | loss = pos_loss + neg_loss
67 | loss.backward()
68 | optimizer.step()
69 |
70 | num_examples = pos_out.size(0)
71 | total_loss += loss.item() * num_examples
72 | total_examples += num_examples
73 |
74 | return total_loss / total_examples
75 |
76 |
77 | @torch.no_grad()
78 | def test(model, predictor, x, adj_t, edge_split, evaluator, batch_size, neg_len):
79 | model.eval()
80 | predictor.eval()
81 |
82 | h = model(x, adj_t)
83 |
84 | def test_split(split, neg_len):
85 | source = edge_split[split]["source_node"].to(h.device)
86 | target = edge_split[split]["target_node"].to(h.device)
87 | target_neg = edge_split[split]["target_node_neg"].to(h.device)
88 |
89 | pos_preds = []
90 | for perm in DataLoader(range(source.size(0)), batch_size):
91 | src, dst = source[perm], target[perm]
92 | pos_preds += [predictor(h[src], h[dst]).squeeze().cpu()]
93 | pos_pred = torch.cat(pos_preds, dim=0)
94 |
95 | neg_preds = []
96 | source = source.view(-1, 1).repeat(1, neg_len).view(-1)
97 | target_neg = target_neg.view(-1)
98 | for perm in DataLoader(range(source.size(0)), batch_size):
99 | src, dst_neg = source[perm], target_neg[perm]
100 | neg_preds += [predictor(h[src], h[dst_neg]).squeeze().cpu()]
101 | neg_pred = torch.cat(neg_preds, dim=0).view(-1, neg_len)
102 |
103 | return (
104 | evaluator.eval(
105 | {
106 | "y_pred_pos": pos_pred,
107 | "y_pred_neg": neg_pred,
108 | }
109 | )["mrr_list"]
110 | .mean()
111 | .item()
112 | )
113 |
114 | train_mrr = test_split("eval_train", neg_len)
115 | valid_mrr = test_split("valid", neg_len)
116 | test_mrr = test_split("test", neg_len)
117 |
118 | return train_mrr, valid_mrr, test_mrr
119 |
120 |
121 | def main():
122 | parser = argparse.ArgumentParser(description="Link-Prediction PLM/TCL")
123 | parser.add_argument("--device", type=int, default=0)
124 | parser.add_argument("--log_steps", type=int, default=5)
125 | parser.add_argument("--use_node_embedding", action="store_true")
126 | parser.add_argument("--num_layers", type=int, default=2)
127 | parser.add_argument("--hidden_channels", type=int, default=256)
128 | parser.add_argument("--dropout", type=float, default=0.0)
129 | parser.add_argument("--batch_size", type=int, default=64 * 1024)
130 | parser.add_argument("--lr", type=float, default=0.0005)
131 | parser.add_argument("--epochs", type=int, default=100)
132 | parser.add_argument("--gnn_model", type=str, help="GNN Model", default="GeneralGNN")
133 | parser.add_argument("--heads", type=int, default=4)
134 | parser.add_argument("--eval_steps", type=int, default=1)
135 | parser.add_argument("--runs", type=int, default=5)
136 | parser.add_argument("--test_ratio", type=float, default=0.08)
137 | parser.add_argument("--val_ratio", type=float, default=0.02)
138 | parser.add_argument("--neg_len", type=int, default=50)
139 | parser.add_argument(
140 | "--use_PLM_node",
141 | "-pn",
142 | type=str,
143 | default="data/CSTAG/Photo/Feature/children_gpt_node.pt",
144 | help="Use LM embedding as node feature",
145 | )
146 | parser.add_argument(
147 | "--use_PLM_edge",
148 | "-pe",
149 | type=str,
150 | default="data/CSTAG/Photo/Feature/children_gpt_edge.pt",
151 | help="Use LM embedding as edge feature",
152 | )
153 | parser.add_argument(
154 | "--path",
155 | type=str,
156 | default="data/CSTAG/Photo/LinkPrediction/",
157 | help="Path to save splitting",
158 | )
159 | parser.add_argument(
160 | "--graph_path",
161 | "-gp",
162 | type=str,
163 | default="data/CSTAG/Photo/children.pkl",
164 | help="Path to load the graph",
165 | )
166 |
167 | args = parser.parse_args()
168 |
169 | wandb.config = args
170 | wandb.init(config=args, reinit=True)
171 | print(args)
172 |
173 | if not os.path.exists(f"{args.path}{args.neg_len}/"):
174 | os.makedirs(f"{args.path}{args.neg_len}/")
175 |
176 | device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
177 | device = torch.device(device)
178 |
179 | with open(f"{args.graph_path}", "rb") as file:
180 | graph = pickle.load(file)
181 |
182 | graph.num_nodes = len(graph.text_nodes)
183 |
184 | edge_split = split_edge_mrr(
185 | graph,
186 | test_ratio=args.test_ratio,
187 | val_ratio=args.val_ratio,
188 | path=args.path,
189 | neg_len=args.neg_len,
190 | )
191 |
192 | x = torch.load(args.use_PLM_node).float().to(device)
193 | edge_feature = torch.load(args.use_PLM_edge)[
194 | edge_split["train"]["train_edge_feature_index"]
195 | ].float()
196 |
197 | x = x.to(device)
198 |
199 | torch.manual_seed(42)
200 | idx = torch.randperm(edge_split["train"]["source_node"].numel())[
201 | : len(edge_split["valid"]["source_node"])
202 | ]
203 | edge_split["eval_train"] = {
204 | "source_node": edge_split["train"]["source_node"][idx],
205 | "target_node": edge_split["train"]["target_node"][idx],
206 | "target_node_neg": edge_split["valid"]["target_node_neg"],
207 | }
208 |
209 |
210 | edge_index = edge_split["train"]["edge"].t()
211 | adj_t = SparseTensor.from_edge_index(edge_index, edge_feature, sparse_sizes=(graph.num_nodes, graph.num_nodes)).t()
212 | adj_t = adj_t.to_symmetric().to(device)
213 |
214 | if args.gnn_model == "GAT": # TODO: use `gen_model` func to simplify the code
215 | model = GAT(
216 | x.size(1),
217 | edge_feature.size(1),
218 | args.hidden_channels,
219 | args.hidden_channels,
220 | args.num_layers,
221 | args.heads,
222 | args.dropout,
223 | ).to(device)
224 | elif args.gnn_model == "GraphTransformer":
225 | model = GraphTransformer(
226 | x.size(1),
227 | edge_feature.size(1),
228 | args.hidden_channels,
229 | args.hidden_channels,
230 | args.num_layers,
231 | args.dropout,
232 | ).to(device)
233 | elif args.gnn_model == "GINE":
234 | model = GINE(
235 | x.size(1),
236 | edge_feature.size(1),
237 | args.hidden_channels,
238 | args.hidden_channels,
239 | args.num_layers,
240 | args.dropout,
241 | ).to(device)
242 | elif args.gnn_model == "GeneralGNN":
243 | model = GeneralGNN(
244 | x.size(1),
245 | edge_feature.size(1),
246 | args.hidden_channels,
247 | args.hidden_channels,
248 | args.num_layers,
249 | args.dropout,
250 | ).to(device)
251 | else:
252 | raise ValueError("Not implemented")
253 |
254 | predictor = LinkPredictor(
255 | args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout
256 | ).to(device)
257 |
258 | evaluator = Evaluator(name="DBLP")
259 | logger = Logger(args.runs, args)
260 |
261 | for run in range(args.runs):
262 | model.reset_parameters()
263 | predictor.reset_parameters()
264 | optimizer = torch.optim.Adam(
265 | list(model.parameters()) + list(predictor.parameters()), lr=args.lr
266 | )
267 |
268 | for epoch in range(1, 1 + args.epochs):
269 | loss = train(
270 | model, predictor, x, adj_t, edge_split, optimizer, args.batch_size
271 | )
272 | wandb.log({"Loss": loss})
273 | if epoch % args.eval_steps == 0:
274 | result = test(
275 | model,
276 | predictor,
277 | x,
278 | adj_t,
279 | edge_split,
280 | evaluator,
281 | args.batch_size,
282 | args.neg_len,
283 | )
284 | logger.add_result(run, result)
285 |
286 | if epoch % args.log_steps == 0:
287 | train_mrr, valid_mrr, test_mrr = result
288 | print(
289 | f"Run: {run + 1:02d}, "
290 | f"Epoch: {epoch:02d}, "
291 | f"Loss: {loss:.4f}, "
292 | f"Train: {train_mrr:.4f}, "
293 | f"Valid: {valid_mrr:.4f}, "
294 | f"Test: {test_mrr:.4f}"
295 | )
296 |
297 | logger.print_statistics(run)
298 |
299 | logger.print_statistics(key="mrr")
300 |
301 |
302 | if __name__ == "__main__":
303 | main()
304 |
--------------------------------------------------------------------------------
/GNN/GNN_Link_HitsK.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.utils.data import DataLoader
7 |
8 | from torch_sparse import SparseTensor
9 |
10 | from model.GNN_library import GAT, GINE, GeneralGNN, GraphTransformer
11 | from model.Dataloader import Evaluator, split_edge
12 | from model.GNN_arg import Logger
13 | import wandb
14 | import os
15 |
16 |
17 | class LinkPredictor(torch.nn.Module):
18 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
19 | super(LinkPredictor, self).__init__()
20 |
21 | self.lins = torch.nn.ModuleList()
22 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
23 | for _ in range(num_layers - 2):
24 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
25 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels))
26 |
27 | self.dropout = dropout
28 |
29 | def reset_parameters(self):
30 | for lin in self.lins:
31 | lin.reset_parameters()
32 |
33 | def forward(self, x_i, x_j):
34 | x = x_i * x_j
35 | for lin in self.lins[:-1]:
36 | x = lin(x)
37 | x = F.relu(x)
38 | x = F.dropout(x, p=self.dropout, training=self.training)
39 | x = self.lins[-1](x)
40 | return torch.sigmoid(x)
41 |
42 |
43 | def train(model, predictor, x, adj_t, edge_split, optimizer, batch_size):
44 | model.train()
45 | predictor.train()
46 |
47 | pos_train_edge = edge_split["train"]["edge"].to(x.device)
48 |
49 | total_loss = total_examples = 0
50 | for perm in DataLoader(range(pos_train_edge.size(0)), batch_size, shuffle=True):
51 | optimizer.zero_grad()
52 |
53 | h = model(x, adj_t)
54 |
55 | edge = pos_train_edge[perm].t()
56 |
57 | pos_out = predictor(h[edge[0]], h[edge[1]])
58 | pos_loss = -torch.log(pos_out + 1e-15).mean()
59 |
60 | # Just do some trivial random sampling.
61 | edge = torch.randint(
62 | 0, x.size(0), edge.size(), dtype=torch.long, device=h.device
63 | )
64 | neg_out = predictor(h[edge[0]], h[edge[1]])
65 | neg_loss = -torch.log(1 - neg_out + 1e-15).mean()
66 |
67 | loss = pos_loss + neg_loss
68 | loss.backward()
69 |
70 | optimizer.step()
71 |
72 | num_examples = pos_out.size(0)
73 | total_loss += loss.item() * num_examples
74 | total_examples += num_examples
75 |
76 | return total_loss / total_examples
77 |
78 |
79 | @torch.no_grad()
80 | def test(model, predictor, x, adj_t, edge_split, evaluator, batch_size):
81 | model.eval()
82 | predictor.eval()
83 |
84 | h = model(x, adj_t)
85 |
86 | pos_train_edge = edge_split["train"]["edge"].to(h.device)
87 | pos_valid_edge = edge_split["valid"]["edge"].to(h.device)
88 | neg_valid_edge = edge_split["valid"]["edge_neg"].to(h.device)
89 | pos_test_edge = edge_split["test"]["edge"].to(h.device)
90 | neg_test_edge = edge_split["test"]["edge_neg"].to(h.device)
91 |
92 | pos_train_preds = []
93 | for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
94 | edge = pos_train_edge[perm].t()
95 | pos_train_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
96 | pos_train_pred = torch.cat(pos_train_preds, dim=0)
97 |
98 | pos_valid_preds = []
99 | for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
100 | edge = pos_valid_edge[perm].t()
101 | pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
102 | pos_valid_pred = torch.cat(pos_valid_preds, dim=0)
103 |
104 | neg_valid_preds = []
105 | for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
106 | edge = neg_valid_edge[perm].t()
107 | neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
108 | neg_valid_pred = torch.cat(neg_valid_preds, dim=0)
109 |
110 | h = model(x, adj_t)
111 |
112 | pos_test_preds = []
113 | for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
114 | edge = pos_test_edge[perm].t()
115 | pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
116 | pos_test_pred = torch.cat(pos_test_preds, dim=0)
117 |
118 | neg_test_preds = []
119 | for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
120 | edge = neg_test_edge[perm].t()
121 | neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
122 | neg_test_pred = torch.cat(neg_test_preds, dim=0)
123 |
124 | results = {}
125 | for K in [10, 50, 100]:
126 | evaluator.K = K
127 | train_hits = evaluator.eval(
128 | {
129 | "y_pred_pos": pos_train_pred,
130 | "y_pred_neg": neg_valid_pred,
131 | }
132 | )[f"hits@{K}"]
133 | valid_hits = evaluator.eval(
134 | {
135 | "y_pred_pos": pos_valid_pred,
136 | "y_pred_neg": neg_valid_pred,
137 | }
138 | )[f"hits@{K}"]
139 | test_hits = evaluator.eval(
140 | {
141 | "y_pred_pos": pos_test_pred,
142 | "y_pred_neg": neg_test_pred,
143 | }
144 | )[f"hits@{K}"]
145 |
146 | results[f"Hits@{K}"] = (train_hits, valid_hits, test_hits)
147 |
148 | return results
149 |
150 |
151 | def main():
152 | parser = argparse.ArgumentParser(description="Link-Prediction PLM/TCL")
153 | parser.add_argument("--device", type=int, default=0)
154 | parser.add_argument("--log_steps", type=int, default=1)
155 | parser.add_argument("--use_node_embedding", action="store_true")
156 | parser.add_argument("--num_layers", type=int, default=3)
157 | parser.add_argument("--hidden_channels", type=int, default=256)
158 | parser.add_argument("--dropout", type=float, default=0.0)
159 | parser.add_argument("--batch_size", type=int, default=64 * 1024)
160 | parser.add_argument("--lr", type=float, default=0.001)
161 | parser.add_argument("--epochs", type=int, default=10)
162 | parser.add_argument(
163 | "--gnn_model", type=str, help="GNN Model", default="GraphTransformer"
164 | )
165 | parser.add_argument("--heads", type=int, default=4)
166 | parser.add_argument("--eval_steps", type=int, default=1)
167 | parser.add_argument("--runs", type=int, default=5)
168 | parser.add_argument("--test_ratio", type=float, default=0.08)
169 | parser.add_argument("--val_ratio", type=float, default=0.02)
170 | parser.add_argument("--neg_len", type=str, default="10000")
171 | parser.add_argument(
172 | "--use_PLM_node",
173 | type=str,
174 | default="data/CSTAG/Photo/Feature/children_gpt_node.pt",
175 | help="Use LM embedding as node feature",
176 | )
177 | parser.add_argument(
178 | "--use_PLM_edge",
179 | type=str,
180 | default="data/CSTAG/Photo/Feature/children_gpt_edge.pt",
181 | help="Use LM embedding as edge feature",
182 | )
183 | parser.add_argument(
184 | "--path",
185 | type=str,
186 | default="data/CSTAG/Photo/LinkPrediction/",
187 | help="Path to save splitting",
188 | )
189 | parser.add_argument(
190 | "--graph_path",
191 | type=str,
192 | default="data/CSTAG/Photo/children.pkl",
193 | help="Path to load the graph",
194 | )
195 | args = parser.parse_args()
196 | wandb.config = args
197 | wandb.init(config=args, reinit=True)
198 | print(args)
199 |
200 | if not os.path.exists(f"{args.path}{args.neg_len}/"):
201 | os.makedirs(f"{args.path}{args.neg_len}/")
202 |
203 | device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
204 | device = torch.device(device)
205 |
206 | with open(f"{args.graph_path}", "rb") as file:
207 | graph = pickle.load(file)
208 |
209 | graph.num_nodes = len(graph.text_nodes)
210 | edge_split = split_edge(
211 | graph, test_ratio=args.test_ratio, val_ratio=args.val_ratio, path=args.path, neg_len=args.neg_len
212 | )
213 |
214 | x = torch.load(args.use_PLM_node).float().to(device)
215 | edge_feature = torch.load(args.use_PLM_edge)[
216 | edge_split["train"]["train_edge_feature_index"]
217 | ].float()
218 |
219 | edge_index = edge_split["train"]["edge"].t()
220 | adj_t = SparseTensor.from_edge_index(edge_index, edge_feature, sparse_sizes=(graph.num_nodes, graph.num_nodes)).t()
221 | adj_t = adj_t.to_symmetric().to(device)
222 | if args.gnn_model == "GAT":
223 | model = GAT(
224 | x.size(1),
225 | edge_feature.size(1),
226 | args.hidden_channels,
227 | args.hidden_channels,
228 | args.num_layers,
229 | args.heads,
230 | args.dropout,
231 | ).to(device)
232 | elif args.gnn_model == "GraphTransformer":
233 | model = GraphTransformer(
234 | x.size(1),
235 | edge_feature.size(1),
236 | args.hidden_channels,
237 | args.hidden_channels,
238 | args.num_layers,
239 | args.dropout,
240 | ).to(device)
241 | elif args.gnn_model == "GINE":
242 | model = GINE(
243 | x.size(1),
244 | edge_feature.size(1),
245 | args.hidden_channels,
246 | args.hidden_channels,
247 | args.num_layers,
248 | args.dropout,
249 | ).to(device)
250 | elif args.gnn_model == "GeneralGNN":
251 | model = GeneralGNN(
252 | x.size(1),
253 | edge_feature.size(1),
254 | args.hidden_channels,
255 | args.hidden_channels,
256 | args.num_layers,
257 | args.dropout,
258 | ).to(device)
259 | else:
260 | raise ValueError("Not implemented")
261 |
262 | predictor = LinkPredictor(
263 | args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout
264 | ).to(device)
265 |
266 | evaluator = Evaluator(name="History")
267 | loggers = {
268 | "Hits@10": Logger(args.runs, args),
269 | "Hits@50": Logger(args.runs, args),
270 | "Hits@100": Logger(args.runs, args),
271 | }
272 |
273 | for run in range(args.runs):
274 | model.reset_parameters()
275 | predictor.reset_parameters()
276 | optimizer = torch.optim.Adam(
277 | list(model.parameters()) + list(predictor.parameters()), lr=args.lr
278 | )
279 |
280 | for epoch in range(1, 1 + args.epochs):
281 | loss = train(
282 | model, predictor, x, adj_t, edge_split, optimizer, args.batch_size
283 | )
284 |
285 | if epoch % args.eval_steps == 0:
286 | results = test(
287 | model, predictor, x, adj_t, edge_split, evaluator, args.batch_size
288 | )
289 | for key, result in results.items():
290 | loggers[key].add_result(run, result)
291 |
292 | if epoch % args.log_steps == 0:
293 | for key, result in results.items():
294 | train_hits, valid_hits, test_hits = result
295 | print(key)
296 | print(
297 | f"Run: {run + 1:02d}, "
298 | f"Epoch: {epoch:02d}, "
299 | f"Loss: {loss:.4f}, "
300 | f"Train: {100 * train_hits:.2f}%, "
301 | f"Valid: {100 * valid_hits:.2f}%, "
302 | f"Test: {100 * test_hits:.2f}%"
303 | )
304 | print("---")
305 |
306 | for key in loggers.keys():
307 | print(key)
308 | loggers[key].print_statistics(run)
309 |
310 | for key in loggers.keys():
311 | print(key)
312 | loggers[key].print_statistics(key=key)
313 |
314 |
315 | if __name__ == "__main__":
316 | main()
317 |
--------------------------------------------------------------------------------
/GNN/GNN_Link_MRR_loader.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.data import Data
2 | from torch_geometric.loader import LinkNeighborLoader
3 | import argparse
4 | import pickle
5 |
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | from torch_sparse import SparseTensor
10 |
11 | from model.GNN_library import GAT, GINE, GeneralGNN, GraphTransformer
12 | from model.Dataloader import Evaluator, split_edge_mrr
13 | from model.GNN_arg import Logger
14 | import wandb
15 | import os
16 |
17 |
18 | class LinkPredictor(torch.nn.Module):
19 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
20 | super(LinkPredictor, self).__init__()
21 |
22 | self.lins = torch.nn.ModuleList()
23 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
24 | for _ in range(num_layers - 2):
25 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
26 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels))
27 |
28 | self.dropout = dropout
29 |
30 | def reset_parameters(self):
31 | for lin in self.lins:
32 | lin.reset_parameters()
33 |
34 | def forward(self, x_i, x_j):
35 | x = x_i * x_j
36 | for lin in self.lins[:-1]:
37 | x = lin(x)
38 | x = F.relu(x)
39 | x = F.dropout(x, p=self.dropout, training=self.training)
40 | x = self.lins[-1](x)
41 | return torch.sigmoid(x)
42 |
43 |
44 | def gen_loader(args, edge_split, x, edge_index, adj_t):
45 | train_data = Data(x=x, adj_t=adj_t)
46 | train_loader = LinkNeighborLoader(
47 | data=train_data,
48 | num_neighbors=[20, 10],
49 | edge_label_index=(edge_index),
50 | edge_label=torch.ones(edge_index.shape[1]),
51 | batch_size=args.batch_size,
52 | neg_sampling_ratio=0.0,
53 | shuffle=True,
54 | )
55 |
56 | val_edge_label_index = edge_split["valid"]["edge"].t()
57 | val_loader = LinkNeighborLoader(
58 | data=train_data,
59 | num_neighbors=[20, 10],
60 | edge_label_index=(val_edge_label_index),
61 | edge_label=torch.ones(val_edge_label_index.shape[1]),
62 | batch_size=args.batch_size,
63 | neg_sampling_ratio=0.0,
64 | shuffle=False,
65 | )
66 |
67 | test_edge_label_index = edge_split["test"]["edge"].t()
68 | test_loader = LinkNeighborLoader(
69 | data=train_data,
70 | num_neighbors=[20, 10],
71 | edge_label_index=(test_edge_label_index),
72 | edge_label=torch.ones(test_edge_label_index.shape[1]),
73 | batch_size=args.batch_size,
74 | neg_sampling_ratio=0.0,
75 | shuffle=False,
76 | )
77 |
78 | return train_loader, val_loader, test_loader
79 |
80 |
81 | def gen_model(args, x, edge_feature):
82 | if args.gnn_model == "GAT":
83 | model = GAT(
84 | x.size(1),
85 | edge_feature.size(1),
86 | args.hidden_channels,
87 | args.hidden_channels,
88 | args.num_layers,
89 | args.heads,
90 | args.dropout,
91 | )
92 | elif args.gnn_model == "GraphTransformer":
93 | model = GraphTransformer(
94 | x.size(1),
95 | edge_feature.size(1),
96 | args.hidden_channels,
97 | args.hidden_channels,
98 | args.num_layers,
99 | args.dropout,
100 | )
101 | elif args.gnn_model == "GINE":
102 | model = GINE(
103 | x.size(1),
104 | edge_feature.size(1),
105 | args.hidden_channels,
106 | args.hidden_channels,
107 | args.num_layers,
108 | args.dropout,
109 | )
110 | elif args.gnn_model == "GeneralGNN":
111 | model = GeneralGNN(
112 | x.size(1),
113 | edge_feature.size(1),
114 | args.hidden_channels,
115 | args.hidden_channels,
116 | args.num_layers,
117 | args.dropout,
118 | )
119 | else:
120 | raise ValueError("Not implemented")
121 | return model
122 |
123 |
124 | def train(model, predictor, train_loader, optimizer, device):
125 | model.train()
126 | predictor.train()
127 |
128 | total_loss = total_examples = 0
129 | for batch in train_loader:
130 | optimizer.zero_grad()
131 | batch = batch.to(device)
132 | h = model(batch.x, batch.adj_t)
133 |
134 | src = batch.edge_label_index.t()[:, 0]
135 | dst = batch.edge_label_index.t()[:, 1]
136 |
137 | pos_out = predictor(h[src], h[dst])
138 | pos_loss = -torch.log(pos_out + 1e-15).mean()
139 |
140 | # Just do some trivial random sampling.
141 | dst_neg = torch.randint(
142 | 0, h.size(0), src.size(), dtype=torch.long, device=h.device
143 | )
144 | neg_out = predictor(h[src], h[dst_neg])
145 | neg_loss = -torch.log(1 - neg_out + 1e-15).mean()
146 |
147 | loss = pos_loss + neg_loss
148 | loss.backward()
149 | optimizer.step()
150 |
151 | num_examples = pos_out.size(0)
152 | total_loss += loss.item() * num_examples
153 | total_examples += num_examples
154 |
155 | return total_loss / total_examples
156 |
157 |
158 | @torch.no_grad()
159 | def test(model, predictor, dataloaders, evaluator, neg_len, device):
160 | model.eval()
161 | predictor.eval()
162 |
163 | def test_split(dataloader, neg_len, device):
164 | pos_preds = []
165 | for batch in dataloader:
166 | batch = batch.to(device)
167 | h = model(batch.x, batch.adj_t)
168 | src = batch.edge_label_index.t()[:, 0]
169 | dst = batch.edge_label_index.t()[:, 1]
170 | pos_preds += [predictor(h[src], h[dst]).squeeze().cpu()]
171 | pos_pred = torch.cat(pos_preds, dim=0)
172 |
173 | neg_preds = []
174 |
175 | for batch in dataloader:
176 | batch = batch.to(device)
177 | h = model(batch.x, batch.adj_t)
178 | src = batch.edge_label_index.t()[:, 0]
179 | dst_neg = torch.randint(
180 | 0, h.size(0), [len(src), int(neg_len)], dtype=torch.long
181 | ).view(-1)
182 | src = src.view(-1, 1).repeat(1, neg_len).view(-1)
183 | neg_preds += [predictor(h[src], h[dst_neg]).squeeze().cpu()]
184 | neg_pred = torch.cat(neg_preds, dim=0).view(-1, neg_len)
185 |
186 | return (
187 | evaluator.eval(
188 | {
189 | "y_pred_pos": pos_pred,
190 | "y_pred_neg": neg_pred,
191 | }
192 | )["mrr_list"]
193 | .mean()
194 | .item()
195 | )
196 |
197 | train_mrr = test_split(dataloaders["train"], neg_len, device)
198 | valid_mrr = test_split(dataloaders["valid"], neg_len, device)
199 | test_mrr = test_split(dataloaders["test"], neg_len, device)
200 |
201 | return train_mrr, valid_mrr, test_mrr
202 |
203 |
204 | def main():
205 | parser = argparse.ArgumentParser(description="Link-Prediction PLM/TCL")
206 | parser.add_argument("--device", type=int, default=0)
207 | parser.add_argument("--log_steps", type=int, default=1) # TODO: update latter
208 | parser.add_argument("--use_node_embedding", action="store_true")
209 | parser.add_argument("--num_layers", type=int, default=2)
210 | parser.add_argument("--hidden_channels", type=int, default=256)
211 | parser.add_argument("--dropout", type=float, default=0.0)
212 | parser.add_argument("--batch_size", type=int, default=64 * 1024)
213 | parser.add_argument("--lr", type=float, default=0.0005)
214 | parser.add_argument("--epochs", type=int, default=100)
215 | parser.add_argument("--gnn_model", type=str, help="GNN Model", default="GeneralGNN")
216 | parser.add_argument("--heads", type=int, default=4)
217 | parser.add_argument("--eval_steps", type=int, default=1)
218 | parser.add_argument("--runs", type=int, default=5)
219 | parser.add_argument("--test_ratio", type=float, default=0.08)
220 | parser.add_argument("--val_ratio", type=float, default=0.02)
221 | parser.add_argument("--neg_len", type=int, default=50)
222 | parser.add_argument(
223 | "--use_PLM_node",
224 | "-pn",
225 | type=str,
226 | default="data/CSTAG/Photo/Feature/children_gpt_node.pt",
227 | help="Use LM embedding as node feature",
228 | )
229 | parser.add_argument(
230 | "--use_PLM_edge",
231 | "-pe",
232 | type=str,
233 | default="data/CSTAG/Photo/Feature/children_gpt_edge.pt",
234 | help="Use LM embedding as edge feature",
235 | )
236 | parser.add_argument(
237 | "--path",
238 | type=str,
239 | default="data/CSTAG/Photo/LinkPrediction/",
240 | help="Path to save splitting",
241 | )
242 | parser.add_argument(
243 | "--graph_path",
244 | "-gp",
245 | type=str,
246 | default="data/CSTAG/Photo/children.pkl",
247 | help="Path to load the graph",
248 | )
249 |
250 | args = parser.parse_args()
251 |
252 | wandb.config = args
253 | wandb.init(config=args, reinit=True)
254 | print(args)
255 |
256 | if not os.path.exists(f"{args.path}{args.neg_len}/"):
257 | os.makedirs(f"{args.path}{args.neg_len}/")
258 |
259 | device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
260 | device = torch.device(device)
261 |
262 | with open(f"{args.graph_path}", "rb") as file:
263 | graph = pickle.load(file)
264 |
265 | graph.num_nodes = len(graph.text_nodes)
266 |
267 | edge_split = split_edge_mrr(
268 | graph,
269 | test_ratio=args.test_ratio,
270 | val_ratio=args.val_ratio,
271 | path=args.path,
272 | neg_len=args.neg_len,
273 | )
274 |
275 | x = torch.load(args.use_PLM_node).float()
276 | edge_feature = torch.load(args.use_PLM_edge)[
277 | edge_split["train"]["train_edge_feature_index"]
278 | ].float()
279 |
280 | edge_index = edge_split["train"]["edge"].t()
281 | adj_t = SparseTensor.from_edge_index(
282 | edge_index, edge_feature, sparse_sizes=(graph.num_nodes, graph.num_nodes)
283 | ).t()
284 | adj_t = adj_t.to_symmetric()
285 |
286 | train_loader, val_loader, test_loader = gen_loader(
287 | args, edge_split, x, edge_index, adj_t
288 | )
289 | dataloaders = {"train": train_loader, "valid": val_loader, "test": test_loader}
290 |
291 | model = gen_model(args, x, edge_feature).to(device)
292 |
293 | predictor = LinkPredictor(
294 | args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout
295 | ).to(device)
296 |
297 | evaluator = Evaluator(name="DBLP")
298 | logger = Logger(args.runs, args)
299 |
300 | for run in range(args.runs):
301 | model.reset_parameters()
302 | predictor.reset_parameters()
303 | optimizer = torch.optim.Adam(
304 | list(model.parameters()) + list(predictor.parameters()), lr=args.lr
305 | )
306 |
307 | for epoch in range(1, 1 + args.epochs):
308 | loss = train(model, predictor, train_loader, optimizer, device)
309 | wandb.log({"Loss": loss})
310 | if epoch % args.eval_steps == 0:
311 | result = test(
312 | model, predictor, dataloaders, evaluator, args.neg_len, device
313 | )
314 | logger.add_result(run, result)
315 |
316 | if epoch % args.log_steps == 0:
317 | train_mrr, valid_mrr, test_mrr = result
318 | print(
319 | f"Run: {run + 1:02d}, "
320 | f"Epoch: {epoch:02d}, "
321 | f"Loss: {loss:.4f}, "
322 | f"Train: {train_mrr:.4f}, "
323 | f"Valid: {valid_mrr:.4f}, "
324 | f"Test: {test_mrr:.4f}"
325 | )
326 |
327 | logger.print_statistics(run)
328 |
329 | logger.print_statistics(key="mrr")
330 |
331 |
332 | if __name__ == "__main__":
333 | main()
334 |
--------------------------------------------------------------------------------
/GNN/GNN_Link_HitsK_loader.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | from torch_geometric.data import Data
3 | from torch_geometric.loader import LinkNeighborLoader
4 | import argparse
5 | import pickle
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from torch_sparse import SparseTensor
11 |
12 | from model.GNN_library import GAT, GINE, GeneralGNN, GraphTransformer
13 | from model.Dataloader import Evaluator, split_edge
14 | from model.GNN_arg import Logger
15 | import wandb
16 | import os
17 |
18 |
19 | class LinkPredictor(torch.nn.Module):
20 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
21 | super(LinkPredictor, self).__init__()
22 |
23 | self.lins = torch.nn.ModuleList()
24 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
25 | for _ in range(num_layers - 2):
26 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
27 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels))
28 |
29 | self.dropout = dropout
30 |
31 | def reset_parameters(self):
32 | for lin in self.lins:
33 | lin.reset_parameters()
34 |
35 | def forward(self, x_i, x_j):
36 | x = x_i * x_j
37 | for lin in self.lins[:-1]:
38 | x = lin(x)
39 | x = F.relu(x)
40 | x = F.dropout(x, p=self.dropout, training=self.training)
41 | x = self.lins[-1](x)
42 | return torch.sigmoid(x)
43 |
44 |
45 | def train(model, predictor, train_loader, optimizer, device):
46 | model.train()
47 | predictor.train()
48 |
49 | total_loss = total_examples = 0
50 | for batch in tqdm.tqdm(train_loader):
51 | optimizer.zero_grad()
52 | batch = batch.to(device)
53 | h = model(batch.x, batch.adj_t)
54 |
55 | src = batch.edge_label_index.t()[:, 0]
56 | dst = batch.edge_label_index.t()[:, 1]
57 |
58 | pos_out = predictor(h[src], h[dst])
59 | pos_loss = -torch.log(pos_out + 1e-15).mean()
60 |
61 | # Just do some trivial random sampling.
62 | dst_neg = torch.randint(
63 | 0, h.size(0), src.size(), dtype=torch.long, device=h.device
64 | )
65 | neg_out = predictor(h[src], h[dst_neg])
66 | neg_loss = -torch.log(1 - neg_out + 1e-15).mean()
67 |
68 | loss = pos_loss + neg_loss
69 | loss.backward()
70 | optimizer.step()
71 |
72 | num_examples = pos_out.size(0)
73 | total_loss += loss.item() * num_examples
74 | total_examples += num_examples
75 |
76 | return total_loss / total_examples
77 |
78 |
79 | @torch.no_grad()
80 | def test(model, predictor, dataloaders, evaluator, device):
81 | model.eval()
82 | predictor.eval()
83 | (
84 | train_loader,
85 | valid_loader,
86 | valid_negative_loader,
87 | test_loader,
88 | test_negative_loader,
89 | ) = (
90 | dataloaders["train"],
91 | dataloaders["valid"],
92 | dataloaders["valid_negative"],
93 | dataloaders["test"],
94 | dataloaders["test_negative"],
95 | )
96 |
97 | def test_split(dataloader, neg_dataloader, device):
98 | pos_preds = []
99 | for batch in dataloader:
100 | batch = batch.to(device)
101 | h = model(batch.x, batch.adj_t)
102 | src = batch.edge_label_index.t()[:, 0]
103 | dst = batch.edge_label_index.t()[:, 1]
104 | pos_preds += [predictor(h[src], h[dst]).squeeze().cpu()]
105 | pos_pred = torch.cat(pos_preds, dim=0)
106 |
107 | neg_preds = []
108 | for batch in neg_dataloader:
109 | batch = batch.to(device)
110 | h = model(batch.x, batch.adj_t)
111 | src = batch.edge_label_index.t()[:, 0]
112 | dst = batch.edge_label_index.t()[:, 1]
113 | neg_preds += [predictor(h[src], h[dst]).squeeze().cpu()]
114 | neg_pred = torch.cat(neg_preds, dim=0)
115 | return pos_pred, neg_pred
116 |
117 | pos_train_pred, neg_valid_pred = test_split(
118 | train_loader, valid_negative_loader, device
119 | )
120 | pos_valid_pred, neg_valid_pred = test_split(
121 | valid_loader, valid_negative_loader, device
122 | )
123 | pos_test_pred, neg_test_pred = test_split(test_loader, test_negative_loader, device)
124 |
125 | results = {}
126 | for K in [10, 50, 100]:
127 | evaluator.K = K
128 | train_hits = evaluator.eval(
129 | {
130 | "y_pred_pos": pos_train_pred,
131 | "y_pred_neg": neg_valid_pred,
132 | }
133 | )[f"hits@{K}"]
134 | valid_hits = evaluator.eval(
135 | {
136 | "y_pred_pos": pos_valid_pred,
137 | "y_pred_neg": neg_valid_pred,
138 | }
139 | )[f"hits@{K}"]
140 | test_hits = evaluator.eval(
141 | {
142 | "y_pred_pos": pos_test_pred,
143 | "y_pred_neg": neg_test_pred,
144 | }
145 | )[f"hits@{K}"]
146 |
147 | results[f"Hits@{K}"] = (train_hits, valid_hits, test_hits)
148 |
149 | return results
150 |
151 | def gen_model(args, x, edge_feature):
152 | if args.gnn_model == "GAT":
153 | model = GAT(
154 | x.size(1),
155 | edge_feature.size(1),
156 | args.hidden_channels,
157 | args.hidden_channels,
158 | args.num_layers,
159 | args.heads,
160 | args.dropout,
161 | )
162 | elif args.gnn_model == "GraphTransformer":
163 | model = GraphTransformer(
164 | x.size(1),
165 | edge_feature.size(1),
166 | args.hidden_channels,
167 | args.hidden_channels,
168 | args.num_layers,
169 | args.dropout,
170 | )
171 | elif args.gnn_model == "GINE":
172 | model = GINE(
173 | x.size(1),
174 | edge_feature.size(1),
175 | args.hidden_channels,
176 | args.hidden_channels,
177 | args.num_layers,
178 | args.dropout,
179 | )
180 | elif args.gnn_model == "GeneralGNN":
181 | model = GeneralGNN(
182 | x.size(1),
183 | edge_feature.size(1),
184 | args.hidden_channels,
185 | args.hidden_channels,
186 | args.num_layers,
187 | args.dropout,
188 | )
189 | else:
190 | raise ValueError("Not implemented")
191 | return model
192 |
193 |
194 | def gen_loader(args, edge_split, x, edge_index, adj_t):
195 | train_data = Data(x=x, adj_t=adj_t)
196 | train_loader = LinkNeighborLoader(
197 | data=train_data,
198 | num_neighbors=[20, 10],
199 | edge_label_index=(edge_index),
200 | edge_label=torch.ones(edge_index.shape[1]),
201 | batch_size=args.batch_size,
202 | neg_sampling_ratio=0.0,
203 | shuffle=True,
204 | )
205 |
206 | val_edge_label_index = edge_split["valid"]["edge"].t()
207 | val_loader = LinkNeighborLoader(
208 | data=train_data,
209 | num_neighbors=[20, 10],
210 | edge_label_index=(val_edge_label_index),
211 | edge_label=torch.ones(val_edge_label_index.shape[1]),
212 | batch_size=args.batch_size,
213 | neg_sampling_ratio=0.0,
214 | shuffle=False,
215 | )
216 |
217 | val_negative_edge_label_index = edge_split["valid"]["edge_neg"].t()
218 | val_negative_loader = LinkNeighborLoader(
219 | data=train_data,
220 | num_neighbors=[20, 10],
221 | edge_label_index=(val_negative_edge_label_index),
222 | edge_label=torch.zeros(val_negative_edge_label_index.shape[1]),
223 | batch_size=args.batch_size,
224 | neg_sampling_ratio=0.0,
225 | shuffle=False,
226 | )
227 |
228 | test_edge_label_index = edge_split["test"]["edge"].t()
229 | test_loader = LinkNeighborLoader(
230 | data=train_data,
231 | num_neighbors=[20, 10],
232 | edge_label_index=(test_edge_label_index),
233 | edge_label=torch.ones(test_edge_label_index.shape[1]),
234 | batch_size=args.batch_size,
235 | neg_sampling_ratio=0.0,
236 | shuffle=False,
237 | )
238 |
239 | test_negative_edge_label_index = edge_split["test"]["edge_neg"].t()
240 | test_negative_loader = LinkNeighborLoader(
241 | data=train_data,
242 | num_neighbors=[20, 10],
243 | edge_label_index=(test_negative_edge_label_index),
244 | edge_label=torch.zeros(test_negative_edge_label_index.shape[1]),
245 | batch_size=args.batch_size,
246 | neg_sampling_ratio=0.0,
247 | shuffle=False,
248 | )
249 |
250 | return (
251 | train_loader,
252 | val_loader,
253 | val_negative_loader,
254 | test_loader,
255 | test_negative_loader,
256 | )
257 |
258 | def main():
259 | parser = argparse.ArgumentParser(description="Link-Prediction PLM/TCL")
260 | parser.add_argument("--device", type=int, default=0)
261 | parser.add_argument("--log_steps", type=int, default=1)
262 | parser.add_argument("--use_node_embedding", action="store_true")
263 | parser.add_argument("--num_layers", type=int, default=3)
264 | parser.add_argument("--hidden_channels", type=int, default=256)
265 | parser.add_argument("--dropout", type=float, default=0.0)
266 | parser.add_argument("--batch_size", type=int, default=64 * 1024)
267 | parser.add_argument("--lr", type=float, default=0.001)
268 | parser.add_argument("--epochs", type=int, default=10)
269 | parser.add_argument(
270 | "--gnn_model", type=str, help="GNN Model", default="GraphTransformer"
271 | )
272 | parser.add_argument("--heads", type=int, default=4)
273 | parser.add_argument("--eval_steps", type=int, default=1)
274 | parser.add_argument("--runs", type=int, default=1)
275 | parser.add_argument("--test_ratio", type=float, default=0.08)
276 | parser.add_argument("--val_ratio", type=float, default=0.02)
277 | parser.add_argument("--neg_len", type=str, default="5000")
278 | parser.add_argument(
279 | "--use_PLM_node",
280 | type=str,
281 | default="data/CSTAG/Photo/Feature/children_gpt_node.pt",
282 | help="Use LM embedding as node feature",
283 | )
284 | parser.add_argument(
285 | "--use_PLM_edge",
286 | type=str,
287 | default="data/CSTAG/Photo/Feature/children_gpt_edge.pt",
288 | help="Use LM embedding as edge feature",
289 | )
290 | parser.add_argument(
291 | "--path",
292 | type=str,
293 | default="data/CSTAG/Photo/LinkPrediction/",
294 | help="Path to save splitting",
295 | )
296 | parser.add_argument(
297 | "--graph_path",
298 | type=str,
299 | default="data/CSTAG/Photo/children.pkl",
300 | help="Path to load the graph",
301 | )
302 | args = parser.parse_args()
303 | wandb.config = args
304 | wandb.init(config=args, reinit=True)
305 | print(args)
306 |
307 | if not os.path.exists(f"{args.path}{args.neg_len}/"):
308 | os.makedirs(f"{args.path}{args.neg_len}/")
309 |
310 | device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
311 | device = torch.device(device)
312 |
313 | with open(f"{args.graph_path}", "rb") as file:
314 | graph = pickle.load(file)
315 |
316 | graph.num_nodes = len(graph.text_nodes)
317 | edge_split = split_edge(
318 | graph,
319 | test_ratio=args.test_ratio,
320 | val_ratio=args.val_ratio,
321 | path=args.path,
322 | neg_len=args.neg_len,
323 | )
324 |
325 | x = torch.load(args.use_PLM_node).float()
326 | edge_feature = torch.load(args.use_PLM_edge)[
327 | edge_split["train"]["train_edge_feature_index"]
328 | ].float()
329 |
330 | edge_index = edge_split["train"]["edge"].t()
331 | adj_t = SparseTensor.from_edge_index(
332 | edge_index, edge_feature, sparse_sizes=(graph.num_nodes, graph.num_nodes)
333 | ).t()
334 | adj_t = adj_t.to_symmetric()
335 |
336 | train_loader, val_loader, val_negative_loader, test_loader, test_negative_loader = (
337 | gen_loader(args, edge_split, x, edge_index, adj_t)
338 | )
339 | dataloaders = {
340 | "train": train_loader,
341 | "valid": val_loader,
342 | "valid_negative": val_negative_loader,
343 | "test": test_loader,
344 | "test_negative": test_negative_loader,
345 | }
346 |
347 | model = gen_model(args, x, edge_feature).to(device)
348 |
349 | predictor = LinkPredictor(
350 | args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout
351 | ).to(device)
352 |
353 | evaluator = Evaluator(name="History")
354 | loggers = {
355 | "Hits@10": Logger(args.runs, args),
356 | "Hits@50": Logger(args.runs, args),
357 | "Hits@100": Logger(args.runs, args),
358 | }
359 |
360 | for run in range(args.runs):
361 | model.reset_parameters()
362 | predictor.reset_parameters()
363 | optimizer = torch.optim.Adam(
364 | list(model.parameters()) + list(predictor.parameters()), lr=args.lr
365 | )
366 |
367 | for epoch in range(1, 1 + args.epochs):
368 | loss = train(model, predictor, train_loader, optimizer, device)
369 |
370 | if epoch % args.eval_steps == 0:
371 | results = test(
372 | model, predictor, dataloaders, evaluator, device
373 | )
374 | for key, result in results.items():
375 | loggers[key].add_result(run, result)
376 |
377 | if epoch % args.log_steps == 0:
378 | for key, result in results.items():
379 | train_hits, valid_hits, test_hits = result
380 | print(key)
381 | print(
382 | f"Run: {run + 1:02d}, "
383 | f"Epoch: {epoch:02d}, "
384 | f"Loss: {loss:.4f}, "
385 | f"Train: {100 * train_hits:.2f}%, "
386 | f"Valid: {100 * valid_hits:.2f}%, "
387 | f"Test: {100 * test_hits:.2f}%"
388 | )
389 | print("---")
390 |
391 | for key in loggers.keys():
392 | print(key)
393 | loggers[key].print_statistics(run)
394 |
395 | for key in loggers.keys():
396 | print(key)
397 | loggers[key].print_statistics(key=key)
398 |
399 |
400 |
401 | if __name__ == "__main__":
402 | main()
403 |
--------------------------------------------------------------------------------
/GNN/model/Dataloader.py:
--------------------------------------------------------------------------------
1 | from ogb.nodeproppred import DglNodePropPredDataset
2 | import dgl
3 | import numpy as np
4 | import torch as th
5 | from sklearn.model_selection import train_test_split
6 | import scipy.sparse as sp
7 | import os
8 | import random
9 | import pickle
10 |
11 |
12 | def split_graph(nodes_num, train_ratio, val_ratio):
13 | np.random.seed(42)
14 | indices = np.random.permutation(nodes_num)
15 | train_size = int(nodes_num * train_ratio)
16 | val_size = int(nodes_num * val_ratio)
17 |
18 | train_ids = indices[:train_size]
19 | val_ids = indices[train_size : train_size + val_size]
20 | test_ids = indices[train_size + val_size :]
21 |
22 | return train_ids, val_ids, test_ids
23 |
24 |
25 | # def split_time(g, train_year=2016, val_year=2017):
26 | # year = list(np.array(g.ndata['year']))
27 | # indices = np.arange(g.num_nodes())
28 | # # 1999-2014 train
29 | # train_ids = indices[:year.index(train_year)]
30 | # val_ids = indices[year.index(train_year):year.index(val_year)]
31 | # test_ids = indices[year.index(val_year):]
32 | #
33 | # return train_ids, val_ids, test_ids
34 |
35 |
36 | def split_time(g, train_year=2016, val_year=2017):
37 | np.random.seed(42)
38 | year = list(np.array(g.ndata["year"]))
39 | indices = np.arange(g.num_nodes())
40 | # 1999-2014 train
41 | # Filter out nodes with label -1
42 | valid_indices = [i for i in indices if g.ndata["label"][i] != -1]
43 |
44 | # Filter out valid indices based on years
45 | train_ids = [i for i in valid_indices if year[i] < train_year]
46 | val_ids = [i for i in valid_indices if year[i] >= train_year and year[i] < val_year]
47 | test_ids = [i for i in valid_indices if year[i] >= val_year]
48 |
49 | train_length = len(train_ids)
50 | val_length = len(val_ids)
51 | test_length = len(test_ids)
52 |
53 | print("Train set length:", train_length)
54 | print("Validation set length:", val_length)
55 | print("Test set length:", test_length)
56 |
57 | return train_ids, val_ids, test_ids
58 |
59 |
60 | def load_data(name, train_ratio=0.6, val_ratio=0.2):
61 | if name == "ogbn-arxiv":
62 | data = DglNodePropPredDataset(name=name)
63 | splitted_idx = data.get_idx_split()
64 | train_idx, val_idx, test_idx = (
65 | splitted_idx["train"],
66 | splitted_idx["valid"],
67 | splitted_idx["test"],
68 | )
69 | graph, labels = data[0]
70 | labels = labels[:, 0]
71 | elif name == "Children":
72 | graph = dgl.load_graphs("data/CSTAG/Children/Children.pt")[0][0]
73 | labels = graph.ndata["label"]
74 | train_idx, val_idx, test_idx = split_graph(graph.num_nodes(), 0.6, 0.2)
75 | train_idx = th.tensor(train_idx)
76 | val_idx = th.tensor(val_idx)
77 | test_idx = th.tensor(test_idx)
78 | elif name == "History":
79 | graph = dgl.load_graphs("data/CSTAG/History/History.pt")[0][0]
80 | labels = graph.ndata["label"]
81 | train_idx, val_idx, test_idx = split_graph(graph.num_nodes(), 0.6, 0.2)
82 | train_idx = th.tensor(train_idx)
83 | val_idx = th.tensor(val_idx)
84 | test_idx = th.tensor(test_idx)
85 | elif name == "Fitness":
86 | graph = dgl.load_graphs("data/CSTAG/Fitness/Fitness.pt")[0][0]
87 | labels = graph.ndata["label"]
88 | train_idx, val_idx, test_idx = split_graph(
89 | graph.num_nodes(), train_ratio, val_ratio
90 | )
91 | train_idx = th.tensor(train_idx)
92 | val_idx = th.tensor(val_idx)
93 | test_idx = th.tensor(test_idx)
94 | elif name == "Photo":
95 | graph = dgl.load_graphs("data/CSTAG/Photo/Photo.pt")[0][0]
96 | labels = graph.ndata["label"]
97 | train_idx, val_idx, test_idx = split_time(graph, 2015, 2016)
98 | train_idx = th.tensor(train_idx)
99 | val_idx = th.tensor(val_idx)
100 | test_idx = th.tensor(test_idx)
101 | elif name == "Computers":
102 | graph = dgl.load_graphs("data/CSTAG/Computers/Computers.pt")[0][0]
103 | labels = graph.ndata["label"]
104 | train_idx, val_idx, test_idx = split_time(graph, 2017, 2018)
105 | train_idx = th.tensor(train_idx)
106 | val_idx = th.tensor(val_idx)
107 | test_idx = th.tensor(test_idx)
108 | elif name == "webkb-cornell":
109 | graph = dgl.load_graphs("data/webkb/Cornell/Cornell.pt")[0][0]
110 | labels = graph.ndata["label"]
111 | train_idx, val_idx, test_idx = split_graph(
112 | graph.num_nodes(), train_ratio, val_ratio
113 | )
114 | train_idx = th.tensor(train_idx)
115 | val_idx = th.tensor(val_idx)
116 | test_idx = th.tensor(test_idx)
117 | elif name == "webkb-texas":
118 | graph = dgl.load_graphs("data/webkb/Texas/Texas.pt")[0][0]
119 | labels = graph.ndata["label"]
120 | train_idx, val_idx, test_idx = split_graph(
121 | graph.num_nodes(), train_ratio, val_ratio
122 | )
123 | train_idx = th.tensor(train_idx)
124 | val_idx = th.tensor(val_idx)
125 | test_idx = th.tensor(test_idx)
126 | elif name == "webkb-washington":
127 | graph = dgl.load_graphs("data/webkb/Washington/Washington.pt")[0][0]
128 | labels = graph.ndata["label"]
129 | train_idx, val_idx, test_idx = split_graph(
130 | graph.num_nodes(), train_ratio, val_ratio
131 | )
132 | train_idx = th.tensor(train_idx)
133 | val_idx = th.tensor(val_idx)
134 | test_idx = th.tensor(test_idx)
135 | elif name == "webkb-wisconsin":
136 | graph = dgl.load_graphs("data/webkb/Wisconsin/Wisconsin.pt")[0][0]
137 | labels = graph.ndata["label"]
138 | train_idx, val_idx, test_idx = split_graph(
139 | graph.num_nodes(), train_ratio, val_ratio
140 | )
141 | train_idx = th.tensor(train_idx)
142 | val_idx = th.tensor(val_idx)
143 | test_idx = th.tensor(test_idx)
144 | else:
145 | raise ValueError("Not implemetned")
146 | return graph, labels, train_idx, val_idx, test_idx
147 |
148 |
149 | def from_dgl(g):
150 | import dgl
151 |
152 | from torch_geometric.data import Data, HeteroData
153 |
154 | if not isinstance(g, dgl.DGLGraph):
155 | raise ValueError(f"Invalid data type (got '{type(g)}')")
156 |
157 | if g.is_homogeneous:
158 | data = Data()
159 | data.edge_index = th.stack(g.edges(), dim=0)
160 |
161 | for attr, value in g.ndata.items():
162 | data[attr] = value
163 | for attr, value in g.edata.items():
164 | data[attr] = value
165 |
166 | return data
167 |
168 | data = HeteroData() # TODO: Delete this
169 |
170 | for node_type in g.ntypes:
171 | for attr, value in g.nodes[node_type].data.items():
172 | data[node_type][attr] = value
173 |
174 | for edge_type in g.canonical_etypes:
175 | row, col = g.edges(form="uv", etype=edge_type)
176 | data[edge_type].edge_index = th.stack([row, col], dim=0)
177 | for attr, value in g.edge_attr_schemes(edge_type).items():
178 | data[edge_type][attr] = value
179 |
180 | return data
181 |
182 |
183 | def split_edge(
184 | dgl_graph,
185 | test_ratio=0.2,
186 | val_ratio=0.1,
187 | random_seed=42,
188 | neg_len="1000",
189 | path=None,
190 | way="random",
191 | ):
192 | if os.path.exists(os.path.join(path, f"{neg_len}/edge_split_hitsk.pt")):
193 | edge_split = th.load(os.path.join(path, f"{neg_len}/edge_split_hitsk.pt"))
194 |
195 | else:
196 | np.random.seed(random_seed)
197 | th.manual_seed(random_seed)
198 |
199 | graph = dgl_graph
200 |
201 | eids = np.arange(graph.num_edges)
202 | eids = np.random.permutation(eids)
203 |
204 | u, v = graph.edge_index
205 |
206 | test_size = int(len(eids) * test_ratio)
207 | val_size = int(len(eids) * val_ratio)
208 |
209 | test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
210 | val_pos_u, val_pos_v = (
211 | u[eids[test_size : test_size + val_size]],
212 | v[eids[test_size : test_size + val_size]],
213 | )
214 | train_pos_u, train_pos_v = (
215 | u[eids[test_size + val_size :]],
216 | v[eids[test_size + val_size :]],
217 | )
218 |
219 | train_edge_index = th.stack((train_pos_u, train_pos_v), dim=1)
220 | val_edge_index = th.stack((val_pos_u, val_pos_v), dim=1)
221 | test_edge_index = th.stack((test_pos_u, test_pos_v), dim=1)
222 |
223 | valid_neg_edge_index = th.randint(
224 | 0, graph.num_nodes, [int(neg_len), 2], dtype=th.long
225 | ) # TODO: filter some edges in graph
226 | test_neg_edge_index = th.randint(
227 | 0, graph.num_nodes, [int(neg_len), 2], dtype=th.long
228 | )
229 |
230 | edge_split = {
231 | "train": {
232 | "edge": train_edge_index,
233 | "train_edge_feature_index": eids[test_size + val_size :],
234 | },
235 | "valid": {"edge": val_edge_index, "edge_neg": valid_neg_edge_index},
236 | "test": {"edge": test_edge_index, "edge_neg": test_neg_edge_index},
237 | }
238 |
239 | th.save(edge_split, os.path.join(path, f"{neg_len}/edge_split_hitsk.pt"))
240 |
241 | return edge_split
242 |
243 |
244 | def split_edge_mrr(
245 | dgl_graph,
246 | test_ratio=0.2,
247 | val_ratio=0.1,
248 | random_seed=42,
249 | neg_len="1000",
250 | path=None,
251 | way="random",
252 | ):
253 | if os.path.exists(os.path.join(path, f"{neg_len}/edge_split_mrr.pt")):
254 | edge_split = th.load(os.path.join(path, f"{neg_len}/edge_split_mrr.pt"))
255 |
256 | else:
257 | np.random.seed(random_seed)
258 | th.manual_seed(random_seed)
259 |
260 | graph = dgl_graph
261 |
262 | eids = np.arange(graph.num_edges)
263 | eids = np.random.permutation(eids)
264 |
265 | u, v = graph.edge_index
266 |
267 | test_size = int(len(eids) * test_ratio)
268 | val_size = int(len(eids) * val_ratio)
269 |
270 | test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
271 | val_pos_u, val_pos_v = (
272 | u[eids[test_size : test_size + val_size]],
273 | v[eids[test_size : test_size + val_size]],
274 | )
275 | train_pos_u, train_pos_v = (
276 | u[eids[test_size + val_size :]],
277 | v[eids[test_size + val_size :]],
278 | )
279 |
280 | valid_target_neg = th.randint(
281 | 0, graph.num_nodes, [len(val_pos_u), int(neg_len)], dtype=th.long
282 | ) # TODO: filter some edges in graph
283 | test_target_neg = th.randint(
284 | 0, graph.num_nodes, [len(test_pos_u), int(neg_len)], dtype=th.long
285 | ) # TODO: filter some edges in graph
286 |
287 | edge_split = {
288 | "train": {
289 | "source_node": train_pos_u,
290 | "target_node": train_pos_v,
291 | "edge": th.stack((train_pos_u, train_pos_v), dim=1),
292 | "train_edge_feature_index": eids[test_size + val_size :],
293 | },
294 | "valid": {
295 | "source_node": val_pos_u,
296 | "target_node": val_pos_v,
297 | "target_node_neg": valid_target_neg,
298 | "edge": th.stack((val_pos_u, val_pos_v), dim=1),
299 | },
300 | "test": {
301 | "source_node": test_pos_u,
302 | "target_node": test_pos_v,
303 | "target_node_neg": test_target_neg,
304 | "edge": th.stack((test_pos_u, test_pos_v), dim=1),
305 | },
306 | }
307 |
308 | th.save(edge_split, os.path.join(path, f"{neg_len}/edge_split_mrr.pt"))
309 |
310 | return edge_split
311 |
312 |
313 | # TODO: Consider time in split
314 | # def split_edge_MMR(dgl_graph, time=2015, random_seed=42, neg_len=1000, path=None):
315 | # if os.path.exists(os.path.join(path, 'edge_split_mrr.pt')):
316 | # edge_split = th.load(os.path.join(path, 'edge_split_mrr.pt'))
317 | # else:
318 |
319 | # np.random.seed(random_seed)
320 | # th.manual_seed(random_seed)
321 |
322 | # graph = from_dgl(dgl_graph)
323 |
324 | # year = list(np.array(graph.year))
325 | # indices = np.arange(graph.num_nodes)
326 | # all_source = graph.edge_index[0]
327 | # all_target = graph.edge_index[1]
328 | # import random
329 | # val_list = []
330 | # test_list = []
331 | # for i in range(year.index(time), graph.num_nodes):
332 | # indices = th.where(all_source == i)[0]
333 | # if len(indices) >= 2:
334 | # _list = random.sample(list(indices), k=2)
335 | # val_list.append(_list[0]) # TOOD: check edge
336 | # test_list.append(_list[1])
337 |
338 | # val_source = all_source[val_list]
339 | # val_target = all_target[val_list]
340 |
341 | # test_source = all_source[test_list]
342 | # test_target = all_target[test_list]
343 |
344 | # all_index = list(range(0, len(graph.edge_index[0])))
345 |
346 | # val_list = np.array(val_list).tolist()
347 | # test_list = np.array(test_list).tolist()
348 |
349 | # train_idx = set(all_index) - set(val_list) - set(test_list)
350 |
351 | # tra_source = all_source[list(train_idx)]
352 | # tra_target = all_target[list(train_idx)]
353 |
354 | # val_target_neg = th.randint(low=0, high=year.index(time), size=(len(val_source), neg_len))
355 | # test_target_neg = th.randint(low=0, high=year.index(time), size=(len(test_source), neg_len))
356 |
357 | # # ! 创建dict类型存法
358 | # edge_split = {'train': {'source_node': tra_source, 'target_node': tra_target},
359 | # 'valid': {'source_node': val_source, 'target_node': val_target,
360 | # 'target_node_neg': val_target_neg},
361 | # 'test': {'source_node': test_source, 'target_node': test_target,
362 | # 'target_node_neg': test_target_neg}}
363 |
364 | # th.save(edge_split, os.path.join(path, 'edge_split_mrr.pt'))
365 | # # ! 保存子图
366 | # train_g = dgl.remove_edges(dgl_graph, val_list + test_list)
367 | # dgl.save_graphs(os.path.join(path, 'train_G.pt'), train_g)
368 |
369 | # return edge_split, train_g
370 |
371 |
372 | class Evaluator:
373 | def __init__(self, name):
374 | self.name = name
375 | meta_info = {
376 | "History": {"name": "History", "eval_metric": "hits@50"},
377 | "DBLP": {"name": "DBLP", "eval_metric": "mrr"},
378 | }
379 |
380 | self.eval_metric = meta_info[self.name]["eval_metric"]
381 |
382 | if "hits@" in self.eval_metric:
383 | ### Hits@K
384 |
385 | self.K = int(self.eval_metric.split("@")[1])
386 |
387 | def _parse_and_check_input(self, input_dict):
388 | if "hits@" in self.eval_metric:
389 | if not "y_pred_pos" in input_dict:
390 | raise RuntimeError("Missing key of y_pred_pos")
391 | if not "y_pred_neg" in input_dict:
392 | raise RuntimeError("Missing key of y_pred_neg")
393 |
394 | y_pred_pos, y_pred_neg = input_dict["y_pred_pos"], input_dict["y_pred_neg"]
395 |
396 | """
397 | y_pred_pos: numpy ndarray or torch tensor of shape (num_edge, )
398 | y_pred_neg: numpy ndarray or torch tensor of shape (num_edge, )
399 | """
400 |
401 | # convert y_pred_pos, y_pred_neg into either torch tensor or both numpy array
402 | # type_info stores information whether torch or numpy is used
403 |
404 | type_info = None
405 |
406 | # check the raw tyep of y_pred_pos
407 | if not (
408 | isinstance(y_pred_pos, np.ndarray)
409 | or (th is not None and isinstance(y_pred_pos, th.Tensor))
410 | ):
411 | raise ValueError(
412 | "y_pred_pos needs to be either numpy ndarray or th tensor"
413 | )
414 |
415 | # check the raw type of y_pred_neg
416 | if not (
417 | isinstance(y_pred_neg, np.ndarray)
418 | or (th is not None and isinstance(y_pred_neg, th.Tensor))
419 | ):
420 | raise ValueError(
421 | "y_pred_neg needs to be either numpy ndarray or th tensor"
422 | )
423 |
424 | # if either y_pred_pos or y_pred_neg is th tensor, use th tensor
425 | if th is not None and (
426 | isinstance(y_pred_pos, th.Tensor) or isinstance(y_pred_neg, th.Tensor)
427 | ):
428 | # converting to th.Tensor to numpy on cpu
429 | if isinstance(y_pred_pos, np.ndarray):
430 | y_pred_pos = th.from_numpy(y_pred_pos)
431 |
432 | if isinstance(y_pred_neg, np.ndarray):
433 | y_pred_neg = th.from_numpy(y_pred_neg)
434 |
435 | # put both y_pred_pos and y_pred_neg on the same device
436 | y_pred_pos = y_pred_pos.to(y_pred_neg.device)
437 |
438 | type_info = "torch"
439 |
440 | else:
441 | # both y_pred_pos and y_pred_neg are numpy ndarray
442 |
443 | type_info = "numpy"
444 |
445 | if not y_pred_pos.ndim == 1:
446 | raise RuntimeError(
447 | "y_pred_pos must to 1-dim arrray, {}-dim array given".format(
448 | y_pred_pos.ndim
449 | )
450 | )
451 |
452 | if not y_pred_neg.ndim == 1:
453 | raise RuntimeError(
454 | "y_pred_neg must to 1-dim arrray, {}-dim array given".format(
455 | y_pred_neg.ndim
456 | )
457 | )
458 |
459 | return y_pred_pos, y_pred_neg, type_info
460 |
461 | elif "mrr" == self.eval_metric:
462 |
463 | if not "y_pred_pos" in input_dict:
464 | raise RuntimeError("Missing key of y_pred_pos")
465 | if not "y_pred_neg" in input_dict:
466 | raise RuntimeError("Missing key of y_pred_neg")
467 |
468 | y_pred_pos, y_pred_neg = input_dict["y_pred_pos"], input_dict["y_pred_neg"]
469 |
470 | """
471 | y_pred_pos: numpy ndarray or torch tensor of shape (num_edge, )
472 | y_pred_neg: numpy ndarray or torch tensor of shape (num_edge, num_node_negative)
473 | """
474 |
475 | # convert y_pred_pos, y_pred_neg into either torch tensor or both numpy array
476 | # type_info stores information whether torch or numpy is used
477 |
478 | type_info = None
479 |
480 | # check the raw tyep of y_pred_pos
481 | if not (
482 | isinstance(y_pred_pos, np.ndarray)
483 | or (th is not None and isinstance(y_pred_pos, th.Tensor))
484 | ):
485 | raise ValueError(
486 | "y_pred_pos needs to be either numpy ndarray or th tensor"
487 | )
488 |
489 | # check the raw type of y_pred_neg
490 | if not (
491 | isinstance(y_pred_neg, np.ndarray)
492 | or (th is not None and isinstance(y_pred_neg, th.Tensor))
493 | ):
494 | raise ValueError(
495 | "y_pred_neg needs to be either numpy ndarray or th tensor"
496 | )
497 |
498 | # if either y_pred_pos or y_pred_neg is th tensor, use th tensor
499 | if th is not None and (
500 | isinstance(y_pred_pos, th.Tensor) or isinstance(y_pred_neg, th.Tensor)
501 | ):
502 | # converting to th.Tensor to numpy on cpu
503 | if isinstance(y_pred_pos, np.ndarray):
504 | y_pred_pos = th.from_numpy(y_pred_pos)
505 |
506 | if isinstance(y_pred_neg, np.ndarray):
507 | y_pred_neg = th.from_numpy(y_pred_neg)
508 |
509 | # put both y_pred_pos and y_pred_neg on the same device
510 | y_pred_pos = y_pred_pos.to(y_pred_neg.device)
511 |
512 | type_info = "torch"
513 |
514 | else:
515 | # both y_pred_pos and y_pred_neg are numpy ndarray
516 |
517 | type_info = "numpy"
518 |
519 | if not y_pred_pos.ndim == 1:
520 | raise RuntimeError(
521 | "y_pred_pos must to 1-dim arrray, {}-dim array given".format(
522 | y_pred_pos.ndim
523 | )
524 | )
525 |
526 | if not y_pred_neg.ndim == 2:
527 | raise RuntimeError(
528 | "y_pred_neg must to 2-dim arrray, {}-dim array given".format(
529 | y_pred_neg.ndim
530 | )
531 | )
532 |
533 | return y_pred_pos, y_pred_neg, type_info
534 |
535 | else:
536 | raise ValueError("Undefined eval metric %s" % (self.eval_metric))
537 |
538 | def eval(self, input_dict):
539 |
540 | if "hits@" in self.eval_metric:
541 | y_pred_pos, y_pred_neg, type_info = self._parse_and_check_input(input_dict)
542 | return self._eval_hits(y_pred_pos, y_pred_neg, type_info)
543 | elif self.eval_metric == "mrr":
544 | y_pred_pos, y_pred_neg, type_info = self._parse_and_check_input(input_dict)
545 | return self._eval_mrr(y_pred_pos, y_pred_neg, type_info)
546 |
547 | else:
548 | raise ValueError("Undefined eval metric %s" % (self.eval_metric))
549 |
550 | def _eval_hits(self, y_pred_pos, y_pred_neg, type_info):
551 | """
552 | compute Hits@K
553 | For each positive target node, the negative target nodes are the same.
554 |
555 | y_pred_neg is an array.
556 | rank y_pred_pos[i] against y_pred_neg for each i
557 | """
558 |
559 | if len(y_pred_neg) < self.K:
560 | return {"hits@{}".format(self.K): 1.0}
561 |
562 | if type_info == "torch":
563 | kth_score_in_negative_edges = th.topk(y_pred_neg, self.K)[0][-1]
564 | hitsK = float(th.sum(y_pred_pos > kth_score_in_negative_edges).cpu()) / len(
565 | y_pred_pos
566 | )
567 |
568 | # type_info is numpy
569 | else:
570 | kth_score_in_negative_edges = np.sort(y_pred_neg)[-self.K]
571 | hitsK = float(np.sum(y_pred_pos > kth_score_in_negative_edges)) / len(
572 | y_pred_pos
573 | )
574 |
575 | return {"hits@{}".format(self.K): hitsK}
576 |
577 | def _eval_mrr(self, y_pred_pos, y_pred_neg, type_info):
578 | """
579 | compute mrr
580 | y_pred_neg is an array with shape (batch size, num_entities_neg).
581 | y_pred_pos is an array with shape (batch size, )
582 | """
583 |
584 | if type_info == "torch":
585 | y_pred = th.cat([y_pred_pos.view(-1, 1), y_pred_neg], dim=1)
586 | argsort = th.argsort(y_pred, dim=1, descending=True)
587 | ranking_list = th.nonzero(argsort == 0, as_tuple=False)
588 | ranking_list = ranking_list[:, 1] + 1
589 | # hits1_list = (ranking_list <= 1).to(th.float)
590 | # hits3_list = (ranking_list <= 3).to(th.float)
591 | # hits10_list = (ranking_list <= 10).to(th.float)
592 | mrr_list = 1.0 / ranking_list.to(th.float)
593 |
594 | return {
595 | # 'hits@1_list': hits1_list,
596 | # 'hits@3_list': hits3_list,
597 | # 'hits@10_list': hits10_list,
598 | "mrr_list": mrr_list
599 | }
600 |
601 | else:
602 | y_pred = np.concatenate([y_pred_pos.reshape(-1, 1), y_pred_neg], axis=1)
603 | argsort = np.argsort(-y_pred, axis=1)
604 | ranking_list = (argsort == 0).nonzero()
605 | ranking_list = ranking_list[1] + 1
606 | # hits1_list = (ranking_list <= 1).astype(np.float32)
607 | # hits3_list = (ranking_list <= 3).astype(np.float32)
608 | # hits10_list = (ranking_list <= 10).astype(np.float32)
609 | mrr_list = 1.0 / ranking_list.astype(np.float32)
610 |
611 | return {
612 | # 'hits@1_list': hits1_list,
613 | # 'hits@3_list': hits3_list,
614 | # 'hits@10_list': hits10_list,
615 | "mrr_list": mrr_list
616 | }
617 |
--------------------------------------------------------------------------------
/GNN/model/GNN_library.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.nn import (
3 | GCNConv,
4 | SAGEConv,
5 | TransformerConv,
6 | GATConv,
7 | GINEConv,
8 | GeneralConv,
9 | )
10 | from torch.nn import Linear
11 | import torch.nn.functional as F
12 | import torch
13 | import torch.nn as nn
14 |
15 | import dgl.nn.pytorch as dglnn
16 | from dgl import function as fn
17 | from dgl.ops import edge_softmax
18 | from dgl.utils import expand_as_pair
19 | import torch.nn.functional as F
20 | from dgl.sampling import node2vec_random_walk
21 | from torch.utils.data import DataLoader
22 | from typing import Callable, Optional, Union
23 |
24 | import torch
25 | from torch import Tensor
26 | from torch.nn import Linear
27 | from torch_geometric.nn.conv import MessagePassing
28 | from torch_geometric.nn.dense.linear import Linear
29 | from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
30 |
31 |
32 | class ElementWiseLinear(nn.Module):
33 | def __init__(self, size, weight=True, bias=True, inplace=False):
34 | super().__init__()
35 | if weight:
36 | self.weight = nn.Parameter(torch.Tensor(size))
37 | else:
38 | self.weight = None
39 | if bias:
40 | self.bias = nn.Parameter(torch.Tensor(size))
41 | else:
42 | self.bias = None
43 | self.inplace = inplace
44 |
45 | self.reset_parameters()
46 |
47 | def reset_parameters(self):
48 | if self.weight is not None:
49 | nn.init.ones_(self.weight)
50 | if self.bias is not None:
51 | nn.init.zeros_(self.bias)
52 |
53 | def forward(self, x):
54 | if self.inplace:
55 | if self.weight is not None:
56 | x.mul_(self.weight)
57 | if self.bias is not None:
58 | x.add_(self.bias)
59 | else:
60 | if self.weight is not None:
61 | x = x * self.weight
62 | if self.bias is not None:
63 | x = x + self.bias
64 | return x
65 |
66 |
67 | class APPNP(nn.Module):
68 | def __init__(
69 | self,
70 | in_feats,
71 | n_layers,
72 | n_hidden,
73 | n_classes,
74 | activation,
75 | input_drop,
76 | edge_drop,
77 | alpha,
78 | k,
79 | ):
80 | super(APPNP, self).__init__()
81 | self.n_layers = n_layers
82 | self.n_hidden = n_hidden
83 | self.layers = nn.ModuleList()
84 | # input layer
85 | for i in range(n_layers):
86 | in_hidden = n_hidden if i > 0 else in_feats
87 | out_hidden = n_hidden if i < n_layers - 1 else n_classes
88 | self.layers.append(nn.Linear(in_hidden, out_hidden))
89 |
90 | self.activation = activation
91 |
92 | self.input_drop = nn.Dropout(input_drop)
93 |
94 | self.propagate = dglnn.APPNPConv(k, alpha, edge_drop)
95 | self.reset_parameters()
96 |
97 | def reset_parameters(self):
98 | for layer in self.layers:
99 | layer.reset_parameters()
100 |
101 | def forward(self, graph, features):
102 | # prediction step
103 | h = features
104 | h = self.input_drop(h)
105 | h = self.activation(self.layers[0](h))
106 | for layer in self.layers[1:-1]:
107 | h = self.activation(layer(h))
108 | h = self.layers[-1](self.input_drop(h))
109 | # propagation step
110 | h = self.propagate(graph, h)
111 | return h
112 |
113 |
114 | class GraphSAGE(nn.Module):
115 | def __init__(
116 | self,
117 | in_feats,
118 | n_hidden,
119 | n_classes,
120 | n_layers,
121 | activation,
122 | dropout,
123 | aggregator_type,
124 | input_drop=0.0,
125 | ):
126 | super(GraphSAGE, self).__init__()
127 | self.n_layers = n_layers
128 | self.n_hidden = n_hidden
129 | self.n_classes = n_classes
130 |
131 | self.layers = nn.ModuleList()
132 | self.norms = nn.ModuleList()
133 |
134 | # input layer
135 | for i in range(n_layers):
136 | in_hidden = n_hidden if i > 0 else in_feats
137 | out_hidden = n_hidden if i < n_layers - 1 else n_classes
138 |
139 | self.layers.append(dglnn.SAGEConv(in_hidden, out_hidden, aggregator_type))
140 | if i < n_layers - 1:
141 | self.norms.append(nn.BatchNorm1d(out_hidden))
142 |
143 | self.input_drop = nn.Dropout(input_drop)
144 | self.dropout = nn.Dropout(dropout)
145 | self.activation = activation
146 |
147 | def forward(self, graph, feat):
148 | h = feat
149 | h = self.input_drop(h)
150 |
151 | for l, layer in enumerate(self.layers):
152 | conv = layer(graph, h)
153 | h = conv
154 |
155 | if l != len(self.layers) - 1:
156 | h = self.norms[l](h)
157 | h = self.activation(h)
158 | h = self.dropout(h)
159 |
160 | return h
161 |
162 |
163 | class GCN(torch.nn.Module):
164 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
165 | super(GCN, self).__init__()
166 |
167 | self.convs = torch.nn.ModuleList()
168 | self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
169 | for _ in range(num_layers - 2):
170 | self.convs.append(GCNConv(hidden_channels, hidden_channels, cached=True))
171 | self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))
172 |
173 | self.dropout = dropout
174 |
175 | def reset_parameters(self):
176 | for conv in self.convs:
177 | conv.reset_parameters()
178 |
179 | def forward(self, x, adj_t):
180 | for conv in self.convs[:-1]:
181 | x = conv(x, adj_t)
182 | x = F.relu(x)
183 | x = F.dropout(x, p=self.dropout, training=self.training)
184 | x = self.convs[-1](x, adj_t)
185 | return x
186 |
187 |
188 | class GCN(torch.nn.Module):
189 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
190 | super(GCN, self).__init__()
191 |
192 | self.convs = torch.nn.ModuleList()
193 | self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
194 | for _ in range(num_layers - 2):
195 | self.convs.append(GCNConv(hidden_channels, hidden_channels, cached=True))
196 | self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))
197 |
198 | self.dropout = dropout
199 |
200 | def reset_parameters(self):
201 | for conv in self.convs:
202 | conv.reset_parameters()
203 |
204 | def forward(self, x, adj_t):
205 | for conv in self.convs[:-1]:
206 | x = conv(x, adj_t)
207 | x = F.relu(x)
208 | x = F.dropout(x, p=self.dropout, training=self.training)
209 | x = self.convs[-1](x, adj_t)
210 | return x
211 |
212 |
213 | class GAT(nn.Module):
214 | def __init__(
215 | self,
216 | in_feats,
217 | n_classes,
218 | n_hidden,
219 | n_layers,
220 | n_heads,
221 | activation,
222 | dropout=0.0,
223 | input_drop=0.0,
224 | attn_drop=0.0,
225 | edge_drop=0.0,
226 | use_attn_dst=True,
227 | use_symmetric_norm=False,
228 | residual=False,
229 | ):
230 | super().__init__()
231 | self.in_feats = in_feats
232 | self.n_hidden = n_hidden
233 | self.n_classes = n_classes
234 | self.n_layers = n_layers
235 | self.num_heads = n_heads
236 |
237 | self.convs = nn.ModuleList()
238 | self.norms = nn.ModuleList()
239 |
240 | for i in range(n_layers):
241 | in_hidden = n_heads * n_hidden if i > 0 else in_feats
242 | out_hidden = n_hidden if i < n_layers - 1 else n_classes
243 | num_heads = n_heads if i < n_layers - 1 else 1
244 | out_channels = n_heads
245 |
246 | self.convs.append(
247 | GATConv(
248 | in_hidden,
249 | out_hidden,
250 | num_heads=num_heads,
251 | attn_drop=attn_drop,
252 | edge_drop=edge_drop,
253 | use_attn_dst=use_attn_dst,
254 | use_symmetric_norm=use_symmetric_norm,
255 | residual=residual,
256 | )
257 | )
258 |
259 | if i < n_layers - 1:
260 | self.norms.append(nn.BatchNorm1d(out_channels * out_hidden))
261 |
262 | self.bias_last = ElementWiseLinear(
263 | n_classes, weight=False, bias=True, inplace=True
264 | )
265 |
266 | self.input_drop = nn.Dropout(input_drop)
267 | self.dropout = nn.Dropout(dropout)
268 | self.activation = activation
269 |
270 | def forward(self, graph, feat):
271 | h = feat
272 | h = self.input_drop(h)
273 |
274 | for i in range(self.n_layers):
275 | conv = self.convs[i](graph, h)
276 |
277 | h = conv
278 |
279 | if i < self.n_layers - 1:
280 | h = h.flatten(1)
281 | h = self.norms[i](h)
282 | h = self.activation(h, inplace=True)
283 | h = self.dropout(h)
284 |
285 | h = h.mean(1)
286 | h = self.bias_last(h)
287 |
288 | return h
289 |
290 |
291 | class ApplyNodeFunc(nn.Module):
292 | """Update the node feature hv with MLP, BN and ReLU."""
293 |
294 | def __init__(self, mlp):
295 | super(ApplyNodeFunc, self).__init__()
296 | self.mlp = mlp
297 | self.bn = nn.BatchNorm1d(self.mlp.output_dim)
298 |
299 | def forward(self, h):
300 | h = self.mlp(h)
301 | h = self.bn(h)
302 | h = F.relu(h)
303 | return h
304 |
305 |
306 | class MLP(nn.Module):
307 | """MLP with linear output"""
308 |
309 | def __init__(
310 | self, num_layers, input_dim, hidden_dim, output_dim, input_drop=0.0, dropout=0.2
311 | ):
312 | """MLP layers construction
313 |
314 | Paramters
315 | ---------
316 | num_layers: int
317 | The number of linear layers
318 | input_dim: int
319 | The dimensionality of input features
320 | hidden_dim: int
321 | The dimensionality of hidden units at ALL layers
322 | output_dim: int
323 | The number of classes for prediction
324 |
325 | """
326 | super(MLP, self).__init__()
327 | self.linear_or_not = True # default is linear model
328 | self.num_layers = num_layers
329 | self.output_dim = output_dim
330 | self.input_drop = nn.Dropout(input_drop)
331 | self.dropout = nn.Dropout(dropout)
332 |
333 | if num_layers < 1:
334 | raise ValueError("number of layers should be positive!")
335 | elif num_layers == 1:
336 | # Linear model
337 | self.linear = nn.Linear(input_dim, output_dim)
338 | else:
339 | # Multi-layer model
340 | self.linear_or_not = False
341 | self.linears = torch.nn.ModuleList()
342 | self.batch_norms = torch.nn.ModuleList()
343 |
344 | self.linears.append(nn.Linear(input_dim, hidden_dim))
345 | for layer in range(num_layers - 2):
346 | self.linears.append(nn.Linear(hidden_dim, hidden_dim))
347 | self.linears.append(nn.Linear(hidden_dim, output_dim))
348 |
349 | for layer in range(num_layers - 1):
350 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))
351 |
352 | def forward(self, x):
353 | if self.linear_or_not:
354 | # If linear model
355 | return self.linear(x)
356 | else:
357 | # If MLP
358 | x = self.input_drop(x)
359 | h = x
360 | for i in range(self.num_layers - 1):
361 | h = F.relu(self.batch_norms[i](self.linears[i](h)))
362 | h = self.dropout(h)
363 | return self.linears[-1](h)
364 |
365 |
366 | class GIN(nn.Module):
367 | """GIN model"""
368 |
369 | def __init__(
370 | self,
371 | in_feats,
372 | n_hidden,
373 | n_classes,
374 | n_layers,
375 | num_mlp_layers,
376 | input_dropout,
377 | learn_eps,
378 | neighbor_pooling_type,
379 | ):
380 | """model parameters setting
381 |
382 | Paramters
383 | ---------
384 | num_layers: int
385 | The number of linear layers in the neural network
386 | num_mlp_layers: int
387 | The number of linear layers in mlps
388 | input_dim: int
389 | The dimensionality of input features
390 | hidden_dim: int
391 | The dimensionality of hidden units at ALL layers
392 | output_dim: int
393 | The number of classes for prediction
394 | final_dropout: float
395 | dropout ratio on the final linear layer
396 | learn_eps: boolean
397 | If True, learn epsilon to distinguish center nodes from neighbors
398 | If False, aggregate neighbors and center nodes altogether.
399 | neighbor_pooling_type: str
400 | how to aggregate neighbors (sum, mean, or max)
401 | graph_pooling_type: str
402 | how to aggregate entire nodes in a graph (sum, mean or max)
403 |
404 | """
405 | super(GIN, self).__init__()
406 | self.n_layers = n_layers
407 | self.n_hidden = n_hidden
408 | self.n_classes = n_classes
409 | self.learn_eps = learn_eps
410 |
411 | # List of MLPs
412 | self.ginlayers = torch.nn.ModuleList()
413 | self.batch_norms = torch.nn.ModuleList()
414 |
415 | for layer in range(self.n_layers - 1):
416 | if layer == 0:
417 | mlp = MLP(num_mlp_layers, in_feats, n_hidden, n_hidden)
418 | else:
419 | mlp = MLP(num_mlp_layers, n_hidden, n_hidden, n_hidden)
420 |
421 | self.ginlayers.append(
422 | dglnn.GINConv(
423 | ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps
424 | )
425 | )
426 | self.batch_norms.append(nn.BatchNorm1d(n_hidden))
427 |
428 | # Linear function for graph poolings of output of each layer
429 | # which maps the output of different layers into a prediction score
430 | self.linears_prediction = torch.nn.ModuleList()
431 |
432 | for layer in range(n_layers):
433 | if layer == 0:
434 | self.linears_prediction.append(nn.Linear(in_feats, n_classes))
435 | else:
436 | self.linears_prediction.append(nn.Linear(n_hidden, n_classes))
437 |
438 | self.input_drop = nn.Dropout(input_dropout)
439 |
440 | def forward(self, g, h):
441 | h = self.input_drop(h)
442 | # list of hidden representation at each layer (including input)
443 | hidden_rep = []
444 |
445 | for i in range(self.n_layers - 1):
446 | h = self.ginlayers[i](g, h)
447 | h = self.batch_norms[i](h)
448 | h = F.relu(h)
449 | hidden_rep.append(h)
450 |
451 | z = torch.cat(hidden_rep, dim=1)
452 |
453 | return z
454 |
455 |
456 | class JKNet(nn.Module):
457 | def __init__(
458 | self, in_feats, n_hidden, n_classes, n_layers=1, mode="cat", dropout=0.0
459 | ):
460 | super(JKNet, self).__init__()
461 |
462 | self.mode = mode
463 | self.dropout = nn.Dropout(dropout)
464 | self.layers = nn.ModuleList()
465 | self.layers.append(dglnn.GraphConv(in_feats, n_hidden, activation=F.relu))
466 | for _ in range(n_layers):
467 | self.layers.append(dglnn.GraphConv(n_hidden, n_hidden, activation=F.relu))
468 |
469 | if self.mode == "lstm":
470 | self.jump = dglnn.JumpingKnowledge(mode, n_hidden, n_layers)
471 | else:
472 | self.jump = dglnn.JumpingKnowledge(mode)
473 |
474 | if self.mode == "cat":
475 | n_hidden = n_hidden * (n_layers + 1)
476 |
477 | self.output = nn.Linear(n_hidden, n_classes)
478 | self.reset_params()
479 |
480 | def reset_params(self):
481 | self.output.reset_parameters()
482 | for layers in self.layers:
483 | layers.reset_parameters()
484 | if self.mode == "lstm":
485 | self.lstm.reset_parameters()
486 | self.attn.reset_parameters()
487 |
488 | def forward(self, g, feats):
489 | feat_lst = []
490 | for layer in self.layers:
491 | feats = self.dropout(layer(g, feats))
492 | feat_lst.append(feats)
493 |
494 | if self.mode == "cat":
495 | out = torch.cat(feat_lst, dim=-1)
496 | elif self.mode == "max":
497 | out = torch.stack(feat_lst, dim=-1).max(dim=-1)[0]
498 | else:
499 | # lstm
500 | x = torch.stack(feat_lst, dim=1)
501 | alpha, _ = self.lstm(x)
502 | alpha = self.attn(alpha).squeeze(-1)
503 | alpha = torch.softmax(alpha, dim=-1).unsqueeze(-1)
504 | out = (x * alpha).sum(dim=1)
505 |
506 | g.ndata["h"] = out
507 | g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
508 |
509 | return self.output(g.ndata["h"])
510 |
511 |
512 | class MoNet(nn.Module):
513 | def __init__(
514 | self,
515 | in_feats,
516 | n_hidden,
517 | n_classes,
518 | n_layers,
519 | dim,
520 | n_kernels,
521 | input_drop,
522 | dropout,
523 | ):
524 | super(MoNet, self).__init__()
525 | self.input_drop = nn.Dropout(input_drop)
526 | self.layers = nn.ModuleList()
527 | self.pseudo_proj = nn.ModuleList()
528 |
529 | # Input layer
530 | self.layers.append(dglnn.GMMConv(in_feats, n_hidden, dim, n_kernels))
531 | self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh()))
532 |
533 | # Hidden layer
534 | for _ in range(n_layers - 1):
535 | self.layers.append(dglnn.GMMConv(n_hidden, n_hidden, dim, n_kernels))
536 | self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh()))
537 |
538 | # Output layer
539 | self.layers.append(dglnn.GMMConv(n_hidden, n_classes, dim, n_kernels))
540 | self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh()))
541 | self.dropout = nn.Dropout(dropout)
542 |
543 | def forward(self, feat, pseudo, graph):
544 | h = feat
545 | h = self.input_drop(h)
546 | for i in range(len(self.layers)):
547 | if i != 0:
548 | h = self.dropout(h)
549 | h = self.layers[i](graph, h, self.pseudo_proj[i](pseudo))
550 | return h
551 |
552 |
553 | class Node2vec(nn.Module):
554 | """Node2vec model from paper node2vec: Scalable Feature Learning for Networks
555 | Attributes
556 | ----------
557 | g: DGLGraph
558 | The graph.
559 | embedding_dim: int
560 | Dimension of node embedding.
561 | walk_length: int
562 | Length of each trace.
563 | p: float
564 | Likelihood of immediately revisiting a node in the walk. Same notation as in the paper.
565 | q: float
566 | Control parameter to interpolate between breadth-first strategy and depth-first strategy.
567 | Same notation as in the paper.
568 | num_walks: int
569 | Number of random walks for each node. Default: 10.
570 | window_size: int
571 | Maximum distance between the center node and predicted node. Default: 5.
572 | num_negatives: int
573 | The number of negative samples for each positive sample. Default: 5.
574 | use_sparse: bool
575 | If set to True, use PyTorch's sparse embedding and optimizer. Default: ``True``.
576 | weight_name : str, optional
577 | The name of the edge feature tensor on the graph storing the (unnormalized)
578 | probabilities associated with each edge for choosing the next node.
579 | The feature tensor must be non-negative and the sum of the probabilities
580 | must be positive for the outbound edges of all nodes (although they don't have
581 | to sum up to one). The result will be undefined otherwise.
582 | If omitted, DGL assumes that the neighbors are picked uniformly.
583 | """
584 |
585 | def __init__(
586 | self,
587 | g,
588 | embedding_dim,
589 | walk_length,
590 | p,
591 | q,
592 | num_walks=10,
593 | window_size=5,
594 | num_negatives=5,
595 | use_sparse=True,
596 | weight_name=None,
597 | ):
598 | super(Node2vec, self).__init__()
599 |
600 | assert walk_length >= window_size
601 |
602 | self.g = g
603 | self.embedding_dim = embedding_dim
604 | self.walk_length = walk_length
605 | self.p = p
606 | self.q = q
607 | self.num_walks = num_walks
608 | self.window_size = window_size
609 | self.num_negatives = num_negatives
610 | self.N = self.g.num_nodes()
611 | if weight_name is not None:
612 | self.prob = weight_name
613 | else:
614 | self.prob = None
615 |
616 | self.embedding = nn.Embedding(self.N, embedding_dim, sparse=use_sparse)
617 |
618 | def reset_parameters(self):
619 | self.embedding.reset_parameters()
620 |
621 | def sample(self, batch):
622 | """
623 | Generate positive and negative samples.
624 | Positive samples are generated from random walk
625 | Negative samples are generated from random sampling
626 | """
627 | if not isinstance(batch, torch.Tensor):
628 | batch = torch.tensor(batch)
629 |
630 | batch = batch.repeat(self.num_walks)
631 | # positive
632 | pos_traces = node2vec_random_walk(
633 | self.g, batch, self.p, self.q, self.walk_length, self.prob
634 | )
635 | pos_traces = pos_traces.unfold(1, self.window_size, 1) # rolling window
636 | pos_traces = pos_traces.contiguous().view(-1, self.window_size)
637 |
638 | # negative
639 | neg_batch = batch.repeat(self.num_negatives)
640 | neg_traces = torch.randint(self.N, (neg_batch.size(0), self.walk_length))
641 | neg_traces = torch.cat([neg_batch.view(-1, 1), neg_traces], dim=-1)
642 | neg_traces = neg_traces.unfold(1, self.window_size, 1) # rolling window
643 | neg_traces = neg_traces.contiguous().view(-1, self.window_size)
644 |
645 | return pos_traces, neg_traces
646 |
647 | def forward(self, nodes=None):
648 | """
649 | Returns the embeddings of the input nodes
650 | Parameters
651 | ----------
652 | nodes: Tensor, optional
653 | Input nodes, if set `None`, will return all the node embedding.
654 | Returns
655 | -------
656 | Tensor
657 | Node embedding
658 | """
659 | emb = self.embedding.weight
660 | if nodes is None:
661 | return emb
662 | else:
663 | return emb[nodes]
664 |
665 | def loss(self, pos_trace, neg_trace):
666 | """
667 | Computes the loss given positive and negative random walks.
668 | Parameters
669 | ----------
670 | pos_trace: Tensor
671 | positive random walk trace
672 | neg_trace: Tensor
673 | negative random walk trace
674 | """
675 | e = 1e-15
676 |
677 | # Positive
678 | pos_start, pos_rest = (
679 | pos_trace[:, 0],
680 | pos_trace[:, 1:].contiguous(),
681 | ) # start node and following trace
682 | w_start = self.embedding(pos_start).unsqueeze(dim=1)
683 | w_rest = self.embedding(pos_rest)
684 | pos_out = (w_start * w_rest).sum(dim=-1).view(-1)
685 |
686 | # Negative
687 | neg_start, neg_rest = neg_trace[:, 0], neg_trace[:, 1:].contiguous()
688 |
689 | w_start = self.embedding(neg_start).unsqueeze(dim=1)
690 | w_rest = self.embedding(neg_rest)
691 | neg_out = (w_start * w_rest).sum(dim=-1).view(-1)
692 |
693 | # compute loss
694 | pos_loss = -torch.log(torch.sigmoid(pos_out) + e).mean()
695 | neg_loss = -torch.log(1 - torch.sigmoid(neg_out) + e).mean()
696 |
697 | return pos_loss + neg_loss
698 |
699 | def loader(self, batch_size):
700 | """
701 | Parameters
702 | ----------
703 | batch_size: int
704 | batch size
705 | Returns
706 | -------
707 | DataLoader
708 | Node2vec training data loader
709 | """
710 | return DataLoader(
711 | torch.arange(self.N),
712 | batch_size=batch_size,
713 | shuffle=True,
714 | collate_fn=self.sample,
715 | )
716 |
717 |
718 | class Node2vecModel(object):
719 | """
720 | Wrapper of the ``Node2Vec`` class with a ``train`` method.
721 | Attributes
722 | ----------
723 | g: DGLGraph
724 | The graph.
725 | embedding_dim: int
726 | Dimension of node embedding.
727 | walk_length: int
728 | Length of each trace.
729 | p: float
730 | Likelihood of immediately revisiting a node in the walk.
731 | q: float
732 | Control parameter to interpolate between breadth-first strategy and depth-first strategy.
733 | num_walks: int
734 | Number of random walks for each node. Default: 10.
735 | window_size: int
736 | Maximum distance between the center node and predicted node. Default: 5.
737 | num_negatives: int
738 | The number of negative samples for each positive sample. Default: 5.
739 | use_sparse: bool
740 | If set to True, uses PyTorch's sparse embedding and optimizer. Default: ``True``.
741 | weight_name : str, optional
742 | The name of the edge feature tensor on the graph storing the (unnormalized)
743 | probabilities associated with each edge for choosing the next node.
744 | The feature tensor must be non-negative and the sum of the probabilities
745 | must be positive for the outbound edges of all nodes (although they don't have
746 | to sum up to one). The result will be undefined otherwise.
747 | If omitted, DGL assumes that the neighbors are picked uniformly. Default: ``None``.
748 | eval_set: list of tuples (Tensor, Tensor)
749 | [(nodes_train,y_train),(nodes_val,y_val)]
750 | If omitted, model will not be evaluated. Default: ``None``.
751 | eval_steps: int
752 | Interval steps of evaluation.
753 | if set <= 0, model will not be evaluated. Default: ``None``.
754 | device: str
755 | device, default 'cpu'.
756 | """
757 |
758 | def __init__(
759 | self,
760 | g,
761 | embedding_dim,
762 | walk_length,
763 | p=1.0,
764 | q=1.0,
765 | num_walks=1,
766 | window_size=5,
767 | num_negatives=5,
768 | use_sparse=True,
769 | weight_name=None,
770 | eval_set=None,
771 | eval_steps=-1,
772 | device="cpu",
773 | ):
774 |
775 | self.model = Node2vec(
776 | g,
777 | embedding_dim,
778 | walk_length,
779 | p,
780 | q,
781 | num_walks,
782 | window_size,
783 | num_negatives,
784 | use_sparse,
785 | weight_name,
786 | )
787 | self.g = g
788 | self.use_sparse = use_sparse
789 | self.eval_steps = eval_steps
790 | self.eval_set = eval_set
791 |
792 | if device == "cpu":
793 | self.device = device
794 | else:
795 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
796 |
797 | def _train_step(self, model, loader, optimizer, device):
798 | model.train()
799 | total_loss = 0
800 | for pos_traces, neg_traces in loader:
801 | pos_traces, neg_traces = pos_traces.to(device), neg_traces.to(device)
802 | optimizer.zero_grad()
803 | loss = model.loss(pos_traces, neg_traces)
804 | loss.backward()
805 | optimizer.step()
806 | total_loss += loss.item()
807 | return total_loss / len(loader)
808 |
809 | @torch.no_grad()
810 | def _evaluate_step(self):
811 | nodes_train, y_train = self.eval_set[0]
812 | nodes_val, y_val = self.eval_set[1]
813 |
814 | acc = self.model.evaluate(nodes_train, y_train, nodes_val, y_val)
815 | return acc
816 |
817 | def train(self, epochs, batch_size, learning_rate=0.01):
818 | """
819 | Parameters
820 | ----------
821 | epochs: int
822 | num of train epoch
823 | batch_size: int
824 | batch size
825 | learning_rate: float
826 | learning rate. Default 0.01.
827 | """
828 |
829 | self.model = self.model.to(self.device)
830 | loader = self.model.loader(batch_size)
831 | if self.use_sparse:
832 | optimizer = torch.optim.SparseAdam(
833 | list(self.model.parameters()), lr=learning_rate
834 | )
835 | else:
836 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
837 | for i in range(epochs):
838 | loss = self._train_step(self.model, loader, optimizer, self.device)
839 | if self.eval_steps > 0:
840 | if epochs % self.eval_steps == 0:
841 | acc = self._evaluate_step()
842 | print(
843 | "Epoch: {}, Train Loss: {:.4f}, Val Acc: {:.4f}".format(
844 | i, loss, acc
845 | )
846 | )
847 |
848 | def embedding(self, nodes=None):
849 | """
850 | Returns the embeddings of the input nodes
851 | Parameters
852 | ----------
853 | nodes: Tensor, optional
854 | Input nodes, if set `None`, will return all the node embedding.
855 | Returns
856 | -------
857 | Tensor
858 | Node embedding.
859 | """
860 |
861 | return self.model(nodes)
862 |
863 |
864 | class GAT(torch.nn.Module):
865 | def __init__(
866 | self,
867 | in_channels,
868 | edge_dim,
869 | hidden_channels,
870 | out_channels,
871 | num_layers,
872 | heads,
873 | dropout,
874 | ):
875 | super().__init__()
876 | self.num_layers = num_layers
877 | self.dropout = dropout
878 |
879 | self.convs = torch.nn.ModuleList()
880 | self.convs.append(
881 | GATConv(
882 | in_channels,
883 | hidden_channels,
884 | heads,
885 | edge_dim=edge_dim,
886 | add_self_loops=False,
887 | )
888 | )
889 | for _ in range(num_layers - 2):
890 | self.convs.append(
891 | GATConv(
892 | heads * hidden_channels,
893 | hidden_channels,
894 | heads,
895 | edge_dim=edge_dim,
896 | add_self_loops=False,
897 | )
898 | )
899 | self.convs.append(
900 | GATConv(
901 | heads * hidden_channels,
902 | out_channels,
903 | heads,
904 | concat=False,
905 | edge_dim=edge_dim,
906 | add_self_loops=False,
907 | )
908 | )
909 |
910 | def reset_parameters(self):
911 | for conv in self.convs:
912 | conv.reset_parameters()
913 |
914 | def forward(self, x, edge_index):
915 | for conv in self.convs[:-1]:
916 | x = conv(x, edge_index)
917 | x = F.elu(x)
918 | x = F.dropout(x, p=self.dropout, training=self.training)
919 | x = self.convs[-1](x, edge_index)
920 | return x
921 |
922 |
923 | class GraphTransformer(torch.nn.Module):
924 | def __init__(
925 | self, in_channels, edge_dim, hidden_channels, out_channels, num_layers, dropout
926 | ):
927 | super().__init__()
928 | self.num_layers = num_layers
929 | self.dropout = dropout
930 |
931 | self.convs = torch.nn.ModuleList()
932 | self.convs.append(
933 | TransformerConv(in_channels, hidden_channels, edge_dim=edge_dim)
934 | )
935 | for _ in range(num_layers - 2):
936 | self.convs.append(
937 | TransformerConv(hidden_channels, hidden_channels, edge_dim=edge_dim)
938 | )
939 | self.convs.append(
940 | TransformerConv(hidden_channels, out_channels, edge_dim=edge_dim)
941 | )
942 |
943 | def reset_parameters(self):
944 | for conv in self.convs:
945 | conv.reset_parameters()
946 |
947 | def forward(self, x, edge_index):
948 | for conv in self.convs[:-1]:
949 | x = conv(x, edge_index)
950 | x = F.relu(x)
951 | x = F.dropout(x, p=self.dropout, training=self.training)
952 | x = self.convs[-1](x, edge_index)
953 | return x
954 |
955 |
956 | class GINE(torch.nn.Module):
957 | def __init__(
958 | self, in_channels, edge_dim, hidden_channels, out_channels, num_layers, dropout
959 | ):
960 | super().__init__()
961 | self.num_layers = num_layers
962 | self.dropout = dropout
963 |
964 | self.convs = torch.nn.ModuleList()
965 | self.convs.append(
966 | GINEConv(Linear(in_channels, hidden_channels), edge_dim=edge_dim)
967 | )
968 | for _ in range(num_layers - 2):
969 | self.convs.append(
970 | GINEConv(Linear(hidden_channels, hidden_channels), edge_dim=edge_dim)
971 | )
972 | self.convs.append(
973 | GINEConv(Linear(hidden_channels, out_channels), edge_dim=edge_dim)
974 | )
975 |
976 | def reset_parameters(self):
977 | for conv in self.convs:
978 | conv.reset_parameters()
979 |
980 | def forward(self, x, edge_index):
981 | for conv in self.convs[:-1]:
982 | x = conv(x, edge_index)
983 | x = F.relu(x)
984 | x = F.dropout(x, p=self.dropout, training=self.training)
985 | x = self.convs[-1](x, edge_index)
986 | return x
987 |
988 |
989 | class EdgeConvConv(MessagePassing):
990 | def __init__(
991 | self,
992 | nn: Callable,
993 | eps: float = 0.0,
994 | train_eps: bool = False,
995 | edge_dim: Optional[int] = None,
996 | **kwargs,
997 | ):
998 | kwargs.setdefault("aggr", "add")
999 | super().__init__(**kwargs)
1000 | self.nn = nn
1001 | self.initial_eps = eps
1002 | if train_eps:
1003 | self.eps = torch.nn.Parameter(torch.Tensor([eps]))
1004 | else:
1005 | self.register_buffer("eps", torch.Tensor([eps]))
1006 | if edge_dim is not None:
1007 | if hasattr(self.nn, "in_features"):
1008 | in_channels = self.nn.in_features
1009 | else:
1010 | in_channels = self.nn.in_channels
1011 | self.lin = Linear(edge_dim, in_channels)
1012 | else:
1013 | self.lin = None
1014 |
1015 | def forward(
1016 | self,
1017 | x: Union[Tensor, OptPairTensor],
1018 | edge_index: Adj,
1019 | edge_attr: OptTensor = None,
1020 | size: Size = None,
1021 | ) -> Tensor:
1022 | """"""
1023 | if isinstance(x, Tensor):
1024 | x: OptPairTensor = (x, x)
1025 |
1026 | # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
1027 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
1028 |
1029 | return out
1030 |
1031 | def message(self, x_i: Tensor, x_j: Tensor, edge_attr: Tensor) -> Tensor:
1032 |
1033 | temp = torch.cat([x_i, x_j, edge_attr], dim=1)
1034 |
1035 | return self.nn(temp)
1036 |
1037 | def __repr__(self) -> str:
1038 | return f"{self.__class__.__name__}(nn={self.nn})"
1039 |
1040 |
1041 | class EdgeConv(torch.nn.Module):
1042 | def __init__(
1043 | self, in_channels, edge_dim, hidden_channels, out_channels, num_layers, dropout
1044 | ):
1045 | super().__init__()
1046 | self.num_layers = num_layers
1047 | self.dropout = dropout
1048 |
1049 | self.convs = torch.nn.ModuleList()
1050 | self.convs.append(
1051 | EdgeConvConv(
1052 | Linear(2 * in_channels + edge_dim, hidden_channels), edge_dim=edge_dim
1053 | )
1054 | )
1055 | for _ in range(num_layers - 2):
1056 | self.convs.append(
1057 | EdgeConvConv(
1058 | Linear(2 * hidden_channels + edge_dim, hidden_channels),
1059 | edge_dim=edge_dim,
1060 | )
1061 | )
1062 | self.convs.append(
1063 | EdgeConvConv(
1064 | Linear(2 * hidden_channels + edge_dim, out_channels), edge_dim=edge_dim
1065 | )
1066 | )
1067 |
1068 | def reset_parameters(self):
1069 | for conv in self.convs:
1070 | conv.reset_parameters()
1071 |
1072 | def forward(self, x, edge_index):
1073 | for conv in self.convs[:-1]:
1074 | x = conv(x, edge_index)
1075 | x = F.relu(x)
1076 | x = F.dropout(x, p=self.dropout, training=self.training)
1077 | x = self.convs[-1](x, edge_index)
1078 | return x
1079 |
1080 |
1081 | class GeneralGNN(torch.nn.Module):
1082 | def __init__(
1083 | self, in_channels, edge_dim, hidden_channels, out_channels, num_layers, dropout
1084 | ):
1085 | super().__init__()
1086 | self.num_layers = num_layers
1087 | self.dropout = dropout
1088 |
1089 | self.convs = torch.nn.ModuleList()
1090 | self.convs.append(
1091 | GeneralConv(in_channels, hidden_channels, in_edge_channels=edge_dim)
1092 | )
1093 | for _ in range(num_layers - 2):
1094 | self.convs.append(
1095 | GeneralConv(hidden_channels, hidden_channels, in_edge_channels=edge_dim)
1096 | )
1097 | self.convs.append(
1098 | GeneralConv(hidden_channels, out_channels, in_edge_channels=edge_dim)
1099 | )
1100 |
1101 | def reset_parameters(self):
1102 | for conv in self.convs:
1103 | conv.reset_parameters()
1104 |
1105 | def forward(self, x, edge_index):
1106 | for conv in self.convs[:-1]:
1107 | x = conv(x, edge_index)
1108 | x = F.relu(x)
1109 | x = F.dropout(x, p=self.dropout, training=self.training)
1110 | x = self.convs[-1](x, edge_index)
1111 | return x
1112 |
--------------------------------------------------------------------------------