├── .gitignore
├── LICENSE
├── README.md
├── checkpoints
└── readme.md
├── configs
├── dblp.json
└── pubmed.json
├── data
└── readme.md
├── data_processor.py
├── layers
├── GCN
│ ├── CL_layers.py
│ ├── DTGCN_layers.py
│ ├── RGCN.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── CL_layers.cpython-39.pyc
│ │ ├── DTGCN_layers.cpython-39.pyc
│ │ ├── RGCN.cpython-39.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ └── custom_gcn.cpython-39.pyc
│ └── custom_gcn.py
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ └── common_layers.cpython-39.pyc
└── common_layers.py
├── main.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ └── base_model.cpython-39.pyc
└── base_model.py
├── our_models
├── CL.py
├── CTSGCN.py
├── DCTSGCN.py
├── EncoderCL.py
├── FCL.py
├── HCL.py
├── TSGCN.py
├── __init__.py
└── __pycache__
│ ├── CL.cpython-39.pyc
│ ├── CTSGCN.cpython-39.pyc
│ ├── DCTSGCN.cpython-39.pyc
│ ├── FCL.cpython-39.pyc
│ ├── HCL.cpython-39.pyc
│ ├── TSGCN.cpython-39.pyc
│ └── __init__.cpython-39.pyc
├── results
└── readme.md
└── utilis
├── __init__.py
├── __pycache__
├── __init__.cpython-39.pyc
├── log_bar.cpython-39.pyc
└── scripts.cpython-39.pyc
├── dblp.py
├── log_bar.py
├── pubmed.py
└── scripts.py
/.gitignore:
--------------------------------------------------------------------------------
1 | <<<<<<< HEAD
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 | =======
132 | /results/*
133 | /data/*
134 | /checkpoints/*
135 | >>>>>>> 80cb730 (update)
136 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # H2CGL
2 | H2CGL: Modeling Dynamics of Citation Network for Impact Prediction
3 | https://arxiv.org/abs/2305.01572
4 |
5 | ## Requirement
6 |
7 | * pytorch >= 1.10.2
8 | * numpy >= 1.13.3
9 | * sklearn
10 | * python 3.9
11 | * transformers
12 | * dgl == 0.91
13 |
14 | ## Usage
15 | Original dataset:
16 | * PMC (pubmed): https://www.ncbi.nlm.nih.gov/pmc/tools/openftlist/
17 | * DBLPv13 (dblp): https://www.aminer.org/citation
18 |
19 | Download the dealt dataset from https://pan.baidu.com/s/1BkcIUiyMbdn9FR3hLdws1w&pwd=kr7h
20 |
21 | You can get the H2CGL model here https://drive.google.com/file/d/1nJEgQdTenTssbZfjywe6poiOQKYvJpBo/view?usp=sharing
22 |
23 | ### Training
24 | ```sh
25 | # H2CGL
26 | python main.py --phase H2CGL --data_source h_pubmed/h_dblp --cl_type label_aug_hard_negative --aug_type cg --encoder_type 'CGIN+RGAT' --n_layers 4 --hn 2 --hn_method co_cite
27 | ```
28 |
29 | ### Testing
30 |
31 | ```sh
32 | # H2CGL
33 | python main.py --phase test_results --model H2CGL --data_source h_pubmed/h_dblp --cl_type label_aug_hard_negative --aug_type cg --encoder_type 'CGIN+RGAT' --n_layers 4 --hn 2 --hn_method co_cite
34 | ```
35 |
--------------------------------------------------------------------------------
/checkpoints/readme.md:
--------------------------------------------------------------------------------
1 | Here place the input graphs and models.
--------------------------------------------------------------------------------
/configs/dblp.json:
--------------------------------------------------------------------------------
1 | {
2 | "default": {
3 | "tokenizer_type": "basic",
4 | "tokenizer_path": null,
5 | "embed_dim": 300,
6 | "batch_size": 32,
7 | "num_classes": 3,
8 | "max_len": 256,
9 | "epochs": 20,
10 | "rate":0.8,
11 | "fixed_num": 120000,
12 | "optimizer": "ADAM",
13 | "criterion": "MSE",
14 | "lr": 1e-3,
15 | "weight_decay": 1e-4,
16 | "scheduler": false,
17 | "best_metric": "male",
18 | "dropout": 0.3,
19 | "saved_data": false,
20 | "fixed_data": false,
21 | "use_graph": false,
22 | "time": [2009, 2010, 2011],
23 | "cut_time": 2013,
24 | "type": "normal",
25 | "split_abstract": true,
26 | "save_model": true,
27 | "num_workers": 0,
28 | "batch_graph": false,
29 | "test_interval": 0,
30 | "graph_name": "graph_sample_feature_vector",
31 | "time_length": 5,
32 | "cut_threshold": [10, 100],
33 | "cl_weight": 0.5,
34 | "aux_weight": 0.5
35 | },
36 | "TSGCN": {
37 | "embed_dim": 300,
38 | "hidden_dim": 256,
39 | "out_dim": 256,
40 | "n_layers": 4,
41 | "time_encoder": "self-attention",
42 | "time_encoder_layers": 4,
43 | "batch_size": 32,
44 | "use_graph": "repeatedly",
45 | "graph_name": "graph_sample_feature_vector",
46 | "etypes": [
47 | "cites", "is cited by", "publishes", "is published by",
48 | "writes", "is writen by"
49 | ],
50 | "ntypes": ["author", "paper", "journal"],
51 | "time_length": 5,
52 | "hop": 1,
53 | "num_workers": 0,
54 | "lr": 1e-4
55 | },
56 | "DCTSGCN": {
57 | "lr": 1e-4,
58 | "num_classes": 3,
59 | "embed_dim": 300,
60 | "hidden_dim": 256,
61 | "out_dim": 256,
62 | "n_layers": 3,
63 | "batch_size": 32,
64 | "use_graph": "repeatedly",
65 | "graph_name": "graph_sample_feature_vector",
66 | "linear_before_gcn": false,
67 | "time_length": 5,
68 | "hop": 1,
69 | "num_workers": 0,
70 | "etypes": [
71 | "cites", "is cited by", "writes", "is writen by", "publishes", "is published by",
72 | "shows", "is shown in", "is in", "has", "t_cites", "is t_cited by"
73 | ],
74 | "ntypes": ["paper", "author", "journal", "time", "snapshot"],
75 | "pred_type": "snapshot",
76 | "edge_sequences": "basic"
77 | },
78 | "H2CGL": {
79 | "lr": 1e-4,
80 | "num_classes": 3,
81 | "embed_dim": 300,
82 | "hidden_dim": 256,
83 | "out_dim": 256,
84 | "n_layers": 3,
85 | "batch_size": 32,
86 | "use_graph": "repeatedly",
87 | "graph_name": "graph_sample_feature_vector",
88 | "linear_before_gcn": false,
89 | "time_length": 5,
90 | "hop": 1,
91 | "num_workers": 0,
92 | "etypes": [
93 | "cites", "is cited by", "writes", "is writen by", "publishes", "is published by",
94 | "shows", "is shown in", "is in", "has", "t_cites", "is t_cited by"
95 | ],
96 | "ntypes": ["paper", "author", "journal", "time", "snapshot"],
97 | "pred_type": "snapshot",
98 | "edge_sequences": "basic",
99 | "tau": 0.4
100 | }
101 | }
--------------------------------------------------------------------------------
/configs/pubmed.json:
--------------------------------------------------------------------------------
1 | {
2 | "default": {
3 | "tokenizer_type": "basic",
4 | "tokenizer_path": null,
5 | "embed_dim": 300,
6 | "batch_size": 32,
7 | "num_classes": 3,
8 | "max_len": 256,
9 | "epochs": 30,
10 | "rate":0.8,
11 | "fixed_num": 150000,
12 | "optimizer": "ADAM",
13 | "criterion": "MSE",
14 | "lr": 1e-3,
15 | "weight_decay": 1e-5,
16 | "scheduler": false,
17 | "best_metric": "male",
18 | "dropout": 0.3,
19 | "saved_data": false,
20 | "fixed_data": false,
21 | "use_graph": false,
22 | "time": [2011, 2012, 2013],
23 | "type": "normal",
24 | "split_abstract": false,
25 | "save_model": true,
26 | "num_workers": 0,
27 | "batch_graph": false,
28 | "test_interval": 0,
29 | "graph_name": "graph_sample_feature_vector",
30 | "time_length": 5,
31 | "cut_threshold": [10, 100],
32 | "cut_time": 2015,
33 | "cl_weight": 0.5,
34 | "aux_weight": 0.5
35 | },
36 | "TSGCN": {
37 | "embed_dim": 300,
38 | "hidden_dim": 256,
39 | "out_dim": 256,
40 | "n_layers": 4,
41 | "time_encoder": "self-attention",
42 | "time_encoder_layers": 4,
43 | "batch_size": 32,
44 | "use_graph": "repeatedly",
45 | "graph_name": "graph_sample_feature_vector",
46 | "etypes": [
47 | "cites", "is cited by", "publishes", "is published by",
48 | "writes", "is writen by"
49 | ],
50 | "ntypes": ["author", "paper", "journal"],
51 | "time_length": 5,
52 | "hop": 1,
53 | "num_workers": 0,
54 | "lr": 1e-4
55 | },
56 | "DCTSGCN": {
57 | "lr": 1e-4,
58 | "num_classes": 3,
59 | "embed_dim": 300,
60 | "hidden_dim": 256,
61 | "out_dim": 256,
62 | "n_layers": 3,
63 | "batch_size": 32,
64 | "use_graph": "repeatedly",
65 | "graph_name": "graph_sample_feature_vector",
66 | "time_length": 5,
67 | "hop": 1,
68 | "num_workers": 0,
69 | "linear_before_gcn": false,
70 | "etypes": [
71 | "cites", "is cited by", "publishes", "is published by",
72 | "shows", "is shown in", "is in", "has", "t_cites", "is t_cited by"
73 | ],
74 | "ntypes": ["paper", "author", "journal", "time", "snapshot"],
75 | "pred_type": "snapshot",
76 | "edge_sequences": "basic"
77 | },
78 | "H2CGL": {
79 | "lr": 1e-4,
80 | "embed_dim": 300,
81 | "hidden_dim": 256,
82 | "out_dim": 256,
83 | "n_layers": 3,
84 | "batch_size": 32,
85 | "use_graph": "repeatedly",
86 | "graph_name": "graph_sample_feature_vector",
87 | "time_length": 5,
88 | "hop": 1,
89 | "num_workers": 0,
90 | "linear_before_gcn": false,
91 | "etypes": [
92 | "cites", "is cited by", "publishes", "is published by",
93 | "shows", "is shown in", "is in", "has", "t_cites", "is t_cited by"
94 | ],
95 | "ntypes": ["paper", "author", "journal", "time", "snapshot"],
96 | "pred_type": "snapshot",
97 | "edge_sequences": "basic",
98 | "tau": 0.4
99 | }
100 | }
--------------------------------------------------------------------------------
/data/readme.md:
--------------------------------------------------------------------------------
1 | Here place the data.
--------------------------------------------------------------------------------
/layers/GCN/CL_layers.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import dgl.nn.pytorch as dglnn
8 | import dgl
9 | from tqdm import tqdm
10 |
11 | from layers.common_layers import MLP
12 | from utilis.scripts import add_new_elements
13 |
14 |
15 | class CommonCLLayer(nn.Module):
16 | def __init__(self, in_feats, out_feats, hidden_feats=None, tau=0.4):
17 | super(CommonCLLayer, self).__init__()
18 | if hidden_feats is None:
19 | hidden_feats = in_feats
20 | # print(in_feats, out_feats, hidden_feats)
21 | self.proj_head = MLP(in_feats, out_feats, dim_inner=hidden_feats, activation=nn.ELU())
22 | self.tau = tau
23 |
24 | def forward(self, z1, z2):
25 | h1, h2 = F.normalize(self.proj_head(z1)), F.normalize(self.proj_head(z2))
26 | s12 = cos_sim(h1, h2, self.tau, norm=False)
27 | if s12.shape[0] == 0:
28 | s12 = torch.tensor([1]).reshape(1, 1).to(h1.device)
29 | s21 = s12
30 |
31 | # compute InfoNCE
32 | loss12 = -torch.log(s12.diag()) + torch.log(s12.sum(1))
33 | loss21 = -torch.log(s21.diag()) + torch.log(s21.sum(1))
34 | L_node = (loss12 + loss21) / 2
35 |
36 |
37 | return L_node.mean()
38 |
39 |
40 | class ProjectionHead(nn.Module):
41 | def __init__(self, in_feats, out_feats, hidden_feats=None):
42 | super(ProjectionHead, self).__init__()
43 | if hidden_feats is None:
44 | hidden_feats = in_feats
45 | self.proj_head = MLP(in_feats, out_feats, dim_inner=hidden_feats, activation=nn.ELU())
46 |
47 | def forward(self, z):
48 | return F.normalize(self.proj_head(z))
49 |
50 |
51 | def simple_cl(z1, z2, tau, valid_lens=None):
52 | h1, h2 = z1, z2
53 | s12 = cos_sim(h1, h2, tau, norm=False)
54 | if s12.shape[0] == 0:
55 | s12 = torch.tensor([1]).reshape(1, 1).to(h1.device)
56 | s21 = s12
57 |
58 | # compute InfoNCE
59 | if s12.dim() > 2:
60 | # print(s12.shape)
61 | num_node = s12.shape[-1]
62 | # loss12 = -torch.log(torch.flatten(torch.diagonal(s12, dim2=-2, dim1=-1))) + \
63 | # torch.log(s12.reshape(-1, num_node).sum(1))
64 | # loss21 = -torch.log(torch.flatten(torch.diagonal(s21, dim2=-2, dim1=-1))) + \
65 | # torch.log(s21.reshape(-1, num_node).sum(1))
66 | # print(loss21.shape)
67 | loss12 = -torch.log(torch.diagonal(s12, dim2=-2, dim1=-1)) + \
68 | torch.log(s12.sum(-1))
69 | loss21 = -torch.log(torch.diagonal(s21, dim2=-2, dim1=-1)) + \
70 | torch.log(s21.sum(-1))
71 | else:
72 | loss12 = -torch.log(s12.diag()) + torch.log(s12.sum(1))
73 | loss21 = -torch.log(s21.diag()) + torch.log(s21.sum(1))
74 | L_node = (loss12 + loss21) / 2
75 |
76 | if type(valid_lens) == torch.Tensor:
77 | # print(valid_lens)
78 | # print(L_node.sum(dim=-1).shape)
79 | L_node = L_node.sum(dim=-1) / valid_lens
80 | # print(L_node.mean())
81 | return L_node.mean()
82 |
83 | def combined_sparse(matrices):
84 | indices = []
85 | values = torch.cat([torch.flatten(matrix) for matrix in matrices])
86 | # print(values)
87 | x, y = 0, 0
88 | for matrix in matrices:
89 | cur_x, cur_y = matrix.shape
90 | for i in range(x, x + cur_x):
91 | for j in range(y, y + cur_y):
92 | indices.append([i, j])
93 | # values = torch.cat((values, torch.flatten(matrix)))
94 | x += cur_x
95 | y += cur_y
96 | # print(len(indices))
97 | # print(indices)
98 | # print(values.shape)
99 | return torch.sparse_coo_tensor(indices=torch.tensor(indices, device=matrices[0].device).T,
100 | values=values, size=(x, y), requires_grad=True)
101 |
102 |
103 | def cos_sim(x1, x2, tau: float, norm: bool = False):
104 | if x1.is_sparse:
105 | if norm:
106 | return torch.softmax(torch.sparse.mm(x1, x2.transpose(0, 1)).to_dense() / tau, dim=1)
107 | else:
108 | return torch.exp(torch.sparse.mm(x1, x2.transpose(0, 1)).to_dense() / tau)
109 | elif x1.dim() > 2:
110 | result = x1 @ x2.transpose(-2, -1) / tau
111 | if norm:
112 | return torch.softmax(result, dim=-1)
113 | else:
114 | return torch.exp(result)
115 | else:
116 | if norm:
117 | return torch.softmax(x1 @ x2.T / tau, dim=1)
118 | else:
119 | return torch.exp(x1 @ x2.T / tau)
120 |
121 |
122 | def RBF_sim(x1, x2, tau: float, norm: bool = False):
123 | xx1, xx2 = torch.stack(len(x2) * [x1]), torch.stack(len(x1) * [x2])
124 | sub = xx1.transpose(0, 1) - xx2
125 | R = (sub * sub).sum(dim=2)
126 | if norm:
127 | return F.softmax(-R / (tau * tau), dim=1)
128 | else:
129 | return torch.exp(-R / (tau * tau))
130 |
131 |
132 | def augmenting_graphs(graphs, aug_configs, aug_methods=(('attr_drop',), ('edge_drop',))):
133 | '''
134 | :param aug_configs:
135 | :param graphs:
136 | :param aug_methods:
137 | :return:
138 | '''
139 | g1_list = []
140 | g2_list = []
141 | for graph in graphs:
142 | # g1, g2 = copy.deepcopy(graph), copy.deepcopy(graph)
143 | g1, g2 = add_new_elements(graph), add_new_elements(graph)
144 | for aug_method in aug_methods[0]:
145 | cur_method = eval(aug_method)
146 | g1 = cur_method(g1, aug_configs[0][aug_method]['prob'], aug_configs[0][aug_method].get('types', None))
147 | for aug_method in aug_methods[1]:
148 | cur_method = eval(aug_method)
149 | g2 = cur_method(g2, aug_configs[1][aug_method]['prob'], aug_configs[1][aug_method].get('types', None))
150 | g1_list.append(g1)
151 | g2_list.append(g2)
152 |
153 | del graphs
154 | return g1_list, g2_list
155 |
156 |
157 | def augmenting_graphs_ctsgcn(graphs, aug_configs, aug_methods=(('attr_drop',), ('edge_drop',)), time_length=10):
158 | '''
159 | :param time_length:
160 | :param aug_configs:
161 | :param graphs:
162 | :param aug_methods:
163 | :return:
164 | '''
165 | g1_list = []
166 | g2_list = []
167 | for graph in graphs:
168 | # g1, g2 = copy.deepcopy(graph), copy.deepcopy(graph)
169 | g1, g2 = add_new_elements(graph), add_new_elements(graph)
170 | for t in range(time_length):
171 | # g1
172 | for aug_method in aug_methods[0]:
173 | selected_nodes = {}
174 | for ntype in graph.ntypes:
175 | selected_nodes[ntype] = g1.nodes[ntype].data['snapshot_idx'] == t
176 | if torch.sum(selected_nodes['paper']).item() > 0:
177 | cur_method = eval(aug_method)
178 | g1 = cur_method(g1, aug_configs[0][aug_method]['prob'], aug_configs[0][aug_method].get('types', None),
179 | selected_nodes)
180 | # g2
181 | for aug_method in aug_methods[1]:
182 | selected_nodes = {}
183 | for ntype in graph.ntypes:
184 | selected_nodes[ntype] = g2.nodes[ntype].data['snapshot_idx'] == t
185 | if torch.sum(selected_nodes['paper']).item() > 0:
186 | cur_method = eval(aug_method)
187 | g2 = cur_method(g2, aug_configs[1][aug_method]['prob'], aug_configs[1][aug_method].get('types', None),
188 | selected_nodes)
189 | g1_list.append(g1)
190 | g2_list.append(g2)
191 |
192 | del graphs
193 | return g1_list, g2_list
194 |
195 |
196 | def augmenting_graphs_single(graphs, aug_config, aug_methods=(('attr_drop',), ('edge_drop',))):
197 | '''
198 | :param aug_configs:
199 | :param graphs:
200 | :param aug_methods:
201 | :return:
202 | '''
203 | graph_list = []
204 | for graph in graphs:
205 | # dealt_graph = copy.deepcopy(graph)
206 | dealt_graph = add_new_elements(graph)
207 | for aug_method in aug_methods:
208 | cur_method = eval(aug_method)
209 | dealt_graph = cur_method(dealt_graph, aug_config[aug_method]['prob'],
210 | aug_config[aug_method].get('types', None))
211 | graph_list.append(dealt_graph)
212 | return graph_list
213 |
214 |
215 | def augmenting_graphs_single_ctsgcn(graphs, aug_config, aug_methods=(('attr_drop',), ('edge_drop',)), time_length=10):
216 | '''
217 | :param aug_configs:
218 | :param graphs:
219 | :param aug_methods:
220 | :return:
221 | '''
222 | graph_list = []
223 | for graph in graphs:
224 | # dealt_graph = copy.deepcopy(graph)
225 | dealt_graph = add_new_elements(graph)
226 | for aug_method in aug_methods:
227 | cur_method = eval(aug_method)
228 | for t in range(time_length):
229 | selected_nodes = {}
230 | for ntype in graph.ntypes:
231 | selected_nodes[ntype] = graph.nodes[ntype].data['snapshot_idx'] == t
232 | if torch.sum(selected_nodes['paper']).item() == 0:
233 | continue
234 | dealt_graph = cur_method(dealt_graph, aug_config[aug_method]['prob'],
235 | aug_config[aug_method].get('types', None), selected_nodes)
236 | graph_list.append(dealt_graph)
237 | return graph_list
238 |
239 |
240 | def attr_mask(graph, prob, ntypes=None, selected_nodes=None):
241 | if ntypes is None:
242 | ntypes = graph.ntypes
243 | if selected_nodes:
244 | for ntype in ntypes:
245 | emb_dim = graph.nodes[ntype].data['h'].shape[-1]
246 | p = torch.bernoulli(torch.tensor([1 - prob] * emb_dim)).to(torch.bool)
247 | graph.nodes[ntype].data['h'][selected_nodes[ntype]][:, p] = 0.
248 | else:
249 | for ntype in ntypes:
250 | emb_dim = graph.nodes[ntype].data['h'].shape[-1]
251 | p = torch.bernoulli(torch.tensor([1 - prob] * emb_dim)).to(torch.bool)
252 | graph.nodes[ntype].data['h'][:, p] = 0.
253 | return graph
254 |
255 |
256 | def attr_drop(graph, prob, ntypes=None, selected_nodes=None):
257 | if ntypes is None:
258 | ntypes = graph.ntypes
259 | if selected_nodes:
260 | for ntype in ntypes:
261 | node_count = graph.nodes[ntype].data['h'][selected_nodes[ntype]].shape[0]
262 | emb_dim = graph.nodes[ntype].data['h'].shape[-1]
263 | p = np.random.rand(node_count)
264 | drop_nodes = torch.where(selected_nodes[ntype])[0][np.where(p <= prob)[0]]
265 | graph.nodes[ntype].data['h'][drop_nodes, :] = torch.zeros(len(drop_nodes), emb_dim)
266 | else:
267 | for ntype in ntypes:
268 | node_count = graph.nodes[ntype].data['h'].shape[0]
269 | emb_dim = graph.nodes[ntype].data['h'].shape[-1]
270 | p = np.random.rand(node_count)
271 | drop_nodes = np.where(p <= prob)[0]
272 | graph.nodes[ntype].data['h'][drop_nodes, :] = torch.zeros(len(drop_nodes), emb_dim)
273 | return graph
274 |
275 |
276 | def edge_drop(graph, prob, etypes=None, selected_nodes=None):
277 | # graph = copy.deepcopy(graph)
278 | if etypes is None:
279 | etypes = graph.canonical_etypes
280 | if selected_nodes:
281 | # print('wip')
282 | for etype in etypes:
283 | dst_type = etype[-1]
284 | dst_nodes = set(graph.nodes(dst_type)[selected_nodes[dst_type]].numpy().tolist())
285 | # print(graph.edges(form='all', etype=etype, order='eid'))
286 | src, dst, edges = graph.edges(form='all', etype=etype, order='eid')
287 | dst_index = []
288 | # print(edges)
289 | for i in range(len(dst)):
290 | if dst[i].item() in dst_nodes:
291 | dst_index.append(i)
292 | edges = edges[dst_index]
293 | # print(dst_nodes, dst_index)
294 |
295 | p = np.random.rand(edges.shape[0])
296 | # print(edges[np.where(p < 0.2)])
297 | drop_edges = edges[np.where(p <= prob)]
298 | graph.remove_edges(drop_edges, etype=etype)
299 | else:
300 | for etype in etypes:
301 | edges = graph.edges(form='eid', etype=etype)
302 | p = np.random.rand(edges.shape[0])
303 | # print(edges[np.where(p < 0.2)])
304 | drop_edges = edges[np.where(p <= prob)]
305 | graph.remove_edges(drop_edges, etype=etype)
306 | # print(graph.edges(form='eid', etype='cites'))
307 | return graph
308 |
309 | def node_drop(graph, prob, ntypes=None, selected_nodes=None):
310 | # drop nodes according to the citation, so just drop the paper, not including other nodes and the target node
311 | selected_index = (1 - graph.nodes['paper'].data['is_target']).to(torch.bool)
312 | if selected_nodes:
313 | selected_nodes = selected_nodes['paper']
314 | selected_index = selected_index & selected_nodes
315 |
316 | count = torch.sum(selected_index).item()
317 | # print(count)
318 | if count > 0:
319 | drop_rate = count * prob
320 | cur_p = torch.softmax(-graph.nodes['paper'].data['citations'][selected_index].float(), dim=-1).numpy() * drop_rate
321 | drop_indicator = np.random.rand(count) < cur_p
322 | drop_nodes = torch.where(selected_index)[0][drop_indicator]
323 | graph.remove_nodes(drop_nodes, ntype='paper')
324 | return graph
325 |
326 |
327 | if __name__ == '__main__':
328 | cl = CommonCLLayer(128, 64)
329 | z1 = torch.randn(10, 128)
330 | z2 = torch.randn(10, 128)
331 | print(cl(z1, z2))
332 |
333 | print(eval('attr_drop'))
334 |
--------------------------------------------------------------------------------
/layers/GCN/DTGCN_layers.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | from collections import defaultdict
4 |
5 | import dgl
6 | import torch
7 | import torch.nn as nn
8 | import dgl.nn.pytorch as dglnn
9 | from dgl.nn.pytorch.hetero import get_aggregate_fn
10 |
11 | from layers.GCN.RGCN import StochasticKLayerRGCN, CustomGCN
12 |
13 |
14 | class DCTSGCNLayer(nn.Module):
15 | def __init__(self, in_feat, hidden_feat, out_feat=None, ntypes=None, edge_sequences=None, k=3, time_length=10):
16 | super().__init__()
17 | # T * k layers
18 | self.k = k
19 | self.time_length = time_length
20 | self.ntypes = ntypes
21 | if out_feat is None:
22 | out_feat = hidden_feat
23 | if edge_sequences:
24 | forward_sequences, backward_sequences = edge_sequences
25 | self.concat_fc = nn.ModuleList(
26 | [dglnn.HeteroLinear({ntype: hidden_feat * 2 for ntype in ntypes}, hidden_feat)])
27 | # print('fc', self.concat_fc[0].linears['author'].weight.device)
28 | self.forward_layers = nn.ModuleList([DirectedCTSGCNLayer(ntypes, time_length, forward_sequences,
29 | in_feat, hidden_feat)])
30 | self.backward_layers = nn.ModuleList([DirectedCTSGCNLayer(ntypes, time_length, backward_sequences,
31 | in_feat, hidden_feat)])
32 | for i in range(k - 2):
33 | self.forward_layers.append(DirectedCTSGCNLayer(ntypes, time_length, forward_sequences,
34 | hidden_feat, hidden_feat))
35 | self.backward_layers.append(DirectedCTSGCNLayer(ntypes, time_length, backward_sequences,
36 | hidden_feat, hidden_feat))
37 | self.concat_fc.append(dglnn.HeteroLinear({ntype: hidden_feat * 2 for ntype in ntypes}, hidden_feat))
38 | self.forward_layers.append(DirectedCTSGCNLayer(ntypes, time_length, forward_sequences,
39 | hidden_feat, out_feat))
40 | self.backward_layers.append(DirectedCTSGCNLayer(ntypes, time_length, backward_sequences,
41 | hidden_feat, out_feat))
42 | self.concat_fc.append(dglnn.HeteroLinear({ntype: out_feat * 2 for ntype in ntypes}, out_feat))
43 |
44 | # self.concat_fc = self.concat_fc.to(device)
45 | # print('fc', self.concat_fc[0].linears['author'].weight.device)
46 | # self.forward_layers = self.forward_layers.to(device)
47 | # self.backward_layers = self.backward_layers.to(device)
48 |
49 | def forward(self, snapshots, input_feats):
50 | hidden_embs = input_feats
51 | # hidden_embs = snapshots.srcdata['h']
52 | for i in range(self.k):
53 | # print('-'*30, i, '-'*30)
54 | # print([emb.shape for emb in hidden_embs[0].values()])
55 | forward_embs = self.forward_layers[i](snapshots, hidden_embs)
56 | # print('=' * 61)
57 | # print([emb.shape for emb in hidden_embs[0].values()])
58 | backward_embs = self.backward_layers[i](snapshots, hidden_embs)
59 | # out_embs = [
60 | # {ntype: torch.cat((forward_embs[t][ntype], backward_embs[t][ntype]), dim=-1) for ntype in self.ntypes}
61 | # for t in range(self.time_length)]
62 | out_embs = {ntype: torch.cat((forward_embs[ntype], backward_embs[ntype]), dim=-1) for ntype in self.ntypes}
63 | # print('here', out_embs[0])
64 | # hidden_embs = [self.concat_fc[i](out_embs[t]) for t in range(self.time_length)]
65 | hidden_embs = self.concat_fc[i](out_embs)
66 |
67 | return hidden_embs
68 |
69 | def to(self, *args, **kwargs):
70 | self = super().to(*args, **kwargs)
71 | print('whole graph_encoder to device')
72 | if self.concat_fc:
73 | self.concat_fc = self.concat_fc.to(*args, **kwargs)
74 | for i in range(len(self.forward_layers)):
75 | self.forward_layers[i] = self.forward_layers[i].to(*args, **kwargs)
76 | # print(self.forward_layers)
77 | for i in range(len(self.backward_layers)):
78 | self.backward_layers[i] = self.backward_layers[i].to(*args, **kwargs)
79 | return self
80 |
81 |
82 | class DirectedCTSGCNLayer(nn.Module):
83 | def __init__(self, ntypes, time_length, edge_sequences, in_feat, hidden_feat, out_feat=None, agg_method='sum',
84 | activation=torch.relu):
85 | super(DirectedCTSGCNLayer, self).__init__()
86 | if not out_feat:
87 | out_feat = hidden_feat
88 | self.time_length = time_length
89 | # meta, paper, to_snapshot, snapshot
90 | self.sequence = [edges_type for edges_type, _ in edge_sequences]
91 | self.edges_type = {edges_type: etypes for edges_type, etypes in edge_sequences}
92 | self.agg_method = get_aggregate_fn(agg_method)
93 | self.snapshot_weight = 'none'
94 | self.convs = nn.ModuleDict({})
95 | self.activation = activation
96 | self.skip_connect_fc = nn.ModuleDict({edges_type: nn.ModuleDict({}) for edges_type in self.edges_type})
97 |
98 | trans_nodes = {}
99 | input_dim = {ntype: in_feat for ntype in ntypes}
100 | # if self.sequence[0] != 'snapshot':
101 | # ntypes = [stype for stype, _, _ in edge_sequences[0][-1]]
102 | # trans_nodes = {ntype: in_feat for ntype in ntypes}
103 | # for ntype in ntypes:
104 | # if ntype not in trans_nodes:
105 | # trans_nodes[ntype] = out_feat
106 | for i in range(len(edge_sequences)):
107 | edges_type, etypes = edge_sequences[i]
108 | if edges_type != 'intra_snapshots':
109 | # if i == 0:
110 | # self.convs[edges_type] = nn.ModuleList([dglnn.HeteroGraphConv(
111 | # {etype: dglnn.GraphConv(in_feat, hidden_feat, norm='right')
112 | # for etype in etypes}) for t in range(time_length)])
113 | # elif i == len(edge_sequences) - 1:
114 | # self.convs[edges_type] = nn.ModuleList([dglnn.HeteroGraphConv(
115 | # {etype: dglnn.GraphConv(hidden_feat, out_feat, norm='right')
116 | # for etype in etypes}) for t in range(time_length)])
117 | # else:
118 | # self.convs[edges_type] = nn.ModuleList([dglnn.HeteroGraphConv(
119 | # {etype: dglnn.GraphConv(hidden_feat, hidden_feat, norm='right')
120 | # for etype in etypes}) for t in range(time_length)])
121 | if i == 0:
122 | self.convs[edges_type] = nn.ModuleDict({str(etype):
123 | dglnn.GraphConv(in_feat, hidden_feat, norm='right',
124 | allow_zero_in_degree=True)
125 | for etype in etypes})
126 | cur_out = hidden_feat
127 | # print(list(self.convs[edges_type].values())[0][0].weight.device)
128 | elif i == len(edge_sequences) - 1:
129 | self.convs[edges_type] = nn.ModuleDict({str(etype):
130 | dglnn.GraphConv(hidden_feat, out_feat, norm='right',
131 | allow_zero_in_degree=True)
132 | for etype in etypes})
133 | cur_out = out_feat
134 | else:
135 | self.convs[edges_type] = nn.ModuleDict({
136 | str(etype):
137 | dglnn.GraphConv(hidden_feat, hidden_feat, norm='right', allow_zero_in_degree=True)
138 | for etype in etypes})
139 | cur_out = hidden_feat
140 | for etype in etypes:
141 | trans_nodes[etype[-1]] = cur_out
142 | self.skip_connect_fc[edges_type][etype[-1]] = nn.Linear(input_dim[etype[-1]], cur_out)
143 | for etype in etypes:
144 | input_dim[etype[-1]] = cur_out
145 | else:
146 | if self.snapshot_weight:
147 | snapshot_norm = 'none'
148 | else:
149 | snapshot_norm = 'right'
150 | if i == 0:
151 | self.convs[edges_type] = dglnn.GraphConv(in_feat, hidden_feat, norm=snapshot_norm)
152 | cur_out = hidden_feat
153 | else:
154 | self.convs[edges_type] = dglnn.GraphConv(hidden_feat, out_feat, norm=snapshot_norm)
155 | cur_out = out_feat
156 | trans_nodes['snapshot'] = cur_out
157 | input_dim['snapshot'] = cur_out
158 | # self.skip_connect_fc[edges_type]['snapshot'] = nn.Linear(cur_in, cur_out)
159 | for ntype in ntypes:
160 | if ntype not in trans_nodes:
161 | trans_nodes[ntype] = in_feat
162 | # self.trans_fc = None
163 | self.trans_fc = dglnn.HeteroLinear(trans_nodes, out_feat)
164 |
165 | # def to(self, *args, **kwargs):
166 | # self = super().to(*args, **kwargs)
167 | # # for edges_type in self.convs:
168 | # # print(edges_type + ' to device')
169 | # # self.convs[edges_type] = self.convs[edges_type].to(*args, **kwargs)
170 | # # self.convs = self.convs.to(*args, **kwargs)
171 | # return self
172 |
173 | def normal_forward(self, snapshots, input_embs, edges_type):
174 | # conv_embs = []
175 | outputs = defaultdict(list)
176 | all_embs = input_embs
177 | for etype in self.edges_type[edges_type]:
178 | # print('etype', etype)
179 | stype, _, dtype = etype
180 | subgraph = snapshots[tuple(etype)]
181 | # subgraph = dgl.add_self_loop(subgraph)
182 | # print(etype, subgraph.in_degrees())
183 | srcdata = all_embs[stype]
184 | dstdata = all_embs[dtype]
185 | outputs[dtype].append(self.convs[edges_type][str(etype)](subgraph, (srcdata, dstdata)))
186 | for dtype in outputs:
187 | if len(outputs[dtype]) > 0:
188 | all_embs[dtype] = self.skip_connect_fc[edges_type][dtype](all_embs[dtype]) \
189 | + self.agg_method(outputs[dtype], dtype)
190 | # conv_embs.append(all_embs)
191 |
192 | return all_embs
193 |
194 | def snapshot_forward(self, input_embs, snapshot_graphs):
195 | ndata = input_embs['snapshot']
196 | subgraph = snapshot_graphs[tuple(self.edges_type['intra_snapshots'])]
197 | if self.snapshot_weight:
198 | norm = dglnn.EdgeWeightNorm(norm=self.snapshot_weight)
199 | norm_edge_weight = norm(subgraph, subgraph.edata['w'])
200 | # print(ndata.shape)
201 | # print(self.convs['intra_snapshots'].weight.shape)
202 | # print(ndata)
203 | # print(self.edges_type)
204 | conv_emb = self.convs['intra_snapshots'](subgraph, (ndata, ndata), edge_weight=norm_edge_weight)
205 | else:
206 | conv_emb = self.convs['intra_snapshots'](subgraph, (ndata, ndata))
207 | # dstdata = torch.split(conv_emb, batch_size, dim=0) # T[B, d]
208 | # for t in range(self.time_length):
209 | # input_embs[t]['snapshot'] = dstdata[t]
210 | input_embs['snapshot'] = conv_emb
211 | return input_embs
212 |
213 | def forward(self, snapshots, input_embs):
214 | # print('-'*59)
215 | # hidden_embs = [snapshot.srcdata['h'] for snapshot in snapshots]
216 | # copy for dual-direction
217 | input_embs = input_embs.copy()
218 | hidden_embs = input_embs
219 | # print(input_embs[0]['snapshot'])
220 | for edges_type in self.sequence:
221 | if edges_type != 'intra_snapshots':
222 | hidden_embs = self.normal_forward(snapshots, hidden_embs, edges_type)
223 | else:
224 | hidden_embs = self.snapshot_forward(hidden_embs, snapshots)
225 | if self.activation:
226 | for ntype in hidden_embs.keys():
227 | hidden_embs[ntype] = self.activation(hidden_embs[ntype])
228 | # 共用一个MLP
229 | hidden_embs = self.trans_fc(hidden_embs)
230 | # T[nodes.data]
231 | return hidden_embs
232 |
233 |
234 | class SimpleDCTSGCNLayer(DCTSGCNLayer):
235 | def __init__(self, in_feat, hidden_feat, out_feat=None, ntypes=None, edge_sequences=None, k=3, time_length=10,
236 | combine_method='concat', encoder_type='GCN', encoder_dict=None, return_list=False, snapshot_weight=None,
237 | **kwargs):
238 | super().__init__(in_feat, hidden_feat, out_feat, ntypes, None, k, time_length)
239 | print(encoder_type, encoder_dict)
240 | # T * k layers
241 | self.k = k
242 | self.time_length = time_length
243 | self.ntypes = ntypes
244 | if out_feat is None:
245 | out_feat = hidden_feat
246 | self.fc_in = dglnn.HeteroLinear({ntype: in_feat for ntype in ntypes}, hidden_feat)
247 | # self.fc_out = dglnn.HeteroLinear({ntype: hidden_feat for ntype in ntypes}, out_feat)
248 | self.fc_out = nn.ModuleList([dglnn.HeteroLinear({ntype: hidden_feat for ntype in ntypes}, out_feat)
249 | for i in range(k)])
250 | self.return_list = return_list
251 |
252 | print(combine_method)
253 | self.concat_fc = None
254 | if combine_method == 'concat':
255 | self.concat_fc = nn.ModuleList(
256 | [dglnn.HeteroLinear({ntype: hidden_feat * 2 for ntype in ntypes}, hidden_feat)
257 | for i in range(k)])
258 | # print('fc', self.concat_fc[0].linears['author'].weight.device)
259 | if len(edge_sequences) > 1:
260 | self.dual_view = True
261 | forward_sequences, backward_sequences = edge_sequences
262 | self.forward_layers = nn.ModuleList([SimpleDirectedCTSGCNLayer(time_length, forward_sequences, hidden_feat,
263 | encoder_type=encoder_type,
264 | encoder_dict=encoder_dict,
265 | snapshot_weight=snapshot_weight, **kwargs)
266 | for i in range(k)])
267 | self.backward_layers = nn.ModuleList([SimpleDirectedCTSGCNLayer(time_length, forward_sequences, hidden_feat,
268 | encoder_type=encoder_type,
269 | encoder_dict=encoder_dict,
270 | snapshot_weight=snapshot_weight, **kwargs)
271 | for i in range(k)])
272 | else:
273 | self.dual_view = False
274 | single_sequences = edge_sequences[0]
275 | self.forward_layers = []
276 | self.backward_layers = []
277 | self.single_layers = nn.ModuleList([SimpleDirectedCTSGCNLayer(time_length, single_sequences, hidden_feat,
278 | encoder_type=encoder_type,
279 | encoder_dict=encoder_dict,
280 | snapshot_weight=snapshot_weight, **kwargs)
281 | for i in range(k)])
282 | print('dual view:', self.dual_view)
283 |
284 | def forward(self, snapshots, hidden_embs, get_attention=False):
285 | # hidden_embs = snapshots.srcdata['h']
286 | output_list = []
287 | attn_list = []
288 | hidden_embs = self.fc_in(hidden_embs)
289 | for i in range(self.k):
290 | if self.dual_view:
291 | forward_embs = self.forward_layers[i](snapshots, hidden_embs)
292 | backward_embs = self.backward_layers[i](snapshots, hidden_embs)
293 | if self.concat_fc:
294 | out_embs = {ntype: torch.cat((forward_embs[ntype], backward_embs[ntype]), dim=-1) for ntype in
295 | self.ntypes}
296 | hidden_embs = self.concat_fc[i](out_embs)
297 | else:
298 | hidden_embs = {ntype: forward_embs[ntype] + backward_embs[ntype] for ntype in
299 | self.ntypes}
300 | else:
301 | if get_attention:
302 | hidden_embs, attn = self.single_layers[i](snapshots, hidden_embs, get_attention=True)
303 | attn_list.append(attn)
304 | else:
305 | hidden_embs = self.single_layers[i](snapshots, hidden_embs)
306 | if self.return_list:
307 | output_list.append(self.fc_out[i](hidden_embs))
308 | hidden_embs = self.fc_out[-1](hidden_embs)
309 | # for key in hidden_embs:
310 | # print(hidden_embs[key].shape)
311 | if get_attention:
312 | return hidden_embs, output_list, attn_list
313 | else:
314 | return hidden_embs, output_list
315 |
316 |
317 | class SimpleDirectedCTSGCNLayer(nn.Module):
318 | def __init__(self, time_length, edge_sequences, hidden_feat, agg_method='sum', encoder_type='GCN',
319 | encoder_dict=None, snapshot_weight=None,
320 | activation=nn.LeakyReLU(), **kwargs):
321 | # simplified version
322 | super(SimpleDirectedCTSGCNLayer, self).__init__()
323 | self.time_length = time_length
324 | # meta, paper, to_snapshot, snapshot
325 | print(len(edge_sequences))
326 | self.sequence = [edges_type for edges_type, _ in edge_sequences]
327 | self.edges_type = {edges_type: etypes for edges_type, etypes in edge_sequences}
328 | self.agg_method = get_aggregate_fn(agg_method)
329 | self.snapshot_weight = snapshot_weight
330 | # self.snapshot_weight = None
331 | print('snapshot_weight', self.snapshot_weight)
332 | self.convs = nn.ModuleDict({})
333 | self.activation = activation
334 |
335 | for i in range(len(edge_sequences)):
336 | edges_type, etypes = edge_sequences[i]
337 | if edges_type != 'intra_snapshots':
338 | # self.convs[edges_type] = nn.ModuleDict({str(etype):
339 | # dglnn.GraphConv(hidden_feat, hidden_feat, norm='right',
340 | # allow_zero_in_degree=True)
341 | # for etype in etypes})
342 | # self.convs[edges_type] = nn.ModuleDict({str(etype):
343 | # CustomGCN(encoder_type, hidden_feat, hidden_feat, layer=i+1,
344 | # activation=activation, **kwargs)
345 | # for etype in etypes})
346 | self.convs[edges_type] = nn.ModuleDict({})
347 | for etype in etypes:
348 | if etype not in encoder_dict:
349 | cur_encoder_type = encoder_type
350 | encoder_args = {}
351 | else:
352 | cur_encoder_type = encoder_dict[etype]['cur_type']
353 | encoder_args = encoder_dict[etype]
354 | self.convs[edges_type][str(etype)] = CustomGCN(cur_encoder_type, hidden_feat, hidden_feat,
355 | layer=i + 1, activation=activation, **encoder_args)
356 |
357 | else:
358 | if self.snapshot_weight:
359 | self.snapshot_weight = 'none'
360 | else:
361 | self.snapshot_weight = 'right'
362 | self.convs[edges_type] = nn.ModuleDict({})
363 | for etype in etypes:
364 | if etype not in encoder_dict:
365 | cur_encoder_type = encoder_type
366 | encoder_args = {}
367 | else:
368 | cur_encoder_type = encoder_dict[etype]['cur_type']
369 | encoder_args = encoder_dict[etype]
370 | print(encoder_args)
371 | self.convs[edges_type][str(etype)] = CustomGCN(cur_encoder_type, hidden_feat, hidden_feat,
372 | layer=i + 1, activation=activation,
373 | **encoder_args)
374 |
375 | def normal_forward(self, snapshots, input_embs, edges_type):
376 | # conv_embs = []
377 | outputs = defaultdict(list)
378 | all_embs = input_embs
379 | for etype in self.edges_type[edges_type]:
380 | # print('etype', etype)
381 | stype, _, dtype = etype
382 | subgraph = snapshots[tuple(etype)]
383 | # subgraph = dgl.add_self_loop(subgraph)
384 | # print(etype, subgraph.in_degrees())
385 | srcdata = all_embs[stype]
386 | dstdata = all_embs[dtype]
387 | outputs[dtype].append(self.convs[edges_type][str(etype)](subgraph, (srcdata, dstdata)))
388 | for dtype in outputs:
389 | if len(outputs[dtype]) > 0:
390 | all_embs[dtype] = self.agg_method(outputs[dtype], dtype)
391 | # conv_embs.append(all_embs)
392 |
393 | return all_embs
394 |
395 | def snapshot_forward(self, input_embs, snapshot_graphs):
396 | outputs = defaultdict(list)
397 | all_embs = input_embs
398 | for etype in self.edges_type['intra_snapshots']:
399 | # print('etype', etype)
400 | stype, _, dtype = etype
401 | subgraph = snapshot_graphs[tuple(etype)]
402 | # subgraph = dgl.add_self_loop(subgraph)
403 | # print(etype, subgraph.in_degrees())
404 | srcdata = all_embs[stype]
405 | dstdata = all_embs[dtype]
406 | norm_edge_weight = None
407 | if self.snapshot_weight:
408 | norm = dglnn.EdgeWeightNorm(norm=self.snapshot_weight)
409 | norm_edge_weight = norm(subgraph, subgraph.edata['w'])
410 | outputs[dtype].append(self.convs['intra_snapshots'][str(etype)](subgraph, (srcdata, dstdata),
411 | edge_weight=norm_edge_weight))
412 | for dtype in outputs:
413 | if len(outputs[dtype]) > 0:
414 | all_embs[dtype] = self.agg_method(outputs[dtype], dtype)
415 | return all_embs
416 |
417 | def get_attn(self, snapshots, input_embs, edges_type, etype):
418 | # conv_embs = []
419 | outputs = defaultdict(list)
420 | all_embs = input_embs
421 | # print('etype', etype)
422 | # etype = ('paper', 'is in', 'snapshot')
423 | stype, _, dtype = etype
424 | subgraph = snapshots[tuple(etype)]
425 | # subgraph = dgl.add_self_loop(subgraph)
426 | # print(etype, subgraph.in_degrees())
427 | srcdata = all_embs[stype]
428 | dstdata = all_embs[dtype]
429 | _, attn = self.convs[edges_type][str(etype)](subgraph, (srcdata, dstdata), get_attention=True)
430 | return attn
431 |
432 | def forward(self, snapshots, input_embs, get_attention=False):
433 | input_embs = input_embs.copy()
434 | hidden_embs = input_embs
435 | for edges_type in self.sequence:
436 | if edges_type != 'intra_snapshots':
437 | if get_attention:
438 | attn_rgat = self.get_attn(snapshots, hidden_embs, edges_type, ('paper', 'is in', 'snapshot'))
439 | attn_cgin = self.get_attn(snapshots, hidden_embs, edges_type, ('paper', 'cites', 'paper'))
440 | attn = [attn_rgat, attn_cgin]
441 | hidden_embs = self.normal_forward(snapshots, hidden_embs, edges_type)
442 | else:
443 | hidden_embs = self.snapshot_forward(hidden_embs, snapshots)
444 | if self.activation:
445 | for ntype in hidden_embs.keys():
446 | hidden_embs[ntype] = self.activation(hidden_embs[ntype])
447 | if get_attention:
448 | return hidden_embs, attn
449 | else:
450 | return hidden_embs
451 |
452 |
453 | if __name__ == '__main__':
454 | start_time = datetime.datetime.now()
455 | parser = argparse.ArgumentParser(description='Process some description.')
456 | parser.add_argument('--phase', default='test', help='the function name.')
457 |
458 | args = parser.parse_args()
459 | rel_names = ['cites', 'is cited by', 'writes', 'is writen by', 'publishes', 'is published by']
460 | model = StochasticKLayerRGCN(768, 300, 300, rel_names)
461 | print(model)
462 |
--------------------------------------------------------------------------------
/layers/GCN/RGCN.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 |
4 | import dgl
5 | import torch
6 | import torch.nn as nn
7 | import dgl.nn.pytorch as dglnn
8 | from layers.common_layers import MLP, CapsuleNetwork
9 | from layers.GCN.custom_gcn import DAGIN, RGIN, RGIN_N, RGAT, DGIN, DGINM
10 |
11 |
12 | def get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=None, **kwargs):
13 | if encoder_type == 'GCN':
14 | return dglnn.GraphConv(in_feat, hidden_feat, norm='right', allow_zero_in_degree=True)
15 | elif encoder_type == 'GIN':
16 | return dglnn.GINConv(MLP(in_feat, hidden_feat, activation=activation))
17 | elif encoder_type == 'GAT':
18 | nheads = kwargs.get('nheads', 4)
19 | out_feat = hidden_feat // nheads
20 | return dglnn.GATv2Conv(in_feat, out_feat, num_heads=nheads, activation=activation, allow_zero_in_degree=True)
21 | elif encoder_type == 'APPNP':
22 | k = kwargs.get('k', 5)
23 | alpha = kwargs.get('alpha', 0.5)
24 | return dglnn.APPNPConv(k=k, alpha=alpha)
25 | elif encoder_type == 'GCN2':
26 | return dglnn.GCN2Conv(in_feat, layer, allow_zero_in_degree=True, activation=activation)
27 | elif encoder_type == 'SAGE':
28 | return dglnn.SAGEConv(in_feat, hidden_feat, 'mean', activation=activation)
29 | elif encoder_type == 'PNA':
30 | return dglnn.PNAConv(in_feat, hidden_feat, ['mean', 'max', 'sum'], ['identity', 'amplification'], 2.5)
31 | elif encoder_type == 'DGN':
32 | return dglnn.DGNConv(in_feat, hidden_feat, ['dir1-av', 'dir1-dx', 'sum'], ['identity', 'amplification'], 2.5)
33 | elif encoder_type == 'DAGIN':
34 | distance_interval = kwargs.get('distance_interval', 5)
35 | return DAGIN(in_feat, hidden_feat, distance_interval=distance_interval, activation=activation)
36 | elif encoder_type == 'DGIN':
37 | attr_indicator = kwargs.get('attr_indicator', 'citations')
38 | return DGIN(in_feat, hidden_feat, attr_indicator=attr_indicator, activation=activation)
39 | elif encoder_type == 'DGINM':
40 | attr_indicator = kwargs.get('attr_indicator', 'citations')
41 | return DGINM(in_feat, hidden_feat, attr_indicator=attr_indicator, activation=activation)
42 | elif encoder_type == 'DGINP':
43 | attr_indicator = kwargs.get('attr_indicator', 'citations')
44 | return DGIN(in_feat, hidden_feat, degree_trans_method='pos',
45 | attr_indicator=attr_indicator, activation=activation)
46 | elif encoder_type == 'DGINPM':
47 | attr_indicator = kwargs.get('attr_indicator', 'citations')
48 | return DGINM(in_feat, hidden_feat, degree_trans_method='pos',
49 | attr_indicator=attr_indicator, activation=activation)
50 | elif encoder_type == 'RGIN':
51 | # print(kwargs)
52 | snapshot_types = kwargs.get('snapshot_types', 3)
53 | node_attr_types = kwargs.get('node_attr_types', [])
54 | return RGIN_N(in_feat, hidden_feat, snapshot_types, node_attr_types, activation=activation)
55 | elif encoder_type == 'CGIN':
56 | return dglnn.GINConv(CapsuleNetwork(in_feat, hidden_feat, activation=activation))
57 | elif encoder_type == 'RGAT':
58 | snapshot_types = kwargs.get('snapshot_types', 3)
59 | node_attr_types = kwargs.get('node_attr_types', [])
60 | print(node_attr_types)
61 | nheads = kwargs.get('nheads', 4)
62 | out_feat = hidden_feat // nheads
63 | return RGAT(in_feat, out_feat, nheads, snapshot_types, node_attr_types,
64 | activation=activation, allow_zero_in_degree=True)
65 |
66 |
67 | class CustomGCN(nn.Module):
68 | # homo single GCN encoder
69 | def __init__(self, encoder_type, in_feat, hidden_feat, layer, activation=None, **kwargs):
70 | super(CustomGCN, self).__init__()
71 | self.encoder_type = encoder_type
72 | self.fc_out = None
73 | # self.conv = get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=activation)
74 | if encoder_type in ['DAGIN', 'CGIN', 'DGIN']:
75 | self.conv = get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=activation)
76 | elif encoder_type == 'RGIN':
77 | snapshot_types = kwargs['snapshot_types']
78 | node_attr_types = kwargs['node_attr_types']
79 | self.conv = get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=activation,
80 | snapshot_types=snapshot_types, node_attr_types=node_attr_types)
81 | elif encoder_type == 'RGAT':
82 | snapshot_types = kwargs['snapshot_types']
83 | node_attr_types = kwargs['node_attr_types']
84 | self.conv = get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=activation,
85 | snapshot_types=snapshot_types, node_attr_types=node_attr_types)
86 | else:
87 | self.conv = get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=activation)
88 |
89 | if 'GAT' in self.encoder_type:
90 | self.fc_out = nn.Linear(hidden_feat, hidden_feat)
91 | elif self.encoder_type in ['APPNP', 'GCN2']:
92 | self.fc_out = nn.Linear(in_feat, hidden_feat)
93 |
94 | def forward(self, graph, feat, edge_weight=None, **kwargs):
95 | '''
96 | :param graph:
97 | :param feat: tuple of src and dst
98 | :param edge_weight:
99 | :param kwargs:
100 | :return: hidden_embs: {ntype: [N, out_feat]}
101 | '''
102 | # print('wip')
103 | # print(self.encoder_type)
104 | attn = None
105 | if self.encoder_type == 'GAT':
106 | feat = self.conv(graph, feat, **kwargs)
107 | else:
108 | feat = self.conv(graph, feat, edge_weight=edge_weight, **kwargs)
109 | if 'GAT' in self.encoder_type:
110 | if kwargs.get('get_attention', False):
111 | feat, attn = feat
112 | num_nodes = feat.shape[0]
113 | feat = feat.reshape(num_nodes, -1)
114 | if self.fc_out:
115 | feat = self.fc_out(feat)
116 | if attn is not None:
117 | return feat, attn
118 | else:
119 | return feat
120 |
121 |
122 | class BasicHeteroGCN(nn.Module):
123 | def __init__(self, encoder_type, in_feat, out_feat, ntypes, etypes, layer, activation=None, **kwargs):
124 | super(BasicHeteroGCN, self).__init__()
125 |
126 | if encoder_type == 'HGT':
127 | nheads = kwargs.get('nheads', 4)
128 | self.conv = dglnn.HGTConv(in_feat, out_feat // nheads, nheads,
129 | num_ntypes=len(ntypes), num_etypes=len(etypes))
130 | else:
131 | self.conv = dglnn.HeteroGraphConv({etype: get_gcn_layer(encoder_type, in_feat, out_feat, layer,
132 | activation=activation) for etype in etypes})
133 | self.encoder_type = encoder_type
134 | print(self.encoder_type)
135 | self.fc_out = None
136 | if self.encoder_type == 'GAT':
137 | self.fc_out = dglnn.HeteroLinear({ntype: out_feat for ntype in ntypes}, out_feat)
138 | elif self.encoder_type in ['APPNP', 'GCN2']:
139 | self.fc_out = dglnn.HeteroLinear({ntype: in_feat for ntype in ntypes}, out_feat)
140 |
141 | def forward(self, graph, feat, edge_weight=None, **kwargs):
142 | '''
143 | :param graph:
144 | :param feat:
145 | :param edge_weight:
146 | :param kwargs:
147 | :return: hidden_embs: {ntype: [N, out_feat]}
148 | '''
149 | # print('wip')
150 | # print('start', [torch.isnan(feat[x]).sum() for x in feat])
151 | # print(self.encoder_type)
152 | # print(self.conv)
153 | if self.encoder_type == 'HGT':
154 | # print(feat.shape)
155 | feat = self.conv(graph, feat, graph.ndata[dgl.NTYPE], graph.edata[dgl.ETYPE], presorted=True)
156 | if self.encoder_type == 'GAT':
157 | feat = self.conv(graph, feat)
158 | else:
159 | # feat = self.conv(graph, feat, edge_weight=edge_weight)
160 | feat = self.conv(graph, feat, mod_kwargs={'edge_weight': edge_weight})
161 | # print('end', [torch.isnan(feat[x]).sum() for x in feat])
162 | if self.encoder_type == 'GAT':
163 | for ntype in feat:
164 | num_nodes = feat[ntype].shape[0]
165 | feat[ntype] = feat[ntype].reshape(num_nodes, -1)
166 | if self.fc_out:
167 | feat = self.fc_out(feat)
168 | return feat
169 |
170 |
171 | class CustomHeteroGCN(nn.Module):
172 | def __init__(self, encoder_type, in_feat, out_feat, ntypes, etypes, layer, activation=None, **kwargs):
173 | super(CustomHeteroGCN, self).__init__()
174 | etype_dict = None
175 |
176 | if encoder_type == 'HGT':
177 | nheads = kwargs.get('nheads', 4)
178 | self.conv = dglnn.HGTConv(in_feat, out_feat // nheads, nheads,
179 | num_ntypes=len(ntypes), num_etypes=len(etypes))
180 | elif encoder_type in ['DAGIN', 'CGIN']:
181 | # print(kwargs)
182 | etype_dict = {}
183 | for etype in kwargs.get('special_edges', []):
184 | etype_dict[etype] = get_gcn_layer(encoder_type, in_feat, out_feat, layer, activation=activation)
185 | for etype in etypes:
186 | if etype not in etype_dict:
187 | etype_dict[etype] = get_gcn_layer('GIN', in_feat, out_feat, layer, activation=activation)
188 | elif encoder_type == 'RGIN':
189 | etype_dict = {}
190 | for etype in kwargs.get('special_edges', []):
191 | snapshot_types = kwargs['snapshot_types']
192 | node_attr_types = kwargs['node_attr_types']
193 | etype_dict[etype] = get_gcn_layer(encoder_type, in_feat, out_feat, layer, activation=activation,
194 | snapshot_types=snapshot_types, node_attr_types=node_attr_types)
195 | # print(etype_dict)
196 | for etype in etypes:
197 | if etype not in etype_dict:
198 | etype_dict[etype] = get_gcn_layer('GIN', in_feat, out_feat, layer, activation=activation)
199 | elif encoder_type == 'AGIN':
200 | etype_dict = {}
201 | # DAGIN
202 | if 'special_edges' in kwargs:
203 | for etype in kwargs['special_edges']['DAGIN']:
204 | etype_dict[etype] = get_gcn_layer('DAGIN', in_feat, out_feat, layer, activation=activation)
205 | # RGIN
206 | for etype in kwargs['special_edges']['RGIN']:
207 | snapshot_types = kwargs['snapshot_types']
208 | node_attr_types = kwargs['node_attr_types']
209 | etype_dict[etype] = get_gcn_layer('RGIN', in_feat, out_feat, layer, activation=activation,
210 | snapshot_types=snapshot_types, node_attr_types=node_attr_types)
211 | # print(etype_dict)
212 | for etype in etypes:
213 | if etype not in etype_dict:
214 | etype_dict[etype] = get_gcn_layer('GIN', in_feat, out_feat, layer, activation=activation)
215 | elif 'encoder_dict' in kwargs:
216 | etype_dict = {}
217 | encoder_dict = kwargs['encoder_dict']
218 | for special_encoder_type in encoder_dict:
219 | for etype in encoder_dict[special_encoder_type]:
220 | etype_dict[etype] = get_gcn_layer(special_encoder_type, in_feat, out_feat, layer,
221 | activation=activation)
222 | print(etype_dict)
223 | for etype in etypes:
224 | if etype not in etype_dict:
225 | etype_dict[etype] = get_gcn_layer(encoder_type, in_feat, out_feat, layer, activation=activation)
226 | else:
227 | etype_dict = {etype: get_gcn_layer(encoder_type, in_feat, out_feat, layer,
228 | activation=activation) for etype in etypes}
229 | if etype_dict:
230 | self.conv = dglnn.HeteroGraphConv(etype_dict)
231 | self.encoder_type = encoder_type
232 | # print(self.encoder_type)
233 | self.fc_out = None
234 | if self.encoder_type == 'GAT':
235 | self.fc_out = dglnn.HeteroLinear({ntype: out_feat for ntype in ntypes}, out_feat)
236 | elif self.encoder_type in ['APPNP', 'GCN2']:
237 | self.fc_out = dglnn.HeteroLinear({ntype: in_feat for ntype in ntypes}, out_feat)
238 |
239 | def forward(self, graph, feat, edge_weight=None, **kwargs):
240 | '''
241 | :param graph:
242 | :param feat:
243 | :param edge_weight:
244 | :param kwargs:
245 | :return: hidden_embs: {ntype: [N, out_feat]}
246 | '''
247 | # print('wip')
248 | if self.encoder_type == 'HGT':
249 | # print(feat.shape)
250 | feat = self.conv(graph, feat, graph.ndata[dgl.NTYPE], graph.edata[dgl.ETYPE], presorted=True)
251 | else:
252 | feat = self.conv(graph, feat, edge_weight=edge_weight)
253 | if self.encoder_type == 'GAT':
254 | for ntype in feat:
255 | num_nodes = feat[ntype].shape[0]
256 | feat[ntype] = feat[ntype].reshape(num_nodes, -1)
257 | if self.fc_out:
258 | feat = self.fc_out(feat)
259 | return feat
260 |
261 |
262 | class StochasticKLayerRGCN(nn.Module):
263 | def __init__(self, in_feat, hidden_feat, out_feat, ntypes, etypes, k=2, residual=False, encoder_type='GCN'):
264 | super().__init__()
265 | self.residual = residual
266 | self.convs = nn.ModuleList()
267 | # print(encoder_type, in_feat, hidden_feat)
268 | # self.conv_start = dglnn.HeteroGraphConv({
269 | # rel: get_gcn_layer(encoder_type, in_feat, hidden_feat, activation=nn.LeakyReLU())
270 | # for rel in etypes
271 | # })
272 | # self.conv_end = dglnn.HeteroGraphConv({
273 | # rel: get_gcn_layer(encoder_type, hidden_feat, out_feat, activation=nn.LeakyReLU())
274 | # for rel in etypes
275 | # })
276 | # self.convs.append(self.conv_start)
277 | # for i in range(k - 2):
278 | # self.convs.append(dglnn.HeteroGraphConv({
279 | # rel: get_gcn_layer(encoder_type, hidden_feat, hidden_feat, activation=nn.LeakyReLU())
280 | # for rel in etypes
281 | # }))
282 | # self.convs.append(self.conv_end)
283 | print(ntypes)
284 | self.skip_fcs = nn.ModuleList()
285 | self.conv_start = BasicHeteroGCN(encoder_type, in_feat, hidden_feat, ntypes, etypes, layer=1,
286 | activation=nn.LeakyReLU())
287 | self.skip_fc_start = dglnn.HeteroLinear({key: in_feat for key in ntypes}, hidden_feat)
288 | self.conv_end = BasicHeteroGCN(encoder_type, hidden_feat, out_feat, ntypes, etypes, layer=k,
289 | activation=nn.LeakyReLU())
290 | self.skip_fc_end = dglnn.HeteroLinear({key: hidden_feat for key in ntypes}, out_feat)
291 | self.convs.append(self.conv_start)
292 | self.skip_fcs.append(self.skip_fc_start)
293 | for i in range(k - 2):
294 | self.convs.append(BasicHeteroGCN(encoder_type, hidden_feat, hidden_feat, ntypes, etypes, layer=i + 2,
295 | activation=nn.LeakyReLU()))
296 | self.skip_fcs.append(dglnn.HeteroLinear({key: hidden_feat for key in ntypes}, hidden_feat))
297 | self.convs.append(self.conv_end)
298 | self.skip_fcs.append(self.skip_fc_end)
299 |
300 | # def forward(self, graph, x):
301 | # x = self.conv_start(graph, x)
302 | # x = self.conv_end(graph, x)
303 | # return x
304 | def forward(self, graph, x):
305 | # print(graph.device)
306 | # print(x.device)
307 | # count = 0
308 | for i in range(len(self.convs)):
309 | # print('-' * 30 + str(count) + '-' * 30)
310 | # print(x)
311 | if self.residual:
312 | # x = x + conv(graph, x)
313 | out = self.convs[i](graph, x)
314 | x = self.skip_fcs[i](x)
315 | for key in x:
316 | x[key] = x[key] + out[key]
317 | else:
318 | x = self.convs[i](graph, x)
319 | # print([torch.isnan(x[e]).sum() for e in x])
320 | # print(x)
321 | # count += 1
322 | return x
323 |
324 | # def forward_block(self, blocks, x):
325 | # x = self.conv_start(blocks[0], x)
326 | # x = self.conv_end(blocks[1], x)
327 | # return x
328 | def forward_block(self, blocks, x):
329 | for i in range(len(self.convs)):
330 | x = self.convs[i](blocks[i], x)
331 | return x
332 |
333 |
334 | class CustomRGCN(nn.Module):
335 | def __init__(self, in_feat, hidden_feat, out_feat, ntypes, etypes, k=2, residual=False, encoder_type='GCN',
336 | return_list=False, **kwargs):
337 | super().__init__()
338 | self.residual = residual
339 | self.convs = nn.ModuleList()
340 | self.conv_start = CustomHeteroGCN(encoder_type, in_feat, hidden_feat, ntypes, etypes, layer=1,
341 | activation=nn.LeakyReLU(), **kwargs)
342 | self.conv_end = CustomHeteroGCN(encoder_type, hidden_feat, out_feat, ntypes, etypes, layer=k,
343 | activation=nn.LeakyReLU(), **kwargs)
344 | self.convs.append(self.conv_start)
345 | for i in range(k - 2):
346 | self.convs.append(CustomHeteroGCN(encoder_type, hidden_feat, hidden_feat, ntypes, etypes, layer=i + 2,
347 | activation=nn.LeakyReLU(), **kwargs))
348 | self.convs.append(self.conv_end)
349 | self.return_list = return_list
350 |
351 | def forward(self, graph, x):
352 | output_list = []
353 | for conv in self.convs:
354 | if self.residual:
355 | x = x + conv(graph, x)
356 | else:
357 | x = conv(graph, x)
358 | if self.return_list:
359 | output_list.append(x)
360 | return x, output_list
361 |
362 |
363 | class HTSGCNLayer(nn.Module):
364 | def __init__(self, in_feat, hidden_feat, out_feat, ntypes, etypes, k=2, time_length=10, encoder_type='GCN',
365 | return_list=False):
366 | super().__init__()
367 | # T * k layers
368 | self.k = k
369 | self.time_length = time_length
370 | # self.spatial_encoders = nn.ModuleList(
371 | # [StochasticKLayerRGCN(in_feat, hidden_feat, out_feat, rel_names, k=k) for i in range(time_length)])
372 | self.spatial_encoders = nn.ModuleList(
373 | [CustomRGCN(in_feat, hidden_feat, out_feat, ntypes, etypes, k=k,
374 | encoder_type=encoder_type, return_list=return_list) for i in range(time_length)])
375 | self.time_encoders = nn.ModuleList()
376 | for i in range(k - 1):
377 | self.time_encoders.append(HTEncoder(hidden_feat, hidden_feat, activation=torch.relu))
378 | self.time_encoders.append(HTEncoder(out_feat, out_feat, activation=torch.relu))
379 |
380 | # def forward(self, graph, x):
381 | # x = self.conv_start(graph, x)
382 | # x = self.conv_end(graph, x)
383 | # return x
384 | def forward(self, graph_list, time_adj):
385 | for i in range(self.k):
386 | # print(i)
387 | for t in range(self.time_length):
388 | # 注意这里都是每年batch好的图
389 | # try:
390 | graph_list[t].dstdata['h'] = self.spatial_encoders[t].convs[i](graph_list[t],
391 | graph_list[t].srcdata['h'])
392 | # except Exception as e:
393 | # print(e)
394 | # print(graph_list[t])
395 | # print(graph_list[t].batch_size)
396 | # raise Exception
397 | # print(graph_list[t].dstdata['h']['snapshot'].shape)
398 | graph_list = self.time_encoders[i](graph_list, time_adj)
399 | return graph_list
400 |
401 |
402 | class HTEncoder(nn.Module):
403 | def __init__(self, in_feats, out_feats, bias=False, activation=None):
404 | super(HTEncoder, self).__init__()
405 | self.in_feats = in_feats
406 | self.out_feats = out_feats
407 | self.l_weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
408 | self.r_weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
409 | self.l_bias = None
410 | self.r_bias = None
411 | self.fc_out = nn.Linear(2 * out_feats, out_feats, bias=bias)
412 | self.activation = activation
413 | if bias:
414 | self.l_bias = nn.Parameter(torch.Tensor(out_feats))
415 | self.r_bias = nn.Parameter(torch.Tensor(out_feats))
416 | self.init_weights()
417 |
418 | def init_weights(self):
419 | if self.l_weight is not None:
420 | nn.init.xavier_uniform_(self.l_weight)
421 | nn.init.xavier_uniform_(self.r_weight)
422 | if self.l_bias is not None:
423 | nn.init.zeros_(self.l_bias)
424 | nn.init.zeros_(self.r_bias)
425 |
426 | def forward(self, snapshots, time_adj):
427 | # T[B], [B, 2, T, T]
428 | # print('wip')
429 | emb_list = []
430 | for snapshot in snapshots:
431 | # print(snapshot.nodes['snapshot'].data['h'].shape)
432 | emb_list.append(snapshot.nodes['snapshot'].data['h'])
433 | time_emb = torch.stack(emb_list, dim=1) # [B, T, d_in]
434 | l_emb = torch.matmul(torch.matmul(time_adj[:, 0], time_emb), self.l_weight)
435 | r_emb = torch.matmul(torch.matmul(time_adj[:, 1], time_emb), self.r_weight)
436 | out_emb = torch.cat((l_emb, r_emb), dim=-1)
437 | if self.activation:
438 | out_emb = self.activation(out_emb)
439 | time_out = self.fc_out(out_emb) # [B, T, d_out]
440 | for i in range(len(snapshots)):
441 | snapshots[i].nodes['snapshot'].data['h'] = time_out[:, i]
442 | return snapshots
443 |
444 |
445 | if __name__ == '__main__':
446 | start_time = datetime.datetime.now()
447 | parser = argparse.ArgumentParser(description='Process some description.')
448 | parser.add_argument('--phase', default='test', help='the function name.')
449 |
450 | args = parser.parse_args()
451 | rel_names = ['cites', 'is cited by', 'writes', 'is writen by', 'publishes', 'is published by']
452 | model = StochasticKLayerRGCN(768, 300, 300, rel_names)
453 | print(model)
454 |
--------------------------------------------------------------------------------
/layers/GCN/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__init__.py
--------------------------------------------------------------------------------
/layers/GCN/__pycache__/CL_layers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__pycache__/CL_layers.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/GCN/__pycache__/DTGCN_layers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__pycache__/DTGCN_layers.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/GCN/__pycache__/RGCN.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__pycache__/RGCN.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/GCN/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/GCN/__pycache__/custom_gcn.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__pycache__/custom_gcn.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/__init__.py
--------------------------------------------------------------------------------
/layers/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/common_layers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/__pycache__/common_layers.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/common_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | class PositionalEncodingSC(nn.Module):
7 | def __init__(self, d_model, dropout=0.5, max_len=5000):
8 | super(PositionalEncodingSC, self).__init__()
9 | self.dropout = nn.Dropout(p=dropout)
10 |
11 | position = torch.arange(max_len).unsqueeze(1)
12 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
13 | pe = torch.zeros(max_len, 1, d_model)
14 | pe[:, 0, 0::2] = torch.sin(position * div_term)
15 | pe[:, 0, 1::2] = torch.cos(position * div_term)
16 | self.register_buffer('pe', pe)
17 |
18 | def forward(self, x):
19 | """
20 | :param x: Tensor, shape [seq_len, batch_size, embedding_dim]
21 | :return Tensor:
22 | """
23 | x = x + self.pe[:x.size(0)]
24 | out = self.dropout(x)
25 | return out
26 |
27 |
28 | class PositionalEncoding(nn.Module):
29 | def __init__(self, d_model, dropout=0.5, max_len=5000):
30 | super(PositionalEncoding, self).__init__()
31 | self.dropout = nn.Dropout(p=dropout)
32 |
33 | position = torch.arange(max_len).unsqueeze(1)
34 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
35 | pe = torch.zeros(max_len, 1, d_model)
36 | pe[:, 0, 0::2] = torch.sin(position * div_term)
37 | pe[:, 0, 1::2] = torch.cos(position * div_term)
38 | pe = pe.squeeze(dim=1)
39 | self.register_buffer('pe', pe)
40 |
41 | def forward(self, x):
42 | """
43 | :param x: Tensor, shape [batch_size, seq_len]
44 | :return Tensor:
45 | """
46 | # x = x + self.pe[:x.size(0)]
47 | # out = self.dropout(x)
48 | out = self.pe[x]
49 | return out
50 |
51 |
52 | class TemporalEncoding(nn.Module):
53 | def __init__(self, d_model, dropout=0, max_len=5000, random=True):
54 | super(TemporalEncoding, self).__init__()
55 | self.dropout = nn.Dropout(p=dropout)
56 |
57 | position = torch.arange(max_len).unsqueeze(1) # t
58 | div_term = 1 / math.sqrt(d_model)
59 | if random:
60 | w = torch.randn((1, d_model))
61 | else:
62 | w = torch.arange(0, d_model, 1).unsqueeze(0) / d_model
63 | # temp = nn.Parameter(position * w)
64 | # print(position.shape, w.shape)
65 | temp = position * w
66 | te = torch.cos(temp) * torch.cos(temp) * div_term
67 | # print(te)
68 | self.register_buffer('te', te)
69 |
70 | def forward(self, x):
71 | """
72 | :param x: Tensor, shape [batch_size, seq_len]
73 | :return Tensor:
74 | """
75 | # print(self.te.device)
76 | out = self.te[x]
77 | return out
78 |
79 |
80 | class AttentionModule(nn.Module):
81 | def __init__(self, dim_in, dim_out=None, activation=nn.Tanh()):
82 | super(AttentionModule, self).__init__()
83 | if dim_out is None:
84 | dim_out = dim_in
85 | # self.trans = nn.Linear(dim_in, dim_out)
86 | # self.seq_trans = nn.Linear(dim_in, dim_out)
87 | self.target_trans = MLP(dim_in, dim_out)
88 | self.seq_trans = MLP(dim_in, dim_out)
89 | # self.w = nn.Linear(dim_out, 1, bias=False)
90 | self.w = nn.Linear(dim_out, dim_out, bias=False)
91 | self.u = nn.Linear(dim_out, dim_out, bias=False)
92 | self.activation = activation
93 |
94 | def forward(self, target, sequence, mask):
95 | # batch_size = target.shape[0]
96 | target = self.target_trans(target)
97 | # print(sequence.shape)
98 | sequence = self.seq_trans(sequence)
99 |
100 | # attn = self.activation(sequence @ self.w(target).unsqueeze(dim=-1)) + mask.unsqueeze(dim=-1).detach()
101 | # print((self.u(sequence) @ self.w(target).unsqueeze(dim=-1)).shape)
102 | # attn = self.activation(self.u(sequence) @ self.w(target).unsqueeze(dim=-1)) + mask.unsqueeze(dim=-1).detach()
103 | attn = self.u(sequence) @ self.w(target).unsqueeze(dim=-1) + mask.unsqueeze(dim=-1).detach()
104 | # attn = attn.view(batch_size, -1)
105 | # attn[mask] = -1e8
106 | attn = torch.softmax(attn, dim=-2)
107 |
108 | out = (sequence * attn).sum(1)
109 | # print(out.shape)
110 | return out
111 |
112 |
113 | class GateModule(nn.Module):
114 | def __init__(self, input_dim, out_dim):
115 | super(GateModule, self).__init__()
116 | self.input_trans = MLP(input_dim, input_dim)
117 | self.target_trans = MLP(input_dim, input_dim)
118 | self.mix_trans = MLP(input_dim * 2, out_dim)
119 | self.sigmoid = nn.Sigmoid()
120 |
121 | def forward(self, input_emb, target_emb):
122 | alpha = self.sigmoid(self.input_trans(input_emb) + self.target_trans(target_emb))
123 | input_emb = alpha * input_emb
124 | mix_out = self.mix_trans(torch.cat((input_emb, target_emb), dim=-1))
125 | return mix_out
126 |
127 |
128 | class MLP(nn.Module):
129 | def __init__(self, dim_in, dim_out, bias=True, dim_inner=None,
130 | num_layers=2, activation=nn.ReLU(inplace=False), **kwargs):
131 | super(MLP, self).__init__()
132 | layers = []
133 | dim_inner = dim_in if dim_inner is None else dim_inner
134 | if num_layers > 1:
135 | for i in range(num_layers - 1):
136 | # print(dim_in, dim_inner, bias)
137 | layers.append(nn.Linear(dim_in, dim_inner, bias))
138 | if activation:
139 | layers.append(activation)
140 | layers.append(nn.Linear(dim_inner, dim_out, bias))
141 | else:
142 | layers.append(nn.Linear(dim_in, dim_out, bias))
143 | self.model = nn.Sequential(*layers)
144 |
145 | def forward(self, batch):
146 | batch = self.model(batch)
147 | return batch
148 |
149 |
150 | class CapsuleNetwork(nn.Module):
151 | def __init__(self, dim_in, dim_out, bias=True, num_routing=3, num_caps=4, return_prob=False,
152 | activation=nn.ReLU(), **kwargs):
153 | # projection and routing
154 | super(CapsuleNetwork, self).__init__()
155 | dim_capsule = dim_out // num_caps
156 | self.dim_out = dim_out
157 | self.model = OriginCaps(in_dim=dim_in, in_caps=1, num_caps=num_caps, dim_caps=dim_capsule,
158 | num_routing=num_routing)
159 | self.bn_out = nn.BatchNorm1d(dim_capsule) # bn is important for capsule
160 | self.fc_out = MLP(dim_out, dim_out, bias=bias, activation=activation)
161 | self.return_prob = return_prob
162 |
163 | def forward(self, x):
164 | # [N, dim_in]
165 | # x = self.bn_in(x)
166 | x = x.unsqueeze(dim=1)
167 | x, p = self.model(x) # [N, Nc, dc] p是capsule probability
168 | # print(p)
169 | x = self.bn_out(x.transpose(1, 2)).transpose(1, 2)
170 | x = self.fc_out(x.reshape(-1, self.dim_out))
171 | # x = self.bn_out(x)
172 | if self.return_prob:
173 | return x, p
174 | else:
175 | return x
176 |
177 |
178 | def squash(x, dim=-1):
179 | # squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
180 | squared_norm = torch.square(x).sum(dim=dim, keepdim=True)
181 | scale = squared_norm / (1 + squared_norm)
182 | # vector = scale * x / (squared_norm.sqrt() + 1e-8)
183 | vector = scale * x / torch.sqrt(squared_norm + 1e-11)
184 | return vector, scale
185 |
186 |
187 | class OriginCaps(nn.Module):
188 | """Original capsule graph_encoder."""
189 |
190 | def __init__(self, in_dim, in_caps, num_caps, dim_caps, num_routing):
191 | """
192 | Initialize the graph_encoder.
193 |
194 | Args:
195 | in_dim: Dimensionality (i.e. length) of each capsule vector.
196 | in_caps: Number of input capsules if digits graph_encoder.
197 | num_caps: Number of capsules in the capsule graph_encoder
198 | dim_caps: Dimensionality, i.e. length, of the output capsule vector.
199 | num_routing: Number of iterations during routing algorithm
200 | """
201 | super(OriginCaps, self).__init__()
202 | self.in_dim = in_dim
203 | self.in_caps = in_caps
204 | self.num_caps = num_caps
205 | self.dim_caps = dim_caps
206 | self.num_routing = num_routing
207 | # self.W = nn.Parameter(0.01 * torch.randn(1, num_caps, in_caps, dim_caps, in_dim),
208 | # requires_grad=True)
209 | self.W = nn.Parameter(torch.empty(1, num_caps, in_caps, dim_caps, in_dim),
210 | requires_grad=True)
211 | self.init_weights()
212 |
213 | def init_weights(self):
214 | nn.init.uniform_(self.W, -0.1, 0.1)
215 |
216 | def forward(self, x):
217 | batch_size = x.size(0)
218 | # (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1)
219 | x = x.unsqueeze(1).unsqueeze(4)
220 | # print('x', torch.isnan(x).sum())
221 | #
222 | # W @ x =
223 | # (1, num_caps, in_caps, dim_caps, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) =
224 | # (batch_size, num_caps, in_caps, dim_caps, 1)
225 | u_hat = torch.matmul(self.W, x)
226 | # (batch_size, num_caps, in_caps, dim_caps)
227 | u_hat = u_hat.squeeze(-1)
228 | # detach u_hat during routing iterations to prevent gradients from flowing
229 | temp_u_hat = u_hat.detach()
230 |
231 | b = torch.zeros(batch_size, self.num_caps, self.in_caps, 1, requires_grad=False).to(x.device)
232 |
233 | for route_iter in range(self.num_routing - 1):
234 | # (batch_size, num_caps, in_caps, 1) -> Softmax along num_caps
235 | c = b.softmax(dim=1)
236 |
237 | # element-wise multiplication
238 | # (batch_size, num_caps, in_caps, 1) * (batch_size, in_caps, num_caps, dim_caps) ->
239 | # (batch_size, num_caps, in_caps, dim_caps) sum across in_caps ->
240 | # (batch_size, num_caps, dim_caps)
241 | s = (c * temp_u_hat).sum(dim=2)
242 | # apply "squashing" non-linearity along dim_caps
243 | v, s = squash(s)
244 | # print('v', torch.isnan(v).sum())
245 | # dot product agreement between the current output vj and the prediction uj|i
246 | # (batch_size, num_caps, in_caps, dim_caps) @ (batch_size, num_caps, dim_caps, 1)
247 | # -> (batch_size, num_caps, in_caps, 1)
248 | uv = torch.matmul(temp_u_hat, v.unsqueeze(-1))
249 | b = b + uv
250 |
251 | # last iteration is done on the original u_hat, without the routing weights update
252 | c = b.softmax(dim=1)
253 | # print(c.shape, u_hat.shape)
254 | s = (c * u_hat).sum(dim=2)
255 | # apply "squashing" non-linearity along dim_caps
256 | v, p = squash(s)
257 | # print('last_v', torch.isnan(x).sum(), torch.isnan(v).sum(), torch.isnan(c).sum(), torch.isnan(s).sum())
258 | # print(x.squeeze(1))
259 | # print(v)
260 | # if torch.isnan(v).sum() > 0:
261 | # raise Exception
262 | # print(s)
263 |
264 | return v, p
265 |
266 | if __name__ == '__main__':
267 | pe = PositionalEncoding(128)
268 | print(pe.pe.shape)
269 | print(pe.pe[[1, 1, 2]])
270 | x = torch.randint(10, (3, 5))
271 | print(pe(x).shape)
272 | # x = torch.tensor([1, 5, 10, 4, 2, 10])
273 | # te = TemporalEncoding(128)
274 | # print(te(x).shape)
275 | # print([torch.norm(te(x)[i]) for i in range(x.shape[0])])
276 | # te = TemporalEncoding(128, random=False)
277 | # print(te(x).shape)
278 | # print([torch.norm(te(x)[i]) for i in range(x.shape[0])])
279 |
280 | # model = OriginCaps(in_dim=256, in_caps=1, num_caps=4, dim_caps=64, num_routing=3)
281 | # # (batch_size, in_caps, in_dim)
282 | # x = torch.randn(10, 1, 256)
283 | # print(model(x)[0].shape)
284 | # # print(model(x).norm(2))
285 | # cn = CapsuleNetwork(dim_in=256, dim_out=256)
286 | # x = torch.randn(8, 256)
287 | # print(cn(x))
288 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/models/__init__.py
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/base_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/models/__pycache__/base_model.cpython-39.pyc
--------------------------------------------------------------------------------
/our_models/CL.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 |
4 | import dgl
5 | import pandas as pd
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn.utils.rnn import pad_sequence
9 |
10 | from layers.GCN.CL_layers import *
11 | from models.base_model import BaseModel
12 | from utilis.scripts import custom_node_unbatch
13 |
14 |
15 | class CLWrapper(BaseModel):
16 | def __init__(self, encoder: BaseModel, cl_type='simple', aug_type=None, tau=0.4, cl_weight=0.5, aug_rate=0.1, **kwargs):
17 | super(CLWrapper, self).__init__(0, 0, 1, 0, None, 0, None, **kwargs)
18 | self.encoder = encoder
19 | self.aux_criterion = encoder.aux_criterion
20 | self.aug_rate = aug_rate
21 | print('aug_rate:', self.aug_rate)
22 | if aug_type == 'md':
23 | self.aug_methods = (('attr_mask',), ('edge_drop',))
24 | self.aug_configs = ({
25 | 'attr_mask': {
26 | 'prob': self.aug_rate,
27 | 'types': ['paper', 'author', 'journal', 'time']
28 | }
29 | }, {
30 | 'edge_drop': {
31 | 'prob': self.aug_rate,
32 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes",
33 | "is published by",
34 | "shows", "is shown in"]
35 | }
36 | })
37 | elif aug_type == 'dmd':
38 | self.aug_methods = (('attr_mask', 'edge_drop'), ('attr_mask', 'edge_drop'))
39 | self.aug_configs = ({
40 | 'attr_mask': {
41 | 'prob': self.aug_rate,
42 | 'types': ['paper', 'author', 'journal', 'time']
43 | },
44 | 'edge_drop': {
45 | 'prob': self.aug_rate,
46 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes",
47 | "is published by",
48 | "shows", "is shown in"]
49 | }
50 |
51 | }, {
52 | 'attr_mask': {
53 | 'prob': self.aug_rate,
54 | 'types': ['paper', 'author', 'journal', 'time']
55 | },
56 | 'edge_drop': {
57 | 'prob': self.aug_rate,
58 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes",
59 | "is published by",
60 | "shows", "is shown in"]
61 | }
62 | })
63 | elif aug_type == 'dd':
64 | self.aug_methods = (('attr_drop', 'edge_drop'), ('attr_drop', 'edge_drop'))
65 | self.aug_configs = ({
66 | 'attr_drop': {
67 | 'prob': self.aug_rate,
68 | 'types': ['paper', 'author', 'journal', 'time']
69 | },
70 | 'edge_drop': {
71 | 'prob': self.aug_rate,
72 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes",
73 | "is published by",
74 | "shows", "is shown in"]
75 | }
76 | }, {
77 | 'edge_drop': {
78 | 'prob': self.aug_rate,
79 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes",
80 | "is published by",
81 | "shows", "is shown in"]
82 | },
83 | 'attr_drop': {
84 | 'prob': self.aug_rate,
85 | 'types': ['paper', 'author', 'journal', 'time']
86 | },
87 | })
88 | else:
89 | aug_type = 'normal'
90 | self.aug_methods = (('attr_drop',), ('edge_drop',))
91 | self.aug_configs = ({
92 | 'attr_drop': {
93 | 'prob': self.aug_rate,
94 | 'types': ['paper', 'author', 'journal', 'time']
95 | }
96 | }, {
97 | 'edge_drop': {
98 | 'prob': self.aug_rate,
99 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes",
100 | "is published by",
101 | "shows", "is shown in"]
102 | }
103 | })
104 | print(self.aug_methods)
105 | self.model_name = self.encoder.model_name + '_CL_' + cl_type + '_' + aug_type
106 | if self.aug_rate != 0.1:
107 | self.model_name += '_p{}'.format(self.aug_rate)
108 | self.similarity = 'cos'
109 | self.cl_type = cl_type
110 | self.encoder.cl_type = self.cl_type
111 | self.tau = tau
112 | if self.tau != 0.4:
113 | self.model_name += '_t{}'.format(self.tau)
114 | self.cross_weight = 0.1
115 | self.cl_weight = cl_weight
116 | if self.cl_weight != 0.5:
117 | self.model_name += '_cw{}'.format(self.cl_weight)
118 | self.df_hn = None
119 | if cl_type in ['simple', 'cross_snapshot']:
120 | self.cl_ntypes = self.encoder.ntypes
121 | elif cl_type.startswith('focus'):
122 | self.cl_ntypes = [self.encoder.pred_type]
123 | elif cl_type.endswith('hard_negative'):
124 | self.cl_ntypes = self.encoder.ntypes
125 | print('wip')
126 | # self.cl_modules = nn.ModuleDict({ntype: CommonCLLayer(self.encoder.out_dim, self.encoder.out_dim, tau=self.tau)
127 | # for ntype in self.cl_ntypes})
128 | self.proj_heads = nn.ModuleDict({ntype: ProjectionHead(self.encoder.out_dim, self.encoder.out_dim) for
129 | ntype in self.cl_ntypes})
130 | print('cur cl mode:', cl_type)
131 | print('weights:', self.cl_weight, self.cross_weight)
132 |
133 | def load_graphs(self, graphs):
134 | dealt_graphs = self.encoder.load_graphs(graphs)
135 | dealt_graphs = self.encoder.aug_graphs(dealt_graphs, self.aug_configs, self.aug_methods)
136 | df_hn = pd.read_csv('./data/{}/hard_negative.csv'.format(graphs['data_source']), index_col=0)
137 | paper_trans = graphs['node_trans']['paper']
138 | df_hn.index = [paper_trans[str(paper)] for paper in df_hn.index]
139 | for column in df_hn:
140 | df_hn[column] = df_hn[column].apply(lambda x: [paper_trans[str(paper)] for paper in eval(x)])
141 | self.df_hn = df_hn
142 | return dealt_graphs
143 |
144 | def graph_to_device(self, graphs):
145 | return self.encoder.graph_to_device(graphs)
146 |
147 | def deal_graphs(self, graphs, data_path=None, path=None, log_path=None):
148 | return self.encoder.deal_graphs(graphs, data_path, path, log_path)
149 |
150 | def get_aux_pred(self, pred):
151 | return self.encoder.get_aux_pred(pred)
152 |
153 | def get_new_gt(self, gt):
154 | return self.encoder.get_new_gt(gt)
155 |
156 | def get_aux_loss(self, pred, gt):
157 | return self.encoder.get_aux_loss(pred, gt)
158 |
159 | def get_cl_inputs(self, snapshots, require_grad=True):
160 | if require_grad:
161 | if type(snapshots) == list:
162 | for snapshot in snapshots:
163 | for ntype in self.cl_ntypes:
164 | snapshot.nodes[ntype].data['h'] = self.proj_heads[ntype](snapshot.nodes[ntype].data['h'])
165 | else:
166 | for ntype in self.cl_ntypes:
167 | snapshots.nodes[ntype].data['h'] = self.proj_heads[ntype](snapshots.nodes[ntype].data['h'])
168 | else:
169 | if type(snapshots) == list:
170 | for snapshot in snapshots:
171 | for ntype in self.cl_ntypes:
172 | snapshot.nodes[ntype].data['h'] = self.proj_heads[ntype](snapshot.nodes[ntype].data['h']).detach()
173 | else:
174 | for ntype in self.cl_ntypes:
175 | snapshots.nodes[ntype].data['h'] = self.proj_heads[ntype](snapshots.nodes[ntype].data['h']).detach()
176 | return snapshots
177 |
178 | def cl(self, s1, s2, ids):
179 | # cl_loss_list = []
180 | cl_loss = 0
181 | z1_snapshots, z2_snapshots = [], []
182 | z1_ids, z2_ids = [], []
183 | if self.cl_type in ['simple', 'focus', 'cross_snapshot', 'focus_cross_snapshot']:
184 | for ntype in self.cl_ntypes:
185 | cur_loss = 0
186 | for t in range(self.encoder.time_length):
187 | if 'HTSGCN' in self.encoder.model_name:
188 | if ntype == 'snapshot':
189 | s1_index = torch.where(s1[t].nodes[ntype].data['mask'] > 0)
190 | s2_index = torch.where(s2[t].nodes[ntype].data['mask'] > 0)
191 | z1 = s1[t].nodes[ntype].data['h'][s1_index]
192 | z2 = s2[t].nodes[ntype].data['h'][s2_index]
193 | else:
194 | z1 = s1[t].nodes[ntype].data['h']
195 | z2 = s2[t].nodes[ntype].data['h']
196 | elif 'CTSGCN' in self.encoder.model_name:
197 | s1_index = torch.where(s1.nodes[ntype].data['snapshot_idx'] == t)
198 | s2_index = torch.where(s2.nodes[ntype].data['snapshot_idx'] == t)
199 | # print(s1_index)
200 | z1 = s1.nodes[ntype].data['h'][s1_index]
201 | z2 = s2.nodes[ntype].data['h'][s2_index]
202 | # print(z1.shape)
203 | if (self.cl_type.endswith('cross_snapshot')) and (ntype == 'snapshot'):
204 | z1_snapshots.append(z1)
205 | z2_snapshots.append(z2)
206 | z1_ids.append(s1.nodes[ntype].data['target'][s1_index])
207 | z2_ids.append(s2.nodes[ntype].data['target'][s2_index])
208 | # cur_list.append(self.cl_modules[ntype](z1, z2))
209 | # temp_loss = self.cl_modules[ntype](z1, z2)
210 | temp_loss = simple_cl(z1, z2, self.tau)
211 | cur_loss = cur_loss + temp_loss
212 | # print(self.cl_modules[ntype](z1, z2))
213 | # if torch.isnan(temp_loss):
214 | # print(z1)
215 | # print(torch.isnan(z1).sum())
216 | # print(z2)
217 | # print(torch.isnan(z2).sum())
218 | # cl_loss_list.append()
219 | cl_loss = cl_loss + cur_loss / self.encoder.time_length
220 | cl_loss = cl_loss / len(self.cl_ntypes)
221 | elif self.cl_type.endswith('hard_negative'):
222 | for ntype in self.cl_ntypes:
223 | cur_loss = 0
224 | p1_index = {}
225 | p2_index = {}
226 | for paper in ids:
227 | p1_index[paper.item()] = s1.nodes[ntype].data['target'] == paper.item()
228 | p2_index[paper.item()] = s2.nodes[ntype].data['target'] == paper.item()
229 | # print(p1_index.keys())
230 | for t in range(self.encoder.time_length):
231 | # start_time = time.time()
232 | cur_z1s = []
233 | cur_z2s = []
234 | len_z1s = []
235 | # len_z2s = []
236 | t1_index = s1.nodes[ntype].data['snapshot_idx'] == t
237 | t2_index = s2.nodes[ntype].data['snapshot_idx'] == t
238 | for paper in ids:
239 | s1_index = torch.where(t1_index & p1_index[paper.item()])
240 | s2_index = torch.where(t2_index & p2_index[paper.item()])
241 | z1 = s1.nodes[ntype].data['h'][s1_index]
242 | z2 = s2.nodes[ntype].data['h'][s2_index]
243 | cur_z1s.append(z1)
244 | cur_z2s.append(z2)
245 | cur_len = len(z1) if len(z1) > 0 else 1
246 | len_z1s.append(cur_len)
247 | # len_z2s.append(len(z2))
248 | cur_z1s = pad_sequence(cur_z1s, batch_first=True)
249 | cur_z2s = pad_sequence(cur_z2s, batch_first=True)
250 | len_z1s = torch.tensor(len_z1s).to(cur_z1s.device)
251 | # len_z2s = torch.tensor(len_z2s).to(cur_z2s.device)
252 | # print('index_time: {:.3f}'.format(time.time() - start_time))
253 | temp_loss = simple_cl(cur_z1s, cur_z2s, self.tau, valid_lens=len_z1s)
254 | # print('loss_time: {:.3f}'.format(time.time() - start_time))
255 | cur_loss = cur_loss + temp_loss
256 | # print(ntype, cur_loss / self.encoder.time_length)
257 | cl_loss = cl_loss + cur_loss / self.encoder.time_length
258 | cl_loss = cl_loss / len(self.cl_ntypes)
259 |
260 | if self.cl_type.endswith('cross_snapshot'):
261 | snapshot_loss = 0
262 | z1_dict = {paper.item(): [] for paper in ids}
263 | z2_dict = {paper.item(): [] for paper in ids}
264 | # print(z1_dict)
265 | for i in range(len(z1_snapshots)):
266 | cur_z1 = z1_snapshots[i]
267 | cur_z2 = z2_snapshots[i]
268 | cur_z1_ids = z1_ids[i]
269 | cur_z2_ids = z2_ids[i]
270 | for j in range(len(cur_z1_ids)):
271 | z1_dict[cur_z1_ids[j].item()].append(cur_z1[j])
272 | z2_dict[cur_z2_ids[j].item()].append(cur_z2[j])
273 | for paper in ids:
274 | # print(torch.stack(z1_dict[paper.item()], dim=0).shape)
275 | snapshot_loss = snapshot_loss + simple_cl(torch.stack(z1_dict[paper.item()], dim=0),
276 | torch.stack(z2_dict[paper.item()], dim=0),
277 | self.tau)
278 |
279 | cl_loss = cl_loss + self.cross_weight * snapshot_loss
280 | # print(cl_loss)
281 | return cl_loss
282 |
283 | def forward(self, content, lengths, masks, ids, graph, times, **kwargs):
284 | # cur_time = time.time()
285 | # g1_snapshots, g2_snapshots, cur_snapshot_adj = self.get_graphs(ids, graph)
286 | # g1_inputs = self.encoder.get_graphs(ids, graph, aug=(self.aug_methods[0], self.aug_configs[0]))
287 | # g2_inputs = self.encoder.get_graphs(ids, graph, aug=(self.aug_methods[1], self.aug_configs[1]))
288 | if self.cl_type.endswith('aug_hard_negative'):
289 | g1_inputs, g2_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times, phase='train',
290 | df_hn=self.df_hn)
291 | else:
292 | g1_inputs, g2_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times, phase='train')
293 | # logging.info('aug done:{:.3f}s'.format(time.time() - cur_time))
294 |
295 | # cur_time = time.time()
296 | # g1_out, g1_snapshots = self.encode(g1_inputs)
297 | # g2_out, g2_snapshots = self.encode(g2_inputs)
298 | g1_out, g1_snapshots = self.encoder.encode(g1_inputs)
299 | g2_out, g2_snapshots = self.encoder.encode(g2_inputs)
300 | # logging.info('encode done:{:.3f}s'.format(time.time() - cur_time))
301 |
302 | # cur_time = time.time()
303 | g1_snapshots, g2_snapshots = self.get_cl_inputs(g1_snapshots), self.get_cl_inputs(g2_snapshots)
304 | self.other_loss = self.cl_weight * self.cl(g1_snapshots, g2_snapshots, ids)
305 | # logging.info('cl done:{:.3f}s'.format(time.time() - cur_time))
306 |
307 | time_out = (g1_out + g2_out) / 2
308 | output = self.encoder.decode(time_out)
309 | # print(time_out.shape)
310 | # logging.info('=' * 20 + 'finish' + '=' * 20)
311 | return output
312 |
313 | def predict(self, content, lengths, masks, ids, graph, times, **kwargs):
314 | # print(final_out.shape)
315 | inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times)
316 | # print(inputs)
317 | time_out, _ = self.encoder.encode(inputs)
318 | output = self.encoder.decode(time_out)
319 | print(output[0].shape)
320 | return output
321 |
322 | def show(self, content, lengths, masks, ids, graph, times, **kwargs):
323 | output = self.encoder.show(content, lengths, masks, ids, graph, times, **kwargs)
324 | return output
325 |
326 |
327 | def get_unbatched_edge_graphs(batched_graph, ntype, etype):
328 | subgraph = dgl.edge_type_subgraph(batched_graph, etypes=[etype])
329 | subgraph.set_batch_num_nodes(batched_graph.batch_num_nodes(ntype))
330 | # print(batched_graph.batch_num_edges('cites'))
331 | subgraph.set_batch_num_edges(batched_graph.batch_num_edges(etype))
332 | # print(subgraph)
333 | # print(dgl.unbatch(subgraph))
334 | return dgl.unbatch(subgraph)
335 |
--------------------------------------------------------------------------------
/our_models/DCTSGCN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dgl
3 |
4 | from layers.GCN.DTGCN_layers import DCTSGCNLayer, SimpleDCTSGCNLayer
5 | from our_models.CTSGCN import CTSGCN
6 |
7 |
8 | class DCTSGCN(CTSGCN):
9 | def __init__(self, vocab_size, embed_dim, num_classes, pad_index, word2vec=None, dropout=0.5, n_layers=3,
10 | pred_type='paper', time_pooling='mean', hop=2, linear_before_gcn=False, encoder_type='GCN',
11 | time_length=10, hidden_dim=300, out_dim=256, ntypes=None, etypes=None, edge_sequences=None,
12 | graph_type=None, gcn_out='last', snapshot_weight=True, hn=None, hn_method=None, max_time_length=5,
13 | **kwargs):
14 | super(DCTSGCN, self).__init__(vocab_size, embed_dim, num_classes, pad_index, word2vec, dropout, n_layers,
15 | pred_type, time_pooling, hop, linear_before_gcn, 'GCN',
16 | time_length, hidden_dim, out_dim, ntypes, etypes, graph_type, gcn_out, hn,
17 | hn_method, max_time_length,
18 | **kwargs)
19 | self.model_name = self.model_name.replace('CTSGCN', 'DCTSGCN').replace('_GCN', '_' + encoder_type) + \
20 | '_{}'.format(edge_sequences)
21 | self.model_type = 'DCTSGCN'
22 | if snapshot_weight:
23 | self.model_name += '_weighted'
24 | if edge_sequences == 'meta':
25 | edge_sequences = [[('meta', [('author', 'writes', 'paper'),
26 | ('journal', 'publishes', 'paper'),
27 | ('time', 'shows', 'paper'),
28 | ('paper', 'is writen by', 'author'),
29 | ('paper', 'is published by', 'journal'),
30 | ('paper', 'is shown in', 'time')]),
31 | ('inter_snapshots', [('paper', 'cites', 'paper'),
32 | ('paper', 'is in', 'snapshot'),
33 | ('paper', 'is cited by', 'paper'),
34 | ('snapshot', 'has', 'paper')]),
35 | ('intra_snapshots', [('snapshot', 't_cites', 'snapshot'),
36 | ('snapshot', 'is t_cited by', 'snapshot')])]]
37 | elif edge_sequences == 'no_gcn':
38 | edge_sequences = [[('inter_snapshots', [('author', 'writes', 'paper'),
39 | ('journal', 'publishes', 'paper'),
40 | ('time', 'shows', 'paper'),
41 | ('paper', 'cites', 'paper'),
42 | ('paper', 'is in', 'snapshot'),
43 | ('paper', 'is cited by', 'paper'),
44 | ('paper', 'is writen by', 'author'),
45 | ('paper', 'is published by', 'journal'),
46 | ('paper', 'is shown in', 'time'),
47 | ('snapshot', 'has', 'paper')])]]
48 | else:
49 | # basic
50 | edge_sequences = [[('inter_snapshots', [('author', 'writes', 'paper'),
51 | ('journal', 'publishes', 'paper'),
52 | ('time', 'shows', 'paper'),
53 | ('paper', 'cites', 'paper'),
54 | ('paper', 'is in', 'snapshot'),
55 | ('paper', 'is cited by', 'paper'),
56 | ('paper', 'is writen by', 'author'),
57 | ('paper', 'is published by', 'journal'),
58 | ('paper', 'is shown in', 'time'),
59 | ('snapshot', 'has', 'paper')]),
60 | ('intra_snapshots', [('snapshot', 't_cites', 'snapshot'),
61 | ('snapshot', 'is t_cited by', 'snapshot')])]]
62 | etypes_set = set(etypes)
63 | for i in range(len(edge_sequences)):
64 | for j in range(len(edge_sequences[i])):
65 | # print('-'*30)
66 | cur_edges = edge_sequences[i][j][1].copy()
67 | for etype in cur_edges:
68 | # print(edge_sequences[i][j])
69 | # print(etype)
70 | if etype[1] not in etypes_set:
71 | edge_sequences[i][j][1].remove(etype)
72 | print('removing', edge_sequences[i][j][0], etype)
73 | print(edge_sequences)
74 | combine_method = 'sum'
75 | # if len(edge_sequences) == 1:
76 | # self.model_name += '_single'
77 | self.graph_encoder = None
78 | # self.graph_encoder = DCTSGCNLayer(in_feat=embed_dim, hidden_feat=hidden_dim, out_feat=out_dim, ntypes=ntypes,
79 | # edge_sequences=edge_sequences, k=n_layers, time_length=time_length).to(self.device)
80 |
81 | encoder_dict = {}
82 | if encoder_type == 'GIN':
83 | # ablation
84 | encoder_dict = {('paper', 'is in', 'snapshot'): {'cur_type': 'GAT'}}
85 | elif encoder_type == 'DGINP' or encoder_type == 'CGIN+RGAT':
86 | print('>>>{}<<<'.format(encoder_type))
87 | # CGIN+RGAT
88 | encoder_dict = {('paper', 'is in', 'snapshot'): {'cur_type': 'RGAT',
89 | 'snapshot_types': 3,
90 | 'node_attr_types': ['is_cite', 'is_ref', 'is_target']
91 | },
92 | ('paper', 'cites', 'paper'): {'cur_type': 'DGINP'},
93 | # ('paper', 'is cited by', 'paper'): {'cur_type': 'DGIN'}
94 | }
95 | encoder_type = 'GIN'
96 |
97 | return_list = False if gcn_out == 'last' else True
98 | self.graph_encoder = SimpleDCTSGCNLayer(in_feat=self.embed_dim, hidden_feat=self.hidden_dim, out_feat=self.out_dim,
99 | ntypes=ntypes,
100 | edge_sequences=edge_sequences, k=n_layers, time_length=time_length,
101 | combine_method=combine_method, encoder_type=encoder_type,
102 | encoder_dict=encoder_dict, return_list=return_list,
103 | snapshot_weight=snapshot_weight, **kwargs).to(self.device)
104 |
105 |
--------------------------------------------------------------------------------
/our_models/EncoderCL.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import pandas as pd
4 | import torch
5 |
6 | from models.base_model import BaseModel
7 | from our_models.CL import CLWrapper
8 |
9 | class EncoderCLWrapper(CLWrapper):
10 | def __init__(self, encoder: BaseModel, cl_type='simple'):
11 | super(EncoderCLWrapper, self).__init__(encoder, cl_type, None)
12 | # self.aux_criterion = encoder.aux_criterion
13 | self.vice_encoder = copy.deepcopy(self.encoder)
14 | self.eta = 1.
15 | self.model_name = self.encoder.model_name + '_ECL_' + cl_type
16 |
17 | def load_graphs(self, graphs):
18 | # no graph aug
19 | dealt_graphs = self.encoder.load_graphs(graphs)
20 | df_hn = pd.read_csv('./data/{}/hard_negative.csv'.format(graphs['data_source']), index_col=0)
21 | paper_trans = graphs['node_trans']['paper']
22 | df_hn.index = [paper_trans[str(paper)] for paper in df_hn.index]
23 | for column in df_hn:
24 | df_hn[column] = df_hn[column].apply(lambda x: [paper_trans[str(paper)] for paper in eval(x)])
25 | self.df_hn = df_hn
26 | return dealt_graphs
27 |
28 | def get_aug_encoder(self):
29 | for (adv_name, adv_param), (name, param) in zip(self.vice_encoder.named_parameters(),
30 | self.encoder.named_parameters()):
31 | if (not name.startswith('fc')) | ('text' not in name):
32 | # print(name, param.data.std())
33 | adv_param.data = param.data.detach() + self.eta * torch.normal(0, torch.ones_like(
34 | param.data) * param.data.std()).detach().to(self.device)
35 |
36 | def forward(self, content, lengths, masks, ids, graph, times, **kwargs):
37 | # cur_inputs = self.encoder.get_graphs(ids, times, graph)
38 | if self.cl_type == 'aug_hard_negative':
39 | cur_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times, df_hn=self.df_hn)
40 | else:
41 | cur_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times)
42 | aug_inputs = copy.deepcopy(cur_inputs)
43 | self.get_aug_encoder()
44 | g1_out, g1_snapshots = self.encoder.encode(cur_inputs)
45 | g2_out, g2_snapshots = self.vice_encoder.encode(aug_inputs)
46 | # logging.info('encode done:{:.3f}s'.format(time.time() - cur_time))
47 |
48 | # cur_time = time.time()
49 | g1_snapshots = self.get_cl_inputs(g1_snapshots)
50 | g2_snapshots = self.get_cl_inputs(g2_snapshots, require_grad=False)
51 | self.other_loss = self.cl_weight * self.cl(g1_snapshots, g2_snapshots, ids)
52 | # logging.info('cl done:{:.3f}s'.format(time.time() - cur_time))
53 |
54 | time_out = g1_out
55 | output = self.encoder.decode(time_out)
56 | # print(time_out.shape)
57 | # logging.info('=' * 20 + 'finish' + '=' * 20)
58 | return output
59 |
60 |
--------------------------------------------------------------------------------
/our_models/FCL.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import dgl
4 | import torch
5 |
6 | from layers.GCN.CL_layers import simple_cl
7 | from models.base_model import BaseModel
8 | from our_models.HCL import HCLWrapper
9 |
10 | class FCLWrapper(HCLWrapper):
11 | def __init__(self, encoder: BaseModel, cl_type='simple', aug_type=None, tau=0.4, cl_weight=0.5, aug_rate=0.1, **kwargs):
12 | super(FCLWrapper, self).__init__(encoder, cl_type, aug_type, tau, cl_weight, aug_rate, **kwargs)
13 | self.model_name = self.model_name.replace('_HCL_', '_FCL_')
14 | if 'DCTSGCN' in self.model_name:
15 | self.model_name = 'H2CGL'
16 |
17 | def get_cl_inputs(self, snapshots, require_grad=True):
18 | if require_grad:
19 | readout = dgl.readout_nodes(snapshots, 'h', op=self.encoder.time_pooling, ntype='snapshot')
20 | else:
21 | readout = dgl.readout_nodes(snapshots, 'h', op=self.encoder.time_pooling, ntype='snapshot').detach()
22 | readout = self.proj_heads['snapshot'](readout)
23 | # print(readout.shape)
24 | return readout
25 |
26 | def cl(self, s1, s2, ids):
27 | # cl_loss_list = []
28 | cl_loss = 0
29 | if self.cl_type in ['simple', 'focus', 'cross_snapshot', 'focus_cross_snapshot']:
30 | temp_loss = simple_cl(s1, s2, self.tau)
31 | cl_loss = cl_loss + temp_loss
32 | elif self.cl_type.endswith('hard_negative'):
33 | sample_num = self.encoder.hn
34 | true_len = s1.shape[0] // sample_num if sample_num else s1.shape[0]
35 | temp_index_list = []
36 | for i in range(sample_num):
37 | temp_index_list.append(range(true_len * i, true_len * (i + 1)))
38 | # cur_index = list(zip(range(true_len), range(true_len, true_len * 2)))
39 | cur_index = list(zip(*temp_index_list))
40 | s1 = s1[cur_index, :]
41 | s2 = s2[cur_index, :]
42 | # print(s1.shape)
43 | temp_loss = simple_cl(s1, s2, self.tau)
44 | cl_loss = cl_loss + temp_loss
45 | return cl_loss
46 |
47 | def forward(self, content, lengths, masks, ids, graph, times, **kwargs):
48 | # cur_time = time.time()
49 | # g1_snapshots, g2_snapshots, cur_snapshot_adj = self.get_graphs(ids, graph)
50 | # g1_inputs = self.encoder.get_graphs(ids, graph, aug=(self.aug_methods[0], self.aug_configs[0]))
51 | # g2_inputs = self.encoder.get_graphs(ids, graph, aug=(self.aug_methods[1], self.aug_configs[1]))
52 | if self.cl_type.endswith('aug_hard_negative'):
53 | all_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times,
54 | phase='train_combined', df_hn=self.df_hn)
55 | else:
56 | all_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times,
57 | phase='train_combined')
58 | # print('get graphs done:{:.3f}s'.format(time.time() - cur_time))
59 | # cur_time = time.time()
60 | all_out, all_snapshots = self.encoder.encode(all_inputs)
61 | # print('encode done:{:.3f}s'.format(time.time() - cur_time))
62 | all_snapshots = self.get_cl_inputs(all_snapshots)
63 |
64 | batch_size = all_out.shape[0] // 2
65 | # print(batch_size, all_out.shape[0])
66 | g1_out, g2_out = all_out[:batch_size, :], all_out[batch_size:, :]
67 | g1_snapshots, g2_snapshots = all_snapshots[:batch_size, :], all_snapshots[batch_size:, :]
68 |
69 | # cur_time = time.time()
70 | self.other_loss = self.cl_weight * self.cl(g1_snapshots, g2_snapshots, ids)
71 | # logging.info('cl done:{:.3f}s'.format(time.time() - cur_time))
72 | # print('cl done:{:.3f}s'.format(time.time() - cur_time))
73 |
74 | time_out = (g1_out + g2_out) / 2
75 | output = self.encoder.decode(time_out)
76 | # print(time_out.shape)
77 | # logging.info('=' * 20 + 'finish' + '=' * 20)
78 | return output
79 |
--------------------------------------------------------------------------------
/our_models/HCL.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn.utils.rnn import pad_sequence
5 |
6 | from layers.GCN.CL_layers import simple_cl, ProjectionHead
7 | from models.base_model import BaseModel
8 | from our_models.CL import CLWrapper
9 |
10 | class HCLWrapper(CLWrapper):
11 | def __init__(self, encoder: BaseModel, cl_type='simple', aug_type=None, tau=0.4, cl_weight=0.5, aug_rate=0.1, **kwargs):
12 | super(HCLWrapper, self).__init__(encoder, cl_type, 'normal', tau, cl_weight, aug_rate, **kwargs)
13 | self.model_name = self.model_name.replace('_CL_', '_HCL_').replace('_normal', '_' + aug_type)
14 | if self.cl_type in ['simple', 'cross_snapshot']:
15 | print('not support contrast for each type')
16 | raise Exception
17 | self.cl_ntypes = ['snapshot']
18 | self.proj_heads = nn.ModuleDict({ntype: ProjectionHead(self.encoder.out_dim, self.encoder.out_dim) for
19 | ntype in self.cl_ntypes})
20 | print('cur cl mode:', cl_type)
21 | # print('weights:', self.cl_weight, self.cross_weight)
22 | self.aug_type = aug_type
23 |
24 | if aug_type == 'dd':
25 | self.aug_methods = (('attr_drop', 'edge_drop'), ('attr_drop', 'edge_drop'))
26 | self.aug_configs = ({
27 | 'attr_drop': {
28 | 'prob': self.aug_rate,
29 | 'types': ['paper', 'author', 'journal', 'time']
30 | },
31 | 'edge_drop': {
32 | 'prob': self.aug_rate,
33 | 'types': [('paper', 'cites', 'paper'),
34 | ('paper', 'is cited by', 'paper'),
35 | ('author', 'writes', 'paper'),
36 | ('paper', 'is writen by', 'author'),
37 | ('journal', 'publishes', 'paper'),
38 | ('paper', 'is published by', 'journal'),
39 | ('time', 'shows', 'paper'),
40 | ('paper', 'is shown in', 'time')]
41 | }
42 | }, {
43 | 'edge_drop': {
44 | 'prob': self.aug_rate,
45 | 'types': [('paper', 'cites', 'paper'),
46 | ('paper', 'is cited by', 'paper'),
47 | ('author', 'writes', 'paper'),
48 | ('paper', 'is writen by', 'author'),
49 | ('journal', 'publishes', 'paper'),
50 | ('paper', 'is published by', 'journal'),
51 | ('time', 'shows', 'paper'),
52 | ('paper', 'is shown in', 'time')]
53 | },
54 | 'attr_drop': {
55 | 'prob': self.aug_rate,
56 | 'types': ['paper', 'author', 'journal', 'time']
57 | },
58 | })
59 | elif aug_type == 'dmd':
60 | self.aug_methods = (('attr_mask', 'edge_drop'), ('attr_mask', 'edge_drop'))
61 | self.aug_configs = ({
62 | 'attr_mask': {
63 | 'prob': self.aug_rate,
64 | 'types': ['paper', 'author', 'journal', 'time']
65 | },
66 | 'edge_drop': {
67 | 'prob': self.aug_rate,
68 | 'types': [('paper', 'cites', 'paper'),
69 | ('paper', 'is cited by', 'paper'),
70 | ('author', 'writes', 'paper'),
71 | ('paper', 'is writen by', 'author'),
72 | ('journal', 'publishes', 'paper'),
73 | ('paper', 'is published by', 'journal'),
74 | ('time', 'shows', 'paper'),
75 | ('paper', 'is shown in', 'time')]
76 | }
77 | }, {
78 | 'edge_drop': {
79 | 'prob': self.aug_rate,
80 | 'types': [('paper', 'cites', 'paper'),
81 | ('paper', 'is cited by', 'paper'),
82 | ('author', 'writes', 'paper'),
83 | ('paper', 'is writen by', 'author'),
84 | ('journal', 'publishes', 'paper'),
85 | ('paper', 'is published by', 'journal'),
86 | ('time', 'shows', 'paper'),
87 | ('paper', 'is shown in', 'time')]
88 | },
89 | 'attr_mask': {
90 | 'prob': self.aug_rate,
91 | 'types': ['paper', 'author', 'journal', 'time']
92 | },
93 | })
94 | elif aug_type == 'cg':
95 | self.aug_methods = (('node_drop',), ('node_drop',))
96 | self.aug_configs = ({
97 | 'node_drop': {
98 | 'prob': self.aug_rate,
99 | 'types': []
100 | }}, {
101 | 'node_drop': {
102 | 'prob': self.aug_rate,
103 | 'types': []
104 | }})
105 | else:
106 | aug_type = 'normal'
107 | self.aug_methods = (('attr_drop',), ('edge_drop',))
108 | self.aug_configs = ({
109 | 'attr_drop': {
110 | 'prob': self.aug_rate,
111 | 'types': ['paper', 'author', 'journal', 'time']
112 | }
113 | }, {
114 | 'edge_drop': {
115 | 'prob': self.aug_rate,
116 | 'types': [('paper', 'cites', 'paper'),
117 | ('paper', 'is cited by', 'paper'),
118 | ('author', 'writes', 'paper'),
119 | ('paper', 'is writen by', 'author'),
120 | ('journal', 'publishes', 'paper'),
121 | ('paper', 'is published by', 'journal'),
122 | ('time', 'shows', 'paper'),
123 | ('paper', 'is shown in', 'time')]
124 | }
125 | })
126 | print(self.aug_methods)
127 |
128 | def load_graphs(self, graphs):
129 | dealt_graphs = self.encoder.load_graphs(graphs)
130 | if self.aug_type != 'cg':
131 | dealt_graphs = self.encoder.aug_graphs_plus(dealt_graphs, self.aug_configs, self.aug_methods)
132 | paper_trans = graphs['node_trans']['paper']
133 | if 'label' in self.cl_type:
134 | df_hn = pd.read_csv('./data/{}/hard_negative_labeled.csv'.format(graphs['data_source']), index_col=0)
135 | df_hn.index = [paper_trans[str(paper)] for paper in df_hn.index]
136 | self.encoder.re_group = {label: list(df_hn[df_hn['label'] != label].index) for label in [0, 1, 2]}
137 | else:
138 | df_hn = pd.read_csv('./data/{}/hard_negative.csv'.format(graphs['data_source']), index_col=0)
139 | df_hn.index = [paper_trans[str(paper)] for paper in df_hn.index]
140 | for column in df_hn:
141 | if column in ['pub_time', 'label']:
142 | df_hn[column] = df_hn[column].astype(int)
143 | else:
144 | df_hn[column] = df_hn[column].apply(lambda x: [paper_trans[str(paper)] for paper in eval(x)])
145 | # df_hn[column] = df_hn[column].astype('object')
146 | self.df_hn = df_hn
147 | return dealt_graphs
148 |
149 | def cl(self, s1, s2, ids):
150 | # cl_loss_list = []
151 | cl_loss = 0
152 | z1_snapshots, z2_snapshots = [], []
153 | z1_ids, z2_ids = [], []
154 | if self.cl_type in ['simple', 'focus', 'cross_snapshot', 'focus_cross_snapshot']:
155 | for ntype in self.cl_ntypes:
156 | cur_loss = 0
157 | for t in range(self.encoder.time_length):
158 | if 'HTSGCN' in self.encoder.model_name:
159 | if ntype == 'snapshot':
160 | s1_index = torch.where(s1[t].nodes[ntype].data['mask'] > 0)
161 | s2_index = torch.where(s2[t].nodes[ntype].data['mask'] > 0)
162 | z1 = s1[t].nodes[ntype].data['h'][s1_index]
163 | z2 = s2[t].nodes[ntype].data['h'][s2_index]
164 | else:
165 | z1 = s1[t].nodes[ntype].data['h']
166 | z2 = s2[t].nodes[ntype].data['h']
167 | elif 'CTSGCN' in self.encoder.model_name:
168 | s1_index = torch.where(s1.nodes[ntype].data['snapshot_idx'] == t)
169 | s2_index = torch.where(s2.nodes[ntype].data['snapshot_idx'] == t)
170 | # print(s1_index)
171 | z1 = s1.nodes[ntype].data['h'][s1_index]
172 | z2 = s2.nodes[ntype].data['h'][s2_index]
173 | # print(z1.shape)
174 | if (self.cl_type.endswith('cross_snapshot')) and (ntype == 'snapshot'):
175 | z1_snapshots.append(z1)
176 | z2_snapshots.append(z2)
177 | z1_ids.append(s1.nodes[ntype].data['target'][s1_index])
178 | z2_ids.append(s2.nodes[ntype].data['target'][s2_index])
179 | # cur_list.append(self.cl_modules[ntype](z1, z2))
180 | # temp_loss = self.cl_modules[ntype](z1, z2)
181 | temp_loss = simple_cl(z1, z2, self.tau)
182 | cur_loss = cur_loss + temp_loss
183 |
184 | cl_loss = cl_loss + cur_loss / self.encoder.time_length
185 | cl_loss = cl_loss / len(self.cl_ntypes)
186 | elif self.cl_type.endswith('hard_negative'):
187 | for ntype in self.cl_ntypes:
188 | cur_loss = 0
189 | p1_index = {}
190 | p2_index = {}
191 | for paper in ids:
192 | p1_index[paper.item()] = s1.nodes[ntype].data['target'] == paper.item()
193 | p2_index[paper.item()] = s2.nodes[ntype].data['target'] == paper.item()
194 | # print(p1_index.keys())
195 | time_length = 0
196 | for t in range(self.encoder.time_length):
197 | # start_time = time.time()
198 | cur_z1s = []
199 | cur_z2s = []
200 | len_z1s = []
201 | # len_z2s = []
202 | t1_index = s1.nodes[ntype].data['snapshot_idx'] == t
203 | t2_index = s2.nodes[ntype].data['snapshot_idx'] == t
204 | for paper in ids:
205 | # s1_index = torch.where(t1_index & p1_index[paper.item()])
206 | # s2_index = torch.where(t2_index & p2_index[paper.item()])
207 | # z1 = s1.nodes[ntype].data['h'][s1_index]
208 | # z2 = s2.nodes[ntype].data['h'][s2_index]
209 | # cur_z1s.append(z1)
210 | # cur_z2s.append(z2)
211 | # cur_len = len(z1) if len(z1) > 0 else 1
212 | # len_z1s.append(cur_len)
213 | # print(s1_index[0].shape[0], cur_len)
214 | s1_index = torch.where(t1_index & p1_index[paper.item()])
215 | s2_index = torch.where(t2_index & p2_index[paper.item()])
216 | if s1_index[0].shape[0] >= 2:
217 | # if s1_index[0].shape[0] >= self.encoder.hn:
218 | z1 = s1.nodes[ntype].data['h'][s1_index]
219 | z2 = s2.nodes[ntype].data['h'][s2_index]
220 | cur_z1s.append(z1)
221 | cur_z2s.append(z2)
222 | cur_len = len(z1)
223 | len_z1s.append(cur_len)
224 | # print(s1_index[0].shape[0], cur_len)
225 | # len_z2s.append(len(z2))
226 | # if s1_index[0].shape[0] > 2:
227 | # print('sidx:', (s1.nodes[ntype].data['snapshot_idx'] == 9).sum())
228 | # print(ntype)
229 | # print(torch.where(p1_index[paper.item()]))
230 | # print(p1_index[paper.item()].sum())
231 | # print(s1_index)
232 | if len(cur_z1s) > 0:
233 | cur_z1s = pad_sequence(cur_z1s, batch_first=True)
234 | cur_z2s = pad_sequence(cur_z2s, batch_first=True)
235 | len_z1s = torch.tensor(len_z1s).to(cur_z1s.device)
236 | # len_z2s = torch.tensor(len_z2s).to(cur_z2s.device)
237 | # print('index_time: {:.3f}'.format(time.time() - start_time))
238 | # print(cur_z1s.shape)
239 | temp_loss = simple_cl(cur_z1s, cur_z2s, self.tau, valid_lens=len_z1s)
240 | time_length = time_length + 1
241 | else:
242 | temp_loss = 0
243 | # print(temp_loss)
244 | # print('loss_time: {:.3f}'.format(time.time() - start_time))
245 | cur_loss = cur_loss + temp_loss
246 | # print(ntype, cur_loss / self.encoder.time_length)
247 | cl_loss = cl_loss + cur_loss / time_length
248 | cl_loss = cl_loss / len(self.cl_ntypes)
249 |
250 | if self.cl_type.endswith('cross_snapshot'):
251 | snapshot_loss = 0
252 | z1_dict = {paper.item(): [] for paper in ids}
253 | z2_dict = {paper.item(): [] for paper in ids}
254 | # print(z1_dict)
255 | for i in range(len(z1_snapshots)):
256 | cur_z1 = z1_snapshots[i]
257 | cur_z2 = z2_snapshots[i]
258 | cur_z1_ids = z1_ids[i]
259 | cur_z2_ids = z2_ids[i]
260 | for j in range(len(cur_z1_ids)):
261 | z1_dict[cur_z1_ids[j].item()].append(cur_z1[j])
262 | z2_dict[cur_z2_ids[j].item()].append(cur_z2[j])
263 | for paper in ids:
264 | # print(torch.stack(z1_dict[paper.item()], dim=0).shape)
265 | snapshot_loss = snapshot_loss + simple_cl(torch.stack(z1_dict[paper.item()], dim=0),
266 | torch.stack(z2_dict[paper.item()], dim=0),
267 | self.tau)
268 |
269 | cl_loss = cl_loss + self.cross_weight * snapshot_loss
270 | # print(cl_loss)
271 | return cl_loss
--------------------------------------------------------------------------------
/our_models/TSGCN.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import logging
5 | import time
6 | from collections import Counter
7 | # import multiprocessing
8 | import joblib
9 | import pandas as pd
10 | from dgl import multiprocessing
11 |
12 | import dgl
13 | import torch
14 | from torch import nn
15 | from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
16 | from tqdm import tqdm
17 |
18 | from layers.GCN.RGCN import StochasticKLayerRGCN
19 | from layers.common_layers import MLP
20 | from models.base_model import BaseModel
21 | from utilis.log_bar import TqdmToLogger
22 |
23 |
24 | class TM:
25 | def __init__(self, paper_ids, hop, log_path):
26 | self.paper_ids = paper_ids
27 | self.hop = hop
28 | # logger = logging.getLogger()
29 | # logger.handlers.append(logging.FileHandler(log_path, mode='w+'))
30 | # self.tqdm_log = TqdmToLogger(logger, level=logging.INFO)
31 |
32 | def get_subgraph(self, cur_graph):
33 | return get_paper_ego_subgraph(self.paper_ids, hop=self.hop, graph=cur_graph, batched=True)
34 |
35 |
36 | class TSGCN(BaseModel):
37 | '''The basic model which is also called H2GCN'''
38 | def __init__(self, vocab_size, embed_dim, num_classes, pad_index, word2vec=None, dropout=0.5, n_layers=3,
39 | time_encoder='gru', time_encoder_layers=2, time_pooling='mean', hop=2,
40 | time_length=10, hidden_dim=300, out_dim=256, ntypes=None, etypes=None, graph_type=None, **kwargs):
41 | super(TSGCN, self).__init__(vocab_size, embed_dim, num_classes, pad_index, word2vec, dropout, **kwargs)
42 |
43 | self.model_name = 'TSGCN'
44 | if graph_type:
45 | self.model_name += '_' + graph_type
46 | self.hop = hop
47 | self.activation = nn.Tanh()
48 | self.drop_en = nn.Dropout(dropout)
49 | self.n_layers = n_layers
50 | if self.n_layers != 2:
51 | self.model_name += '_n{}'.format(n_layers)
52 | self.time_length = time_length
53 | final_dim = out_dim
54 | print(ntypes)
55 | residual = False
56 | # print(residual)
57 | if residual:
58 | self.model_name += '_residual'
59 | self.graph_encoder = nn.ModuleList(
60 | [StochasticKLayerRGCN(embed_dim, hidden_dim, out_dim, ntypes, etypes, k=n_layers, residual=residual)
61 | for i in range(time_length)])
62 | if time_encoder == 'gru':
63 | self.time_encoder = nn.GRU(input_size=out_dim, hidden_size=out_dim, dropout=dropout, batch_first=True,
64 | bidirectional=True, num_layers=2)
65 | final_dim = 2 * out_dim
66 | elif time_encoder == 'self-attention':
67 | self.time_encoder = nn.Sequential(*[
68 | nn.TransformerEncoderLayer(d_model=out_dim, nhead=out_dim // 64, dim_feedforward=2 * out_dim)
69 | for i in range(time_encoder_layers)])
70 | # final_dim = 2 * out_dim
71 | self.time_pooling = time_pooling
72 | # self.fc = MLP(dim_in=final_dim, dim_inner=final_dim // 2, dim_out=num_classes, num_layers=2)
73 | self.fc_reg = MLP(dim_in=final_dim, dim_inner=final_dim // 2, dim_out=1, num_layers=2)
74 | self.fc_cls = MLP(dim_in=final_dim, dim_inner=final_dim // 2, dim_out=num_classes, num_layers=2)
75 |
76 |
77 | def load_graphs(self, graphs):
78 | '''
79 | {
80 | 'data': {phase: None},
81 | 'node_trans': json.load(open(self.data_cat_path + 'sample_node_trans.json', 'r')),
82 | 'time': {'train': train_time, 'val': val_time, 'test': test_time},
83 | 'all_graphs': graphs,
84 | 'selected_papers': {phase: [list]},
85 | 'all_trans_index': dict,
86 | 'all_ego_graphs': ego_graphs
87 | }
88 | '''
89 | print(graphs['data'])
90 | logger = logging.getLogger()
91 | # LOG = logging.getLogger(__name__)
92 | tqdm_out = TqdmToLogger(logger, level=logging.INFO)
93 |
94 | graphs['all_trans_index'] = json.load(open(graphs['graphs_path'] + '_trans.json', 'r'))
95 | # ego_graphs
96 | graphs['all_ego_graphs'] = {}
97 | for pub_time in tqdm(graphs['all_graphs'], file=tqdm_out, mininterval=30):
98 | graphs['all_ego_graphs'][pub_time] = joblib.load(
99 | graphs['graphs_path'] + '_{}.job'.format(pub_time))
100 |
101 | phase_dict = {phase: [] for phase in graphs['data']}
102 | for phase in graphs['data']:
103 | selected_papers = [graphs['node_trans']['paper'][paper] for paper in graphs['selected_papers'][phase]]
104 | trans_index = dict(zip(selected_papers, range(len(selected_papers))))
105 | graphs['selected_papers'][phase] = selected_papers
106 | graphs['data'][phase] = [None, trans_index]
107 |
108 | time2phase = {pub_time: [phase for phase in graphs['data']
109 | if (pub_time <= graphs['time'][phase]) and (pub_time > graphs['time'][phase] - 10)]
110 | for pub_time in graphs['all_graphs']}
111 | for t in sorted(list(time2phase.keys())):
112 | cur_snapshot = graphs['all_ego_graphs'][t]
113 | start_time = time.time()
114 | for phase in time2phase[t]:
115 | batched_graphs = [cur_snapshot[graphs['all_trans_index'][str(paper)]]
116 | for paper in graphs['selected_papers'][phase]]
117 | cur_emb_graph = graphs['all_graphs'][t]
118 | # unbatched_graphs = dgl.unbatch(batched_graphs)
119 | unbatched_graphs = batched_graphs
120 | logging.info('unbatching done! {:.2f}s'.format(time.time() - start_time))
121 | for graph in tqdm(unbatched_graphs, file=tqdm_out, mininterval=30):
122 | for ntype in graph.ntypes:
123 | cur_index = graph.nodes[ntype].data[dgl.NID]
124 | graph.nodes[ntype].data['h'] = cur_emb_graph.nodes[ntype].data['h'][cur_index].detach()
125 | phase_dict[phase].append(unbatched_graphs)
126 |
127 | logging.info('-' * 30 + str(t) + ' done' + '-' * 30)
128 | del cur_snapshot
129 | del graphs['all_graphs'][t]
130 | del graphs['all_ego_graphs'][t]
131 |
132 | for phase in phase_dict:
133 | graphs['data'][phase][0] = phase_dict[phase]
134 |
135 | del phase_dict
136 | del graphs['all_graphs']
137 | del graphs['all_ego_graphs']
138 | return graphs
139 |
140 | def deal_graphs(self, graphs, data_path=None, path=None, log_path=None):
141 | tqdm_out = None
142 | if log_path:
143 | logging.basicConfig(level=logging.INFO,
144 | filename=log_path,
145 | filemode='w+',
146 | format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s',
147 | force=True)
148 | logger = logging.getLogger()
149 | # LOG = logging.getLogger(__name__)
150 | tqdm_out = TqdmToLogger(logger, level=logging.INFO)
151 |
152 | phases = list(graphs['data'].keys())
153 | # dealt_graphs = {'data': {}, 'time': graphs['time'], 'node_trans': graphs['node_trans'], 'all_graphs': graphs['all_graphs']}
154 | dealt_graphs = {'data': {}, 'time': graphs['time'], 'node_trans': graphs['node_trans'],
155 | 'all_graphs': None}
156 | split_data = torch.load(data_path)
157 | all_selected_papers = [graphs['node_trans']['paper'][paper] for paper in split_data['test'][0]]
158 | logging.info('all count: {}'.format(str(len(all_selected_papers))))
159 | # all_snapshots = {}
160 | all_trans_index = dict(zip(all_selected_papers, range(len(all_selected_papers))))
161 | pub_times = list(graphs['all_graphs'].keys())
162 | # del graphs['all_graphs']
163 | del graphs['data']
164 | # print('not get ego_subgraph for saving time')
165 | for pub_time in pub_times:
166 | # all_snapshots[pub_time] = get_paper_ego_subgraph(all_selected_papers, self.hop,
167 | # graphs['all_graphs'][pub_time],
168 | # batched=False, tqdm_log=tqdm_out)
169 | cur_snapshot = get_paper_ego_subgraph(all_selected_papers, self.hop,
170 | graphs['all_graphs'][pub_time],
171 | batched=False, tqdm_log=tqdm_out)
172 | # if path:
173 | # torch.save(cur_snapshot, path + '_' + str(pub_time))
174 | if path:
175 | joblib.dump(cur_snapshot, path + '_' + str(pub_time) + '.job')
176 | # del graphs['all_graphs'][pub_time]
177 |
178 | if path:
179 | json.dump(all_trans_index, open(path + '_trans.json', 'w+'))
180 |
181 | # all_trans_index = json.load()
182 | del graphs['all_graphs']
183 |
184 | joblib.dump([], path)
185 | return dealt_graphs
186 |
187 | def forward(self, content, lengths, masks, ids, graph, times, **kwargs):
188 | # graphs T个batched_graph
189 | # start_time = time.time()
190 | # print(start_time)
191 | # snapshots = [batch.to(self.device) for batch in snapshots]
192 | # print(time.time() - start_time)
193 | batched_graphs_list, trans_index = graph
194 | # graphs_list = dgl.unbatch(batched_graphs)
195 | snapshots = []
196 | for batched_graphs in batched_graphs_list:
197 | temp_snapshots = []
198 | # temp_graphs = dgl.unbatch(batched_graphs)
199 | temp_graphs = batched_graphs
200 | for paper in ids:
201 | temp_snapshots.append(temp_graphs[trans_index[paper.item()]])
202 | snapshots.append(dgl.batch(temp_snapshots).to(self.device))
203 | # print(time.time() - start_time)
204 |
205 | all_snapshots = []
206 | for t in range(self.time_length):
207 | cur_snapshot = snapshots[t]
208 | out_emb = self.graph_encoder[t](cur_snapshot, cur_snapshot.srcdata['h'])
209 | cur_snapshot.dstdata['h'] = out_emb
210 | sum_readout = dgl.readout_nodes(cur_snapshot, 'h', op='sum', ntype='paper')
211 | all_snapshots.append(sum_readout)
212 | all_snapshots = torch.stack(all_snapshots, dim=1) # [B, T, d_out]
213 | # print(all_snapshots.shape)
214 | time_out = self.time_encoder(all_snapshots)
215 | if type(time_out) != torch.Tensor:
216 | time_out = time_out[0]
217 | # print(time_out.shape)
218 | if self.time_pooling == 'mean':
219 | final_out = time_out.mean(dim=1)
220 | elif self.time_pooling == 'sum':
221 | final_out = time_out.sum(dim=1)
222 | elif self.time_pooling == 'max':
223 | final_out = time_out.max(dim=1)[0]
224 | # print(final_out.shape)
225 | output_reg = self.fc_reg(final_out)
226 | output_cls = self.fc_cls(final_out)
227 | return output_reg, output_cls
228 |
229 |
230 | def get_paper_ego_subgraph(batch_ids, hop, graph, batched=True, tqdm_log=None):
231 | '''
232 | get ego batch_graphs of current time
233 | target-centric graphs extracted from complete graph
234 | :param batch_ids:
235 | :param hop:
236 | :param graph:
237 | :return:
238 | '''
239 | oids = graph.nodes['paper'].data[dgl.NID].numpy().tolist()
240 | cids = graph.nodes('paper').numpy().tolist()
241 | oc_trans = dict(zip(oids, cids))
242 | trans_papers = [oc_trans.get(paper, None) for paper in batch_ids]
243 | graph.nodes['paper'].data['citations'] = graph.in_degrees(etype='cites')
244 |
245 | paper_subgraph = dgl.node_type_subgraph(graph, ['paper'])
246 | all_subgraphs = []
247 | single_subgraph = SingleSubgraph(paper_subgraph, graph, hop)
248 | for paper in tqdm(trans_papers, file=tqdm_log, mininterval=30):
249 | # for paper in tqdm(trans_papers):
250 | cur_graph = single_subgraph.get_single_subgraph(paper)
251 | for ntype in cur_graph.ntypes:
252 | if 'h' in cur_graph.nodes[ntype].data:
253 | del cur_graph.nodes[ntype].data['h']
254 | all_subgraphs.append(cur_graph)
255 | # logging.info('cur year done!')
256 | if batched:
257 | return dgl.batch(all_subgraphs)
258 | else:
259 | return all_subgraphs
260 |
261 |
262 | class SingleSubgraph:
263 | def __init__(self, paper_subgraph, origin_graph, hop, citation='global'):
264 | # one-hop instead
265 | self.paper_subgraph = paper_subgraph
266 | # self.origin_graph = origin_graph
267 | self.origin_graph = dgl.node_type_subgraph(origin_graph, ntypes=['paper', 'journal', 'author'])
268 | self.hop = 1
269 | self.citation = citation
270 | self.max_nodes = 100
271 | # self.oids = self.paper_subgraph.nodes['paper'].data[dgl.NID]
272 | self.citations = self.paper_subgraph.nodes['paper'].data['citations']
273 | self.time = self.paper_subgraph.nodes['paper'].data['time']
274 |
275 | def get_single_subgraph(self, paper):
276 | if paper is not None:
277 | ref_hop_papers = self.paper_subgraph.in_edges(paper, etype='is cited by')[0].numpy().tolist()
278 | if len(ref_hop_papers) > self.max_nodes:
279 | temp_df = pd.DataFrame(
280 | data={
281 | 'id': ref_hop_papers,
282 | 'citations': self.citations[ref_hop_papers].numpy(),
283 | 'time': self.time[ref_hop_papers].squeeze(dim=-1).numpy()})
284 | ref_hop_papers = list(temp_df.sort_values(by=['time', 'citations'],
285 | ascending=False).head(self.max_nodes)['id'].tolist())
286 | cite_hop_papers = self.paper_subgraph.in_edges(paper, etype='cites')[0].numpy().tolist()
287 | if len(cite_hop_papers) > self.max_nodes:
288 | temp_df = pd.DataFrame(
289 | data={
290 | 'id': cite_hop_papers,
291 | 'citations': self.citations[cite_hop_papers].numpy(),
292 | 'time': self.time[cite_hop_papers].squeeze(dim=-1).numpy()})
293 | cite_hop_papers = list(temp_df.sort_values(by=['time', 'citations'],
294 | ascending=False).head(self.max_nodes)['id'].tolist())
295 |
296 | # print(len(ref_hop_papers), len(cite_hop_papers))
297 | # print(len(cite_hop_papers))
298 | k_hop_papers = list(set(ref_hop_papers + cite_hop_papers + [paper]))
299 | # print(ref_hop_papers, cite_hop_papers)
300 | # print(k_hop_papers)
301 | # print(self.origin_graph)
302 | temp_graph = dgl.in_subgraph(self.origin_graph, nodes={'paper': k_hop_papers}, relabel_nodes=True)
303 | journals = temp_graph.nodes['journal'].data[dgl.NID]
304 | authors = temp_graph.nodes['author'].data[dgl.NID]
305 | cur_graph = dgl.node_subgraph(self.origin_graph,
306 | nodes={'paper': k_hop_papers, 'journal': journals, 'author': authors})
307 |
308 | # 修改了indicator
309 | ref_set = set(ref_hop_papers)
310 | cite_set = set(cite_hop_papers)
311 | cur_graph.nodes['paper'].data['is_ref'] = torch.zeros(len(k_hop_papers), dtype=torch.long)
312 | cur_graph.nodes['paper'].data['is_cite'] = torch.zeros(len(k_hop_papers), dtype=torch.long)
313 | cur_graph.nodes['paper'].data['is_target'] = torch.zeros(len(k_hop_papers), dtype=torch.long)
314 | oids = cur_graph.nodes['paper'].data[dgl.NID].numpy().tolist()
315 | id_trans = dict(zip(oids, range(len(oids))))
316 | ref_idx = [id_trans[paper] for paper in ref_set]
317 | cite_idx = [id_trans[paper] for paper in cite_set]
318 | cur_graph.nodes['paper'].data['is_ref'][ref_idx] = 1
319 | cur_graph.nodes['paper'].data['is_cite'][cite_idx] = 1
320 | cur_graph.nodes['paper'].data['is_target'][id_trans[paper]] = 1
321 | else:
322 | cur_graph = dgl.node_subgraph(self.origin_graph, nodes={'paper': []})
323 | cur_graph.nodes['paper'].data['is_ref'] = torch.zeros(0, dtype=torch.long)
324 | cur_graph.nodes['paper'].data['is_cite'] = torch.zeros(0, dtype=torch.long)
325 | cur_graph.nodes['paper'].data['is_target'] = torch.zeros(0, dtype=torch.long)
326 | return cur_graph
327 |
328 |
329 | if __name__ == '__main__':
330 | start_time = datetime.datetime.now()
331 | parser = argparse.ArgumentParser(description='Process some description.')
332 | parser.add_argument('--phase', default='test', help='the function name.')
333 |
334 | args = parser.parse_args()
335 | rel_names = ['cites', 'is cited by', 'writes', 'is writen by', 'publishes', 'is published by']
336 | # vocab_size, embed_dim, num_classes, pad_index
337 |
338 | graphs = {}
339 | graph_name = 'graph_sample_feature_vector'
340 | time_list = range(2002, 2012)
341 | for i in range(len(time_list)):
342 | graphs[i] = torch.load('../data/pubmed/' + graph_name + '_' + str(time_list[i]))
343 |
344 | graph_data = {
345 | 'data': graphs,
346 | 'node_trans': json.load(open('../data/pubmed/' + 'sample_node_trans.json', 'r'))
347 | }
348 | print(graph_data['data'])
349 | print(graph_data['data'][0].nodes('paper'))
350 | # print(graph_data['node_trans']['paper'])
351 | reverse_paper_dict = dict(zip(graph_data['node_trans']['paper'].values(), graph_data['node_trans']['paper'].keys()))
352 | # print(reverse_paper_dict)
353 | citation_dict = json.load(open('../data/pubmed/sample_citation_accum.json'))
354 | valid_ids = set([reverse_paper_dict[idx.item()] for idx in graph_data['data'][0].nodes('paper')])
355 | useful_ids = [idx for idx in valid_ids if idx in citation_dict]
356 | print(len(useful_ids))
357 | print(useful_ids[:10])
358 |
359 | # sample_ids = ['3266430', '2790711', '2747541', '3016217', '7111196']
360 | # trans_ids = [graph_data['node_trans']['paper'][idx] for idx in sample_ids]
361 | # print(trans_ids)
362 | # x = torch.randint(100, (5, 20)).cuda()
363 | # lens = [20] * 5
364 | # print(x)
365 | # model = TSGCN(100, 768, 0, 0, etypes=etypes).cuda()
366 | # output = model(x, lens, None, trans_ids, graph_data['data'])
367 | # print(output)
368 | # print(output.shape)
369 |
370 | sample_ids = [4724, 12968, 8536, 9558, None]
371 | cur_graph = graph_data['data'][9]
372 | print(cur_graph.nodes['paper'].data[dgl.NID][[68, 2855, 377, 845]])
373 | batch_graphs = get_paper_ego_subgraph(sample_ids, 2, cur_graph, batched=False)
374 | print(batch_graphs)
375 | # print(batch_graphs.batch_size)
376 | # snapshots = []
377 | # for i in range(len(time_list)):
378 | # snapshots.append(get_paper_ego_subgraph(sample_ids, 2, graph_data['data'][i]))
379 | # print(snapshots)
380 | # graph_encoder = StochasticKLayerRGCN(300, 128, 128, etypes, k=2)
381 | # cur_snapshot = snapshots[0]
382 | # out_emb = graph_encoder(cur_snapshot, cur_snapshot.srcdata['h'])
383 | # cur_snapshot.dstdata['h'] = out_emb
384 |
385 | # sum_readout = dgl.readout_nodes(cur_snapshot, 'h', op='max', ntype='paper')
386 | # print([vector.norm(2) for vector in sum_readout])
387 | # print(sum_readout.shape)
388 |
389 | # model = TSGCN(0, 300, 1, 0, etypes=etypes, time_encoder='gru').cuda()
390 | # # print(model(None, None, None, None, snapshots))
391 | # g1 = torch.load('../data/pubmed/graph_sample_feature_vector_2002')
392 | # g2 = torch.load('../data/pubmed/graph_sample_feature_vector_2003')
393 | # g3 = torch.load('../data/pubmed/graph_sample_feature_vector_2004')
394 | # graphs = [g1, g2, g3]
395 | # # sub_g = g1.edge_type_subgraph(['is cited by'])
396 | # # print(sub_g)
397 | # # print(sub_g.ndata[dgl.NID])
398 | # # print(sub_g.ndata['h'].shape)
399 | # node_trans = json.load(open('../data/pubmed/sample_node_trans.json'))
400 | # graphs_dict = {
401 | # 'data': {'train': graphs},
402 | # 'node_trans': node_trans
403 | # }
404 | # graphs = model.deal_graphs(graphs_dict, '../data/pubmed/split_data', '../checkpoints/pubmed/TSGCN')
405 |
--------------------------------------------------------------------------------
/our_models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__init__.py
--------------------------------------------------------------------------------
/our_models/__pycache__/CL.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/CL.cpython-39.pyc
--------------------------------------------------------------------------------
/our_models/__pycache__/CTSGCN.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/CTSGCN.cpython-39.pyc
--------------------------------------------------------------------------------
/our_models/__pycache__/DCTSGCN.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/DCTSGCN.cpython-39.pyc
--------------------------------------------------------------------------------
/our_models/__pycache__/FCL.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/FCL.cpython-39.pyc
--------------------------------------------------------------------------------
/our_models/__pycache__/HCL.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/HCL.cpython-39.pyc
--------------------------------------------------------------------------------
/our_models/__pycache__/TSGCN.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/TSGCN.cpython-39.pyc
--------------------------------------------------------------------------------
/our_models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/results/readme.md:
--------------------------------------------------------------------------------
1 | Here output the results.
--------------------------------------------------------------------------------
/utilis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/utilis/__init__.py
--------------------------------------------------------------------------------
/utilis/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/utilis/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/utilis/__pycache__/log_bar.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/utilis/__pycache__/log_bar.cpython-39.pyc
--------------------------------------------------------------------------------
/utilis/__pycache__/scripts.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/utilis/__pycache__/scripts.cpython-39.pyc
--------------------------------------------------------------------------------
/utilis/dblp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | from collections import defaultdict, Counter
4 |
5 | import dgl
6 | import pandas as pd
7 | import numpy as np
8 | import json
9 | import re
10 | from pprint import pprint
11 |
12 |
13 | # useful attr: _id, abstract, authors, fos, keywords, n_citations, references, title, venue, year
14 | import torch
15 |
16 |
17 | def get_lines_data(data_path):
18 | with open(data_path + 'dblpv13.json') as fr:
19 | data = fr.read()
20 | data = re.sub(r"NumberInt\((\d+)\)", r"\1", data)
21 | print('replace successfully!')
22 | data = json.loads(data)
23 |
24 | info_fw = open(data_path + 'temp_info.lines', 'w+')
25 | abstract_fw = open(data_path + 'temp_abstract.lines', 'w+')
26 | ref_fw = open(data_path + 'temp_ref.lines', 'w+')
27 | for paper in data:
28 | info_fw.write(str([paper['_id'],
29 | {
30 | 'authors': paper['authors'] if 'authors' in paper else [],
31 | 'fos': paper['fos'] if 'fos' in paper else [],
32 | 'keywords': paper['keywords'] if 'keywords' in paper else [],
33 | 'n_citations': paper['n_citations'] if 'n_citations' in paper else 0,
34 | 'title': paper['title'] if 'title' in paper else None,
35 | 'venue': paper['venue'] if 'venue' in paper else None,
36 | 'year': paper['year'] if 'year' in paper else None,
37 | }]) + '\n')
38 | temp_ab = paper['abstract'] if 'abstract' in paper else None
39 | temp_ref = paper['references'] if 'references' in paper else []
40 | abstract_fw.write(str([paper['_id'], temp_ab]) + '\n')
41 | ref_fw.write(str([paper['_id'], temp_ref]) + '\n')
42 |
43 |
44 | def get_all_dict(data_path, part='info'):
45 | result_dict = {}
46 | with open(data_path + 'temp_{}.lines'.format(part)) as fr:
47 | for line in fr:
48 | data = eval(line)
49 | result_dict[data[0]] = data[1]
50 | json.dump(result_dict, open(data_path + 'all_{}_dict.json'.format(part), 'w+'))
51 |
52 |
53 | # def get_cite_data(data_path):
54 | # ref_dict = json.load(open(data_path + 'all_ref_dict.json'))
55 | # cite_dict = defaultdict(list)
56 | # for paper in ref_dict:
57 | # # ref_list = ref_dict[paper]
58 | # for ref in ref_dict[paper]:
59 | # cite_dict[ref].append(paper)
60 | # json.dump(cite_dict, open(data_path + 'all_cite_dict.json', 'w+'))
61 | # del ref_dict
62 | #
63 | # year_dict = {}
64 | # info_dict = json.load(open(data_path + 'all_info_dict.json', 'r'))
65 | # for paper in cite_dict:
66 | # year_dict[paper] = list(filter(lambda x: x, map(lambda x: info_dict[x]['year'], cite_dict[paper])))
67 | # json.dump(year_dict, open(data_path + 'cite_year_dict.json', 'w+'))
68 | def get_valid_papers(data_path, time_point=None):
69 | data = json.load(open(data_path + 'all_info_dict.json', 'r'))
70 | print(len(data))
71 | # journal
72 | data = dict(filter(lambda x: x[1]['venue'].get('_id', False) if x[1]['venue'] else False, data.items()))
73 | print('has journal:', len(data))
74 |
75 | # authors
76 | data = dict(filter(lambda x: len([author['_id'] for author in x[1]['authors'] if '_id' in author]) > 0,
77 | data.items()))
78 | print('has author:', len(data))
79 | # title
80 | data = dict(filter(lambda x: x[1]['title'], data.items()))
81 | print('has title:', len(data))
82 | # json.dump(data, open(data_path + 'sample_info_dict.json', 'w+'))
83 | selected_list = set(data.keys())
84 | # time
85 | if time_point:
86 | data = dict(filter(lambda x: (int(x[1]['year']) <= time_point), data.items()))
87 | print('<=year:',len(data))
88 | # json.dump(data, open(data_path + 'sample_info_dict.json', 'w+'))
89 | selected_list = set(data.keys())
90 |
91 | del data
92 | # words
93 | abs_dict = json.load(open(data_path + 'all_abstract_dict.json', 'r'))
94 | abs_dict = dict(filter(lambda x: (x[0] in selected_list) & (len(str(x[1]).strip().split(' ')) >= 20),
95 | abs_dict.items()))
96 | selected_list = set(abs_dict.keys())
97 | print('has abs:', len(selected_list))
98 | # json.dump(abs_dict, open(data_path + 'sample_abstract_dict.json', 'w+'))
99 | return set(selected_list)
100 |
101 |
102 | def get_cite_data(data_path):
103 | # here already filter the data
104 | selected_set = get_valid_papers(data_path)
105 | ref_dict = json.load(open(data_path + 'all_ref_dict.json'))
106 | cite_dict = defaultdict(list)
107 | # for paper in ref_dict:
108 | # # ref_list = ref_dict[paper]
109 | # for ref in ref_dict[paper]:
110 | # cite_dict[ref].append(paper)
111 | for paper in selected_set:
112 | ref_list = ref_dict.get(paper, [])
113 | for ref in ref_list:
114 | if ref in selected_set:
115 | cite_dict[ref].append(paper)
116 | json.dump(cite_dict, open(data_path + 'all_cite_dict.json', 'w+'))
117 | del ref_dict
118 |
119 | year_dict = {}
120 | info_dict = json.load(open(data_path + 'all_info_dict.json', 'r'))
121 | for paper in cite_dict:
122 | year_dict[paper] = list(filter(lambda x: x, map(lambda x: info_dict[x]['year'], cite_dict[paper])))
123 | json.dump(year_dict, open(data_path + 'cite_year_dict.json', 'w+'))
124 |
125 |
126 |
127 | def show_data(data_path):
128 | info_dict = json.load(open(data_path + 'all_info_dict.json', 'r'))
129 | print('total_count', len(info_dict))
130 | # 'authors', 'fos', 'keywords', 'n_citations', 'title', 'venue', 'year'
131 | cite = json.load(open(data_path + 'all_cite_dict.json', 'r'))
132 | dealt_dict = map(lambda x: {
133 | 'id': x[0],
134 | 'authors': len([author['_id'] for author in x[1]['authors'] if '_id' in author]),
135 | 'authors_valid': len([author.get('_id', '') for author in x[1]['authors']
136 | if len(author.get('_id', '').strip()) > 0]),
137 | 'authors_g': len([author['gid'] for author in x[1]['authors'] if 'gid' in author]),
138 | 'venue': x[1]['venue'].get('_id', None) if x[1]['venue'] else None,
139 | 'venue_raw': x[1]['venue'].get('raw', None) if x[1]['venue'] else None,
140 | 'fos': len(x[1]['fos']),
141 | 'keywords': len(x[1]['keywords']),
142 | 'year': x[1]['year'],
143 | 'citations': len(cite.get(x[0], []))
144 | }, info_dict.items())
145 | no_cite_paper = [paper for paper in info_dict if paper not in cite]
146 | print(len(no_cite_paper))
147 | print(len(info_dict))
148 |
149 | df = pd.DataFrame(dealt_dict)
150 | print(df.describe())
151 | df_venue = df.groupby('venue').count().sort_values(by='id', ascending=False)['id']
152 | df_venue.to_csv(data_path + 'stats_venue_count.csv')
153 | print(df_venue.head(20))
154 | df_venue_raw = df.groupby(['venue', 'venue_raw']).count().sort_values(by='id', ascending=False)['id']
155 | df_venue_raw.to_csv(data_path + 'stats_venue_raw_count.csv')
156 | print(df_venue_raw.head(20))
157 |
158 | df_year = df.groupby('year').count()
159 | df_year.to_csv(data_path + 'stats_year_count.csv')
160 | print(df_year)
161 | del dealt_dict
162 |
163 | abstract_dict = json.load(open(data_path + 'all_abstract_dict.json'))
164 | # valid_paper = [len(abstract.strip().split(' ')) >= 20 for abstract in abstract_dict.values() if abstract]
165 | valid_paper = set(map(lambda x: x[0],
166 | filter(lambda x: len(str(x[1]).strip().split(' ')) >= 20, abstract_dict.items())))
167 | print(len(valid_paper))
168 |
169 | df_journal_counts = df.groupby('venue').count().sort_values(by='id', ascending=False)['id']
170 | df_journal_citations = df.groupby('venue').sum().sort_values(by='citations', ascending=False)['citations']
171 | print(df_journal_citations.head(20))
172 | df_journal_avg_citations = (df_journal_citations / df_journal_counts).sort_values(ascending=False)
173 | print(df_journal_avg_citations.head(20))
174 | df_journal_counts.to_csv(data_path + 'journal_count.csv')
175 | df_journal_citations.to_csv(data_path + 'journal_citation.csv')
176 | df_journal_avg_citations.to_csv(data_path + 'journal_avg_citation.csv')
177 |
178 | df_selected = df[df['id'].isin(valid_paper)]
179 | df_selected.to_csv(data_path + 'selected_data.csv')
180 | df_selected_year = df_selected.groupby('year').count()
181 | df_selected_year.to_csv(data_path + 'stats_selected_year_count.csv')
182 | print(df_selected_year)
183 |
184 |
185 | def get_subset(data_path, time_point=None):
186 | print(time_point)
187 |
188 | selected_set = get_valid_papers(data_path, time_point)
189 | abs_dict = json.load(open(data_path + 'all_abstract_dict.json', 'r'))
190 | abs_dict = dict(
191 | filter(lambda x: (x[0] in selected_set), abs_dict.items()))
192 |
193 | # save subset
194 | json.dump(abs_dict, open(data_path + 'sample_abstract_dict.json', 'w+'))
195 | del abs_dict
196 | data = json.load(open(data_path + 'all_info_dict.json', 'r'))
197 | data = dict(filter(lambda x: x[0] in selected_set, data.items()))
198 | json.dump(data, open(data_path + 'sample_info_dict.json', 'w+'))
199 | del data
200 |
201 | ref = json.load(open(data_path + 'all_ref_dict.json', 'r'))
202 | # cite = json.load(open(data_path + 'all_cite_data.json', 'r'))
203 | cite_year = json.load(open(data_path + 'cite_year_dict.json', 'r'))
204 |
205 | ref = dict(filter(lambda x: x[0] in selected_set, ref.items()))
206 | ref = dict(map(lambda x: (x[0], [paper for paper in x[1] if paper in selected_set]), ref.items()))
207 | json.dump(ref, open(data_path + 'sample_ref_dict.json', 'w+'))
208 |
209 | cite_year = dict(
210 | map(lambda x: (x[0], Counter(x[1])), filter(lambda x: x[0] in selected_set, cite_year.items())))
211 | json.dump(cite_year, open(data_path + 'sample_cite_year_dict.json', 'w+'))
212 |
213 | def get_citation_accum(data_path, time_point, all_times=(2011, 2013, 2015)):
214 | cite_year_path = data_path + 'sample_cite_year_dict.json'
215 | ref_path = data_path + 'sample_ref_dict.json'
216 | info_path = data_path + 'sample_info_dict.json'
217 | cite_year_dict = json.load(open(cite_year_path, 'r'))
218 | ref_dict = json.load(open(ref_path, 'r'))
219 | info_dict = json.load(open(info_path, 'r'))
220 |
221 | predicted_list = list(cite_year_dict.keys())
222 | accum_num_dict = {}
223 |
224 | # for paper in predicted_list:
225 | # temp_cum_num = []
226 | # pub_year = int(info_dict[paper]['pub_date']['year'])
227 | # count_dict = cite_year_dict[paper]
228 | # count = 0
229 | # for year in range(pub_year, pub_year+time_window):
230 | # if str(year) in count_dict:
231 | # count += int(count_dict[str(year)])
232 | # temp_cum_num.append(count)
233 | # accum_num_dict[paper] = temp_cum_num
234 | #
235 | # print(len(accum_num_dict))
236 | # print(len(list(filter(lambda x: x[-1] > 0, accum_num_dict.values()))))
237 | # print(len(list(filter(lambda x: x[-1] >= 10, accum_num_dict.values()))))
238 | # print(len(list(filter(lambda x: x[-1] >= 100, accum_num_dict.values()))))
239 | #
240 | # if subset:
241 | # accum_num_dict = dict(filter(lambda x: x[1][-1] >= 10, accum_num_dict.items()))
242 | # json.dump(accum_num_dict, open(data_path + 'sample_citation_accum.json', 'w+'))
243 | # else:
244 | # json.dump(accum_num_dict, open(data_path + 'all_citation_accum.json', 'w+'))
245 | print('writing citations accum')
246 | pub_dict = {}
247 | for paper in predicted_list:
248 | count_dict = dict(map(lambda x: (int(x[0]), x[1]), cite_year_dict[paper].items()))
249 | if (sum(count_dict.values()) >= 10) and (len(ref_dict[paper]) >= 5):
250 | # if len(ref_dict[paper]) >= 5:
251 | # pub_year = int(info_dict[paper]['pub_date']['year'])
252 | pub_year = int(info_dict[paper]['year'])
253 | pub_dict[paper] = pub_year
254 | start_year = max(pub_year, time_point - 4)
255 | # print(count_dict)
256 | count = sum(dict(filter(lambda x: x[0] <= start_year, count_dict.items())).values())
257 | # print(start_count)
258 | # temp_input_accum_num = [None] * (start_year + 4 - time_point)
259 | temp_input_accum_num = [-1] * (start_year + 4 - time_point)
260 | temp_output_accum_num = []
261 | for year in range(start_year, time_point + 1):
262 | if year in count_dict:
263 | count += int(count_dict[year])
264 | temp_input_accum_num.append(count)
265 | for year in range(time_point + 1, time_point + 6):
266 | if year in count_dict:
267 | count += int(count_dict[year])
268 | temp_output_accum_num.append(count)
269 | accum_num_dict[paper] = (temp_input_accum_num, temp_output_accum_num)
270 | # print(temp_input_accum_num)
271 | print(len(accum_num_dict))
272 | # accum_num_dict = dict(filter(lambda x: x[1][-1] >= 10, accum_num_dict.items()))
273 | print('writing file')
274 | json.dump(accum_num_dict, open(data_path + 'sample_citation_accum.json', 'w+'))
275 | for cur_time in all_times:
276 | cur_dict = {}
277 | cur_papers = map(lambda x: x[0], filter(lambda x: x[1] <= cur_time, pub_dict.items()))
278 | for paper in cur_papers:
279 | cur_dict[paper] = accum_num_dict[paper][1][cur_time + 4 - time_point] - \
280 | accum_num_dict[paper][0][cur_time + 4 - time_point]
281 | print(cur_time, len(cur_dict))
282 | df = pd.DataFrame(data=cur_dict.items())
283 | df.columns = ['paper', 'citations']
284 | df.to_csv(data_path + 'citation_{}.csv'.format(cur_time))
285 |
286 |
287 |
288 | def get_input_data(data_path, time_point=None, subset=False):
289 | info_path = data_path + 'all_info_dict.json'
290 | ref_path = data_path + 'all_ref_dict.json'
291 | cite_year_path = data_path + 'all_cite_year_dict.json'
292 | # abs_path = 'all_abstract_dict.json'
293 | if subset:
294 | info_path = data_path + 'sample_info_dict.json'
295 | ref_path = data_path + 'sample_ref_dict.json'
296 | cite_year_path = data_path + 'sample_cite_year_dict.json'
297 | # abs_path = 'sample_abstract_dict.json'
298 |
299 | info_dict = json.load(open(info_path, 'r'))
300 | ref_dict = json.load(open(ref_path, 'r'))
301 | cite_year_dict = json.load(open(cite_year_path, 'r'))
302 |
303 | node_trans = {}
304 | paper_trans = {}
305 | # graph created
306 | src_list = []
307 | dst_list = []
308 | index = 0
309 | for dst in ref_dict:
310 | if dst in paper_trans:
311 | dst_idx = paper_trans[dst]
312 | else:
313 | dst_idx = index
314 | paper_trans[dst] = dst_idx
315 | index += 1
316 | for src in set(ref_dict[dst]):
317 | if src in paper_trans:
318 | src_idx = paper_trans[src]
319 | else:
320 | src_idx = index
321 | paper_trans[src] = src_idx
322 | index += 1
323 | src_list.append(src_idx)
324 | dst_list.append(dst_idx)
325 |
326 | node_trans['paper'] = paper_trans
327 |
328 | author_trans = {}
329 | journal_trans = {}
330 | author_src, author_dst = [], []
331 | journal_src, journal_dst = [], []
332 | author_index = 0
333 | journal_index = 0
334 | for paper in paper_trans:
335 | meta_data = info_dict[paper]
336 | try:
337 | # authors = list(map(lambda x: (x['name'].get('given-names', '') + ' ' + x['name'].get('sur', '')).strip().lower(),
338 | # meta_data['authors']))
339 | authors = set([author['_id'] for author in meta_data['authors'] if '_id' in author])
340 | except Exception as e:
341 | print(meta_data['authors'])
342 | authors = None
343 | # print(authors)
344 | # print(meta_data['authors'][0]['name'])
345 | # journal = meta_data['journal']['ids']['nlm-ta'] if 'nlm-ta' in meta_data['journal']['ids'] else None
346 | journal = meta_data['venue']['_id'] if '_id' in meta_data['venue'] else None
347 |
348 | for author in authors:
349 | if author in author_trans:
350 | cur_idx = author_trans[author]
351 | else:
352 | cur_idx = author_index
353 | author_trans[author] = cur_idx
354 | author_index += 1
355 | author_src.append(cur_idx)
356 | author_dst.append(paper_trans[paper])
357 |
358 | if journal:
359 | if journal in journal_trans:
360 | cur_idx = journal_trans[journal]
361 | else:
362 | cur_idx = journal_index
363 | journal_trans[journal] = cur_idx
364 | journal_index += 1
365 | journal_src.append(cur_idx)
366 | journal_dst.append(paper_trans[paper])
367 |
368 | node_trans['author'] = author_trans
369 | node_trans['journal'] = journal_trans
370 |
371 | # node_trans_reverse = dict(map(lambda x: (x[0], dict(zip(x[1].values(), x[1].keys()))), node_trans.items()))
372 | node_trans_reverse = {key: dict(zip(node_trans[key].values(), node_trans[key].keys())) for key in node_trans}
373 | # paper_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['pub_date']['year']) for paper in dst_list]
374 | # author_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['pub_date']['year']) for paper in author_dst]
375 | # journal_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['pub_date']['year']) for paper in journal_dst]
376 | paper_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['year']) for paper in dst_list]
377 | author_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['year']) for paper in author_dst]
378 | journal_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['year']) for paper in journal_dst]
379 | paper_time = [int(info_dict[paper]['year']) for paper in node_trans['paper'].keys()]
380 |
381 | if subset:
382 | json.dump(node_trans, open(data_path + 'sample_node_trans.json', 'w+'))
383 | else:
384 | json.dump(node_trans, open(data_path + 'all_node_trans.json', 'w+'))
385 |
386 | # graph = dgl.graph((src_list, dst_list), num_nodes=len(paper_trans))
387 | graph = dgl.heterograph({
388 | ('paper', 'is cited by', 'paper'): (src_list, dst_list),
389 | ('paper', 'cites', 'paper'): (dst_list, src_list),
390 | ('author', 'writes', 'paper'): (author_src, author_dst),
391 | ('paper', 'is writen by', 'author'): (author_dst, author_src),
392 | ('journal', 'publishes', 'paper'): (journal_src, journal_dst),
393 | ('paper', 'is published by', 'journal'): (journal_dst, journal_src),
394 | }, num_nodes_dict={
395 | 'paper': len(paper_trans),
396 | 'author': len(author_trans),
397 | 'journal': len(journal_trans)
398 | })
399 | # graph.ndata['paper_id'] = torch.tensor(list(node_trans.keys())).unsqueeze(dim=0)
400 | graph.nodes['paper'].data['time'] = torch.tensor(paper_time, dtype=torch.int16).unsqueeze(dim=-1)
401 | graph.edges['is cited by'].data['time'] = torch.tensor(paper_link_time, dtype=torch.int16).unsqueeze(dim=-1)
402 | graph.edges['cites'].data['time'] = torch.tensor(paper_link_time, dtype=torch.int16).unsqueeze(dim=-1)
403 | graph.edges['writes'].data['time'] = torch.tensor(author_link_time, dtype=torch.int16).unsqueeze(dim=-1)
404 | graph.edges['is writen by'].data['time'] = torch.tensor(author_link_time, dtype=torch.int16).unsqueeze(dim=-1)
405 | graph.edges['publishes'].data['time'] = torch.tensor(journal_link_time, dtype=torch.int16).unsqueeze(dim=-1)
406 | graph.edges['is published by'].data['time'] = torch.tensor(journal_link_time, dtype=torch.int16).unsqueeze(dim=-1)
407 |
408 | graph = dgl.remove_self_loop(graph, 'is cited by')
409 | graph = dgl.remove_self_loop(graph, 'cites')
410 | print(graph)
411 | print(graph.edges['cites'].data['time'])
412 | if subset:
413 | torch.save(graph, data_path + 'graph_sample')
414 | else:
415 | torch.save(graph, data_path + 'graph')
416 |
417 | del graph, src_list, dst_list
418 |
419 | if __name__ == "__main__":
420 | start_time = datetime.datetime.now()
421 | parser = argparse.ArgumentParser(description='Process some description.')
422 | parser.add_argument('--phase', default='test', help='the function name.')
423 | parser.add_argument('--data_path', default=None, help='the input.')
424 | parser.add_argument('--name', default=None, help='file name.')
425 | parser.add_argument('--out_path', default=None, help='the output.')
426 | # parser.add_argument('--seed', default=123, help='the seed.')
427 | args = parser.parse_args()
428 | # TIME_POINT = 2015
429 | # ALL_TIMES = (2011, 2013, 2015)
430 | TIME_POINT = 2013
431 | ALL_TIMES = (2009, 2011, 2013)
432 | if args.phase == 'test':
433 | print('This is a test process.')
434 | # get_lines_data('./data/')
435 | # get_all_dict('./data/', 'info')
436 | # get_all_dict('./data/', 'abstract')
437 | # get_all_dict('./data/', 'ref')
438 | # get_cite_data('../data/')
439 | elif args.phase == 'lines_data':
440 | get_lines_data(args.data_path)
441 | print('lines data done')
442 | elif args.phase == 'info_data':
443 | get_all_dict(args.data_path, 'info')
444 | print('nfo data done')
445 | elif args.phase == 'abstract_data':
446 | get_all_dict(args.data_path, 'abstract')
447 | print('abstract data done')
448 | elif args.phase == 'ref_data':
449 | get_all_dict(args.data_path, 'ref')
450 | print('ref data done')
451 | elif args.phase == 'cite_data':
452 | get_cite_data(args.data_path)
453 | print('cite data done')
454 | elif args.phase == 'show_data':
455 | show_data(args.data_path)
456 | print('show data done')
457 | elif args.phase == 'subset':
458 | get_subset(args.data_path, time_point=TIME_POINT)
459 | print('subset data done')
460 | elif args.phase == 'subset_input_data':
461 | get_input_data(args.data_path, subset=True, time_point=TIME_POINT)
462 | print('subset input data done')
463 | elif args.phase == 'citation_accum':
464 | get_citation_accum(args.data_path, time_point=TIME_POINT, all_times=ALL_TIMES)
465 | print('citation accum done')
--------------------------------------------------------------------------------
/utilis/log_bar.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | from tqdm import tqdm
4 | import io
5 |
6 | class TqdmToLogger(io.StringIO):
7 | """
8 | Output stream for TQDM which will output to logger module instead of
9 | the StdOut.
10 | """
11 | logger = None
12 | level = None
13 | buf = ''
14 | def __init__(self,logger,level=None):
15 | super(TqdmToLogger, self).__init__()
16 | self.logger = logger
17 | self.level = level or logging.INFO
18 | def write(self,buf):
19 | self.buf = buf.strip('\r\n\t ')
20 | def flush(self):
21 | self.logger.log(self.level, self.buf)
22 |
23 | if __name__ == "__main__":
24 | logging.basicConfig(format='%(asctime)s [%(levelname)-8s] %(message)s')
25 | logger = logging.getLogger()
26 | logger.setLevel(logging.DEBUG)
27 |
28 | tqdm_out = TqdmToLogger(logger,level=logging.INFO)
29 | for x in tqdm(range(100),file=tqdm_out,mininterval=1,):
30 | time.sleep(2)
31 |
--------------------------------------------------------------------------------
/utilis/scripts.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import random
4 |
5 | import dgl
6 | import joblib
7 | import numpy as np
8 | import pandas as pd
9 | import torch
10 | from matplotlib import pyplot as plt
11 | from sklearn.manifold import TSNE
12 | from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, mean_squared_log_error
13 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss, roc_auc_score, \
14 | precision_recall_curve, auc
15 | import torch.nn.functional as F
16 |
17 |
18 | def eval_result(all_true_values, all_predicted_values):
19 | # print(mean_absolute_error(all_true_values, all_predicted_values, multioutput='raw_values'))
20 | all_true_values = all_true_values.flatten(order='C')
21 | # print(all_true_values)
22 | all_predicted_values = all_predicted_values.flatten(order='C')
23 | all_predicted_values = np.where(all_predicted_values == np.inf, 1e20, all_predicted_values)
24 | # print(all_predicted_values)
25 | mae = mean_absolute_error(all_true_values, all_predicted_values)
26 | mse = mean_squared_error(all_true_values, all_predicted_values)
27 | rmse = math.sqrt(mse)
28 | r2 = r2_score(all_true_values, all_predicted_values)
29 |
30 | rse = (all_true_values - all_predicted_values) ** 2 / (all_true_values + 1)
31 | Mrse = np.mean(rse, axis=0)
32 | mrse = np.percentile(rse, 50)
33 |
34 | acc = np.mean(np.all(((0.5 * all_true_values) <= all_predicted_values,
35 | all_predicted_values <= (1.5 * all_true_values)), axis=0), axis=0)
36 |
37 | # print(all_true_values)
38 | # add min for log
39 | log_true_values = np.log(1 + all_true_values)
40 | log_predicted_values = np.log(1 + np.where(all_predicted_values >= 0, all_predicted_values, 0))
41 | male = mean_absolute_error(log_true_values, log_predicted_values)
42 | msle = mean_squared_error(log_true_values, log_predicted_values)
43 | rmsle = math.sqrt(msle)
44 | log_r2 = r2_score(log_true_values, log_predicted_values)
45 | smape = np.mean(np.abs(log_true_values - log_predicted_values) /
46 | ((np.abs(log_true_values) + np.abs(log_predicted_values) + 0.1) / 2), axis=0)
47 | mape = np.mean(np.abs(log_true_values - log_predicted_values) / (np.abs(log_true_values) + 0.1), axis=0)
48 | # print('true', all_true_values[:10])
49 | # print('pred', all_predicted_values[:10])
50 | return [acc, mae, r2, mse, rmse, Mrse, mrse, male, log_r2, msle, rmsle, smape, mape]
51 |
52 |
53 | def eval_aux_results(all_true_label, all_predicted_result):
54 | all_predicted_label = np.argmax(all_predicted_result, axis=1)
55 | # all
56 | acc = accuracy_score(all_true_label, all_predicted_label)
57 | one_hot = np.eye(all_predicted_result.shape[-1])[all_true_label]
58 | try:
59 | roc_auc = roc_auc_score(one_hot, all_predicted_result)
60 | log_loss_value = log_loss(all_true_label, all_predicted_result)
61 | except Exception as e:
62 | print(e)
63 | roc_auc = 0
64 | log_loss_value = 0
65 | # print(all_predicted_result)
66 | # # pos
67 | # prec = precision_score(all_true_label, all_predicted_label)
68 | # recall = recall_score(all_true_label, all_predicted_label)
69 | # f1 = f1_score(all_true_label, all_predicted_label, average='binary')
70 | # pr_p, pr_r, _ = precision_recall_curve(all_true_label, all_predicted_result[:, 1])
71 | # pr_auc = auc(pr_r, pr_p)
72 | # # neg
73 | # prec_neg = precision_score(all_true_label, all_predicted_label, pos_label=0)
74 | # recall_neg = recall_score(all_true_label, all_predicted_label, pos_label=0)
75 | # f1_neg = f1_score(all_true_label, all_predicted_label, average='binary', pos_label=0)
76 | # pr_np, pr_nr, _ = precision_recall_curve(all_true_label, all_predicted_result[:, 0], pos_label=0)
77 | # pr_auc_neg = auc(pr_nr, pr_np)
78 | # for all
79 | micro_prec = precision_score(all_true_label, all_predicted_label, average='micro')
80 | micro_recall = recall_score(all_true_label, all_predicted_label, average='micro')
81 | micro_f1 = f1_score(all_true_label, all_predicted_label, average='micro')
82 | macro_prec = precision_score(all_true_label, all_predicted_label, average='macro')
83 | macro_recall = recall_score(all_true_label, all_predicted_label, average='macro')
84 | macro_f1 = f1_score(all_true_label, all_predicted_label, average='macro')
85 |
86 | results = [acc, roc_auc, log_loss_value, micro_prec, micro_recall, micro_f1, macro_prec, macro_recall, macro_f1]
87 | return results
88 |
89 |
90 | def eval_aux_labels(all_true_label, all_predicted_label):
91 | # acc, mi_prec, mi_recall, mi_f1, ma_prec, ma_recall, ma_f1
92 | # all_predicted_label = np.argmax(all_predicted_result, axis=1)
93 | # all
94 | acc = accuracy_score(all_true_label, all_predicted_label)
95 | # one_hot = np.eye(all_predicted_result.shape[-1])[all_true_label]
96 | # try:
97 | # roc_auc = roc_auc_score(one_hot, all_predicted_result)
98 | # log_loss_value = log_loss(all_true_label, all_predicted_result)
99 | # except Exception as e:
100 | # print(e)
101 | # roc_auc = 0
102 | # log_loss_value = 0
103 | # print(all_predicted_result)
104 | # # pos
105 | # prec = precision_score(all_true_label, all_predicted_label)
106 | # recall = recall_score(all_true_label, all_predicted_label)
107 | # f1 = f1_score(all_true_label, all_predicted_label, average='binary')
108 | # pr_p, pr_r, _ = precision_recall_curve(all_true_label, all_predicted_result[:, 1])
109 | # pr_auc = auc(pr_r, pr_p)
110 | # # neg
111 | # prec_neg = precision_score(all_true_label, all_predicted_label, pos_label=0)
112 | # recall_neg = recall_score(all_true_label, all_predicted_label, pos_label=0)
113 | # f1_neg = f1_score(all_true_label, all_predicted_label, average='binary', pos_label=0)
114 | # pr_np, pr_nr, _ = precision_recall_curve(all_true_label, all_predicted_result[:, 0], pos_label=0)
115 | # pr_auc_neg = auc(pr_nr, pr_np)
116 | # for all
117 | micro_prec = precision_score(all_true_label, all_predicted_label, average='micro')
118 | micro_recall = recall_score(all_true_label, all_predicted_label, average='micro')
119 | micro_f1 = f1_score(all_true_label, all_predicted_label, average='micro')
120 | macro_prec = precision_score(all_true_label, all_predicted_label, average='macro')
121 | macro_recall = recall_score(all_true_label, all_predicted_label, average='macro')
122 | macro_f1 = f1_score(all_true_label, all_predicted_label, average='macro')
123 |
124 | results = [acc, micro_prec, micro_recall, micro_f1, macro_prec, macro_recall, macro_f1]
125 | return results
126 |
127 |
128 | def result_format(results):
129 | loss, acc, mae, r2, mse, rmse, Mrse, mrse, male, log_r2, msle, rmsle, smape, mape = results
130 | format_str = '| loss {:8.4f} | acc {:9.4f} |\n' \
131 | '| MAE {:9.4f} | R2 {:10.4f} |\n' \
132 | '| MSE {:9.4f} | RMSE {:8.4f} |\n'\
133 | '| MRSE {:8.4f} | mRSE {:8.4f} |\n'\
134 | '| MALE {:8.4f} | LR2 {:9.4f} |\n'\
135 | '| MSLE {:8.4f} | RMSLE {:7.4f} |\n'\
136 | '| SMAPE {:8.4f} | MAPE {:7.4f} |\n'.format(loss, acc,
137 | mae, r2,
138 | mse, rmse,
139 | Mrse, mrse,
140 | male, log_r2,
141 | msle, rmsle,
142 | smape, mape) + \
143 | '-' * 59
144 | print(format_str)
145 | return format_str
146 |
147 | def aux_result_format(results, phase):
148 | avg_loss, acc, roc_auc, log_loss_value, micro_prec, micro_recall, micro_f1, macro_prec, macro_recall, macro_f1 = results
149 |
150 | format_str = '| aux loss {:8.3f} | {:9} {:7.3f} |\n' \
151 | '| roc_auc {:9.3f} | log_loss {:8.3f} |\n' \
152 | 'micro:\n' \
153 | '| precision {:7.3f} | recall {:10.3f} |\n' \
154 | '| f1 {:14.3f} |\n' \
155 | 'macro:\n' \
156 | '| precision {:7.3f} | recall {:10.3f} |\n' \
157 | '| f1 {:14.3f} |\n'.format(avg_loss, phase + ' acc', acc, roc_auc, log_loss_value,
158 | micro_prec, micro_recall, micro_f1,
159 | macro_prec, macro_recall, macro_f1)
160 | print(format_str)
161 | return format_str
162 |
163 |
164 | def get_configs(data_source, model_list, replace_dict=None):
165 | fr = open('./configs/{}.json'.format(data_source))
166 | configs = json.load(fr)
167 | full_configs = {'default': configs['default']}
168 | if replace_dict:
169 | print('>>>>>>>>>>replacing here<<<<<<<<<<<<<<<')
170 | print(replace_dict)
171 | for key in replace_dict:
172 | full_configs['default'][key] = replace_dict[key]
173 | for model in model_list:
174 | full_configs[model] = configs['default'].copy()
175 | if model in configs.keys():
176 | for key in configs[model].keys():
177 | full_configs[model][key] = configs[model][key]
178 | return full_configs
179 |
180 |
181 | def setup_seed(seed):
182 | torch.manual_seed(seed)
183 | torch.cuda.manual_seed(seed)
184 | torch.cuda.manual_seed_all(seed)
185 | np.random.seed(seed)
186 | random.seed(seed)
187 |
188 | try:
189 | dgl.seed(seed)
190 | except Exception as E:
191 | print(E)
192 | # torch.backends.cudnn.deterministic = True
193 | # os.environ["OMP_NUM_THREADS"] = '1'
194 |
195 |
196 | class IndexDict(dict):
197 | def __init__(self):
198 | super(IndexDict, self).__init__()
199 | self.count = 0
200 |
201 | def __getitem__(self, item):
202 | if item not in self.keys():
203 | super().__setitem__(item, self.count)
204 | self.count += 1
205 | return super().__getitem__(item)
206 |
207 |
208 | def add_new_elements(graph, nodes={}, ndata={}, edges={}, edata={}):
209 | '''
210 | :param graph: Dgl
211 | :param nodes: {ntype: []}
212 | :param ndata: {ntype: {attr: []}
213 | :param edges: {(src_type, etype, dst_type): [(src, dst)]
214 | :param ndata: {etype: {attr: []}
215 | :return:
216 | '''
217 | # etypes = graph.canonical_etypes
218 | num_nodes_dict = {ntype: graph.num_nodes(ntype) for ntype in graph.ntypes}
219 | for ntype in nodes:
220 | num_nodes_dict[ntype] = nodes[ntype]
221 | # etypes.extend(list(edges.keys()))
222 |
223 | relations = {}
224 | for etype in graph.canonical_etypes:
225 | src, dst = graph.edges(etype=etype[1])
226 | relations[etype] = (src, dst)
227 | for etype in edges:
228 | relations[etype] = edges[etype]
229 |
230 | # print(relations)
231 | # print(num_nodes_dict)
232 | new_g = dgl.heterograph(relations, num_nodes_dict=num_nodes_dict)
233 |
234 | for ntype in graph.ntypes:
235 | for k, v in graph.nodes[ntype].data.items():
236 | new_g.nodes[ntype].data[k] = v.detach().clone()
237 | for etype in graph.etypes:
238 | for k, v in graph.edges[etype].data.items():
239 | new_g.edges[etype].data[k] = v.detach().clone()
240 |
241 | for ntype in ndata:
242 | for attr in ndata[ntype]:
243 | new_g.nodes[ntype].data[attr] = ndata[ntype][attr]
244 |
245 | for etype in edata:
246 | for attr in edata[etype]:
247 | # try:
248 | new_g.edges[etype].data[attr] = edata[etype][attr]
249 | # except Exception as e:
250 | # print(e)
251 | # print(etype, attr)
252 | # print(new_g.edges(etype=etype))
253 | # print(edata[etype][attr])
254 | # raise Exception
255 | # print(new_g)
256 | return new_g
257 |
258 |
259 | def get_norm_adj(sparse_adj):
260 | I = torch.diag(torch.ones(sparse_adj.shape[0])).to_sparse()
261 | sparse_adj = sparse_adj + I
262 | D = torch.diag(torch.sparse.sum(sparse_adj, dim=-1).pow(-1).to_dense())
263 | # print(I)
264 | # print(sparse_adj.to_dense())
265 | return torch.matmul(D, sparse_adj.to_dense()).to_sparse()
266 |
267 |
268 | def custom_node_unbatch(g):
269 | # just unbatch the node attrs for cl, not including etypes for efficiency
270 | node_split = {ntype: g.batch_num_nodes(ntype) for ntype in g.ntypes}
271 | node_split = {k: dgl.backend.asnumpy(split).tolist() for k, split in node_split.items()}
272 | print(node_split)
273 | node_attr_dict = {ntype: {} for ntype in g.ntypes}
274 | for ntype in g.ntypes:
275 | for key, feat in g.nodes[ntype].data.items():
276 | subfeats = dgl.backend.split(feat, node_split[ntype], 0)
277 | # print(subfeats)
278 | node_attr_dict[ntype][key] = subfeats
279 | return node_attr_dict
280 |
281 | def get_label(citation):
282 | if citation < 10:
283 | return 0
284 | elif citation >= 100:
285 | return 2
286 | else:
287 | return 1
288 |
289 |
290 | def get_sampled_index(data):
291 | print('wip')
292 | df = pd.DataFrame(data['test'][1])
293 | df['label'] = df[0].apply(get_label)
294 | sampled_count = 1000
295 | print(df[df['label'] == 0])
296 | cls_0 = df[df['label'] == 0].sample(sampled_count, random_state=123).index
297 | cls_1 = df[df['label'] == 1].sample(sampled_count, random_state=123).index
298 | print(cls_0[:5], cls_1[:5])
299 | cls_2 = df[df['label'] == 2].sample(sampled_count, random_state=123).index
300 | all_index = list(cls_0) + list(cls_1) + list(cls_2)
301 | return all_index
302 |
303 |
304 | def get_selected_data(ids, dataProcessor):
305 | data = dataProcessor.get_data()
306 |
--------------------------------------------------------------------------------