├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── readme.md ├── configs ├── dblp.json └── pubmed.json ├── data └── readme.md ├── data_processor.py ├── layers ├── GCN │ ├── CL_layers.py │ ├── DTGCN_layers.py │ ├── RGCN.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── CL_layers.cpython-39.pyc │ │ ├── DTGCN_layers.cpython-39.pyc │ │ ├── RGCN.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ └── custom_gcn.cpython-39.pyc │ └── custom_gcn.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── common_layers.cpython-39.pyc └── common_layers.py ├── main.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── base_model.cpython-39.pyc └── base_model.py ├── our_models ├── CL.py ├── CTSGCN.py ├── DCTSGCN.py ├── EncoderCL.py ├── FCL.py ├── HCL.py ├── TSGCN.py ├── __init__.py └── __pycache__ │ ├── CL.cpython-39.pyc │ ├── CTSGCN.cpython-39.pyc │ ├── DCTSGCN.cpython-39.pyc │ ├── FCL.cpython-39.pyc │ ├── HCL.cpython-39.pyc │ ├── TSGCN.cpython-39.pyc │ └── __init__.cpython-39.pyc ├── results └── readme.md └── utilis ├── __init__.py ├── __pycache__ ├── __init__.cpython-39.pyc ├── log_bar.cpython-39.pyc └── scripts.cpython-39.pyc ├── dblp.py ├── log_bar.py ├── pubmed.py └── scripts.py /.gitignore: -------------------------------------------------------------------------------- 1 | <<<<<<< HEAD 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | ======= 132 | /results/* 133 | /data/* 134 | /checkpoints/* 135 | >>>>>>> 80cb730 (update) 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # H2CGL 2 | H2CGL: Modeling Dynamics of Citation Network for Impact Prediction
3 | https://arxiv.org/abs/2305.01572 4 | 5 | ## Requirement 6 | 7 | * pytorch >= 1.10.2 8 | * numpy >= 1.13.3 9 | * sklearn 10 | * python 3.9 11 | * transformers 12 | * dgl == 0.91 13 | 14 | ## Usage 15 | Original dataset: 16 | * PMC (pubmed): https://www.ncbi.nlm.nih.gov/pmc/tools/openftlist/ 17 | * DBLPv13 (dblp): https://www.aminer.org/citation 18 | 19 | Download the dealt dataset from https://pan.baidu.com/s/1BkcIUiyMbdn9FR3hLdws1w&pwd=kr7h
20 | 21 | You can get the H2CGL model here https://drive.google.com/file/d/1nJEgQdTenTssbZfjywe6poiOQKYvJpBo/view?usp=sharing 22 | 23 | ### Training 24 | ```sh 25 | # H2CGL 26 | python main.py --phase H2CGL --data_source h_pubmed/h_dblp --cl_type label_aug_hard_negative --aug_type cg --encoder_type 'CGIN+RGAT' --n_layers 4 --hn 2 --hn_method co_cite 27 | ``` 28 | 29 | ### Testing 30 | 31 | ```sh 32 | # H2CGL 33 | python main.py --phase test_results --model H2CGL --data_source h_pubmed/h_dblp --cl_type label_aug_hard_negative --aug_type cg --encoder_type 'CGIN+RGAT' --n_layers 4 --hn 2 --hn_method co_cite 34 | ``` 35 | -------------------------------------------------------------------------------- /checkpoints/readme.md: -------------------------------------------------------------------------------- 1 | Here place the input graphs and models. -------------------------------------------------------------------------------- /configs/dblp.json: -------------------------------------------------------------------------------- 1 | { 2 | "default": { 3 | "tokenizer_type": "basic", 4 | "tokenizer_path": null, 5 | "embed_dim": 300, 6 | "batch_size": 32, 7 | "num_classes": 3, 8 | "max_len": 256, 9 | "epochs": 20, 10 | "rate":0.8, 11 | "fixed_num": 120000, 12 | "optimizer": "ADAM", 13 | "criterion": "MSE", 14 | "lr": 1e-3, 15 | "weight_decay": 1e-4, 16 | "scheduler": false, 17 | "best_metric": "male", 18 | "dropout": 0.3, 19 | "saved_data": false, 20 | "fixed_data": false, 21 | "use_graph": false, 22 | "time": [2009, 2010, 2011], 23 | "cut_time": 2013, 24 | "type": "normal", 25 | "split_abstract": true, 26 | "save_model": true, 27 | "num_workers": 0, 28 | "batch_graph": false, 29 | "test_interval": 0, 30 | "graph_name": "graph_sample_feature_vector", 31 | "time_length": 5, 32 | "cut_threshold": [10, 100], 33 | "cl_weight": 0.5, 34 | "aux_weight": 0.5 35 | }, 36 | "TSGCN": { 37 | "embed_dim": 300, 38 | "hidden_dim": 256, 39 | "out_dim": 256, 40 | "n_layers": 4, 41 | "time_encoder": "self-attention", 42 | "time_encoder_layers": 4, 43 | "batch_size": 32, 44 | "use_graph": "repeatedly", 45 | "graph_name": "graph_sample_feature_vector", 46 | "etypes": [ 47 | "cites", "is cited by", "publishes", "is published by", 48 | "writes", "is writen by" 49 | ], 50 | "ntypes": ["author", "paper", "journal"], 51 | "time_length": 5, 52 | "hop": 1, 53 | "num_workers": 0, 54 | "lr": 1e-4 55 | }, 56 | "DCTSGCN": { 57 | "lr": 1e-4, 58 | "num_classes": 3, 59 | "embed_dim": 300, 60 | "hidden_dim": 256, 61 | "out_dim": 256, 62 | "n_layers": 3, 63 | "batch_size": 32, 64 | "use_graph": "repeatedly", 65 | "graph_name": "graph_sample_feature_vector", 66 | "linear_before_gcn": false, 67 | "time_length": 5, 68 | "hop": 1, 69 | "num_workers": 0, 70 | "etypes": [ 71 | "cites", "is cited by", "writes", "is writen by", "publishes", "is published by", 72 | "shows", "is shown in", "is in", "has", "t_cites", "is t_cited by" 73 | ], 74 | "ntypes": ["paper", "author", "journal", "time", "snapshot"], 75 | "pred_type": "snapshot", 76 | "edge_sequences": "basic" 77 | }, 78 | "H2CGL": { 79 | "lr": 1e-4, 80 | "num_classes": 3, 81 | "embed_dim": 300, 82 | "hidden_dim": 256, 83 | "out_dim": 256, 84 | "n_layers": 3, 85 | "batch_size": 32, 86 | "use_graph": "repeatedly", 87 | "graph_name": "graph_sample_feature_vector", 88 | "linear_before_gcn": false, 89 | "time_length": 5, 90 | "hop": 1, 91 | "num_workers": 0, 92 | "etypes": [ 93 | "cites", "is cited by", "writes", "is writen by", "publishes", "is published by", 94 | "shows", "is shown in", "is in", "has", "t_cites", "is t_cited by" 95 | ], 96 | "ntypes": ["paper", "author", "journal", "time", "snapshot"], 97 | "pred_type": "snapshot", 98 | "edge_sequences": "basic", 99 | "tau": 0.4 100 | } 101 | } -------------------------------------------------------------------------------- /configs/pubmed.json: -------------------------------------------------------------------------------- 1 | { 2 | "default": { 3 | "tokenizer_type": "basic", 4 | "tokenizer_path": null, 5 | "embed_dim": 300, 6 | "batch_size": 32, 7 | "num_classes": 3, 8 | "max_len": 256, 9 | "epochs": 30, 10 | "rate":0.8, 11 | "fixed_num": 150000, 12 | "optimizer": "ADAM", 13 | "criterion": "MSE", 14 | "lr": 1e-3, 15 | "weight_decay": 1e-5, 16 | "scheduler": false, 17 | "best_metric": "male", 18 | "dropout": 0.3, 19 | "saved_data": false, 20 | "fixed_data": false, 21 | "use_graph": false, 22 | "time": [2011, 2012, 2013], 23 | "type": "normal", 24 | "split_abstract": false, 25 | "save_model": true, 26 | "num_workers": 0, 27 | "batch_graph": false, 28 | "test_interval": 0, 29 | "graph_name": "graph_sample_feature_vector", 30 | "time_length": 5, 31 | "cut_threshold": [10, 100], 32 | "cut_time": 2015, 33 | "cl_weight": 0.5, 34 | "aux_weight": 0.5 35 | }, 36 | "TSGCN": { 37 | "embed_dim": 300, 38 | "hidden_dim": 256, 39 | "out_dim": 256, 40 | "n_layers": 4, 41 | "time_encoder": "self-attention", 42 | "time_encoder_layers": 4, 43 | "batch_size": 32, 44 | "use_graph": "repeatedly", 45 | "graph_name": "graph_sample_feature_vector", 46 | "etypes": [ 47 | "cites", "is cited by", "publishes", "is published by", 48 | "writes", "is writen by" 49 | ], 50 | "ntypes": ["author", "paper", "journal"], 51 | "time_length": 5, 52 | "hop": 1, 53 | "num_workers": 0, 54 | "lr": 1e-4 55 | }, 56 | "DCTSGCN": { 57 | "lr": 1e-4, 58 | "num_classes": 3, 59 | "embed_dim": 300, 60 | "hidden_dim": 256, 61 | "out_dim": 256, 62 | "n_layers": 3, 63 | "batch_size": 32, 64 | "use_graph": "repeatedly", 65 | "graph_name": "graph_sample_feature_vector", 66 | "time_length": 5, 67 | "hop": 1, 68 | "num_workers": 0, 69 | "linear_before_gcn": false, 70 | "etypes": [ 71 | "cites", "is cited by", "publishes", "is published by", 72 | "shows", "is shown in", "is in", "has", "t_cites", "is t_cited by" 73 | ], 74 | "ntypes": ["paper", "author", "journal", "time", "snapshot"], 75 | "pred_type": "snapshot", 76 | "edge_sequences": "basic" 77 | }, 78 | "H2CGL": { 79 | "lr": 1e-4, 80 | "embed_dim": 300, 81 | "hidden_dim": 256, 82 | "out_dim": 256, 83 | "n_layers": 3, 84 | "batch_size": 32, 85 | "use_graph": "repeatedly", 86 | "graph_name": "graph_sample_feature_vector", 87 | "time_length": 5, 88 | "hop": 1, 89 | "num_workers": 0, 90 | "linear_before_gcn": false, 91 | "etypes": [ 92 | "cites", "is cited by", "publishes", "is published by", 93 | "shows", "is shown in", "is in", "has", "t_cites", "is t_cited by" 94 | ], 95 | "ntypes": ["paper", "author", "journal", "time", "snapshot"], 96 | "pred_type": "snapshot", 97 | "edge_sequences": "basic", 98 | "tau": 0.4 99 | } 100 | } -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | Here place the data. -------------------------------------------------------------------------------- /layers/GCN/CL_layers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import dgl.nn.pytorch as dglnn 8 | import dgl 9 | from tqdm import tqdm 10 | 11 | from layers.common_layers import MLP 12 | from utilis.scripts import add_new_elements 13 | 14 | 15 | class CommonCLLayer(nn.Module): 16 | def __init__(self, in_feats, out_feats, hidden_feats=None, tau=0.4): 17 | super(CommonCLLayer, self).__init__() 18 | if hidden_feats is None: 19 | hidden_feats = in_feats 20 | # print(in_feats, out_feats, hidden_feats) 21 | self.proj_head = MLP(in_feats, out_feats, dim_inner=hidden_feats, activation=nn.ELU()) 22 | self.tau = tau 23 | 24 | def forward(self, z1, z2): 25 | h1, h2 = F.normalize(self.proj_head(z1)), F.normalize(self.proj_head(z2)) 26 | s12 = cos_sim(h1, h2, self.tau, norm=False) 27 | if s12.shape[0] == 0: 28 | s12 = torch.tensor([1]).reshape(1, 1).to(h1.device) 29 | s21 = s12 30 | 31 | # compute InfoNCE 32 | loss12 = -torch.log(s12.diag()) + torch.log(s12.sum(1)) 33 | loss21 = -torch.log(s21.diag()) + torch.log(s21.sum(1)) 34 | L_node = (loss12 + loss21) / 2 35 | 36 | 37 | return L_node.mean() 38 | 39 | 40 | class ProjectionHead(nn.Module): 41 | def __init__(self, in_feats, out_feats, hidden_feats=None): 42 | super(ProjectionHead, self).__init__() 43 | if hidden_feats is None: 44 | hidden_feats = in_feats 45 | self.proj_head = MLP(in_feats, out_feats, dim_inner=hidden_feats, activation=nn.ELU()) 46 | 47 | def forward(self, z): 48 | return F.normalize(self.proj_head(z)) 49 | 50 | 51 | def simple_cl(z1, z2, tau, valid_lens=None): 52 | h1, h2 = z1, z2 53 | s12 = cos_sim(h1, h2, tau, norm=False) 54 | if s12.shape[0] == 0: 55 | s12 = torch.tensor([1]).reshape(1, 1).to(h1.device) 56 | s21 = s12 57 | 58 | # compute InfoNCE 59 | if s12.dim() > 2: 60 | # print(s12.shape) 61 | num_node = s12.shape[-1] 62 | # loss12 = -torch.log(torch.flatten(torch.diagonal(s12, dim2=-2, dim1=-1))) + \ 63 | # torch.log(s12.reshape(-1, num_node).sum(1)) 64 | # loss21 = -torch.log(torch.flatten(torch.diagonal(s21, dim2=-2, dim1=-1))) + \ 65 | # torch.log(s21.reshape(-1, num_node).sum(1)) 66 | # print(loss21.shape) 67 | loss12 = -torch.log(torch.diagonal(s12, dim2=-2, dim1=-1)) + \ 68 | torch.log(s12.sum(-1)) 69 | loss21 = -torch.log(torch.diagonal(s21, dim2=-2, dim1=-1)) + \ 70 | torch.log(s21.sum(-1)) 71 | else: 72 | loss12 = -torch.log(s12.diag()) + torch.log(s12.sum(1)) 73 | loss21 = -torch.log(s21.diag()) + torch.log(s21.sum(1)) 74 | L_node = (loss12 + loss21) / 2 75 | 76 | if type(valid_lens) == torch.Tensor: 77 | # print(valid_lens) 78 | # print(L_node.sum(dim=-1).shape) 79 | L_node = L_node.sum(dim=-1) / valid_lens 80 | # print(L_node.mean()) 81 | return L_node.mean() 82 | 83 | def combined_sparse(matrices): 84 | indices = [] 85 | values = torch.cat([torch.flatten(matrix) for matrix in matrices]) 86 | # print(values) 87 | x, y = 0, 0 88 | for matrix in matrices: 89 | cur_x, cur_y = matrix.shape 90 | for i in range(x, x + cur_x): 91 | for j in range(y, y + cur_y): 92 | indices.append([i, j]) 93 | # values = torch.cat((values, torch.flatten(matrix))) 94 | x += cur_x 95 | y += cur_y 96 | # print(len(indices)) 97 | # print(indices) 98 | # print(values.shape) 99 | return torch.sparse_coo_tensor(indices=torch.tensor(indices, device=matrices[0].device).T, 100 | values=values, size=(x, y), requires_grad=True) 101 | 102 | 103 | def cos_sim(x1, x2, tau: float, norm: bool = False): 104 | if x1.is_sparse: 105 | if norm: 106 | return torch.softmax(torch.sparse.mm(x1, x2.transpose(0, 1)).to_dense() / tau, dim=1) 107 | else: 108 | return torch.exp(torch.sparse.mm(x1, x2.transpose(0, 1)).to_dense() / tau) 109 | elif x1.dim() > 2: 110 | result = x1 @ x2.transpose(-2, -1) / tau 111 | if norm: 112 | return torch.softmax(result, dim=-1) 113 | else: 114 | return torch.exp(result) 115 | else: 116 | if norm: 117 | return torch.softmax(x1 @ x2.T / tau, dim=1) 118 | else: 119 | return torch.exp(x1 @ x2.T / tau) 120 | 121 | 122 | def RBF_sim(x1, x2, tau: float, norm: bool = False): 123 | xx1, xx2 = torch.stack(len(x2) * [x1]), torch.stack(len(x1) * [x2]) 124 | sub = xx1.transpose(0, 1) - xx2 125 | R = (sub * sub).sum(dim=2) 126 | if norm: 127 | return F.softmax(-R / (tau * tau), dim=1) 128 | else: 129 | return torch.exp(-R / (tau * tau)) 130 | 131 | 132 | def augmenting_graphs(graphs, aug_configs, aug_methods=(('attr_drop',), ('edge_drop',))): 133 | ''' 134 | :param aug_configs: 135 | :param graphs: 136 | :param aug_methods: 137 | :return: 138 | ''' 139 | g1_list = [] 140 | g2_list = [] 141 | for graph in graphs: 142 | # g1, g2 = copy.deepcopy(graph), copy.deepcopy(graph) 143 | g1, g2 = add_new_elements(graph), add_new_elements(graph) 144 | for aug_method in aug_methods[0]: 145 | cur_method = eval(aug_method) 146 | g1 = cur_method(g1, aug_configs[0][aug_method]['prob'], aug_configs[0][aug_method].get('types', None)) 147 | for aug_method in aug_methods[1]: 148 | cur_method = eval(aug_method) 149 | g2 = cur_method(g2, aug_configs[1][aug_method]['prob'], aug_configs[1][aug_method].get('types', None)) 150 | g1_list.append(g1) 151 | g2_list.append(g2) 152 | 153 | del graphs 154 | return g1_list, g2_list 155 | 156 | 157 | def augmenting_graphs_ctsgcn(graphs, aug_configs, aug_methods=(('attr_drop',), ('edge_drop',)), time_length=10): 158 | ''' 159 | :param time_length: 160 | :param aug_configs: 161 | :param graphs: 162 | :param aug_methods: 163 | :return: 164 | ''' 165 | g1_list = [] 166 | g2_list = [] 167 | for graph in graphs: 168 | # g1, g2 = copy.deepcopy(graph), copy.deepcopy(graph) 169 | g1, g2 = add_new_elements(graph), add_new_elements(graph) 170 | for t in range(time_length): 171 | # g1 172 | for aug_method in aug_methods[0]: 173 | selected_nodes = {} 174 | for ntype in graph.ntypes: 175 | selected_nodes[ntype] = g1.nodes[ntype].data['snapshot_idx'] == t 176 | if torch.sum(selected_nodes['paper']).item() > 0: 177 | cur_method = eval(aug_method) 178 | g1 = cur_method(g1, aug_configs[0][aug_method]['prob'], aug_configs[0][aug_method].get('types', None), 179 | selected_nodes) 180 | # g2 181 | for aug_method in aug_methods[1]: 182 | selected_nodes = {} 183 | for ntype in graph.ntypes: 184 | selected_nodes[ntype] = g2.nodes[ntype].data['snapshot_idx'] == t 185 | if torch.sum(selected_nodes['paper']).item() > 0: 186 | cur_method = eval(aug_method) 187 | g2 = cur_method(g2, aug_configs[1][aug_method]['prob'], aug_configs[1][aug_method].get('types', None), 188 | selected_nodes) 189 | g1_list.append(g1) 190 | g2_list.append(g2) 191 | 192 | del graphs 193 | return g1_list, g2_list 194 | 195 | 196 | def augmenting_graphs_single(graphs, aug_config, aug_methods=(('attr_drop',), ('edge_drop',))): 197 | ''' 198 | :param aug_configs: 199 | :param graphs: 200 | :param aug_methods: 201 | :return: 202 | ''' 203 | graph_list = [] 204 | for graph in graphs: 205 | # dealt_graph = copy.deepcopy(graph) 206 | dealt_graph = add_new_elements(graph) 207 | for aug_method in aug_methods: 208 | cur_method = eval(aug_method) 209 | dealt_graph = cur_method(dealt_graph, aug_config[aug_method]['prob'], 210 | aug_config[aug_method].get('types', None)) 211 | graph_list.append(dealt_graph) 212 | return graph_list 213 | 214 | 215 | def augmenting_graphs_single_ctsgcn(graphs, aug_config, aug_methods=(('attr_drop',), ('edge_drop',)), time_length=10): 216 | ''' 217 | :param aug_configs: 218 | :param graphs: 219 | :param aug_methods: 220 | :return: 221 | ''' 222 | graph_list = [] 223 | for graph in graphs: 224 | # dealt_graph = copy.deepcopy(graph) 225 | dealt_graph = add_new_elements(graph) 226 | for aug_method in aug_methods: 227 | cur_method = eval(aug_method) 228 | for t in range(time_length): 229 | selected_nodes = {} 230 | for ntype in graph.ntypes: 231 | selected_nodes[ntype] = graph.nodes[ntype].data['snapshot_idx'] == t 232 | if torch.sum(selected_nodes['paper']).item() == 0: 233 | continue 234 | dealt_graph = cur_method(dealt_graph, aug_config[aug_method]['prob'], 235 | aug_config[aug_method].get('types', None), selected_nodes) 236 | graph_list.append(dealt_graph) 237 | return graph_list 238 | 239 | 240 | def attr_mask(graph, prob, ntypes=None, selected_nodes=None): 241 | if ntypes is None: 242 | ntypes = graph.ntypes 243 | if selected_nodes: 244 | for ntype in ntypes: 245 | emb_dim = graph.nodes[ntype].data['h'].shape[-1] 246 | p = torch.bernoulli(torch.tensor([1 - prob] * emb_dim)).to(torch.bool) 247 | graph.nodes[ntype].data['h'][selected_nodes[ntype]][:, p] = 0. 248 | else: 249 | for ntype in ntypes: 250 | emb_dim = graph.nodes[ntype].data['h'].shape[-1] 251 | p = torch.bernoulli(torch.tensor([1 - prob] * emb_dim)).to(torch.bool) 252 | graph.nodes[ntype].data['h'][:, p] = 0. 253 | return graph 254 | 255 | 256 | def attr_drop(graph, prob, ntypes=None, selected_nodes=None): 257 | if ntypes is None: 258 | ntypes = graph.ntypes 259 | if selected_nodes: 260 | for ntype in ntypes: 261 | node_count = graph.nodes[ntype].data['h'][selected_nodes[ntype]].shape[0] 262 | emb_dim = graph.nodes[ntype].data['h'].shape[-1] 263 | p = np.random.rand(node_count) 264 | drop_nodes = torch.where(selected_nodes[ntype])[0][np.where(p <= prob)[0]] 265 | graph.nodes[ntype].data['h'][drop_nodes, :] = torch.zeros(len(drop_nodes), emb_dim) 266 | else: 267 | for ntype in ntypes: 268 | node_count = graph.nodes[ntype].data['h'].shape[0] 269 | emb_dim = graph.nodes[ntype].data['h'].shape[-1] 270 | p = np.random.rand(node_count) 271 | drop_nodes = np.where(p <= prob)[0] 272 | graph.nodes[ntype].data['h'][drop_nodes, :] = torch.zeros(len(drop_nodes), emb_dim) 273 | return graph 274 | 275 | 276 | def edge_drop(graph, prob, etypes=None, selected_nodes=None): 277 | # graph = copy.deepcopy(graph) 278 | if etypes is None: 279 | etypes = graph.canonical_etypes 280 | if selected_nodes: 281 | # print('wip') 282 | for etype in etypes: 283 | dst_type = etype[-1] 284 | dst_nodes = set(graph.nodes(dst_type)[selected_nodes[dst_type]].numpy().tolist()) 285 | # print(graph.edges(form='all', etype=etype, order='eid')) 286 | src, dst, edges = graph.edges(form='all', etype=etype, order='eid') 287 | dst_index = [] 288 | # print(edges) 289 | for i in range(len(dst)): 290 | if dst[i].item() in dst_nodes: 291 | dst_index.append(i) 292 | edges = edges[dst_index] 293 | # print(dst_nodes, dst_index) 294 | 295 | p = np.random.rand(edges.shape[0]) 296 | # print(edges[np.where(p < 0.2)]) 297 | drop_edges = edges[np.where(p <= prob)] 298 | graph.remove_edges(drop_edges, etype=etype) 299 | else: 300 | for etype in etypes: 301 | edges = graph.edges(form='eid', etype=etype) 302 | p = np.random.rand(edges.shape[0]) 303 | # print(edges[np.where(p < 0.2)]) 304 | drop_edges = edges[np.where(p <= prob)] 305 | graph.remove_edges(drop_edges, etype=etype) 306 | # print(graph.edges(form='eid', etype='cites')) 307 | return graph 308 | 309 | def node_drop(graph, prob, ntypes=None, selected_nodes=None): 310 | # drop nodes according to the citation, so just drop the paper, not including other nodes and the target node 311 | selected_index = (1 - graph.nodes['paper'].data['is_target']).to(torch.bool) 312 | if selected_nodes: 313 | selected_nodes = selected_nodes['paper'] 314 | selected_index = selected_index & selected_nodes 315 | 316 | count = torch.sum(selected_index).item() 317 | # print(count) 318 | if count > 0: 319 | drop_rate = count * prob 320 | cur_p = torch.softmax(-graph.nodes['paper'].data['citations'][selected_index].float(), dim=-1).numpy() * drop_rate 321 | drop_indicator = np.random.rand(count) < cur_p 322 | drop_nodes = torch.where(selected_index)[0][drop_indicator] 323 | graph.remove_nodes(drop_nodes, ntype='paper') 324 | return graph 325 | 326 | 327 | if __name__ == '__main__': 328 | cl = CommonCLLayer(128, 64) 329 | z1 = torch.randn(10, 128) 330 | z2 = torch.randn(10, 128) 331 | print(cl(z1, z2)) 332 | 333 | print(eval('attr_drop')) 334 | -------------------------------------------------------------------------------- /layers/GCN/DTGCN_layers.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | from collections import defaultdict 4 | 5 | import dgl 6 | import torch 7 | import torch.nn as nn 8 | import dgl.nn.pytorch as dglnn 9 | from dgl.nn.pytorch.hetero import get_aggregate_fn 10 | 11 | from layers.GCN.RGCN import StochasticKLayerRGCN, CustomGCN 12 | 13 | 14 | class DCTSGCNLayer(nn.Module): 15 | def __init__(self, in_feat, hidden_feat, out_feat=None, ntypes=None, edge_sequences=None, k=3, time_length=10): 16 | super().__init__() 17 | # T * k layers 18 | self.k = k 19 | self.time_length = time_length 20 | self.ntypes = ntypes 21 | if out_feat is None: 22 | out_feat = hidden_feat 23 | if edge_sequences: 24 | forward_sequences, backward_sequences = edge_sequences 25 | self.concat_fc = nn.ModuleList( 26 | [dglnn.HeteroLinear({ntype: hidden_feat * 2 for ntype in ntypes}, hidden_feat)]) 27 | # print('fc', self.concat_fc[0].linears['author'].weight.device) 28 | self.forward_layers = nn.ModuleList([DirectedCTSGCNLayer(ntypes, time_length, forward_sequences, 29 | in_feat, hidden_feat)]) 30 | self.backward_layers = nn.ModuleList([DirectedCTSGCNLayer(ntypes, time_length, backward_sequences, 31 | in_feat, hidden_feat)]) 32 | for i in range(k - 2): 33 | self.forward_layers.append(DirectedCTSGCNLayer(ntypes, time_length, forward_sequences, 34 | hidden_feat, hidden_feat)) 35 | self.backward_layers.append(DirectedCTSGCNLayer(ntypes, time_length, backward_sequences, 36 | hidden_feat, hidden_feat)) 37 | self.concat_fc.append(dglnn.HeteroLinear({ntype: hidden_feat * 2 for ntype in ntypes}, hidden_feat)) 38 | self.forward_layers.append(DirectedCTSGCNLayer(ntypes, time_length, forward_sequences, 39 | hidden_feat, out_feat)) 40 | self.backward_layers.append(DirectedCTSGCNLayer(ntypes, time_length, backward_sequences, 41 | hidden_feat, out_feat)) 42 | self.concat_fc.append(dglnn.HeteroLinear({ntype: out_feat * 2 for ntype in ntypes}, out_feat)) 43 | 44 | # self.concat_fc = self.concat_fc.to(device) 45 | # print('fc', self.concat_fc[0].linears['author'].weight.device) 46 | # self.forward_layers = self.forward_layers.to(device) 47 | # self.backward_layers = self.backward_layers.to(device) 48 | 49 | def forward(self, snapshots, input_feats): 50 | hidden_embs = input_feats 51 | # hidden_embs = snapshots.srcdata['h'] 52 | for i in range(self.k): 53 | # print('-'*30, i, '-'*30) 54 | # print([emb.shape for emb in hidden_embs[0].values()]) 55 | forward_embs = self.forward_layers[i](snapshots, hidden_embs) 56 | # print('=' * 61) 57 | # print([emb.shape for emb in hidden_embs[0].values()]) 58 | backward_embs = self.backward_layers[i](snapshots, hidden_embs) 59 | # out_embs = [ 60 | # {ntype: torch.cat((forward_embs[t][ntype], backward_embs[t][ntype]), dim=-1) for ntype in self.ntypes} 61 | # for t in range(self.time_length)] 62 | out_embs = {ntype: torch.cat((forward_embs[ntype], backward_embs[ntype]), dim=-1) for ntype in self.ntypes} 63 | # print('here', out_embs[0]) 64 | # hidden_embs = [self.concat_fc[i](out_embs[t]) for t in range(self.time_length)] 65 | hidden_embs = self.concat_fc[i](out_embs) 66 | 67 | return hidden_embs 68 | 69 | def to(self, *args, **kwargs): 70 | self = super().to(*args, **kwargs) 71 | print('whole graph_encoder to device') 72 | if self.concat_fc: 73 | self.concat_fc = self.concat_fc.to(*args, **kwargs) 74 | for i in range(len(self.forward_layers)): 75 | self.forward_layers[i] = self.forward_layers[i].to(*args, **kwargs) 76 | # print(self.forward_layers) 77 | for i in range(len(self.backward_layers)): 78 | self.backward_layers[i] = self.backward_layers[i].to(*args, **kwargs) 79 | return self 80 | 81 | 82 | class DirectedCTSGCNLayer(nn.Module): 83 | def __init__(self, ntypes, time_length, edge_sequences, in_feat, hidden_feat, out_feat=None, agg_method='sum', 84 | activation=torch.relu): 85 | super(DirectedCTSGCNLayer, self).__init__() 86 | if not out_feat: 87 | out_feat = hidden_feat 88 | self.time_length = time_length 89 | # meta, paper, to_snapshot, snapshot 90 | self.sequence = [edges_type for edges_type, _ in edge_sequences] 91 | self.edges_type = {edges_type: etypes for edges_type, etypes in edge_sequences} 92 | self.agg_method = get_aggregate_fn(agg_method) 93 | self.snapshot_weight = 'none' 94 | self.convs = nn.ModuleDict({}) 95 | self.activation = activation 96 | self.skip_connect_fc = nn.ModuleDict({edges_type: nn.ModuleDict({}) for edges_type in self.edges_type}) 97 | 98 | trans_nodes = {} 99 | input_dim = {ntype: in_feat for ntype in ntypes} 100 | # if self.sequence[0] != 'snapshot': 101 | # ntypes = [stype for stype, _, _ in edge_sequences[0][-1]] 102 | # trans_nodes = {ntype: in_feat for ntype in ntypes} 103 | # for ntype in ntypes: 104 | # if ntype not in trans_nodes: 105 | # trans_nodes[ntype] = out_feat 106 | for i in range(len(edge_sequences)): 107 | edges_type, etypes = edge_sequences[i] 108 | if edges_type != 'intra_snapshots': 109 | # if i == 0: 110 | # self.convs[edges_type] = nn.ModuleList([dglnn.HeteroGraphConv( 111 | # {etype: dglnn.GraphConv(in_feat, hidden_feat, norm='right') 112 | # for etype in etypes}) for t in range(time_length)]) 113 | # elif i == len(edge_sequences) - 1: 114 | # self.convs[edges_type] = nn.ModuleList([dglnn.HeteroGraphConv( 115 | # {etype: dglnn.GraphConv(hidden_feat, out_feat, norm='right') 116 | # for etype in etypes}) for t in range(time_length)]) 117 | # else: 118 | # self.convs[edges_type] = nn.ModuleList([dglnn.HeteroGraphConv( 119 | # {etype: dglnn.GraphConv(hidden_feat, hidden_feat, norm='right') 120 | # for etype in etypes}) for t in range(time_length)]) 121 | if i == 0: 122 | self.convs[edges_type] = nn.ModuleDict({str(etype): 123 | dglnn.GraphConv(in_feat, hidden_feat, norm='right', 124 | allow_zero_in_degree=True) 125 | for etype in etypes}) 126 | cur_out = hidden_feat 127 | # print(list(self.convs[edges_type].values())[0][0].weight.device) 128 | elif i == len(edge_sequences) - 1: 129 | self.convs[edges_type] = nn.ModuleDict({str(etype): 130 | dglnn.GraphConv(hidden_feat, out_feat, norm='right', 131 | allow_zero_in_degree=True) 132 | for etype in etypes}) 133 | cur_out = out_feat 134 | else: 135 | self.convs[edges_type] = nn.ModuleDict({ 136 | str(etype): 137 | dglnn.GraphConv(hidden_feat, hidden_feat, norm='right', allow_zero_in_degree=True) 138 | for etype in etypes}) 139 | cur_out = hidden_feat 140 | for etype in etypes: 141 | trans_nodes[etype[-1]] = cur_out 142 | self.skip_connect_fc[edges_type][etype[-1]] = nn.Linear(input_dim[etype[-1]], cur_out) 143 | for etype in etypes: 144 | input_dim[etype[-1]] = cur_out 145 | else: 146 | if self.snapshot_weight: 147 | snapshot_norm = 'none' 148 | else: 149 | snapshot_norm = 'right' 150 | if i == 0: 151 | self.convs[edges_type] = dglnn.GraphConv(in_feat, hidden_feat, norm=snapshot_norm) 152 | cur_out = hidden_feat 153 | else: 154 | self.convs[edges_type] = dglnn.GraphConv(hidden_feat, out_feat, norm=snapshot_norm) 155 | cur_out = out_feat 156 | trans_nodes['snapshot'] = cur_out 157 | input_dim['snapshot'] = cur_out 158 | # self.skip_connect_fc[edges_type]['snapshot'] = nn.Linear(cur_in, cur_out) 159 | for ntype in ntypes: 160 | if ntype not in trans_nodes: 161 | trans_nodes[ntype] = in_feat 162 | # self.trans_fc = None 163 | self.trans_fc = dglnn.HeteroLinear(trans_nodes, out_feat) 164 | 165 | # def to(self, *args, **kwargs): 166 | # self = super().to(*args, **kwargs) 167 | # # for edges_type in self.convs: 168 | # # print(edges_type + ' to device') 169 | # # self.convs[edges_type] = self.convs[edges_type].to(*args, **kwargs) 170 | # # self.convs = self.convs.to(*args, **kwargs) 171 | # return self 172 | 173 | def normal_forward(self, snapshots, input_embs, edges_type): 174 | # conv_embs = [] 175 | outputs = defaultdict(list) 176 | all_embs = input_embs 177 | for etype in self.edges_type[edges_type]: 178 | # print('etype', etype) 179 | stype, _, dtype = etype 180 | subgraph = snapshots[tuple(etype)] 181 | # subgraph = dgl.add_self_loop(subgraph) 182 | # print(etype, subgraph.in_degrees()) 183 | srcdata = all_embs[stype] 184 | dstdata = all_embs[dtype] 185 | outputs[dtype].append(self.convs[edges_type][str(etype)](subgraph, (srcdata, dstdata))) 186 | for dtype in outputs: 187 | if len(outputs[dtype]) > 0: 188 | all_embs[dtype] = self.skip_connect_fc[edges_type][dtype](all_embs[dtype]) \ 189 | + self.agg_method(outputs[dtype], dtype) 190 | # conv_embs.append(all_embs) 191 | 192 | return all_embs 193 | 194 | def snapshot_forward(self, input_embs, snapshot_graphs): 195 | ndata = input_embs['snapshot'] 196 | subgraph = snapshot_graphs[tuple(self.edges_type['intra_snapshots'])] 197 | if self.snapshot_weight: 198 | norm = dglnn.EdgeWeightNorm(norm=self.snapshot_weight) 199 | norm_edge_weight = norm(subgraph, subgraph.edata['w']) 200 | # print(ndata.shape) 201 | # print(self.convs['intra_snapshots'].weight.shape) 202 | # print(ndata) 203 | # print(self.edges_type) 204 | conv_emb = self.convs['intra_snapshots'](subgraph, (ndata, ndata), edge_weight=norm_edge_weight) 205 | else: 206 | conv_emb = self.convs['intra_snapshots'](subgraph, (ndata, ndata)) 207 | # dstdata = torch.split(conv_emb, batch_size, dim=0) # T[B, d] 208 | # for t in range(self.time_length): 209 | # input_embs[t]['snapshot'] = dstdata[t] 210 | input_embs['snapshot'] = conv_emb 211 | return input_embs 212 | 213 | def forward(self, snapshots, input_embs): 214 | # print('-'*59) 215 | # hidden_embs = [snapshot.srcdata['h'] for snapshot in snapshots] 216 | # copy for dual-direction 217 | input_embs = input_embs.copy() 218 | hidden_embs = input_embs 219 | # print(input_embs[0]['snapshot']) 220 | for edges_type in self.sequence: 221 | if edges_type != 'intra_snapshots': 222 | hidden_embs = self.normal_forward(snapshots, hidden_embs, edges_type) 223 | else: 224 | hidden_embs = self.snapshot_forward(hidden_embs, snapshots) 225 | if self.activation: 226 | for ntype in hidden_embs.keys(): 227 | hidden_embs[ntype] = self.activation(hidden_embs[ntype]) 228 | # 共用一个MLP 229 | hidden_embs = self.trans_fc(hidden_embs) 230 | # T[nodes.data] 231 | return hidden_embs 232 | 233 | 234 | class SimpleDCTSGCNLayer(DCTSGCNLayer): 235 | def __init__(self, in_feat, hidden_feat, out_feat=None, ntypes=None, edge_sequences=None, k=3, time_length=10, 236 | combine_method='concat', encoder_type='GCN', encoder_dict=None, return_list=False, snapshot_weight=None, 237 | **kwargs): 238 | super().__init__(in_feat, hidden_feat, out_feat, ntypes, None, k, time_length) 239 | print(encoder_type, encoder_dict) 240 | # T * k layers 241 | self.k = k 242 | self.time_length = time_length 243 | self.ntypes = ntypes 244 | if out_feat is None: 245 | out_feat = hidden_feat 246 | self.fc_in = dglnn.HeteroLinear({ntype: in_feat for ntype in ntypes}, hidden_feat) 247 | # self.fc_out = dglnn.HeteroLinear({ntype: hidden_feat for ntype in ntypes}, out_feat) 248 | self.fc_out = nn.ModuleList([dglnn.HeteroLinear({ntype: hidden_feat for ntype in ntypes}, out_feat) 249 | for i in range(k)]) 250 | self.return_list = return_list 251 | 252 | print(combine_method) 253 | self.concat_fc = None 254 | if combine_method == 'concat': 255 | self.concat_fc = nn.ModuleList( 256 | [dglnn.HeteroLinear({ntype: hidden_feat * 2 for ntype in ntypes}, hidden_feat) 257 | for i in range(k)]) 258 | # print('fc', self.concat_fc[0].linears['author'].weight.device) 259 | if len(edge_sequences) > 1: 260 | self.dual_view = True 261 | forward_sequences, backward_sequences = edge_sequences 262 | self.forward_layers = nn.ModuleList([SimpleDirectedCTSGCNLayer(time_length, forward_sequences, hidden_feat, 263 | encoder_type=encoder_type, 264 | encoder_dict=encoder_dict, 265 | snapshot_weight=snapshot_weight, **kwargs) 266 | for i in range(k)]) 267 | self.backward_layers = nn.ModuleList([SimpleDirectedCTSGCNLayer(time_length, forward_sequences, hidden_feat, 268 | encoder_type=encoder_type, 269 | encoder_dict=encoder_dict, 270 | snapshot_weight=snapshot_weight, **kwargs) 271 | for i in range(k)]) 272 | else: 273 | self.dual_view = False 274 | single_sequences = edge_sequences[0] 275 | self.forward_layers = [] 276 | self.backward_layers = [] 277 | self.single_layers = nn.ModuleList([SimpleDirectedCTSGCNLayer(time_length, single_sequences, hidden_feat, 278 | encoder_type=encoder_type, 279 | encoder_dict=encoder_dict, 280 | snapshot_weight=snapshot_weight, **kwargs) 281 | for i in range(k)]) 282 | print('dual view:', self.dual_view) 283 | 284 | def forward(self, snapshots, hidden_embs, get_attention=False): 285 | # hidden_embs = snapshots.srcdata['h'] 286 | output_list = [] 287 | attn_list = [] 288 | hidden_embs = self.fc_in(hidden_embs) 289 | for i in range(self.k): 290 | if self.dual_view: 291 | forward_embs = self.forward_layers[i](snapshots, hidden_embs) 292 | backward_embs = self.backward_layers[i](snapshots, hidden_embs) 293 | if self.concat_fc: 294 | out_embs = {ntype: torch.cat((forward_embs[ntype], backward_embs[ntype]), dim=-1) for ntype in 295 | self.ntypes} 296 | hidden_embs = self.concat_fc[i](out_embs) 297 | else: 298 | hidden_embs = {ntype: forward_embs[ntype] + backward_embs[ntype] for ntype in 299 | self.ntypes} 300 | else: 301 | if get_attention: 302 | hidden_embs, attn = self.single_layers[i](snapshots, hidden_embs, get_attention=True) 303 | attn_list.append(attn) 304 | else: 305 | hidden_embs = self.single_layers[i](snapshots, hidden_embs) 306 | if self.return_list: 307 | output_list.append(self.fc_out[i](hidden_embs)) 308 | hidden_embs = self.fc_out[-1](hidden_embs) 309 | # for key in hidden_embs: 310 | # print(hidden_embs[key].shape) 311 | if get_attention: 312 | return hidden_embs, output_list, attn_list 313 | else: 314 | return hidden_embs, output_list 315 | 316 | 317 | class SimpleDirectedCTSGCNLayer(nn.Module): 318 | def __init__(self, time_length, edge_sequences, hidden_feat, agg_method='sum', encoder_type='GCN', 319 | encoder_dict=None, snapshot_weight=None, 320 | activation=nn.LeakyReLU(), **kwargs): 321 | # simplified version 322 | super(SimpleDirectedCTSGCNLayer, self).__init__() 323 | self.time_length = time_length 324 | # meta, paper, to_snapshot, snapshot 325 | print(len(edge_sequences)) 326 | self.sequence = [edges_type for edges_type, _ in edge_sequences] 327 | self.edges_type = {edges_type: etypes for edges_type, etypes in edge_sequences} 328 | self.agg_method = get_aggregate_fn(agg_method) 329 | self.snapshot_weight = snapshot_weight 330 | # self.snapshot_weight = None 331 | print('snapshot_weight', self.snapshot_weight) 332 | self.convs = nn.ModuleDict({}) 333 | self.activation = activation 334 | 335 | for i in range(len(edge_sequences)): 336 | edges_type, etypes = edge_sequences[i] 337 | if edges_type != 'intra_snapshots': 338 | # self.convs[edges_type] = nn.ModuleDict({str(etype): 339 | # dglnn.GraphConv(hidden_feat, hidden_feat, norm='right', 340 | # allow_zero_in_degree=True) 341 | # for etype in etypes}) 342 | # self.convs[edges_type] = nn.ModuleDict({str(etype): 343 | # CustomGCN(encoder_type, hidden_feat, hidden_feat, layer=i+1, 344 | # activation=activation, **kwargs) 345 | # for etype in etypes}) 346 | self.convs[edges_type] = nn.ModuleDict({}) 347 | for etype in etypes: 348 | if etype not in encoder_dict: 349 | cur_encoder_type = encoder_type 350 | encoder_args = {} 351 | else: 352 | cur_encoder_type = encoder_dict[etype]['cur_type'] 353 | encoder_args = encoder_dict[etype] 354 | self.convs[edges_type][str(etype)] = CustomGCN(cur_encoder_type, hidden_feat, hidden_feat, 355 | layer=i + 1, activation=activation, **encoder_args) 356 | 357 | else: 358 | if self.snapshot_weight: 359 | self.snapshot_weight = 'none' 360 | else: 361 | self.snapshot_weight = 'right' 362 | self.convs[edges_type] = nn.ModuleDict({}) 363 | for etype in etypes: 364 | if etype not in encoder_dict: 365 | cur_encoder_type = encoder_type 366 | encoder_args = {} 367 | else: 368 | cur_encoder_type = encoder_dict[etype]['cur_type'] 369 | encoder_args = encoder_dict[etype] 370 | print(encoder_args) 371 | self.convs[edges_type][str(etype)] = CustomGCN(cur_encoder_type, hidden_feat, hidden_feat, 372 | layer=i + 1, activation=activation, 373 | **encoder_args) 374 | 375 | def normal_forward(self, snapshots, input_embs, edges_type): 376 | # conv_embs = [] 377 | outputs = defaultdict(list) 378 | all_embs = input_embs 379 | for etype in self.edges_type[edges_type]: 380 | # print('etype', etype) 381 | stype, _, dtype = etype 382 | subgraph = snapshots[tuple(etype)] 383 | # subgraph = dgl.add_self_loop(subgraph) 384 | # print(etype, subgraph.in_degrees()) 385 | srcdata = all_embs[stype] 386 | dstdata = all_embs[dtype] 387 | outputs[dtype].append(self.convs[edges_type][str(etype)](subgraph, (srcdata, dstdata))) 388 | for dtype in outputs: 389 | if len(outputs[dtype]) > 0: 390 | all_embs[dtype] = self.agg_method(outputs[dtype], dtype) 391 | # conv_embs.append(all_embs) 392 | 393 | return all_embs 394 | 395 | def snapshot_forward(self, input_embs, snapshot_graphs): 396 | outputs = defaultdict(list) 397 | all_embs = input_embs 398 | for etype in self.edges_type['intra_snapshots']: 399 | # print('etype', etype) 400 | stype, _, dtype = etype 401 | subgraph = snapshot_graphs[tuple(etype)] 402 | # subgraph = dgl.add_self_loop(subgraph) 403 | # print(etype, subgraph.in_degrees()) 404 | srcdata = all_embs[stype] 405 | dstdata = all_embs[dtype] 406 | norm_edge_weight = None 407 | if self.snapshot_weight: 408 | norm = dglnn.EdgeWeightNorm(norm=self.snapshot_weight) 409 | norm_edge_weight = norm(subgraph, subgraph.edata['w']) 410 | outputs[dtype].append(self.convs['intra_snapshots'][str(etype)](subgraph, (srcdata, dstdata), 411 | edge_weight=norm_edge_weight)) 412 | for dtype in outputs: 413 | if len(outputs[dtype]) > 0: 414 | all_embs[dtype] = self.agg_method(outputs[dtype], dtype) 415 | return all_embs 416 | 417 | def get_attn(self, snapshots, input_embs, edges_type, etype): 418 | # conv_embs = [] 419 | outputs = defaultdict(list) 420 | all_embs = input_embs 421 | # print('etype', etype) 422 | # etype = ('paper', 'is in', 'snapshot') 423 | stype, _, dtype = etype 424 | subgraph = snapshots[tuple(etype)] 425 | # subgraph = dgl.add_self_loop(subgraph) 426 | # print(etype, subgraph.in_degrees()) 427 | srcdata = all_embs[stype] 428 | dstdata = all_embs[dtype] 429 | _, attn = self.convs[edges_type][str(etype)](subgraph, (srcdata, dstdata), get_attention=True) 430 | return attn 431 | 432 | def forward(self, snapshots, input_embs, get_attention=False): 433 | input_embs = input_embs.copy() 434 | hidden_embs = input_embs 435 | for edges_type in self.sequence: 436 | if edges_type != 'intra_snapshots': 437 | if get_attention: 438 | attn_rgat = self.get_attn(snapshots, hidden_embs, edges_type, ('paper', 'is in', 'snapshot')) 439 | attn_cgin = self.get_attn(snapshots, hidden_embs, edges_type, ('paper', 'cites', 'paper')) 440 | attn = [attn_rgat, attn_cgin] 441 | hidden_embs = self.normal_forward(snapshots, hidden_embs, edges_type) 442 | else: 443 | hidden_embs = self.snapshot_forward(hidden_embs, snapshots) 444 | if self.activation: 445 | for ntype in hidden_embs.keys(): 446 | hidden_embs[ntype] = self.activation(hidden_embs[ntype]) 447 | if get_attention: 448 | return hidden_embs, attn 449 | else: 450 | return hidden_embs 451 | 452 | 453 | if __name__ == '__main__': 454 | start_time = datetime.datetime.now() 455 | parser = argparse.ArgumentParser(description='Process some description.') 456 | parser.add_argument('--phase', default='test', help='the function name.') 457 | 458 | args = parser.parse_args() 459 | rel_names = ['cites', 'is cited by', 'writes', 'is writen by', 'publishes', 'is published by'] 460 | model = StochasticKLayerRGCN(768, 300, 300, rel_names) 461 | print(model) 462 | -------------------------------------------------------------------------------- /layers/GCN/RGCN.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | 4 | import dgl 5 | import torch 6 | import torch.nn as nn 7 | import dgl.nn.pytorch as dglnn 8 | from layers.common_layers import MLP, CapsuleNetwork 9 | from layers.GCN.custom_gcn import DAGIN, RGIN, RGIN_N, RGAT, DGIN, DGINM 10 | 11 | 12 | def get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=None, **kwargs): 13 | if encoder_type == 'GCN': 14 | return dglnn.GraphConv(in_feat, hidden_feat, norm='right', allow_zero_in_degree=True) 15 | elif encoder_type == 'GIN': 16 | return dglnn.GINConv(MLP(in_feat, hidden_feat, activation=activation)) 17 | elif encoder_type == 'GAT': 18 | nheads = kwargs.get('nheads', 4) 19 | out_feat = hidden_feat // nheads 20 | return dglnn.GATv2Conv(in_feat, out_feat, num_heads=nheads, activation=activation, allow_zero_in_degree=True) 21 | elif encoder_type == 'APPNP': 22 | k = kwargs.get('k', 5) 23 | alpha = kwargs.get('alpha', 0.5) 24 | return dglnn.APPNPConv(k=k, alpha=alpha) 25 | elif encoder_type == 'GCN2': 26 | return dglnn.GCN2Conv(in_feat, layer, allow_zero_in_degree=True, activation=activation) 27 | elif encoder_type == 'SAGE': 28 | return dglnn.SAGEConv(in_feat, hidden_feat, 'mean', activation=activation) 29 | elif encoder_type == 'PNA': 30 | return dglnn.PNAConv(in_feat, hidden_feat, ['mean', 'max', 'sum'], ['identity', 'amplification'], 2.5) 31 | elif encoder_type == 'DGN': 32 | return dglnn.DGNConv(in_feat, hidden_feat, ['dir1-av', 'dir1-dx', 'sum'], ['identity', 'amplification'], 2.5) 33 | elif encoder_type == 'DAGIN': 34 | distance_interval = kwargs.get('distance_interval', 5) 35 | return DAGIN(in_feat, hidden_feat, distance_interval=distance_interval, activation=activation) 36 | elif encoder_type == 'DGIN': 37 | attr_indicator = kwargs.get('attr_indicator', 'citations') 38 | return DGIN(in_feat, hidden_feat, attr_indicator=attr_indicator, activation=activation) 39 | elif encoder_type == 'DGINM': 40 | attr_indicator = kwargs.get('attr_indicator', 'citations') 41 | return DGINM(in_feat, hidden_feat, attr_indicator=attr_indicator, activation=activation) 42 | elif encoder_type == 'DGINP': 43 | attr_indicator = kwargs.get('attr_indicator', 'citations') 44 | return DGIN(in_feat, hidden_feat, degree_trans_method='pos', 45 | attr_indicator=attr_indicator, activation=activation) 46 | elif encoder_type == 'DGINPM': 47 | attr_indicator = kwargs.get('attr_indicator', 'citations') 48 | return DGINM(in_feat, hidden_feat, degree_trans_method='pos', 49 | attr_indicator=attr_indicator, activation=activation) 50 | elif encoder_type == 'RGIN': 51 | # print(kwargs) 52 | snapshot_types = kwargs.get('snapshot_types', 3) 53 | node_attr_types = kwargs.get('node_attr_types', []) 54 | return RGIN_N(in_feat, hidden_feat, snapshot_types, node_attr_types, activation=activation) 55 | elif encoder_type == 'CGIN': 56 | return dglnn.GINConv(CapsuleNetwork(in_feat, hidden_feat, activation=activation)) 57 | elif encoder_type == 'RGAT': 58 | snapshot_types = kwargs.get('snapshot_types', 3) 59 | node_attr_types = kwargs.get('node_attr_types', []) 60 | print(node_attr_types) 61 | nheads = kwargs.get('nheads', 4) 62 | out_feat = hidden_feat // nheads 63 | return RGAT(in_feat, out_feat, nheads, snapshot_types, node_attr_types, 64 | activation=activation, allow_zero_in_degree=True) 65 | 66 | 67 | class CustomGCN(nn.Module): 68 | # homo single GCN encoder 69 | def __init__(self, encoder_type, in_feat, hidden_feat, layer, activation=None, **kwargs): 70 | super(CustomGCN, self).__init__() 71 | self.encoder_type = encoder_type 72 | self.fc_out = None 73 | # self.conv = get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=activation) 74 | if encoder_type in ['DAGIN', 'CGIN', 'DGIN']: 75 | self.conv = get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=activation) 76 | elif encoder_type == 'RGIN': 77 | snapshot_types = kwargs['snapshot_types'] 78 | node_attr_types = kwargs['node_attr_types'] 79 | self.conv = get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=activation, 80 | snapshot_types=snapshot_types, node_attr_types=node_attr_types) 81 | elif encoder_type == 'RGAT': 82 | snapshot_types = kwargs['snapshot_types'] 83 | node_attr_types = kwargs['node_attr_types'] 84 | self.conv = get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=activation, 85 | snapshot_types=snapshot_types, node_attr_types=node_attr_types) 86 | else: 87 | self.conv = get_gcn_layer(encoder_type, in_feat, hidden_feat, layer, activation=activation) 88 | 89 | if 'GAT' in self.encoder_type: 90 | self.fc_out = nn.Linear(hidden_feat, hidden_feat) 91 | elif self.encoder_type in ['APPNP', 'GCN2']: 92 | self.fc_out = nn.Linear(in_feat, hidden_feat) 93 | 94 | def forward(self, graph, feat, edge_weight=None, **kwargs): 95 | ''' 96 | :param graph: 97 | :param feat: tuple of src and dst 98 | :param edge_weight: 99 | :param kwargs: 100 | :return: hidden_embs: {ntype: [N, out_feat]} 101 | ''' 102 | # print('wip') 103 | # print(self.encoder_type) 104 | attn = None 105 | if self.encoder_type == 'GAT': 106 | feat = self.conv(graph, feat, **kwargs) 107 | else: 108 | feat = self.conv(graph, feat, edge_weight=edge_weight, **kwargs) 109 | if 'GAT' in self.encoder_type: 110 | if kwargs.get('get_attention', False): 111 | feat, attn = feat 112 | num_nodes = feat.shape[0] 113 | feat = feat.reshape(num_nodes, -1) 114 | if self.fc_out: 115 | feat = self.fc_out(feat) 116 | if attn is not None: 117 | return feat, attn 118 | else: 119 | return feat 120 | 121 | 122 | class BasicHeteroGCN(nn.Module): 123 | def __init__(self, encoder_type, in_feat, out_feat, ntypes, etypes, layer, activation=None, **kwargs): 124 | super(BasicHeteroGCN, self).__init__() 125 | 126 | if encoder_type == 'HGT': 127 | nheads = kwargs.get('nheads', 4) 128 | self.conv = dglnn.HGTConv(in_feat, out_feat // nheads, nheads, 129 | num_ntypes=len(ntypes), num_etypes=len(etypes)) 130 | else: 131 | self.conv = dglnn.HeteroGraphConv({etype: get_gcn_layer(encoder_type, in_feat, out_feat, layer, 132 | activation=activation) for etype in etypes}) 133 | self.encoder_type = encoder_type 134 | print(self.encoder_type) 135 | self.fc_out = None 136 | if self.encoder_type == 'GAT': 137 | self.fc_out = dglnn.HeteroLinear({ntype: out_feat for ntype in ntypes}, out_feat) 138 | elif self.encoder_type in ['APPNP', 'GCN2']: 139 | self.fc_out = dglnn.HeteroLinear({ntype: in_feat for ntype in ntypes}, out_feat) 140 | 141 | def forward(self, graph, feat, edge_weight=None, **kwargs): 142 | ''' 143 | :param graph: 144 | :param feat: 145 | :param edge_weight: 146 | :param kwargs: 147 | :return: hidden_embs: {ntype: [N, out_feat]} 148 | ''' 149 | # print('wip') 150 | # print('start', [torch.isnan(feat[x]).sum() for x in feat]) 151 | # print(self.encoder_type) 152 | # print(self.conv) 153 | if self.encoder_type == 'HGT': 154 | # print(feat.shape) 155 | feat = self.conv(graph, feat, graph.ndata[dgl.NTYPE], graph.edata[dgl.ETYPE], presorted=True) 156 | if self.encoder_type == 'GAT': 157 | feat = self.conv(graph, feat) 158 | else: 159 | # feat = self.conv(graph, feat, edge_weight=edge_weight) 160 | feat = self.conv(graph, feat, mod_kwargs={'edge_weight': edge_weight}) 161 | # print('end', [torch.isnan(feat[x]).sum() for x in feat]) 162 | if self.encoder_type == 'GAT': 163 | for ntype in feat: 164 | num_nodes = feat[ntype].shape[0] 165 | feat[ntype] = feat[ntype].reshape(num_nodes, -1) 166 | if self.fc_out: 167 | feat = self.fc_out(feat) 168 | return feat 169 | 170 | 171 | class CustomHeteroGCN(nn.Module): 172 | def __init__(self, encoder_type, in_feat, out_feat, ntypes, etypes, layer, activation=None, **kwargs): 173 | super(CustomHeteroGCN, self).__init__() 174 | etype_dict = None 175 | 176 | if encoder_type == 'HGT': 177 | nheads = kwargs.get('nheads', 4) 178 | self.conv = dglnn.HGTConv(in_feat, out_feat // nheads, nheads, 179 | num_ntypes=len(ntypes), num_etypes=len(etypes)) 180 | elif encoder_type in ['DAGIN', 'CGIN']: 181 | # print(kwargs) 182 | etype_dict = {} 183 | for etype in kwargs.get('special_edges', []): 184 | etype_dict[etype] = get_gcn_layer(encoder_type, in_feat, out_feat, layer, activation=activation) 185 | for etype in etypes: 186 | if etype not in etype_dict: 187 | etype_dict[etype] = get_gcn_layer('GIN', in_feat, out_feat, layer, activation=activation) 188 | elif encoder_type == 'RGIN': 189 | etype_dict = {} 190 | for etype in kwargs.get('special_edges', []): 191 | snapshot_types = kwargs['snapshot_types'] 192 | node_attr_types = kwargs['node_attr_types'] 193 | etype_dict[etype] = get_gcn_layer(encoder_type, in_feat, out_feat, layer, activation=activation, 194 | snapshot_types=snapshot_types, node_attr_types=node_attr_types) 195 | # print(etype_dict) 196 | for etype in etypes: 197 | if etype not in etype_dict: 198 | etype_dict[etype] = get_gcn_layer('GIN', in_feat, out_feat, layer, activation=activation) 199 | elif encoder_type == 'AGIN': 200 | etype_dict = {} 201 | # DAGIN 202 | if 'special_edges' in kwargs: 203 | for etype in kwargs['special_edges']['DAGIN']: 204 | etype_dict[etype] = get_gcn_layer('DAGIN', in_feat, out_feat, layer, activation=activation) 205 | # RGIN 206 | for etype in kwargs['special_edges']['RGIN']: 207 | snapshot_types = kwargs['snapshot_types'] 208 | node_attr_types = kwargs['node_attr_types'] 209 | etype_dict[etype] = get_gcn_layer('RGIN', in_feat, out_feat, layer, activation=activation, 210 | snapshot_types=snapshot_types, node_attr_types=node_attr_types) 211 | # print(etype_dict) 212 | for etype in etypes: 213 | if etype not in etype_dict: 214 | etype_dict[etype] = get_gcn_layer('GIN', in_feat, out_feat, layer, activation=activation) 215 | elif 'encoder_dict' in kwargs: 216 | etype_dict = {} 217 | encoder_dict = kwargs['encoder_dict'] 218 | for special_encoder_type in encoder_dict: 219 | for etype in encoder_dict[special_encoder_type]: 220 | etype_dict[etype] = get_gcn_layer(special_encoder_type, in_feat, out_feat, layer, 221 | activation=activation) 222 | print(etype_dict) 223 | for etype in etypes: 224 | if etype not in etype_dict: 225 | etype_dict[etype] = get_gcn_layer(encoder_type, in_feat, out_feat, layer, activation=activation) 226 | else: 227 | etype_dict = {etype: get_gcn_layer(encoder_type, in_feat, out_feat, layer, 228 | activation=activation) for etype in etypes} 229 | if etype_dict: 230 | self.conv = dglnn.HeteroGraphConv(etype_dict) 231 | self.encoder_type = encoder_type 232 | # print(self.encoder_type) 233 | self.fc_out = None 234 | if self.encoder_type == 'GAT': 235 | self.fc_out = dglnn.HeteroLinear({ntype: out_feat for ntype in ntypes}, out_feat) 236 | elif self.encoder_type in ['APPNP', 'GCN2']: 237 | self.fc_out = dglnn.HeteroLinear({ntype: in_feat for ntype in ntypes}, out_feat) 238 | 239 | def forward(self, graph, feat, edge_weight=None, **kwargs): 240 | ''' 241 | :param graph: 242 | :param feat: 243 | :param edge_weight: 244 | :param kwargs: 245 | :return: hidden_embs: {ntype: [N, out_feat]} 246 | ''' 247 | # print('wip') 248 | if self.encoder_type == 'HGT': 249 | # print(feat.shape) 250 | feat = self.conv(graph, feat, graph.ndata[dgl.NTYPE], graph.edata[dgl.ETYPE], presorted=True) 251 | else: 252 | feat = self.conv(graph, feat, edge_weight=edge_weight) 253 | if self.encoder_type == 'GAT': 254 | for ntype in feat: 255 | num_nodes = feat[ntype].shape[0] 256 | feat[ntype] = feat[ntype].reshape(num_nodes, -1) 257 | if self.fc_out: 258 | feat = self.fc_out(feat) 259 | return feat 260 | 261 | 262 | class StochasticKLayerRGCN(nn.Module): 263 | def __init__(self, in_feat, hidden_feat, out_feat, ntypes, etypes, k=2, residual=False, encoder_type='GCN'): 264 | super().__init__() 265 | self.residual = residual 266 | self.convs = nn.ModuleList() 267 | # print(encoder_type, in_feat, hidden_feat) 268 | # self.conv_start = dglnn.HeteroGraphConv({ 269 | # rel: get_gcn_layer(encoder_type, in_feat, hidden_feat, activation=nn.LeakyReLU()) 270 | # for rel in etypes 271 | # }) 272 | # self.conv_end = dglnn.HeteroGraphConv({ 273 | # rel: get_gcn_layer(encoder_type, hidden_feat, out_feat, activation=nn.LeakyReLU()) 274 | # for rel in etypes 275 | # }) 276 | # self.convs.append(self.conv_start) 277 | # for i in range(k - 2): 278 | # self.convs.append(dglnn.HeteroGraphConv({ 279 | # rel: get_gcn_layer(encoder_type, hidden_feat, hidden_feat, activation=nn.LeakyReLU()) 280 | # for rel in etypes 281 | # })) 282 | # self.convs.append(self.conv_end) 283 | print(ntypes) 284 | self.skip_fcs = nn.ModuleList() 285 | self.conv_start = BasicHeteroGCN(encoder_type, in_feat, hidden_feat, ntypes, etypes, layer=1, 286 | activation=nn.LeakyReLU()) 287 | self.skip_fc_start = dglnn.HeteroLinear({key: in_feat for key in ntypes}, hidden_feat) 288 | self.conv_end = BasicHeteroGCN(encoder_type, hidden_feat, out_feat, ntypes, etypes, layer=k, 289 | activation=nn.LeakyReLU()) 290 | self.skip_fc_end = dglnn.HeteroLinear({key: hidden_feat for key in ntypes}, out_feat) 291 | self.convs.append(self.conv_start) 292 | self.skip_fcs.append(self.skip_fc_start) 293 | for i in range(k - 2): 294 | self.convs.append(BasicHeteroGCN(encoder_type, hidden_feat, hidden_feat, ntypes, etypes, layer=i + 2, 295 | activation=nn.LeakyReLU())) 296 | self.skip_fcs.append(dglnn.HeteroLinear({key: hidden_feat for key in ntypes}, hidden_feat)) 297 | self.convs.append(self.conv_end) 298 | self.skip_fcs.append(self.skip_fc_end) 299 | 300 | # def forward(self, graph, x): 301 | # x = self.conv_start(graph, x) 302 | # x = self.conv_end(graph, x) 303 | # return x 304 | def forward(self, graph, x): 305 | # print(graph.device) 306 | # print(x.device) 307 | # count = 0 308 | for i in range(len(self.convs)): 309 | # print('-' * 30 + str(count) + '-' * 30) 310 | # print(x) 311 | if self.residual: 312 | # x = x + conv(graph, x) 313 | out = self.convs[i](graph, x) 314 | x = self.skip_fcs[i](x) 315 | for key in x: 316 | x[key] = x[key] + out[key] 317 | else: 318 | x = self.convs[i](graph, x) 319 | # print([torch.isnan(x[e]).sum() for e in x]) 320 | # print(x) 321 | # count += 1 322 | return x 323 | 324 | # def forward_block(self, blocks, x): 325 | # x = self.conv_start(blocks[0], x) 326 | # x = self.conv_end(blocks[1], x) 327 | # return x 328 | def forward_block(self, blocks, x): 329 | for i in range(len(self.convs)): 330 | x = self.convs[i](blocks[i], x) 331 | return x 332 | 333 | 334 | class CustomRGCN(nn.Module): 335 | def __init__(self, in_feat, hidden_feat, out_feat, ntypes, etypes, k=2, residual=False, encoder_type='GCN', 336 | return_list=False, **kwargs): 337 | super().__init__() 338 | self.residual = residual 339 | self.convs = nn.ModuleList() 340 | self.conv_start = CustomHeteroGCN(encoder_type, in_feat, hidden_feat, ntypes, etypes, layer=1, 341 | activation=nn.LeakyReLU(), **kwargs) 342 | self.conv_end = CustomHeteroGCN(encoder_type, hidden_feat, out_feat, ntypes, etypes, layer=k, 343 | activation=nn.LeakyReLU(), **kwargs) 344 | self.convs.append(self.conv_start) 345 | for i in range(k - 2): 346 | self.convs.append(CustomHeteroGCN(encoder_type, hidden_feat, hidden_feat, ntypes, etypes, layer=i + 2, 347 | activation=nn.LeakyReLU(), **kwargs)) 348 | self.convs.append(self.conv_end) 349 | self.return_list = return_list 350 | 351 | def forward(self, graph, x): 352 | output_list = [] 353 | for conv in self.convs: 354 | if self.residual: 355 | x = x + conv(graph, x) 356 | else: 357 | x = conv(graph, x) 358 | if self.return_list: 359 | output_list.append(x) 360 | return x, output_list 361 | 362 | 363 | class HTSGCNLayer(nn.Module): 364 | def __init__(self, in_feat, hidden_feat, out_feat, ntypes, etypes, k=2, time_length=10, encoder_type='GCN', 365 | return_list=False): 366 | super().__init__() 367 | # T * k layers 368 | self.k = k 369 | self.time_length = time_length 370 | # self.spatial_encoders = nn.ModuleList( 371 | # [StochasticKLayerRGCN(in_feat, hidden_feat, out_feat, rel_names, k=k) for i in range(time_length)]) 372 | self.spatial_encoders = nn.ModuleList( 373 | [CustomRGCN(in_feat, hidden_feat, out_feat, ntypes, etypes, k=k, 374 | encoder_type=encoder_type, return_list=return_list) for i in range(time_length)]) 375 | self.time_encoders = nn.ModuleList() 376 | for i in range(k - 1): 377 | self.time_encoders.append(HTEncoder(hidden_feat, hidden_feat, activation=torch.relu)) 378 | self.time_encoders.append(HTEncoder(out_feat, out_feat, activation=torch.relu)) 379 | 380 | # def forward(self, graph, x): 381 | # x = self.conv_start(graph, x) 382 | # x = self.conv_end(graph, x) 383 | # return x 384 | def forward(self, graph_list, time_adj): 385 | for i in range(self.k): 386 | # print(i) 387 | for t in range(self.time_length): 388 | # 注意这里都是每年batch好的图 389 | # try: 390 | graph_list[t].dstdata['h'] = self.spatial_encoders[t].convs[i](graph_list[t], 391 | graph_list[t].srcdata['h']) 392 | # except Exception as e: 393 | # print(e) 394 | # print(graph_list[t]) 395 | # print(graph_list[t].batch_size) 396 | # raise Exception 397 | # print(graph_list[t].dstdata['h']['snapshot'].shape) 398 | graph_list = self.time_encoders[i](graph_list, time_adj) 399 | return graph_list 400 | 401 | 402 | class HTEncoder(nn.Module): 403 | def __init__(self, in_feats, out_feats, bias=False, activation=None): 404 | super(HTEncoder, self).__init__() 405 | self.in_feats = in_feats 406 | self.out_feats = out_feats 407 | self.l_weight = nn.Parameter(torch.Tensor(in_feats, out_feats)) 408 | self.r_weight = nn.Parameter(torch.Tensor(in_feats, out_feats)) 409 | self.l_bias = None 410 | self.r_bias = None 411 | self.fc_out = nn.Linear(2 * out_feats, out_feats, bias=bias) 412 | self.activation = activation 413 | if bias: 414 | self.l_bias = nn.Parameter(torch.Tensor(out_feats)) 415 | self.r_bias = nn.Parameter(torch.Tensor(out_feats)) 416 | self.init_weights() 417 | 418 | def init_weights(self): 419 | if self.l_weight is not None: 420 | nn.init.xavier_uniform_(self.l_weight) 421 | nn.init.xavier_uniform_(self.r_weight) 422 | if self.l_bias is not None: 423 | nn.init.zeros_(self.l_bias) 424 | nn.init.zeros_(self.r_bias) 425 | 426 | def forward(self, snapshots, time_adj): 427 | # T[B], [B, 2, T, T] 428 | # print('wip') 429 | emb_list = [] 430 | for snapshot in snapshots: 431 | # print(snapshot.nodes['snapshot'].data['h'].shape) 432 | emb_list.append(snapshot.nodes['snapshot'].data['h']) 433 | time_emb = torch.stack(emb_list, dim=1) # [B, T, d_in] 434 | l_emb = torch.matmul(torch.matmul(time_adj[:, 0], time_emb), self.l_weight) 435 | r_emb = torch.matmul(torch.matmul(time_adj[:, 1], time_emb), self.r_weight) 436 | out_emb = torch.cat((l_emb, r_emb), dim=-1) 437 | if self.activation: 438 | out_emb = self.activation(out_emb) 439 | time_out = self.fc_out(out_emb) # [B, T, d_out] 440 | for i in range(len(snapshots)): 441 | snapshots[i].nodes['snapshot'].data['h'] = time_out[:, i] 442 | return snapshots 443 | 444 | 445 | if __name__ == '__main__': 446 | start_time = datetime.datetime.now() 447 | parser = argparse.ArgumentParser(description='Process some description.') 448 | parser.add_argument('--phase', default='test', help='the function name.') 449 | 450 | args = parser.parse_args() 451 | rel_names = ['cites', 'is cited by', 'writes', 'is writen by', 'publishes', 'is published by'] 452 | model = StochasticKLayerRGCN(768, 300, 300, rel_names) 453 | print(model) 454 | -------------------------------------------------------------------------------- /layers/GCN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__init__.py -------------------------------------------------------------------------------- /layers/GCN/__pycache__/CL_layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__pycache__/CL_layers.cpython-39.pyc -------------------------------------------------------------------------------- /layers/GCN/__pycache__/DTGCN_layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__pycache__/DTGCN_layers.cpython-39.pyc -------------------------------------------------------------------------------- /layers/GCN/__pycache__/RGCN.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__pycache__/RGCN.cpython-39.pyc -------------------------------------------------------------------------------- /layers/GCN/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /layers/GCN/__pycache__/custom_gcn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/GCN/__pycache__/custom_gcn.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/__init__.py -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/common_layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/layers/__pycache__/common_layers.cpython-39.pyc -------------------------------------------------------------------------------- /layers/common_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class PositionalEncodingSC(nn.Module): 7 | def __init__(self, d_model, dropout=0.5, max_len=5000): 8 | super(PositionalEncodingSC, self).__init__() 9 | self.dropout = nn.Dropout(p=dropout) 10 | 11 | position = torch.arange(max_len).unsqueeze(1) 12 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 13 | pe = torch.zeros(max_len, 1, d_model) 14 | pe[:, 0, 0::2] = torch.sin(position * div_term) 15 | pe[:, 0, 1::2] = torch.cos(position * div_term) 16 | self.register_buffer('pe', pe) 17 | 18 | def forward(self, x): 19 | """ 20 | :param x: Tensor, shape [seq_len, batch_size, embedding_dim] 21 | :return Tensor: 22 | """ 23 | x = x + self.pe[:x.size(0)] 24 | out = self.dropout(x) 25 | return out 26 | 27 | 28 | class PositionalEncoding(nn.Module): 29 | def __init__(self, d_model, dropout=0.5, max_len=5000): 30 | super(PositionalEncoding, self).__init__() 31 | self.dropout = nn.Dropout(p=dropout) 32 | 33 | position = torch.arange(max_len).unsqueeze(1) 34 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 35 | pe = torch.zeros(max_len, 1, d_model) 36 | pe[:, 0, 0::2] = torch.sin(position * div_term) 37 | pe[:, 0, 1::2] = torch.cos(position * div_term) 38 | pe = pe.squeeze(dim=1) 39 | self.register_buffer('pe', pe) 40 | 41 | def forward(self, x): 42 | """ 43 | :param x: Tensor, shape [batch_size, seq_len] 44 | :return Tensor: 45 | """ 46 | # x = x + self.pe[:x.size(0)] 47 | # out = self.dropout(x) 48 | out = self.pe[x] 49 | return out 50 | 51 | 52 | class TemporalEncoding(nn.Module): 53 | def __init__(self, d_model, dropout=0, max_len=5000, random=True): 54 | super(TemporalEncoding, self).__init__() 55 | self.dropout = nn.Dropout(p=dropout) 56 | 57 | position = torch.arange(max_len).unsqueeze(1) # t 58 | div_term = 1 / math.sqrt(d_model) 59 | if random: 60 | w = torch.randn((1, d_model)) 61 | else: 62 | w = torch.arange(0, d_model, 1).unsqueeze(0) / d_model 63 | # temp = nn.Parameter(position * w) 64 | # print(position.shape, w.shape) 65 | temp = position * w 66 | te = torch.cos(temp) * torch.cos(temp) * div_term 67 | # print(te) 68 | self.register_buffer('te', te) 69 | 70 | def forward(self, x): 71 | """ 72 | :param x: Tensor, shape [batch_size, seq_len] 73 | :return Tensor: 74 | """ 75 | # print(self.te.device) 76 | out = self.te[x] 77 | return out 78 | 79 | 80 | class AttentionModule(nn.Module): 81 | def __init__(self, dim_in, dim_out=None, activation=nn.Tanh()): 82 | super(AttentionModule, self).__init__() 83 | if dim_out is None: 84 | dim_out = dim_in 85 | # self.trans = nn.Linear(dim_in, dim_out) 86 | # self.seq_trans = nn.Linear(dim_in, dim_out) 87 | self.target_trans = MLP(dim_in, dim_out) 88 | self.seq_trans = MLP(dim_in, dim_out) 89 | # self.w = nn.Linear(dim_out, 1, bias=False) 90 | self.w = nn.Linear(dim_out, dim_out, bias=False) 91 | self.u = nn.Linear(dim_out, dim_out, bias=False) 92 | self.activation = activation 93 | 94 | def forward(self, target, sequence, mask): 95 | # batch_size = target.shape[0] 96 | target = self.target_trans(target) 97 | # print(sequence.shape) 98 | sequence = self.seq_trans(sequence) 99 | 100 | # attn = self.activation(sequence @ self.w(target).unsqueeze(dim=-1)) + mask.unsqueeze(dim=-1).detach() 101 | # print((self.u(sequence) @ self.w(target).unsqueeze(dim=-1)).shape) 102 | # attn = self.activation(self.u(sequence) @ self.w(target).unsqueeze(dim=-1)) + mask.unsqueeze(dim=-1).detach() 103 | attn = self.u(sequence) @ self.w(target).unsqueeze(dim=-1) + mask.unsqueeze(dim=-1).detach() 104 | # attn = attn.view(batch_size, -1) 105 | # attn[mask] = -1e8 106 | attn = torch.softmax(attn, dim=-2) 107 | 108 | out = (sequence * attn).sum(1) 109 | # print(out.shape) 110 | return out 111 | 112 | 113 | class GateModule(nn.Module): 114 | def __init__(self, input_dim, out_dim): 115 | super(GateModule, self).__init__() 116 | self.input_trans = MLP(input_dim, input_dim) 117 | self.target_trans = MLP(input_dim, input_dim) 118 | self.mix_trans = MLP(input_dim * 2, out_dim) 119 | self.sigmoid = nn.Sigmoid() 120 | 121 | def forward(self, input_emb, target_emb): 122 | alpha = self.sigmoid(self.input_trans(input_emb) + self.target_trans(target_emb)) 123 | input_emb = alpha * input_emb 124 | mix_out = self.mix_trans(torch.cat((input_emb, target_emb), dim=-1)) 125 | return mix_out 126 | 127 | 128 | class MLP(nn.Module): 129 | def __init__(self, dim_in, dim_out, bias=True, dim_inner=None, 130 | num_layers=2, activation=nn.ReLU(inplace=False), **kwargs): 131 | super(MLP, self).__init__() 132 | layers = [] 133 | dim_inner = dim_in if dim_inner is None else dim_inner 134 | if num_layers > 1: 135 | for i in range(num_layers - 1): 136 | # print(dim_in, dim_inner, bias) 137 | layers.append(nn.Linear(dim_in, dim_inner, bias)) 138 | if activation: 139 | layers.append(activation) 140 | layers.append(nn.Linear(dim_inner, dim_out, bias)) 141 | else: 142 | layers.append(nn.Linear(dim_in, dim_out, bias)) 143 | self.model = nn.Sequential(*layers) 144 | 145 | def forward(self, batch): 146 | batch = self.model(batch) 147 | return batch 148 | 149 | 150 | class CapsuleNetwork(nn.Module): 151 | def __init__(self, dim_in, dim_out, bias=True, num_routing=3, num_caps=4, return_prob=False, 152 | activation=nn.ReLU(), **kwargs): 153 | # projection and routing 154 | super(CapsuleNetwork, self).__init__() 155 | dim_capsule = dim_out // num_caps 156 | self.dim_out = dim_out 157 | self.model = OriginCaps(in_dim=dim_in, in_caps=1, num_caps=num_caps, dim_caps=dim_capsule, 158 | num_routing=num_routing) 159 | self.bn_out = nn.BatchNorm1d(dim_capsule) # bn is important for capsule 160 | self.fc_out = MLP(dim_out, dim_out, bias=bias, activation=activation) 161 | self.return_prob = return_prob 162 | 163 | def forward(self, x): 164 | # [N, dim_in] 165 | # x = self.bn_in(x) 166 | x = x.unsqueeze(dim=1) 167 | x, p = self.model(x) # [N, Nc, dc] p是capsule probability 168 | # print(p) 169 | x = self.bn_out(x.transpose(1, 2)).transpose(1, 2) 170 | x = self.fc_out(x.reshape(-1, self.dim_out)) 171 | # x = self.bn_out(x) 172 | if self.return_prob: 173 | return x, p 174 | else: 175 | return x 176 | 177 | 178 | def squash(x, dim=-1): 179 | # squared_norm = (x ** 2).sum(dim=dim, keepdim=True) 180 | squared_norm = torch.square(x).sum(dim=dim, keepdim=True) 181 | scale = squared_norm / (1 + squared_norm) 182 | # vector = scale * x / (squared_norm.sqrt() + 1e-8) 183 | vector = scale * x / torch.sqrt(squared_norm + 1e-11) 184 | return vector, scale 185 | 186 | 187 | class OriginCaps(nn.Module): 188 | """Original capsule graph_encoder.""" 189 | 190 | def __init__(self, in_dim, in_caps, num_caps, dim_caps, num_routing): 191 | """ 192 | Initialize the graph_encoder. 193 | 194 | Args: 195 | in_dim: Dimensionality (i.e. length) of each capsule vector. 196 | in_caps: Number of input capsules if digits graph_encoder. 197 | num_caps: Number of capsules in the capsule graph_encoder 198 | dim_caps: Dimensionality, i.e. length, of the output capsule vector. 199 | num_routing: Number of iterations during routing algorithm 200 | """ 201 | super(OriginCaps, self).__init__() 202 | self.in_dim = in_dim 203 | self.in_caps = in_caps 204 | self.num_caps = num_caps 205 | self.dim_caps = dim_caps 206 | self.num_routing = num_routing 207 | # self.W = nn.Parameter(0.01 * torch.randn(1, num_caps, in_caps, dim_caps, in_dim), 208 | # requires_grad=True) 209 | self.W = nn.Parameter(torch.empty(1, num_caps, in_caps, dim_caps, in_dim), 210 | requires_grad=True) 211 | self.init_weights() 212 | 213 | def init_weights(self): 214 | nn.init.uniform_(self.W, -0.1, 0.1) 215 | 216 | def forward(self, x): 217 | batch_size = x.size(0) 218 | # (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1) 219 | x = x.unsqueeze(1).unsqueeze(4) 220 | # print('x', torch.isnan(x).sum()) 221 | # 222 | # W @ x = 223 | # (1, num_caps, in_caps, dim_caps, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) = 224 | # (batch_size, num_caps, in_caps, dim_caps, 1) 225 | u_hat = torch.matmul(self.W, x) 226 | # (batch_size, num_caps, in_caps, dim_caps) 227 | u_hat = u_hat.squeeze(-1) 228 | # detach u_hat during routing iterations to prevent gradients from flowing 229 | temp_u_hat = u_hat.detach() 230 | 231 | b = torch.zeros(batch_size, self.num_caps, self.in_caps, 1, requires_grad=False).to(x.device) 232 | 233 | for route_iter in range(self.num_routing - 1): 234 | # (batch_size, num_caps, in_caps, 1) -> Softmax along num_caps 235 | c = b.softmax(dim=1) 236 | 237 | # element-wise multiplication 238 | # (batch_size, num_caps, in_caps, 1) * (batch_size, in_caps, num_caps, dim_caps) -> 239 | # (batch_size, num_caps, in_caps, dim_caps) sum across in_caps -> 240 | # (batch_size, num_caps, dim_caps) 241 | s = (c * temp_u_hat).sum(dim=2) 242 | # apply "squashing" non-linearity along dim_caps 243 | v, s = squash(s) 244 | # print('v', torch.isnan(v).sum()) 245 | # dot product agreement between the current output vj and the prediction uj|i 246 | # (batch_size, num_caps, in_caps, dim_caps) @ (batch_size, num_caps, dim_caps, 1) 247 | # -> (batch_size, num_caps, in_caps, 1) 248 | uv = torch.matmul(temp_u_hat, v.unsqueeze(-1)) 249 | b = b + uv 250 | 251 | # last iteration is done on the original u_hat, without the routing weights update 252 | c = b.softmax(dim=1) 253 | # print(c.shape, u_hat.shape) 254 | s = (c * u_hat).sum(dim=2) 255 | # apply "squashing" non-linearity along dim_caps 256 | v, p = squash(s) 257 | # print('last_v', torch.isnan(x).sum(), torch.isnan(v).sum(), torch.isnan(c).sum(), torch.isnan(s).sum()) 258 | # print(x.squeeze(1)) 259 | # print(v) 260 | # if torch.isnan(v).sum() > 0: 261 | # raise Exception 262 | # print(s) 263 | 264 | return v, p 265 | 266 | if __name__ == '__main__': 267 | pe = PositionalEncoding(128) 268 | print(pe.pe.shape) 269 | print(pe.pe[[1, 1, 2]]) 270 | x = torch.randint(10, (3, 5)) 271 | print(pe(x).shape) 272 | # x = torch.tensor([1, 5, 10, 4, 2, 10]) 273 | # te = TemporalEncoding(128) 274 | # print(te(x).shape) 275 | # print([torch.norm(te(x)[i]) for i in range(x.shape[0])]) 276 | # te = TemporalEncoding(128, random=False) 277 | # print(te(x).shape) 278 | # print([torch.norm(te(x)[i]) for i in range(x.shape[0])]) 279 | 280 | # model = OriginCaps(in_dim=256, in_caps=1, num_caps=4, dim_caps=64, num_routing=3) 281 | # # (batch_size, in_caps, in_dim) 282 | # x = torch.randn(10, 1, 256) 283 | # print(model(x)[0].shape) 284 | # # print(model(x).norm(2)) 285 | # cn = CapsuleNetwork(dim_in=256, dim_out=256) 286 | # x = torch.randn(8, 256) 287 | # print(cn(x)) 288 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/models/__pycache__/base_model.cpython-39.pyc -------------------------------------------------------------------------------- /our_models/CL.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import dgl 5 | import pandas as pd 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.utils.rnn import pad_sequence 9 | 10 | from layers.GCN.CL_layers import * 11 | from models.base_model import BaseModel 12 | from utilis.scripts import custom_node_unbatch 13 | 14 | 15 | class CLWrapper(BaseModel): 16 | def __init__(self, encoder: BaseModel, cl_type='simple', aug_type=None, tau=0.4, cl_weight=0.5, aug_rate=0.1, **kwargs): 17 | super(CLWrapper, self).__init__(0, 0, 1, 0, None, 0, None, **kwargs) 18 | self.encoder = encoder 19 | self.aux_criterion = encoder.aux_criterion 20 | self.aug_rate = aug_rate 21 | print('aug_rate:', self.aug_rate) 22 | if aug_type == 'md': 23 | self.aug_methods = (('attr_mask',), ('edge_drop',)) 24 | self.aug_configs = ({ 25 | 'attr_mask': { 26 | 'prob': self.aug_rate, 27 | 'types': ['paper', 'author', 'journal', 'time'] 28 | } 29 | }, { 30 | 'edge_drop': { 31 | 'prob': self.aug_rate, 32 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes", 33 | "is published by", 34 | "shows", "is shown in"] 35 | } 36 | }) 37 | elif aug_type == 'dmd': 38 | self.aug_methods = (('attr_mask', 'edge_drop'), ('attr_mask', 'edge_drop')) 39 | self.aug_configs = ({ 40 | 'attr_mask': { 41 | 'prob': self.aug_rate, 42 | 'types': ['paper', 'author', 'journal', 'time'] 43 | }, 44 | 'edge_drop': { 45 | 'prob': self.aug_rate, 46 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes", 47 | "is published by", 48 | "shows", "is shown in"] 49 | } 50 | 51 | }, { 52 | 'attr_mask': { 53 | 'prob': self.aug_rate, 54 | 'types': ['paper', 'author', 'journal', 'time'] 55 | }, 56 | 'edge_drop': { 57 | 'prob': self.aug_rate, 58 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes", 59 | "is published by", 60 | "shows", "is shown in"] 61 | } 62 | }) 63 | elif aug_type == 'dd': 64 | self.aug_methods = (('attr_drop', 'edge_drop'), ('attr_drop', 'edge_drop')) 65 | self.aug_configs = ({ 66 | 'attr_drop': { 67 | 'prob': self.aug_rate, 68 | 'types': ['paper', 'author', 'journal', 'time'] 69 | }, 70 | 'edge_drop': { 71 | 'prob': self.aug_rate, 72 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes", 73 | "is published by", 74 | "shows", "is shown in"] 75 | } 76 | }, { 77 | 'edge_drop': { 78 | 'prob': self.aug_rate, 79 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes", 80 | "is published by", 81 | "shows", "is shown in"] 82 | }, 83 | 'attr_drop': { 84 | 'prob': self.aug_rate, 85 | 'types': ['paper', 'author', 'journal', 'time'] 86 | }, 87 | }) 88 | else: 89 | aug_type = 'normal' 90 | self.aug_methods = (('attr_drop',), ('edge_drop',)) 91 | self.aug_configs = ({ 92 | 'attr_drop': { 93 | 'prob': self.aug_rate, 94 | 'types': ['paper', 'author', 'journal', 'time'] 95 | } 96 | }, { 97 | 'edge_drop': { 98 | 'prob': self.aug_rate, 99 | 'types': ["cites", "is cited by", "writes", "is writen by", "publishes", 100 | "is published by", 101 | "shows", "is shown in"] 102 | } 103 | }) 104 | print(self.aug_methods) 105 | self.model_name = self.encoder.model_name + '_CL_' + cl_type + '_' + aug_type 106 | if self.aug_rate != 0.1: 107 | self.model_name += '_p{}'.format(self.aug_rate) 108 | self.similarity = 'cos' 109 | self.cl_type = cl_type 110 | self.encoder.cl_type = self.cl_type 111 | self.tau = tau 112 | if self.tau != 0.4: 113 | self.model_name += '_t{}'.format(self.tau) 114 | self.cross_weight = 0.1 115 | self.cl_weight = cl_weight 116 | if self.cl_weight != 0.5: 117 | self.model_name += '_cw{}'.format(self.cl_weight) 118 | self.df_hn = None 119 | if cl_type in ['simple', 'cross_snapshot']: 120 | self.cl_ntypes = self.encoder.ntypes 121 | elif cl_type.startswith('focus'): 122 | self.cl_ntypes = [self.encoder.pred_type] 123 | elif cl_type.endswith('hard_negative'): 124 | self.cl_ntypes = self.encoder.ntypes 125 | print('wip') 126 | # self.cl_modules = nn.ModuleDict({ntype: CommonCLLayer(self.encoder.out_dim, self.encoder.out_dim, tau=self.tau) 127 | # for ntype in self.cl_ntypes}) 128 | self.proj_heads = nn.ModuleDict({ntype: ProjectionHead(self.encoder.out_dim, self.encoder.out_dim) for 129 | ntype in self.cl_ntypes}) 130 | print('cur cl mode:', cl_type) 131 | print('weights:', self.cl_weight, self.cross_weight) 132 | 133 | def load_graphs(self, graphs): 134 | dealt_graphs = self.encoder.load_graphs(graphs) 135 | dealt_graphs = self.encoder.aug_graphs(dealt_graphs, self.aug_configs, self.aug_methods) 136 | df_hn = pd.read_csv('./data/{}/hard_negative.csv'.format(graphs['data_source']), index_col=0) 137 | paper_trans = graphs['node_trans']['paper'] 138 | df_hn.index = [paper_trans[str(paper)] for paper in df_hn.index] 139 | for column in df_hn: 140 | df_hn[column] = df_hn[column].apply(lambda x: [paper_trans[str(paper)] for paper in eval(x)]) 141 | self.df_hn = df_hn 142 | return dealt_graphs 143 | 144 | def graph_to_device(self, graphs): 145 | return self.encoder.graph_to_device(graphs) 146 | 147 | def deal_graphs(self, graphs, data_path=None, path=None, log_path=None): 148 | return self.encoder.deal_graphs(graphs, data_path, path, log_path) 149 | 150 | def get_aux_pred(self, pred): 151 | return self.encoder.get_aux_pred(pred) 152 | 153 | def get_new_gt(self, gt): 154 | return self.encoder.get_new_gt(gt) 155 | 156 | def get_aux_loss(self, pred, gt): 157 | return self.encoder.get_aux_loss(pred, gt) 158 | 159 | def get_cl_inputs(self, snapshots, require_grad=True): 160 | if require_grad: 161 | if type(snapshots) == list: 162 | for snapshot in snapshots: 163 | for ntype in self.cl_ntypes: 164 | snapshot.nodes[ntype].data['h'] = self.proj_heads[ntype](snapshot.nodes[ntype].data['h']) 165 | else: 166 | for ntype in self.cl_ntypes: 167 | snapshots.nodes[ntype].data['h'] = self.proj_heads[ntype](snapshots.nodes[ntype].data['h']) 168 | else: 169 | if type(snapshots) == list: 170 | for snapshot in snapshots: 171 | for ntype in self.cl_ntypes: 172 | snapshot.nodes[ntype].data['h'] = self.proj_heads[ntype](snapshot.nodes[ntype].data['h']).detach() 173 | else: 174 | for ntype in self.cl_ntypes: 175 | snapshots.nodes[ntype].data['h'] = self.proj_heads[ntype](snapshots.nodes[ntype].data['h']).detach() 176 | return snapshots 177 | 178 | def cl(self, s1, s2, ids): 179 | # cl_loss_list = [] 180 | cl_loss = 0 181 | z1_snapshots, z2_snapshots = [], [] 182 | z1_ids, z2_ids = [], [] 183 | if self.cl_type in ['simple', 'focus', 'cross_snapshot', 'focus_cross_snapshot']: 184 | for ntype in self.cl_ntypes: 185 | cur_loss = 0 186 | for t in range(self.encoder.time_length): 187 | if 'HTSGCN' in self.encoder.model_name: 188 | if ntype == 'snapshot': 189 | s1_index = torch.where(s1[t].nodes[ntype].data['mask'] > 0) 190 | s2_index = torch.where(s2[t].nodes[ntype].data['mask'] > 0) 191 | z1 = s1[t].nodes[ntype].data['h'][s1_index] 192 | z2 = s2[t].nodes[ntype].data['h'][s2_index] 193 | else: 194 | z1 = s1[t].nodes[ntype].data['h'] 195 | z2 = s2[t].nodes[ntype].data['h'] 196 | elif 'CTSGCN' in self.encoder.model_name: 197 | s1_index = torch.where(s1.nodes[ntype].data['snapshot_idx'] == t) 198 | s2_index = torch.where(s2.nodes[ntype].data['snapshot_idx'] == t) 199 | # print(s1_index) 200 | z1 = s1.nodes[ntype].data['h'][s1_index] 201 | z2 = s2.nodes[ntype].data['h'][s2_index] 202 | # print(z1.shape) 203 | if (self.cl_type.endswith('cross_snapshot')) and (ntype == 'snapshot'): 204 | z1_snapshots.append(z1) 205 | z2_snapshots.append(z2) 206 | z1_ids.append(s1.nodes[ntype].data['target'][s1_index]) 207 | z2_ids.append(s2.nodes[ntype].data['target'][s2_index]) 208 | # cur_list.append(self.cl_modules[ntype](z1, z2)) 209 | # temp_loss = self.cl_modules[ntype](z1, z2) 210 | temp_loss = simple_cl(z1, z2, self.tau) 211 | cur_loss = cur_loss + temp_loss 212 | # print(self.cl_modules[ntype](z1, z2)) 213 | # if torch.isnan(temp_loss): 214 | # print(z1) 215 | # print(torch.isnan(z1).sum()) 216 | # print(z2) 217 | # print(torch.isnan(z2).sum()) 218 | # cl_loss_list.append() 219 | cl_loss = cl_loss + cur_loss / self.encoder.time_length 220 | cl_loss = cl_loss / len(self.cl_ntypes) 221 | elif self.cl_type.endswith('hard_negative'): 222 | for ntype in self.cl_ntypes: 223 | cur_loss = 0 224 | p1_index = {} 225 | p2_index = {} 226 | for paper in ids: 227 | p1_index[paper.item()] = s1.nodes[ntype].data['target'] == paper.item() 228 | p2_index[paper.item()] = s2.nodes[ntype].data['target'] == paper.item() 229 | # print(p1_index.keys()) 230 | for t in range(self.encoder.time_length): 231 | # start_time = time.time() 232 | cur_z1s = [] 233 | cur_z2s = [] 234 | len_z1s = [] 235 | # len_z2s = [] 236 | t1_index = s1.nodes[ntype].data['snapshot_idx'] == t 237 | t2_index = s2.nodes[ntype].data['snapshot_idx'] == t 238 | for paper in ids: 239 | s1_index = torch.where(t1_index & p1_index[paper.item()]) 240 | s2_index = torch.where(t2_index & p2_index[paper.item()]) 241 | z1 = s1.nodes[ntype].data['h'][s1_index] 242 | z2 = s2.nodes[ntype].data['h'][s2_index] 243 | cur_z1s.append(z1) 244 | cur_z2s.append(z2) 245 | cur_len = len(z1) if len(z1) > 0 else 1 246 | len_z1s.append(cur_len) 247 | # len_z2s.append(len(z2)) 248 | cur_z1s = pad_sequence(cur_z1s, batch_first=True) 249 | cur_z2s = pad_sequence(cur_z2s, batch_first=True) 250 | len_z1s = torch.tensor(len_z1s).to(cur_z1s.device) 251 | # len_z2s = torch.tensor(len_z2s).to(cur_z2s.device) 252 | # print('index_time: {:.3f}'.format(time.time() - start_time)) 253 | temp_loss = simple_cl(cur_z1s, cur_z2s, self.tau, valid_lens=len_z1s) 254 | # print('loss_time: {:.3f}'.format(time.time() - start_time)) 255 | cur_loss = cur_loss + temp_loss 256 | # print(ntype, cur_loss / self.encoder.time_length) 257 | cl_loss = cl_loss + cur_loss / self.encoder.time_length 258 | cl_loss = cl_loss / len(self.cl_ntypes) 259 | 260 | if self.cl_type.endswith('cross_snapshot'): 261 | snapshot_loss = 0 262 | z1_dict = {paper.item(): [] for paper in ids} 263 | z2_dict = {paper.item(): [] for paper in ids} 264 | # print(z1_dict) 265 | for i in range(len(z1_snapshots)): 266 | cur_z1 = z1_snapshots[i] 267 | cur_z2 = z2_snapshots[i] 268 | cur_z1_ids = z1_ids[i] 269 | cur_z2_ids = z2_ids[i] 270 | for j in range(len(cur_z1_ids)): 271 | z1_dict[cur_z1_ids[j].item()].append(cur_z1[j]) 272 | z2_dict[cur_z2_ids[j].item()].append(cur_z2[j]) 273 | for paper in ids: 274 | # print(torch.stack(z1_dict[paper.item()], dim=0).shape) 275 | snapshot_loss = snapshot_loss + simple_cl(torch.stack(z1_dict[paper.item()], dim=0), 276 | torch.stack(z2_dict[paper.item()], dim=0), 277 | self.tau) 278 | 279 | cl_loss = cl_loss + self.cross_weight * snapshot_loss 280 | # print(cl_loss) 281 | return cl_loss 282 | 283 | def forward(self, content, lengths, masks, ids, graph, times, **kwargs): 284 | # cur_time = time.time() 285 | # g1_snapshots, g2_snapshots, cur_snapshot_adj = self.get_graphs(ids, graph) 286 | # g1_inputs = self.encoder.get_graphs(ids, graph, aug=(self.aug_methods[0], self.aug_configs[0])) 287 | # g2_inputs = self.encoder.get_graphs(ids, graph, aug=(self.aug_methods[1], self.aug_configs[1])) 288 | if self.cl_type.endswith('aug_hard_negative'): 289 | g1_inputs, g2_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times, phase='train', 290 | df_hn=self.df_hn) 291 | else: 292 | g1_inputs, g2_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times, phase='train') 293 | # logging.info('aug done:{:.3f}s'.format(time.time() - cur_time)) 294 | 295 | # cur_time = time.time() 296 | # g1_out, g1_snapshots = self.encode(g1_inputs) 297 | # g2_out, g2_snapshots = self.encode(g2_inputs) 298 | g1_out, g1_snapshots = self.encoder.encode(g1_inputs) 299 | g2_out, g2_snapshots = self.encoder.encode(g2_inputs) 300 | # logging.info('encode done:{:.3f}s'.format(time.time() - cur_time)) 301 | 302 | # cur_time = time.time() 303 | g1_snapshots, g2_snapshots = self.get_cl_inputs(g1_snapshots), self.get_cl_inputs(g2_snapshots) 304 | self.other_loss = self.cl_weight * self.cl(g1_snapshots, g2_snapshots, ids) 305 | # logging.info('cl done:{:.3f}s'.format(time.time() - cur_time)) 306 | 307 | time_out = (g1_out + g2_out) / 2 308 | output = self.encoder.decode(time_out) 309 | # print(time_out.shape) 310 | # logging.info('=' * 20 + 'finish' + '=' * 20) 311 | return output 312 | 313 | def predict(self, content, lengths, masks, ids, graph, times, **kwargs): 314 | # print(final_out.shape) 315 | inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times) 316 | # print(inputs) 317 | time_out, _ = self.encoder.encode(inputs) 318 | output = self.encoder.decode(time_out) 319 | print(output[0].shape) 320 | return output 321 | 322 | def show(self, content, lengths, masks, ids, graph, times, **kwargs): 323 | output = self.encoder.show(content, lengths, masks, ids, graph, times, **kwargs) 324 | return output 325 | 326 | 327 | def get_unbatched_edge_graphs(batched_graph, ntype, etype): 328 | subgraph = dgl.edge_type_subgraph(batched_graph, etypes=[etype]) 329 | subgraph.set_batch_num_nodes(batched_graph.batch_num_nodes(ntype)) 330 | # print(batched_graph.batch_num_edges('cites')) 331 | subgraph.set_batch_num_edges(batched_graph.batch_num_edges(etype)) 332 | # print(subgraph) 333 | # print(dgl.unbatch(subgraph)) 334 | return dgl.unbatch(subgraph) 335 | -------------------------------------------------------------------------------- /our_models/DCTSGCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dgl 3 | 4 | from layers.GCN.DTGCN_layers import DCTSGCNLayer, SimpleDCTSGCNLayer 5 | from our_models.CTSGCN import CTSGCN 6 | 7 | 8 | class DCTSGCN(CTSGCN): 9 | def __init__(self, vocab_size, embed_dim, num_classes, pad_index, word2vec=None, dropout=0.5, n_layers=3, 10 | pred_type='paper', time_pooling='mean', hop=2, linear_before_gcn=False, encoder_type='GCN', 11 | time_length=10, hidden_dim=300, out_dim=256, ntypes=None, etypes=None, edge_sequences=None, 12 | graph_type=None, gcn_out='last', snapshot_weight=True, hn=None, hn_method=None, max_time_length=5, 13 | **kwargs): 14 | super(DCTSGCN, self).__init__(vocab_size, embed_dim, num_classes, pad_index, word2vec, dropout, n_layers, 15 | pred_type, time_pooling, hop, linear_before_gcn, 'GCN', 16 | time_length, hidden_dim, out_dim, ntypes, etypes, graph_type, gcn_out, hn, 17 | hn_method, max_time_length, 18 | **kwargs) 19 | self.model_name = self.model_name.replace('CTSGCN', 'DCTSGCN').replace('_GCN', '_' + encoder_type) + \ 20 | '_{}'.format(edge_sequences) 21 | self.model_type = 'DCTSGCN' 22 | if snapshot_weight: 23 | self.model_name += '_weighted' 24 | if edge_sequences == 'meta': 25 | edge_sequences = [[('meta', [('author', 'writes', 'paper'), 26 | ('journal', 'publishes', 'paper'), 27 | ('time', 'shows', 'paper'), 28 | ('paper', 'is writen by', 'author'), 29 | ('paper', 'is published by', 'journal'), 30 | ('paper', 'is shown in', 'time')]), 31 | ('inter_snapshots', [('paper', 'cites', 'paper'), 32 | ('paper', 'is in', 'snapshot'), 33 | ('paper', 'is cited by', 'paper'), 34 | ('snapshot', 'has', 'paper')]), 35 | ('intra_snapshots', [('snapshot', 't_cites', 'snapshot'), 36 | ('snapshot', 'is t_cited by', 'snapshot')])]] 37 | elif edge_sequences == 'no_gcn': 38 | edge_sequences = [[('inter_snapshots', [('author', 'writes', 'paper'), 39 | ('journal', 'publishes', 'paper'), 40 | ('time', 'shows', 'paper'), 41 | ('paper', 'cites', 'paper'), 42 | ('paper', 'is in', 'snapshot'), 43 | ('paper', 'is cited by', 'paper'), 44 | ('paper', 'is writen by', 'author'), 45 | ('paper', 'is published by', 'journal'), 46 | ('paper', 'is shown in', 'time'), 47 | ('snapshot', 'has', 'paper')])]] 48 | else: 49 | # basic 50 | edge_sequences = [[('inter_snapshots', [('author', 'writes', 'paper'), 51 | ('journal', 'publishes', 'paper'), 52 | ('time', 'shows', 'paper'), 53 | ('paper', 'cites', 'paper'), 54 | ('paper', 'is in', 'snapshot'), 55 | ('paper', 'is cited by', 'paper'), 56 | ('paper', 'is writen by', 'author'), 57 | ('paper', 'is published by', 'journal'), 58 | ('paper', 'is shown in', 'time'), 59 | ('snapshot', 'has', 'paper')]), 60 | ('intra_snapshots', [('snapshot', 't_cites', 'snapshot'), 61 | ('snapshot', 'is t_cited by', 'snapshot')])]] 62 | etypes_set = set(etypes) 63 | for i in range(len(edge_sequences)): 64 | for j in range(len(edge_sequences[i])): 65 | # print('-'*30) 66 | cur_edges = edge_sequences[i][j][1].copy() 67 | for etype in cur_edges: 68 | # print(edge_sequences[i][j]) 69 | # print(etype) 70 | if etype[1] not in etypes_set: 71 | edge_sequences[i][j][1].remove(etype) 72 | print('removing', edge_sequences[i][j][0], etype) 73 | print(edge_sequences) 74 | combine_method = 'sum' 75 | # if len(edge_sequences) == 1: 76 | # self.model_name += '_single' 77 | self.graph_encoder = None 78 | # self.graph_encoder = DCTSGCNLayer(in_feat=embed_dim, hidden_feat=hidden_dim, out_feat=out_dim, ntypes=ntypes, 79 | # edge_sequences=edge_sequences, k=n_layers, time_length=time_length).to(self.device) 80 | 81 | encoder_dict = {} 82 | if encoder_type == 'GIN': 83 | # ablation 84 | encoder_dict = {('paper', 'is in', 'snapshot'): {'cur_type': 'GAT'}} 85 | elif encoder_type == 'DGINP' or encoder_type == 'CGIN+RGAT': 86 | print('>>>{}<<<'.format(encoder_type)) 87 | # CGIN+RGAT 88 | encoder_dict = {('paper', 'is in', 'snapshot'): {'cur_type': 'RGAT', 89 | 'snapshot_types': 3, 90 | 'node_attr_types': ['is_cite', 'is_ref', 'is_target'] 91 | }, 92 | ('paper', 'cites', 'paper'): {'cur_type': 'DGINP'}, 93 | # ('paper', 'is cited by', 'paper'): {'cur_type': 'DGIN'} 94 | } 95 | encoder_type = 'GIN' 96 | 97 | return_list = False if gcn_out == 'last' else True 98 | self.graph_encoder = SimpleDCTSGCNLayer(in_feat=self.embed_dim, hidden_feat=self.hidden_dim, out_feat=self.out_dim, 99 | ntypes=ntypes, 100 | edge_sequences=edge_sequences, k=n_layers, time_length=time_length, 101 | combine_method=combine_method, encoder_type=encoder_type, 102 | encoder_dict=encoder_dict, return_list=return_list, 103 | snapshot_weight=snapshot_weight, **kwargs).to(self.device) 104 | 105 | -------------------------------------------------------------------------------- /our_models/EncoderCL.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import pandas as pd 4 | import torch 5 | 6 | from models.base_model import BaseModel 7 | from our_models.CL import CLWrapper 8 | 9 | class EncoderCLWrapper(CLWrapper): 10 | def __init__(self, encoder: BaseModel, cl_type='simple'): 11 | super(EncoderCLWrapper, self).__init__(encoder, cl_type, None) 12 | # self.aux_criterion = encoder.aux_criterion 13 | self.vice_encoder = copy.deepcopy(self.encoder) 14 | self.eta = 1. 15 | self.model_name = self.encoder.model_name + '_ECL_' + cl_type 16 | 17 | def load_graphs(self, graphs): 18 | # no graph aug 19 | dealt_graphs = self.encoder.load_graphs(graphs) 20 | df_hn = pd.read_csv('./data/{}/hard_negative.csv'.format(graphs['data_source']), index_col=0) 21 | paper_trans = graphs['node_trans']['paper'] 22 | df_hn.index = [paper_trans[str(paper)] for paper in df_hn.index] 23 | for column in df_hn: 24 | df_hn[column] = df_hn[column].apply(lambda x: [paper_trans[str(paper)] for paper in eval(x)]) 25 | self.df_hn = df_hn 26 | return dealt_graphs 27 | 28 | def get_aug_encoder(self): 29 | for (adv_name, adv_param), (name, param) in zip(self.vice_encoder.named_parameters(), 30 | self.encoder.named_parameters()): 31 | if (not name.startswith('fc')) | ('text' not in name): 32 | # print(name, param.data.std()) 33 | adv_param.data = param.data.detach() + self.eta * torch.normal(0, torch.ones_like( 34 | param.data) * param.data.std()).detach().to(self.device) 35 | 36 | def forward(self, content, lengths, masks, ids, graph, times, **kwargs): 37 | # cur_inputs = self.encoder.get_graphs(ids, times, graph) 38 | if self.cl_type == 'aug_hard_negative': 39 | cur_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times, df_hn=self.df_hn) 40 | else: 41 | cur_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times) 42 | aug_inputs = copy.deepcopy(cur_inputs) 43 | self.get_aug_encoder() 44 | g1_out, g1_snapshots = self.encoder.encode(cur_inputs) 45 | g2_out, g2_snapshots = self.vice_encoder.encode(aug_inputs) 46 | # logging.info('encode done:{:.3f}s'.format(time.time() - cur_time)) 47 | 48 | # cur_time = time.time() 49 | g1_snapshots = self.get_cl_inputs(g1_snapshots) 50 | g2_snapshots = self.get_cl_inputs(g2_snapshots, require_grad=False) 51 | self.other_loss = self.cl_weight * self.cl(g1_snapshots, g2_snapshots, ids) 52 | # logging.info('cl done:{:.3f}s'.format(time.time() - cur_time)) 53 | 54 | time_out = g1_out 55 | output = self.encoder.decode(time_out) 56 | # print(time_out.shape) 57 | # logging.info('=' * 20 + 'finish' + '=' * 20) 58 | return output 59 | 60 | -------------------------------------------------------------------------------- /our_models/FCL.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import dgl 4 | import torch 5 | 6 | from layers.GCN.CL_layers import simple_cl 7 | from models.base_model import BaseModel 8 | from our_models.HCL import HCLWrapper 9 | 10 | class FCLWrapper(HCLWrapper): 11 | def __init__(self, encoder: BaseModel, cl_type='simple', aug_type=None, tau=0.4, cl_weight=0.5, aug_rate=0.1, **kwargs): 12 | super(FCLWrapper, self).__init__(encoder, cl_type, aug_type, tau, cl_weight, aug_rate, **kwargs) 13 | self.model_name = self.model_name.replace('_HCL_', '_FCL_') 14 | if 'DCTSGCN' in self.model_name: 15 | self.model_name = 'H2CGL' 16 | 17 | def get_cl_inputs(self, snapshots, require_grad=True): 18 | if require_grad: 19 | readout = dgl.readout_nodes(snapshots, 'h', op=self.encoder.time_pooling, ntype='snapshot') 20 | else: 21 | readout = dgl.readout_nodes(snapshots, 'h', op=self.encoder.time_pooling, ntype='snapshot').detach() 22 | readout = self.proj_heads['snapshot'](readout) 23 | # print(readout.shape) 24 | return readout 25 | 26 | def cl(self, s1, s2, ids): 27 | # cl_loss_list = [] 28 | cl_loss = 0 29 | if self.cl_type in ['simple', 'focus', 'cross_snapshot', 'focus_cross_snapshot']: 30 | temp_loss = simple_cl(s1, s2, self.tau) 31 | cl_loss = cl_loss + temp_loss 32 | elif self.cl_type.endswith('hard_negative'): 33 | sample_num = self.encoder.hn 34 | true_len = s1.shape[0] // sample_num if sample_num else s1.shape[0] 35 | temp_index_list = [] 36 | for i in range(sample_num): 37 | temp_index_list.append(range(true_len * i, true_len * (i + 1))) 38 | # cur_index = list(zip(range(true_len), range(true_len, true_len * 2))) 39 | cur_index = list(zip(*temp_index_list)) 40 | s1 = s1[cur_index, :] 41 | s2 = s2[cur_index, :] 42 | # print(s1.shape) 43 | temp_loss = simple_cl(s1, s2, self.tau) 44 | cl_loss = cl_loss + temp_loss 45 | return cl_loss 46 | 47 | def forward(self, content, lengths, masks, ids, graph, times, **kwargs): 48 | # cur_time = time.time() 49 | # g1_snapshots, g2_snapshots, cur_snapshot_adj = self.get_graphs(ids, graph) 50 | # g1_inputs = self.encoder.get_graphs(ids, graph, aug=(self.aug_methods[0], self.aug_configs[0])) 51 | # g2_inputs = self.encoder.get_graphs(ids, graph, aug=(self.aug_methods[1], self.aug_configs[1])) 52 | if self.cl_type.endswith('aug_hard_negative'): 53 | all_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times, 54 | phase='train_combined', df_hn=self.df_hn) 55 | else: 56 | all_inputs = self.encoder.get_graphs(content, lengths, masks, ids, graph, times, 57 | phase='train_combined') 58 | # print('get graphs done:{:.3f}s'.format(time.time() - cur_time)) 59 | # cur_time = time.time() 60 | all_out, all_snapshots = self.encoder.encode(all_inputs) 61 | # print('encode done:{:.3f}s'.format(time.time() - cur_time)) 62 | all_snapshots = self.get_cl_inputs(all_snapshots) 63 | 64 | batch_size = all_out.shape[0] // 2 65 | # print(batch_size, all_out.shape[0]) 66 | g1_out, g2_out = all_out[:batch_size, :], all_out[batch_size:, :] 67 | g1_snapshots, g2_snapshots = all_snapshots[:batch_size, :], all_snapshots[batch_size:, :] 68 | 69 | # cur_time = time.time() 70 | self.other_loss = self.cl_weight * self.cl(g1_snapshots, g2_snapshots, ids) 71 | # logging.info('cl done:{:.3f}s'.format(time.time() - cur_time)) 72 | # print('cl done:{:.3f}s'.format(time.time() - cur_time)) 73 | 74 | time_out = (g1_out + g2_out) / 2 75 | output = self.encoder.decode(time_out) 76 | # print(time_out.shape) 77 | # logging.info('=' * 20 + 'finish' + '=' * 20) 78 | return output 79 | -------------------------------------------------------------------------------- /our_models/HCL.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils.rnn import pad_sequence 5 | 6 | from layers.GCN.CL_layers import simple_cl, ProjectionHead 7 | from models.base_model import BaseModel 8 | from our_models.CL import CLWrapper 9 | 10 | class HCLWrapper(CLWrapper): 11 | def __init__(self, encoder: BaseModel, cl_type='simple', aug_type=None, tau=0.4, cl_weight=0.5, aug_rate=0.1, **kwargs): 12 | super(HCLWrapper, self).__init__(encoder, cl_type, 'normal', tau, cl_weight, aug_rate, **kwargs) 13 | self.model_name = self.model_name.replace('_CL_', '_HCL_').replace('_normal', '_' + aug_type) 14 | if self.cl_type in ['simple', 'cross_snapshot']: 15 | print('not support contrast for each type') 16 | raise Exception 17 | self.cl_ntypes = ['snapshot'] 18 | self.proj_heads = nn.ModuleDict({ntype: ProjectionHead(self.encoder.out_dim, self.encoder.out_dim) for 19 | ntype in self.cl_ntypes}) 20 | print('cur cl mode:', cl_type) 21 | # print('weights:', self.cl_weight, self.cross_weight) 22 | self.aug_type = aug_type 23 | 24 | if aug_type == 'dd': 25 | self.aug_methods = (('attr_drop', 'edge_drop'), ('attr_drop', 'edge_drop')) 26 | self.aug_configs = ({ 27 | 'attr_drop': { 28 | 'prob': self.aug_rate, 29 | 'types': ['paper', 'author', 'journal', 'time'] 30 | }, 31 | 'edge_drop': { 32 | 'prob': self.aug_rate, 33 | 'types': [('paper', 'cites', 'paper'), 34 | ('paper', 'is cited by', 'paper'), 35 | ('author', 'writes', 'paper'), 36 | ('paper', 'is writen by', 'author'), 37 | ('journal', 'publishes', 'paper'), 38 | ('paper', 'is published by', 'journal'), 39 | ('time', 'shows', 'paper'), 40 | ('paper', 'is shown in', 'time')] 41 | } 42 | }, { 43 | 'edge_drop': { 44 | 'prob': self.aug_rate, 45 | 'types': [('paper', 'cites', 'paper'), 46 | ('paper', 'is cited by', 'paper'), 47 | ('author', 'writes', 'paper'), 48 | ('paper', 'is writen by', 'author'), 49 | ('journal', 'publishes', 'paper'), 50 | ('paper', 'is published by', 'journal'), 51 | ('time', 'shows', 'paper'), 52 | ('paper', 'is shown in', 'time')] 53 | }, 54 | 'attr_drop': { 55 | 'prob': self.aug_rate, 56 | 'types': ['paper', 'author', 'journal', 'time'] 57 | }, 58 | }) 59 | elif aug_type == 'dmd': 60 | self.aug_methods = (('attr_mask', 'edge_drop'), ('attr_mask', 'edge_drop')) 61 | self.aug_configs = ({ 62 | 'attr_mask': { 63 | 'prob': self.aug_rate, 64 | 'types': ['paper', 'author', 'journal', 'time'] 65 | }, 66 | 'edge_drop': { 67 | 'prob': self.aug_rate, 68 | 'types': [('paper', 'cites', 'paper'), 69 | ('paper', 'is cited by', 'paper'), 70 | ('author', 'writes', 'paper'), 71 | ('paper', 'is writen by', 'author'), 72 | ('journal', 'publishes', 'paper'), 73 | ('paper', 'is published by', 'journal'), 74 | ('time', 'shows', 'paper'), 75 | ('paper', 'is shown in', 'time')] 76 | } 77 | }, { 78 | 'edge_drop': { 79 | 'prob': self.aug_rate, 80 | 'types': [('paper', 'cites', 'paper'), 81 | ('paper', 'is cited by', 'paper'), 82 | ('author', 'writes', 'paper'), 83 | ('paper', 'is writen by', 'author'), 84 | ('journal', 'publishes', 'paper'), 85 | ('paper', 'is published by', 'journal'), 86 | ('time', 'shows', 'paper'), 87 | ('paper', 'is shown in', 'time')] 88 | }, 89 | 'attr_mask': { 90 | 'prob': self.aug_rate, 91 | 'types': ['paper', 'author', 'journal', 'time'] 92 | }, 93 | }) 94 | elif aug_type == 'cg': 95 | self.aug_methods = (('node_drop',), ('node_drop',)) 96 | self.aug_configs = ({ 97 | 'node_drop': { 98 | 'prob': self.aug_rate, 99 | 'types': [] 100 | }}, { 101 | 'node_drop': { 102 | 'prob': self.aug_rate, 103 | 'types': [] 104 | }}) 105 | else: 106 | aug_type = 'normal' 107 | self.aug_methods = (('attr_drop',), ('edge_drop',)) 108 | self.aug_configs = ({ 109 | 'attr_drop': { 110 | 'prob': self.aug_rate, 111 | 'types': ['paper', 'author', 'journal', 'time'] 112 | } 113 | }, { 114 | 'edge_drop': { 115 | 'prob': self.aug_rate, 116 | 'types': [('paper', 'cites', 'paper'), 117 | ('paper', 'is cited by', 'paper'), 118 | ('author', 'writes', 'paper'), 119 | ('paper', 'is writen by', 'author'), 120 | ('journal', 'publishes', 'paper'), 121 | ('paper', 'is published by', 'journal'), 122 | ('time', 'shows', 'paper'), 123 | ('paper', 'is shown in', 'time')] 124 | } 125 | }) 126 | print(self.aug_methods) 127 | 128 | def load_graphs(self, graphs): 129 | dealt_graphs = self.encoder.load_graphs(graphs) 130 | if self.aug_type != 'cg': 131 | dealt_graphs = self.encoder.aug_graphs_plus(dealt_graphs, self.aug_configs, self.aug_methods) 132 | paper_trans = graphs['node_trans']['paper'] 133 | if 'label' in self.cl_type: 134 | df_hn = pd.read_csv('./data/{}/hard_negative_labeled.csv'.format(graphs['data_source']), index_col=0) 135 | df_hn.index = [paper_trans[str(paper)] for paper in df_hn.index] 136 | self.encoder.re_group = {label: list(df_hn[df_hn['label'] != label].index) for label in [0, 1, 2]} 137 | else: 138 | df_hn = pd.read_csv('./data/{}/hard_negative.csv'.format(graphs['data_source']), index_col=0) 139 | df_hn.index = [paper_trans[str(paper)] for paper in df_hn.index] 140 | for column in df_hn: 141 | if column in ['pub_time', 'label']: 142 | df_hn[column] = df_hn[column].astype(int) 143 | else: 144 | df_hn[column] = df_hn[column].apply(lambda x: [paper_trans[str(paper)] for paper in eval(x)]) 145 | # df_hn[column] = df_hn[column].astype('object') 146 | self.df_hn = df_hn 147 | return dealt_graphs 148 | 149 | def cl(self, s1, s2, ids): 150 | # cl_loss_list = [] 151 | cl_loss = 0 152 | z1_snapshots, z2_snapshots = [], [] 153 | z1_ids, z2_ids = [], [] 154 | if self.cl_type in ['simple', 'focus', 'cross_snapshot', 'focus_cross_snapshot']: 155 | for ntype in self.cl_ntypes: 156 | cur_loss = 0 157 | for t in range(self.encoder.time_length): 158 | if 'HTSGCN' in self.encoder.model_name: 159 | if ntype == 'snapshot': 160 | s1_index = torch.where(s1[t].nodes[ntype].data['mask'] > 0) 161 | s2_index = torch.where(s2[t].nodes[ntype].data['mask'] > 0) 162 | z1 = s1[t].nodes[ntype].data['h'][s1_index] 163 | z2 = s2[t].nodes[ntype].data['h'][s2_index] 164 | else: 165 | z1 = s1[t].nodes[ntype].data['h'] 166 | z2 = s2[t].nodes[ntype].data['h'] 167 | elif 'CTSGCN' in self.encoder.model_name: 168 | s1_index = torch.where(s1.nodes[ntype].data['snapshot_idx'] == t) 169 | s2_index = torch.where(s2.nodes[ntype].data['snapshot_idx'] == t) 170 | # print(s1_index) 171 | z1 = s1.nodes[ntype].data['h'][s1_index] 172 | z2 = s2.nodes[ntype].data['h'][s2_index] 173 | # print(z1.shape) 174 | if (self.cl_type.endswith('cross_snapshot')) and (ntype == 'snapshot'): 175 | z1_snapshots.append(z1) 176 | z2_snapshots.append(z2) 177 | z1_ids.append(s1.nodes[ntype].data['target'][s1_index]) 178 | z2_ids.append(s2.nodes[ntype].data['target'][s2_index]) 179 | # cur_list.append(self.cl_modules[ntype](z1, z2)) 180 | # temp_loss = self.cl_modules[ntype](z1, z2) 181 | temp_loss = simple_cl(z1, z2, self.tau) 182 | cur_loss = cur_loss + temp_loss 183 | 184 | cl_loss = cl_loss + cur_loss / self.encoder.time_length 185 | cl_loss = cl_loss / len(self.cl_ntypes) 186 | elif self.cl_type.endswith('hard_negative'): 187 | for ntype in self.cl_ntypes: 188 | cur_loss = 0 189 | p1_index = {} 190 | p2_index = {} 191 | for paper in ids: 192 | p1_index[paper.item()] = s1.nodes[ntype].data['target'] == paper.item() 193 | p2_index[paper.item()] = s2.nodes[ntype].data['target'] == paper.item() 194 | # print(p1_index.keys()) 195 | time_length = 0 196 | for t in range(self.encoder.time_length): 197 | # start_time = time.time() 198 | cur_z1s = [] 199 | cur_z2s = [] 200 | len_z1s = [] 201 | # len_z2s = [] 202 | t1_index = s1.nodes[ntype].data['snapshot_idx'] == t 203 | t2_index = s2.nodes[ntype].data['snapshot_idx'] == t 204 | for paper in ids: 205 | # s1_index = torch.where(t1_index & p1_index[paper.item()]) 206 | # s2_index = torch.where(t2_index & p2_index[paper.item()]) 207 | # z1 = s1.nodes[ntype].data['h'][s1_index] 208 | # z2 = s2.nodes[ntype].data['h'][s2_index] 209 | # cur_z1s.append(z1) 210 | # cur_z2s.append(z2) 211 | # cur_len = len(z1) if len(z1) > 0 else 1 212 | # len_z1s.append(cur_len) 213 | # print(s1_index[0].shape[0], cur_len) 214 | s1_index = torch.where(t1_index & p1_index[paper.item()]) 215 | s2_index = torch.where(t2_index & p2_index[paper.item()]) 216 | if s1_index[0].shape[0] >= 2: 217 | # if s1_index[0].shape[0] >= self.encoder.hn: 218 | z1 = s1.nodes[ntype].data['h'][s1_index] 219 | z2 = s2.nodes[ntype].data['h'][s2_index] 220 | cur_z1s.append(z1) 221 | cur_z2s.append(z2) 222 | cur_len = len(z1) 223 | len_z1s.append(cur_len) 224 | # print(s1_index[0].shape[0], cur_len) 225 | # len_z2s.append(len(z2)) 226 | # if s1_index[0].shape[0] > 2: 227 | # print('sidx:', (s1.nodes[ntype].data['snapshot_idx'] == 9).sum()) 228 | # print(ntype) 229 | # print(torch.where(p1_index[paper.item()])) 230 | # print(p1_index[paper.item()].sum()) 231 | # print(s1_index) 232 | if len(cur_z1s) > 0: 233 | cur_z1s = pad_sequence(cur_z1s, batch_first=True) 234 | cur_z2s = pad_sequence(cur_z2s, batch_first=True) 235 | len_z1s = torch.tensor(len_z1s).to(cur_z1s.device) 236 | # len_z2s = torch.tensor(len_z2s).to(cur_z2s.device) 237 | # print('index_time: {:.3f}'.format(time.time() - start_time)) 238 | # print(cur_z1s.shape) 239 | temp_loss = simple_cl(cur_z1s, cur_z2s, self.tau, valid_lens=len_z1s) 240 | time_length = time_length + 1 241 | else: 242 | temp_loss = 0 243 | # print(temp_loss) 244 | # print('loss_time: {:.3f}'.format(time.time() - start_time)) 245 | cur_loss = cur_loss + temp_loss 246 | # print(ntype, cur_loss / self.encoder.time_length) 247 | cl_loss = cl_loss + cur_loss / time_length 248 | cl_loss = cl_loss / len(self.cl_ntypes) 249 | 250 | if self.cl_type.endswith('cross_snapshot'): 251 | snapshot_loss = 0 252 | z1_dict = {paper.item(): [] for paper in ids} 253 | z2_dict = {paper.item(): [] for paper in ids} 254 | # print(z1_dict) 255 | for i in range(len(z1_snapshots)): 256 | cur_z1 = z1_snapshots[i] 257 | cur_z2 = z2_snapshots[i] 258 | cur_z1_ids = z1_ids[i] 259 | cur_z2_ids = z2_ids[i] 260 | for j in range(len(cur_z1_ids)): 261 | z1_dict[cur_z1_ids[j].item()].append(cur_z1[j]) 262 | z2_dict[cur_z2_ids[j].item()].append(cur_z2[j]) 263 | for paper in ids: 264 | # print(torch.stack(z1_dict[paper.item()], dim=0).shape) 265 | snapshot_loss = snapshot_loss + simple_cl(torch.stack(z1_dict[paper.item()], dim=0), 266 | torch.stack(z2_dict[paper.item()], dim=0), 267 | self.tau) 268 | 269 | cl_loss = cl_loss + self.cross_weight * snapshot_loss 270 | # print(cl_loss) 271 | return cl_loss -------------------------------------------------------------------------------- /our_models/TSGCN.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import logging 5 | import time 6 | from collections import Counter 7 | # import multiprocessing 8 | import joblib 9 | import pandas as pd 10 | from dgl import multiprocessing 11 | 12 | import dgl 13 | import torch 14 | from torch import nn 15 | from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence 16 | from tqdm import tqdm 17 | 18 | from layers.GCN.RGCN import StochasticKLayerRGCN 19 | from layers.common_layers import MLP 20 | from models.base_model import BaseModel 21 | from utilis.log_bar import TqdmToLogger 22 | 23 | 24 | class TM: 25 | def __init__(self, paper_ids, hop, log_path): 26 | self.paper_ids = paper_ids 27 | self.hop = hop 28 | # logger = logging.getLogger() 29 | # logger.handlers.append(logging.FileHandler(log_path, mode='w+')) 30 | # self.tqdm_log = TqdmToLogger(logger, level=logging.INFO) 31 | 32 | def get_subgraph(self, cur_graph): 33 | return get_paper_ego_subgraph(self.paper_ids, hop=self.hop, graph=cur_graph, batched=True) 34 | 35 | 36 | class TSGCN(BaseModel): 37 | '''The basic model which is also called H2GCN''' 38 | def __init__(self, vocab_size, embed_dim, num_classes, pad_index, word2vec=None, dropout=0.5, n_layers=3, 39 | time_encoder='gru', time_encoder_layers=2, time_pooling='mean', hop=2, 40 | time_length=10, hidden_dim=300, out_dim=256, ntypes=None, etypes=None, graph_type=None, **kwargs): 41 | super(TSGCN, self).__init__(vocab_size, embed_dim, num_classes, pad_index, word2vec, dropout, **kwargs) 42 | 43 | self.model_name = 'TSGCN' 44 | if graph_type: 45 | self.model_name += '_' + graph_type 46 | self.hop = hop 47 | self.activation = nn.Tanh() 48 | self.drop_en = nn.Dropout(dropout) 49 | self.n_layers = n_layers 50 | if self.n_layers != 2: 51 | self.model_name += '_n{}'.format(n_layers) 52 | self.time_length = time_length 53 | final_dim = out_dim 54 | print(ntypes) 55 | residual = False 56 | # print(residual) 57 | if residual: 58 | self.model_name += '_residual' 59 | self.graph_encoder = nn.ModuleList( 60 | [StochasticKLayerRGCN(embed_dim, hidden_dim, out_dim, ntypes, etypes, k=n_layers, residual=residual) 61 | for i in range(time_length)]) 62 | if time_encoder == 'gru': 63 | self.time_encoder = nn.GRU(input_size=out_dim, hidden_size=out_dim, dropout=dropout, batch_first=True, 64 | bidirectional=True, num_layers=2) 65 | final_dim = 2 * out_dim 66 | elif time_encoder == 'self-attention': 67 | self.time_encoder = nn.Sequential(*[ 68 | nn.TransformerEncoderLayer(d_model=out_dim, nhead=out_dim // 64, dim_feedforward=2 * out_dim) 69 | for i in range(time_encoder_layers)]) 70 | # final_dim = 2 * out_dim 71 | self.time_pooling = time_pooling 72 | # self.fc = MLP(dim_in=final_dim, dim_inner=final_dim // 2, dim_out=num_classes, num_layers=2) 73 | self.fc_reg = MLP(dim_in=final_dim, dim_inner=final_dim // 2, dim_out=1, num_layers=2) 74 | self.fc_cls = MLP(dim_in=final_dim, dim_inner=final_dim // 2, dim_out=num_classes, num_layers=2) 75 | 76 | 77 | def load_graphs(self, graphs): 78 | ''' 79 | { 80 | 'data': {phase: None}, 81 | 'node_trans': json.load(open(self.data_cat_path + 'sample_node_trans.json', 'r')), 82 | 'time': {'train': train_time, 'val': val_time, 'test': test_time}, 83 | 'all_graphs': graphs, 84 | 'selected_papers': {phase: [list]}, 85 | 'all_trans_index': dict, 86 | 'all_ego_graphs': ego_graphs 87 | } 88 | ''' 89 | print(graphs['data']) 90 | logger = logging.getLogger() 91 | # LOG = logging.getLogger(__name__) 92 | tqdm_out = TqdmToLogger(logger, level=logging.INFO) 93 | 94 | graphs['all_trans_index'] = json.load(open(graphs['graphs_path'] + '_trans.json', 'r')) 95 | # ego_graphs 96 | graphs['all_ego_graphs'] = {} 97 | for pub_time in tqdm(graphs['all_graphs'], file=tqdm_out, mininterval=30): 98 | graphs['all_ego_graphs'][pub_time] = joblib.load( 99 | graphs['graphs_path'] + '_{}.job'.format(pub_time)) 100 | 101 | phase_dict = {phase: [] for phase in graphs['data']} 102 | for phase in graphs['data']: 103 | selected_papers = [graphs['node_trans']['paper'][paper] for paper in graphs['selected_papers'][phase]] 104 | trans_index = dict(zip(selected_papers, range(len(selected_papers)))) 105 | graphs['selected_papers'][phase] = selected_papers 106 | graphs['data'][phase] = [None, trans_index] 107 | 108 | time2phase = {pub_time: [phase for phase in graphs['data'] 109 | if (pub_time <= graphs['time'][phase]) and (pub_time > graphs['time'][phase] - 10)] 110 | for pub_time in graphs['all_graphs']} 111 | for t in sorted(list(time2phase.keys())): 112 | cur_snapshot = graphs['all_ego_graphs'][t] 113 | start_time = time.time() 114 | for phase in time2phase[t]: 115 | batched_graphs = [cur_snapshot[graphs['all_trans_index'][str(paper)]] 116 | for paper in graphs['selected_papers'][phase]] 117 | cur_emb_graph = graphs['all_graphs'][t] 118 | # unbatched_graphs = dgl.unbatch(batched_graphs) 119 | unbatched_graphs = batched_graphs 120 | logging.info('unbatching done! {:.2f}s'.format(time.time() - start_time)) 121 | for graph in tqdm(unbatched_graphs, file=tqdm_out, mininterval=30): 122 | for ntype in graph.ntypes: 123 | cur_index = graph.nodes[ntype].data[dgl.NID] 124 | graph.nodes[ntype].data['h'] = cur_emb_graph.nodes[ntype].data['h'][cur_index].detach() 125 | phase_dict[phase].append(unbatched_graphs) 126 | 127 | logging.info('-' * 30 + str(t) + ' done' + '-' * 30) 128 | del cur_snapshot 129 | del graphs['all_graphs'][t] 130 | del graphs['all_ego_graphs'][t] 131 | 132 | for phase in phase_dict: 133 | graphs['data'][phase][0] = phase_dict[phase] 134 | 135 | del phase_dict 136 | del graphs['all_graphs'] 137 | del graphs['all_ego_graphs'] 138 | return graphs 139 | 140 | def deal_graphs(self, graphs, data_path=None, path=None, log_path=None): 141 | tqdm_out = None 142 | if log_path: 143 | logging.basicConfig(level=logging.INFO, 144 | filename=log_path, 145 | filemode='w+', 146 | format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s', 147 | force=True) 148 | logger = logging.getLogger() 149 | # LOG = logging.getLogger(__name__) 150 | tqdm_out = TqdmToLogger(logger, level=logging.INFO) 151 | 152 | phases = list(graphs['data'].keys()) 153 | # dealt_graphs = {'data': {}, 'time': graphs['time'], 'node_trans': graphs['node_trans'], 'all_graphs': graphs['all_graphs']} 154 | dealt_graphs = {'data': {}, 'time': graphs['time'], 'node_trans': graphs['node_trans'], 155 | 'all_graphs': None} 156 | split_data = torch.load(data_path) 157 | all_selected_papers = [graphs['node_trans']['paper'][paper] for paper in split_data['test'][0]] 158 | logging.info('all count: {}'.format(str(len(all_selected_papers)))) 159 | # all_snapshots = {} 160 | all_trans_index = dict(zip(all_selected_papers, range(len(all_selected_papers)))) 161 | pub_times = list(graphs['all_graphs'].keys()) 162 | # del graphs['all_graphs'] 163 | del graphs['data'] 164 | # print('not get ego_subgraph for saving time') 165 | for pub_time in pub_times: 166 | # all_snapshots[pub_time] = get_paper_ego_subgraph(all_selected_papers, self.hop, 167 | # graphs['all_graphs'][pub_time], 168 | # batched=False, tqdm_log=tqdm_out) 169 | cur_snapshot = get_paper_ego_subgraph(all_selected_papers, self.hop, 170 | graphs['all_graphs'][pub_time], 171 | batched=False, tqdm_log=tqdm_out) 172 | # if path: 173 | # torch.save(cur_snapshot, path + '_' + str(pub_time)) 174 | if path: 175 | joblib.dump(cur_snapshot, path + '_' + str(pub_time) + '.job') 176 | # del graphs['all_graphs'][pub_time] 177 | 178 | if path: 179 | json.dump(all_trans_index, open(path + '_trans.json', 'w+')) 180 | 181 | # all_trans_index = json.load() 182 | del graphs['all_graphs'] 183 | 184 | joblib.dump([], path) 185 | return dealt_graphs 186 | 187 | def forward(self, content, lengths, masks, ids, graph, times, **kwargs): 188 | # graphs T个batched_graph 189 | # start_time = time.time() 190 | # print(start_time) 191 | # snapshots = [batch.to(self.device) for batch in snapshots] 192 | # print(time.time() - start_time) 193 | batched_graphs_list, trans_index = graph 194 | # graphs_list = dgl.unbatch(batched_graphs) 195 | snapshots = [] 196 | for batched_graphs in batched_graphs_list: 197 | temp_snapshots = [] 198 | # temp_graphs = dgl.unbatch(batched_graphs) 199 | temp_graphs = batched_graphs 200 | for paper in ids: 201 | temp_snapshots.append(temp_graphs[trans_index[paper.item()]]) 202 | snapshots.append(dgl.batch(temp_snapshots).to(self.device)) 203 | # print(time.time() - start_time) 204 | 205 | all_snapshots = [] 206 | for t in range(self.time_length): 207 | cur_snapshot = snapshots[t] 208 | out_emb = self.graph_encoder[t](cur_snapshot, cur_snapshot.srcdata['h']) 209 | cur_snapshot.dstdata['h'] = out_emb 210 | sum_readout = dgl.readout_nodes(cur_snapshot, 'h', op='sum', ntype='paper') 211 | all_snapshots.append(sum_readout) 212 | all_snapshots = torch.stack(all_snapshots, dim=1) # [B, T, d_out] 213 | # print(all_snapshots.shape) 214 | time_out = self.time_encoder(all_snapshots) 215 | if type(time_out) != torch.Tensor: 216 | time_out = time_out[0] 217 | # print(time_out.shape) 218 | if self.time_pooling == 'mean': 219 | final_out = time_out.mean(dim=1) 220 | elif self.time_pooling == 'sum': 221 | final_out = time_out.sum(dim=1) 222 | elif self.time_pooling == 'max': 223 | final_out = time_out.max(dim=1)[0] 224 | # print(final_out.shape) 225 | output_reg = self.fc_reg(final_out) 226 | output_cls = self.fc_cls(final_out) 227 | return output_reg, output_cls 228 | 229 | 230 | def get_paper_ego_subgraph(batch_ids, hop, graph, batched=True, tqdm_log=None): 231 | ''' 232 | get ego batch_graphs of current time 233 | target-centric graphs extracted from complete graph 234 | :param batch_ids: 235 | :param hop: 236 | :param graph: 237 | :return: 238 | ''' 239 | oids = graph.nodes['paper'].data[dgl.NID].numpy().tolist() 240 | cids = graph.nodes('paper').numpy().tolist() 241 | oc_trans = dict(zip(oids, cids)) 242 | trans_papers = [oc_trans.get(paper, None) for paper in batch_ids] 243 | graph.nodes['paper'].data['citations'] = graph.in_degrees(etype='cites') 244 | 245 | paper_subgraph = dgl.node_type_subgraph(graph, ['paper']) 246 | all_subgraphs = [] 247 | single_subgraph = SingleSubgraph(paper_subgraph, graph, hop) 248 | for paper in tqdm(trans_papers, file=tqdm_log, mininterval=30): 249 | # for paper in tqdm(trans_papers): 250 | cur_graph = single_subgraph.get_single_subgraph(paper) 251 | for ntype in cur_graph.ntypes: 252 | if 'h' in cur_graph.nodes[ntype].data: 253 | del cur_graph.nodes[ntype].data['h'] 254 | all_subgraphs.append(cur_graph) 255 | # logging.info('cur year done!') 256 | if batched: 257 | return dgl.batch(all_subgraphs) 258 | else: 259 | return all_subgraphs 260 | 261 | 262 | class SingleSubgraph: 263 | def __init__(self, paper_subgraph, origin_graph, hop, citation='global'): 264 | # one-hop instead 265 | self.paper_subgraph = paper_subgraph 266 | # self.origin_graph = origin_graph 267 | self.origin_graph = dgl.node_type_subgraph(origin_graph, ntypes=['paper', 'journal', 'author']) 268 | self.hop = 1 269 | self.citation = citation 270 | self.max_nodes = 100 271 | # self.oids = self.paper_subgraph.nodes['paper'].data[dgl.NID] 272 | self.citations = self.paper_subgraph.nodes['paper'].data['citations'] 273 | self.time = self.paper_subgraph.nodes['paper'].data['time'] 274 | 275 | def get_single_subgraph(self, paper): 276 | if paper is not None: 277 | ref_hop_papers = self.paper_subgraph.in_edges(paper, etype='is cited by')[0].numpy().tolist() 278 | if len(ref_hop_papers) > self.max_nodes: 279 | temp_df = pd.DataFrame( 280 | data={ 281 | 'id': ref_hop_papers, 282 | 'citations': self.citations[ref_hop_papers].numpy(), 283 | 'time': self.time[ref_hop_papers].squeeze(dim=-1).numpy()}) 284 | ref_hop_papers = list(temp_df.sort_values(by=['time', 'citations'], 285 | ascending=False).head(self.max_nodes)['id'].tolist()) 286 | cite_hop_papers = self.paper_subgraph.in_edges(paper, etype='cites')[0].numpy().tolist() 287 | if len(cite_hop_papers) > self.max_nodes: 288 | temp_df = pd.DataFrame( 289 | data={ 290 | 'id': cite_hop_papers, 291 | 'citations': self.citations[cite_hop_papers].numpy(), 292 | 'time': self.time[cite_hop_papers].squeeze(dim=-1).numpy()}) 293 | cite_hop_papers = list(temp_df.sort_values(by=['time', 'citations'], 294 | ascending=False).head(self.max_nodes)['id'].tolist()) 295 | 296 | # print(len(ref_hop_papers), len(cite_hop_papers)) 297 | # print(len(cite_hop_papers)) 298 | k_hop_papers = list(set(ref_hop_papers + cite_hop_papers + [paper])) 299 | # print(ref_hop_papers, cite_hop_papers) 300 | # print(k_hop_papers) 301 | # print(self.origin_graph) 302 | temp_graph = dgl.in_subgraph(self.origin_graph, nodes={'paper': k_hop_papers}, relabel_nodes=True) 303 | journals = temp_graph.nodes['journal'].data[dgl.NID] 304 | authors = temp_graph.nodes['author'].data[dgl.NID] 305 | cur_graph = dgl.node_subgraph(self.origin_graph, 306 | nodes={'paper': k_hop_papers, 'journal': journals, 'author': authors}) 307 | 308 | # 修改了indicator 309 | ref_set = set(ref_hop_papers) 310 | cite_set = set(cite_hop_papers) 311 | cur_graph.nodes['paper'].data['is_ref'] = torch.zeros(len(k_hop_papers), dtype=torch.long) 312 | cur_graph.nodes['paper'].data['is_cite'] = torch.zeros(len(k_hop_papers), dtype=torch.long) 313 | cur_graph.nodes['paper'].data['is_target'] = torch.zeros(len(k_hop_papers), dtype=torch.long) 314 | oids = cur_graph.nodes['paper'].data[dgl.NID].numpy().tolist() 315 | id_trans = dict(zip(oids, range(len(oids)))) 316 | ref_idx = [id_trans[paper] for paper in ref_set] 317 | cite_idx = [id_trans[paper] for paper in cite_set] 318 | cur_graph.nodes['paper'].data['is_ref'][ref_idx] = 1 319 | cur_graph.nodes['paper'].data['is_cite'][cite_idx] = 1 320 | cur_graph.nodes['paper'].data['is_target'][id_trans[paper]] = 1 321 | else: 322 | cur_graph = dgl.node_subgraph(self.origin_graph, nodes={'paper': []}) 323 | cur_graph.nodes['paper'].data['is_ref'] = torch.zeros(0, dtype=torch.long) 324 | cur_graph.nodes['paper'].data['is_cite'] = torch.zeros(0, dtype=torch.long) 325 | cur_graph.nodes['paper'].data['is_target'] = torch.zeros(0, dtype=torch.long) 326 | return cur_graph 327 | 328 | 329 | if __name__ == '__main__': 330 | start_time = datetime.datetime.now() 331 | parser = argparse.ArgumentParser(description='Process some description.') 332 | parser.add_argument('--phase', default='test', help='the function name.') 333 | 334 | args = parser.parse_args() 335 | rel_names = ['cites', 'is cited by', 'writes', 'is writen by', 'publishes', 'is published by'] 336 | # vocab_size, embed_dim, num_classes, pad_index 337 | 338 | graphs = {} 339 | graph_name = 'graph_sample_feature_vector' 340 | time_list = range(2002, 2012) 341 | for i in range(len(time_list)): 342 | graphs[i] = torch.load('../data/pubmed/' + graph_name + '_' + str(time_list[i])) 343 | 344 | graph_data = { 345 | 'data': graphs, 346 | 'node_trans': json.load(open('../data/pubmed/' + 'sample_node_trans.json', 'r')) 347 | } 348 | print(graph_data['data']) 349 | print(graph_data['data'][0].nodes('paper')) 350 | # print(graph_data['node_trans']['paper']) 351 | reverse_paper_dict = dict(zip(graph_data['node_trans']['paper'].values(), graph_data['node_trans']['paper'].keys())) 352 | # print(reverse_paper_dict) 353 | citation_dict = json.load(open('../data/pubmed/sample_citation_accum.json')) 354 | valid_ids = set([reverse_paper_dict[idx.item()] for idx in graph_data['data'][0].nodes('paper')]) 355 | useful_ids = [idx for idx in valid_ids if idx in citation_dict] 356 | print(len(useful_ids)) 357 | print(useful_ids[:10]) 358 | 359 | # sample_ids = ['3266430', '2790711', '2747541', '3016217', '7111196'] 360 | # trans_ids = [graph_data['node_trans']['paper'][idx] for idx in sample_ids] 361 | # print(trans_ids) 362 | # x = torch.randint(100, (5, 20)).cuda() 363 | # lens = [20] * 5 364 | # print(x) 365 | # model = TSGCN(100, 768, 0, 0, etypes=etypes).cuda() 366 | # output = model(x, lens, None, trans_ids, graph_data['data']) 367 | # print(output) 368 | # print(output.shape) 369 | 370 | sample_ids = [4724, 12968, 8536, 9558, None] 371 | cur_graph = graph_data['data'][9] 372 | print(cur_graph.nodes['paper'].data[dgl.NID][[68, 2855, 377, 845]]) 373 | batch_graphs = get_paper_ego_subgraph(sample_ids, 2, cur_graph, batched=False) 374 | print(batch_graphs) 375 | # print(batch_graphs.batch_size) 376 | # snapshots = [] 377 | # for i in range(len(time_list)): 378 | # snapshots.append(get_paper_ego_subgraph(sample_ids, 2, graph_data['data'][i])) 379 | # print(snapshots) 380 | # graph_encoder = StochasticKLayerRGCN(300, 128, 128, etypes, k=2) 381 | # cur_snapshot = snapshots[0] 382 | # out_emb = graph_encoder(cur_snapshot, cur_snapshot.srcdata['h']) 383 | # cur_snapshot.dstdata['h'] = out_emb 384 | 385 | # sum_readout = dgl.readout_nodes(cur_snapshot, 'h', op='max', ntype='paper') 386 | # print([vector.norm(2) for vector in sum_readout]) 387 | # print(sum_readout.shape) 388 | 389 | # model = TSGCN(0, 300, 1, 0, etypes=etypes, time_encoder='gru').cuda() 390 | # # print(model(None, None, None, None, snapshots)) 391 | # g1 = torch.load('../data/pubmed/graph_sample_feature_vector_2002') 392 | # g2 = torch.load('../data/pubmed/graph_sample_feature_vector_2003') 393 | # g3 = torch.load('../data/pubmed/graph_sample_feature_vector_2004') 394 | # graphs = [g1, g2, g3] 395 | # # sub_g = g1.edge_type_subgraph(['is cited by']) 396 | # # print(sub_g) 397 | # # print(sub_g.ndata[dgl.NID]) 398 | # # print(sub_g.ndata['h'].shape) 399 | # node_trans = json.load(open('../data/pubmed/sample_node_trans.json')) 400 | # graphs_dict = { 401 | # 'data': {'train': graphs}, 402 | # 'node_trans': node_trans 403 | # } 404 | # graphs = model.deal_graphs(graphs_dict, '../data/pubmed/split_data', '../checkpoints/pubmed/TSGCN') 405 | -------------------------------------------------------------------------------- /our_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__init__.py -------------------------------------------------------------------------------- /our_models/__pycache__/CL.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/CL.cpython-39.pyc -------------------------------------------------------------------------------- /our_models/__pycache__/CTSGCN.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/CTSGCN.cpython-39.pyc -------------------------------------------------------------------------------- /our_models/__pycache__/DCTSGCN.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/DCTSGCN.cpython-39.pyc -------------------------------------------------------------------------------- /our_models/__pycache__/FCL.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/FCL.cpython-39.pyc -------------------------------------------------------------------------------- /our_models/__pycache__/HCL.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/HCL.cpython-39.pyc -------------------------------------------------------------------------------- /our_models/__pycache__/TSGCN.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/TSGCN.cpython-39.pyc -------------------------------------------------------------------------------- /our_models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/our_models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /results/readme.md: -------------------------------------------------------------------------------- 1 | Here output the results. -------------------------------------------------------------------------------- /utilis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/utilis/__init__.py -------------------------------------------------------------------------------- /utilis/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/utilis/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utilis/__pycache__/log_bar.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/utilis/__pycache__/log_bar.cpython-39.pyc -------------------------------------------------------------------------------- /utilis/__pycache__/scripts.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-Text-Computing/H2CGL/f3a3af4b8af16e24f20236da62d219c9c39594dd/utilis/__pycache__/scripts.cpython-39.pyc -------------------------------------------------------------------------------- /utilis/dblp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | from collections import defaultdict, Counter 4 | 5 | import dgl 6 | import pandas as pd 7 | import numpy as np 8 | import json 9 | import re 10 | from pprint import pprint 11 | 12 | 13 | # useful attr: _id, abstract, authors, fos, keywords, n_citations, references, title, venue, year 14 | import torch 15 | 16 | 17 | def get_lines_data(data_path): 18 | with open(data_path + 'dblpv13.json') as fr: 19 | data = fr.read() 20 | data = re.sub(r"NumberInt\((\d+)\)", r"\1", data) 21 | print('replace successfully!') 22 | data = json.loads(data) 23 | 24 | info_fw = open(data_path + 'temp_info.lines', 'w+') 25 | abstract_fw = open(data_path + 'temp_abstract.lines', 'w+') 26 | ref_fw = open(data_path + 'temp_ref.lines', 'w+') 27 | for paper in data: 28 | info_fw.write(str([paper['_id'], 29 | { 30 | 'authors': paper['authors'] if 'authors' in paper else [], 31 | 'fos': paper['fos'] if 'fos' in paper else [], 32 | 'keywords': paper['keywords'] if 'keywords' in paper else [], 33 | 'n_citations': paper['n_citations'] if 'n_citations' in paper else 0, 34 | 'title': paper['title'] if 'title' in paper else None, 35 | 'venue': paper['venue'] if 'venue' in paper else None, 36 | 'year': paper['year'] if 'year' in paper else None, 37 | }]) + '\n') 38 | temp_ab = paper['abstract'] if 'abstract' in paper else None 39 | temp_ref = paper['references'] if 'references' in paper else [] 40 | abstract_fw.write(str([paper['_id'], temp_ab]) + '\n') 41 | ref_fw.write(str([paper['_id'], temp_ref]) + '\n') 42 | 43 | 44 | def get_all_dict(data_path, part='info'): 45 | result_dict = {} 46 | with open(data_path + 'temp_{}.lines'.format(part)) as fr: 47 | for line in fr: 48 | data = eval(line) 49 | result_dict[data[0]] = data[1] 50 | json.dump(result_dict, open(data_path + 'all_{}_dict.json'.format(part), 'w+')) 51 | 52 | 53 | # def get_cite_data(data_path): 54 | # ref_dict = json.load(open(data_path + 'all_ref_dict.json')) 55 | # cite_dict = defaultdict(list) 56 | # for paper in ref_dict: 57 | # # ref_list = ref_dict[paper] 58 | # for ref in ref_dict[paper]: 59 | # cite_dict[ref].append(paper) 60 | # json.dump(cite_dict, open(data_path + 'all_cite_dict.json', 'w+')) 61 | # del ref_dict 62 | # 63 | # year_dict = {} 64 | # info_dict = json.load(open(data_path + 'all_info_dict.json', 'r')) 65 | # for paper in cite_dict: 66 | # year_dict[paper] = list(filter(lambda x: x, map(lambda x: info_dict[x]['year'], cite_dict[paper]))) 67 | # json.dump(year_dict, open(data_path + 'cite_year_dict.json', 'w+')) 68 | def get_valid_papers(data_path, time_point=None): 69 | data = json.load(open(data_path + 'all_info_dict.json', 'r')) 70 | print(len(data)) 71 | # journal 72 | data = dict(filter(lambda x: x[1]['venue'].get('_id', False) if x[1]['venue'] else False, data.items())) 73 | print('has journal:', len(data)) 74 | 75 | # authors 76 | data = dict(filter(lambda x: len([author['_id'] for author in x[1]['authors'] if '_id' in author]) > 0, 77 | data.items())) 78 | print('has author:', len(data)) 79 | # title 80 | data = dict(filter(lambda x: x[1]['title'], data.items())) 81 | print('has title:', len(data)) 82 | # json.dump(data, open(data_path + 'sample_info_dict.json', 'w+')) 83 | selected_list = set(data.keys()) 84 | # time 85 | if time_point: 86 | data = dict(filter(lambda x: (int(x[1]['year']) <= time_point), data.items())) 87 | print('<=year:',len(data)) 88 | # json.dump(data, open(data_path + 'sample_info_dict.json', 'w+')) 89 | selected_list = set(data.keys()) 90 | 91 | del data 92 | # words 93 | abs_dict = json.load(open(data_path + 'all_abstract_dict.json', 'r')) 94 | abs_dict = dict(filter(lambda x: (x[0] in selected_list) & (len(str(x[1]).strip().split(' ')) >= 20), 95 | abs_dict.items())) 96 | selected_list = set(abs_dict.keys()) 97 | print('has abs:', len(selected_list)) 98 | # json.dump(abs_dict, open(data_path + 'sample_abstract_dict.json', 'w+')) 99 | return set(selected_list) 100 | 101 | 102 | def get_cite_data(data_path): 103 | # here already filter the data 104 | selected_set = get_valid_papers(data_path) 105 | ref_dict = json.load(open(data_path + 'all_ref_dict.json')) 106 | cite_dict = defaultdict(list) 107 | # for paper in ref_dict: 108 | # # ref_list = ref_dict[paper] 109 | # for ref in ref_dict[paper]: 110 | # cite_dict[ref].append(paper) 111 | for paper in selected_set: 112 | ref_list = ref_dict.get(paper, []) 113 | for ref in ref_list: 114 | if ref in selected_set: 115 | cite_dict[ref].append(paper) 116 | json.dump(cite_dict, open(data_path + 'all_cite_dict.json', 'w+')) 117 | del ref_dict 118 | 119 | year_dict = {} 120 | info_dict = json.load(open(data_path + 'all_info_dict.json', 'r')) 121 | for paper in cite_dict: 122 | year_dict[paper] = list(filter(lambda x: x, map(lambda x: info_dict[x]['year'], cite_dict[paper]))) 123 | json.dump(year_dict, open(data_path + 'cite_year_dict.json', 'w+')) 124 | 125 | 126 | 127 | def show_data(data_path): 128 | info_dict = json.load(open(data_path + 'all_info_dict.json', 'r')) 129 | print('total_count', len(info_dict)) 130 | # 'authors', 'fos', 'keywords', 'n_citations', 'title', 'venue', 'year' 131 | cite = json.load(open(data_path + 'all_cite_dict.json', 'r')) 132 | dealt_dict = map(lambda x: { 133 | 'id': x[0], 134 | 'authors': len([author['_id'] for author in x[1]['authors'] if '_id' in author]), 135 | 'authors_valid': len([author.get('_id', '') for author in x[1]['authors'] 136 | if len(author.get('_id', '').strip()) > 0]), 137 | 'authors_g': len([author['gid'] for author in x[1]['authors'] if 'gid' in author]), 138 | 'venue': x[1]['venue'].get('_id', None) if x[1]['venue'] else None, 139 | 'venue_raw': x[1]['venue'].get('raw', None) if x[1]['venue'] else None, 140 | 'fos': len(x[1]['fos']), 141 | 'keywords': len(x[1]['keywords']), 142 | 'year': x[1]['year'], 143 | 'citations': len(cite.get(x[0], [])) 144 | }, info_dict.items()) 145 | no_cite_paper = [paper for paper in info_dict if paper not in cite] 146 | print(len(no_cite_paper)) 147 | print(len(info_dict)) 148 | 149 | df = pd.DataFrame(dealt_dict) 150 | print(df.describe()) 151 | df_venue = df.groupby('venue').count().sort_values(by='id', ascending=False)['id'] 152 | df_venue.to_csv(data_path + 'stats_venue_count.csv') 153 | print(df_venue.head(20)) 154 | df_venue_raw = df.groupby(['venue', 'venue_raw']).count().sort_values(by='id', ascending=False)['id'] 155 | df_venue_raw.to_csv(data_path + 'stats_venue_raw_count.csv') 156 | print(df_venue_raw.head(20)) 157 | 158 | df_year = df.groupby('year').count() 159 | df_year.to_csv(data_path + 'stats_year_count.csv') 160 | print(df_year) 161 | del dealt_dict 162 | 163 | abstract_dict = json.load(open(data_path + 'all_abstract_dict.json')) 164 | # valid_paper = [len(abstract.strip().split(' ')) >= 20 for abstract in abstract_dict.values() if abstract] 165 | valid_paper = set(map(lambda x: x[0], 166 | filter(lambda x: len(str(x[1]).strip().split(' ')) >= 20, abstract_dict.items()))) 167 | print(len(valid_paper)) 168 | 169 | df_journal_counts = df.groupby('venue').count().sort_values(by='id', ascending=False)['id'] 170 | df_journal_citations = df.groupby('venue').sum().sort_values(by='citations', ascending=False)['citations'] 171 | print(df_journal_citations.head(20)) 172 | df_journal_avg_citations = (df_journal_citations / df_journal_counts).sort_values(ascending=False) 173 | print(df_journal_avg_citations.head(20)) 174 | df_journal_counts.to_csv(data_path + 'journal_count.csv') 175 | df_journal_citations.to_csv(data_path + 'journal_citation.csv') 176 | df_journal_avg_citations.to_csv(data_path + 'journal_avg_citation.csv') 177 | 178 | df_selected = df[df['id'].isin(valid_paper)] 179 | df_selected.to_csv(data_path + 'selected_data.csv') 180 | df_selected_year = df_selected.groupby('year').count() 181 | df_selected_year.to_csv(data_path + 'stats_selected_year_count.csv') 182 | print(df_selected_year) 183 | 184 | 185 | def get_subset(data_path, time_point=None): 186 | print(time_point) 187 | 188 | selected_set = get_valid_papers(data_path, time_point) 189 | abs_dict = json.load(open(data_path + 'all_abstract_dict.json', 'r')) 190 | abs_dict = dict( 191 | filter(lambda x: (x[0] in selected_set), abs_dict.items())) 192 | 193 | # save subset 194 | json.dump(abs_dict, open(data_path + 'sample_abstract_dict.json', 'w+')) 195 | del abs_dict 196 | data = json.load(open(data_path + 'all_info_dict.json', 'r')) 197 | data = dict(filter(lambda x: x[0] in selected_set, data.items())) 198 | json.dump(data, open(data_path + 'sample_info_dict.json', 'w+')) 199 | del data 200 | 201 | ref = json.load(open(data_path + 'all_ref_dict.json', 'r')) 202 | # cite = json.load(open(data_path + 'all_cite_data.json', 'r')) 203 | cite_year = json.load(open(data_path + 'cite_year_dict.json', 'r')) 204 | 205 | ref = dict(filter(lambda x: x[0] in selected_set, ref.items())) 206 | ref = dict(map(lambda x: (x[0], [paper for paper in x[1] if paper in selected_set]), ref.items())) 207 | json.dump(ref, open(data_path + 'sample_ref_dict.json', 'w+')) 208 | 209 | cite_year = dict( 210 | map(lambda x: (x[0], Counter(x[1])), filter(lambda x: x[0] in selected_set, cite_year.items()))) 211 | json.dump(cite_year, open(data_path + 'sample_cite_year_dict.json', 'w+')) 212 | 213 | def get_citation_accum(data_path, time_point, all_times=(2011, 2013, 2015)): 214 | cite_year_path = data_path + 'sample_cite_year_dict.json' 215 | ref_path = data_path + 'sample_ref_dict.json' 216 | info_path = data_path + 'sample_info_dict.json' 217 | cite_year_dict = json.load(open(cite_year_path, 'r')) 218 | ref_dict = json.load(open(ref_path, 'r')) 219 | info_dict = json.load(open(info_path, 'r')) 220 | 221 | predicted_list = list(cite_year_dict.keys()) 222 | accum_num_dict = {} 223 | 224 | # for paper in predicted_list: 225 | # temp_cum_num = [] 226 | # pub_year = int(info_dict[paper]['pub_date']['year']) 227 | # count_dict = cite_year_dict[paper] 228 | # count = 0 229 | # for year in range(pub_year, pub_year+time_window): 230 | # if str(year) in count_dict: 231 | # count += int(count_dict[str(year)]) 232 | # temp_cum_num.append(count) 233 | # accum_num_dict[paper] = temp_cum_num 234 | # 235 | # print(len(accum_num_dict)) 236 | # print(len(list(filter(lambda x: x[-1] > 0, accum_num_dict.values())))) 237 | # print(len(list(filter(lambda x: x[-1] >= 10, accum_num_dict.values())))) 238 | # print(len(list(filter(lambda x: x[-1] >= 100, accum_num_dict.values())))) 239 | # 240 | # if subset: 241 | # accum_num_dict = dict(filter(lambda x: x[1][-1] >= 10, accum_num_dict.items())) 242 | # json.dump(accum_num_dict, open(data_path + 'sample_citation_accum.json', 'w+')) 243 | # else: 244 | # json.dump(accum_num_dict, open(data_path + 'all_citation_accum.json', 'w+')) 245 | print('writing citations accum') 246 | pub_dict = {} 247 | for paper in predicted_list: 248 | count_dict = dict(map(lambda x: (int(x[0]), x[1]), cite_year_dict[paper].items())) 249 | if (sum(count_dict.values()) >= 10) and (len(ref_dict[paper]) >= 5): 250 | # if len(ref_dict[paper]) >= 5: 251 | # pub_year = int(info_dict[paper]['pub_date']['year']) 252 | pub_year = int(info_dict[paper]['year']) 253 | pub_dict[paper] = pub_year 254 | start_year = max(pub_year, time_point - 4) 255 | # print(count_dict) 256 | count = sum(dict(filter(lambda x: x[0] <= start_year, count_dict.items())).values()) 257 | # print(start_count) 258 | # temp_input_accum_num = [None] * (start_year + 4 - time_point) 259 | temp_input_accum_num = [-1] * (start_year + 4 - time_point) 260 | temp_output_accum_num = [] 261 | for year in range(start_year, time_point + 1): 262 | if year in count_dict: 263 | count += int(count_dict[year]) 264 | temp_input_accum_num.append(count) 265 | for year in range(time_point + 1, time_point + 6): 266 | if year in count_dict: 267 | count += int(count_dict[year]) 268 | temp_output_accum_num.append(count) 269 | accum_num_dict[paper] = (temp_input_accum_num, temp_output_accum_num) 270 | # print(temp_input_accum_num) 271 | print(len(accum_num_dict)) 272 | # accum_num_dict = dict(filter(lambda x: x[1][-1] >= 10, accum_num_dict.items())) 273 | print('writing file') 274 | json.dump(accum_num_dict, open(data_path + 'sample_citation_accum.json', 'w+')) 275 | for cur_time in all_times: 276 | cur_dict = {} 277 | cur_papers = map(lambda x: x[0], filter(lambda x: x[1] <= cur_time, pub_dict.items())) 278 | for paper in cur_papers: 279 | cur_dict[paper] = accum_num_dict[paper][1][cur_time + 4 - time_point] - \ 280 | accum_num_dict[paper][0][cur_time + 4 - time_point] 281 | print(cur_time, len(cur_dict)) 282 | df = pd.DataFrame(data=cur_dict.items()) 283 | df.columns = ['paper', 'citations'] 284 | df.to_csv(data_path + 'citation_{}.csv'.format(cur_time)) 285 | 286 | 287 | 288 | def get_input_data(data_path, time_point=None, subset=False): 289 | info_path = data_path + 'all_info_dict.json' 290 | ref_path = data_path + 'all_ref_dict.json' 291 | cite_year_path = data_path + 'all_cite_year_dict.json' 292 | # abs_path = 'all_abstract_dict.json' 293 | if subset: 294 | info_path = data_path + 'sample_info_dict.json' 295 | ref_path = data_path + 'sample_ref_dict.json' 296 | cite_year_path = data_path + 'sample_cite_year_dict.json' 297 | # abs_path = 'sample_abstract_dict.json' 298 | 299 | info_dict = json.load(open(info_path, 'r')) 300 | ref_dict = json.load(open(ref_path, 'r')) 301 | cite_year_dict = json.load(open(cite_year_path, 'r')) 302 | 303 | node_trans = {} 304 | paper_trans = {} 305 | # graph created 306 | src_list = [] 307 | dst_list = [] 308 | index = 0 309 | for dst in ref_dict: 310 | if dst in paper_trans: 311 | dst_idx = paper_trans[dst] 312 | else: 313 | dst_idx = index 314 | paper_trans[dst] = dst_idx 315 | index += 1 316 | for src in set(ref_dict[dst]): 317 | if src in paper_trans: 318 | src_idx = paper_trans[src] 319 | else: 320 | src_idx = index 321 | paper_trans[src] = src_idx 322 | index += 1 323 | src_list.append(src_idx) 324 | dst_list.append(dst_idx) 325 | 326 | node_trans['paper'] = paper_trans 327 | 328 | author_trans = {} 329 | journal_trans = {} 330 | author_src, author_dst = [], [] 331 | journal_src, journal_dst = [], [] 332 | author_index = 0 333 | journal_index = 0 334 | for paper in paper_trans: 335 | meta_data = info_dict[paper] 336 | try: 337 | # authors = list(map(lambda x: (x['name'].get('given-names', '') + ' ' + x['name'].get('sur', '')).strip().lower(), 338 | # meta_data['authors'])) 339 | authors = set([author['_id'] for author in meta_data['authors'] if '_id' in author]) 340 | except Exception as e: 341 | print(meta_data['authors']) 342 | authors = None 343 | # print(authors) 344 | # print(meta_data['authors'][0]['name']) 345 | # journal = meta_data['journal']['ids']['nlm-ta'] if 'nlm-ta' in meta_data['journal']['ids'] else None 346 | journal = meta_data['venue']['_id'] if '_id' in meta_data['venue'] else None 347 | 348 | for author in authors: 349 | if author in author_trans: 350 | cur_idx = author_trans[author] 351 | else: 352 | cur_idx = author_index 353 | author_trans[author] = cur_idx 354 | author_index += 1 355 | author_src.append(cur_idx) 356 | author_dst.append(paper_trans[paper]) 357 | 358 | if journal: 359 | if journal in journal_trans: 360 | cur_idx = journal_trans[journal] 361 | else: 362 | cur_idx = journal_index 363 | journal_trans[journal] = cur_idx 364 | journal_index += 1 365 | journal_src.append(cur_idx) 366 | journal_dst.append(paper_trans[paper]) 367 | 368 | node_trans['author'] = author_trans 369 | node_trans['journal'] = journal_trans 370 | 371 | # node_trans_reverse = dict(map(lambda x: (x[0], dict(zip(x[1].values(), x[1].keys()))), node_trans.items())) 372 | node_trans_reverse = {key: dict(zip(node_trans[key].values(), node_trans[key].keys())) for key in node_trans} 373 | # paper_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['pub_date']['year']) for paper in dst_list] 374 | # author_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['pub_date']['year']) for paper in author_dst] 375 | # journal_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['pub_date']['year']) for paper in journal_dst] 376 | paper_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['year']) for paper in dst_list] 377 | author_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['year']) for paper in author_dst] 378 | journal_link_time = [int(info_dict[node_trans_reverse['paper'][paper]]['year']) for paper in journal_dst] 379 | paper_time = [int(info_dict[paper]['year']) for paper in node_trans['paper'].keys()] 380 | 381 | if subset: 382 | json.dump(node_trans, open(data_path + 'sample_node_trans.json', 'w+')) 383 | else: 384 | json.dump(node_trans, open(data_path + 'all_node_trans.json', 'w+')) 385 | 386 | # graph = dgl.graph((src_list, dst_list), num_nodes=len(paper_trans)) 387 | graph = dgl.heterograph({ 388 | ('paper', 'is cited by', 'paper'): (src_list, dst_list), 389 | ('paper', 'cites', 'paper'): (dst_list, src_list), 390 | ('author', 'writes', 'paper'): (author_src, author_dst), 391 | ('paper', 'is writen by', 'author'): (author_dst, author_src), 392 | ('journal', 'publishes', 'paper'): (journal_src, journal_dst), 393 | ('paper', 'is published by', 'journal'): (journal_dst, journal_src), 394 | }, num_nodes_dict={ 395 | 'paper': len(paper_trans), 396 | 'author': len(author_trans), 397 | 'journal': len(journal_trans) 398 | }) 399 | # graph.ndata['paper_id'] = torch.tensor(list(node_trans.keys())).unsqueeze(dim=0) 400 | graph.nodes['paper'].data['time'] = torch.tensor(paper_time, dtype=torch.int16).unsqueeze(dim=-1) 401 | graph.edges['is cited by'].data['time'] = torch.tensor(paper_link_time, dtype=torch.int16).unsqueeze(dim=-1) 402 | graph.edges['cites'].data['time'] = torch.tensor(paper_link_time, dtype=torch.int16).unsqueeze(dim=-1) 403 | graph.edges['writes'].data['time'] = torch.tensor(author_link_time, dtype=torch.int16).unsqueeze(dim=-1) 404 | graph.edges['is writen by'].data['time'] = torch.tensor(author_link_time, dtype=torch.int16).unsqueeze(dim=-1) 405 | graph.edges['publishes'].data['time'] = torch.tensor(journal_link_time, dtype=torch.int16).unsqueeze(dim=-1) 406 | graph.edges['is published by'].data['time'] = torch.tensor(journal_link_time, dtype=torch.int16).unsqueeze(dim=-1) 407 | 408 | graph = dgl.remove_self_loop(graph, 'is cited by') 409 | graph = dgl.remove_self_loop(graph, 'cites') 410 | print(graph) 411 | print(graph.edges['cites'].data['time']) 412 | if subset: 413 | torch.save(graph, data_path + 'graph_sample') 414 | else: 415 | torch.save(graph, data_path + 'graph') 416 | 417 | del graph, src_list, dst_list 418 | 419 | if __name__ == "__main__": 420 | start_time = datetime.datetime.now() 421 | parser = argparse.ArgumentParser(description='Process some description.') 422 | parser.add_argument('--phase', default='test', help='the function name.') 423 | parser.add_argument('--data_path', default=None, help='the input.') 424 | parser.add_argument('--name', default=None, help='file name.') 425 | parser.add_argument('--out_path', default=None, help='the output.') 426 | # parser.add_argument('--seed', default=123, help='the seed.') 427 | args = parser.parse_args() 428 | # TIME_POINT = 2015 429 | # ALL_TIMES = (2011, 2013, 2015) 430 | TIME_POINT = 2013 431 | ALL_TIMES = (2009, 2011, 2013) 432 | if args.phase == 'test': 433 | print('This is a test process.') 434 | # get_lines_data('./data/') 435 | # get_all_dict('./data/', 'info') 436 | # get_all_dict('./data/', 'abstract') 437 | # get_all_dict('./data/', 'ref') 438 | # get_cite_data('../data/') 439 | elif args.phase == 'lines_data': 440 | get_lines_data(args.data_path) 441 | print('lines data done') 442 | elif args.phase == 'info_data': 443 | get_all_dict(args.data_path, 'info') 444 | print('nfo data done') 445 | elif args.phase == 'abstract_data': 446 | get_all_dict(args.data_path, 'abstract') 447 | print('abstract data done') 448 | elif args.phase == 'ref_data': 449 | get_all_dict(args.data_path, 'ref') 450 | print('ref data done') 451 | elif args.phase == 'cite_data': 452 | get_cite_data(args.data_path) 453 | print('cite data done') 454 | elif args.phase == 'show_data': 455 | show_data(args.data_path) 456 | print('show data done') 457 | elif args.phase == 'subset': 458 | get_subset(args.data_path, time_point=TIME_POINT) 459 | print('subset data done') 460 | elif args.phase == 'subset_input_data': 461 | get_input_data(args.data_path, subset=True, time_point=TIME_POINT) 462 | print('subset input data done') 463 | elif args.phase == 'citation_accum': 464 | get_citation_accum(args.data_path, time_point=TIME_POINT, all_times=ALL_TIMES) 465 | print('citation accum done') -------------------------------------------------------------------------------- /utilis/log_bar.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from tqdm import tqdm 4 | import io 5 | 6 | class TqdmToLogger(io.StringIO): 7 | """ 8 | Output stream for TQDM which will output to logger module instead of 9 | the StdOut. 10 | """ 11 | logger = None 12 | level = None 13 | buf = '' 14 | def __init__(self,logger,level=None): 15 | super(TqdmToLogger, self).__init__() 16 | self.logger = logger 17 | self.level = level or logging.INFO 18 | def write(self,buf): 19 | self.buf = buf.strip('\r\n\t ') 20 | def flush(self): 21 | self.logger.log(self.level, self.buf) 22 | 23 | if __name__ == "__main__": 24 | logging.basicConfig(format='%(asctime)s [%(levelname)-8s] %(message)s') 25 | logger = logging.getLogger() 26 | logger.setLevel(logging.DEBUG) 27 | 28 | tqdm_out = TqdmToLogger(logger,level=logging.INFO) 29 | for x in tqdm(range(100),file=tqdm_out,mininterval=1,): 30 | time.sleep(2) 31 | -------------------------------------------------------------------------------- /utilis/scripts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import random 4 | 5 | import dgl 6 | import joblib 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from matplotlib import pyplot as plt 11 | from sklearn.manifold import TSNE 12 | from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, mean_squared_log_error 13 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss, roc_auc_score, \ 14 | precision_recall_curve, auc 15 | import torch.nn.functional as F 16 | 17 | 18 | def eval_result(all_true_values, all_predicted_values): 19 | # print(mean_absolute_error(all_true_values, all_predicted_values, multioutput='raw_values')) 20 | all_true_values = all_true_values.flatten(order='C') 21 | # print(all_true_values) 22 | all_predicted_values = all_predicted_values.flatten(order='C') 23 | all_predicted_values = np.where(all_predicted_values == np.inf, 1e20, all_predicted_values) 24 | # print(all_predicted_values) 25 | mae = mean_absolute_error(all_true_values, all_predicted_values) 26 | mse = mean_squared_error(all_true_values, all_predicted_values) 27 | rmse = math.sqrt(mse) 28 | r2 = r2_score(all_true_values, all_predicted_values) 29 | 30 | rse = (all_true_values - all_predicted_values) ** 2 / (all_true_values + 1) 31 | Mrse = np.mean(rse, axis=0) 32 | mrse = np.percentile(rse, 50) 33 | 34 | acc = np.mean(np.all(((0.5 * all_true_values) <= all_predicted_values, 35 | all_predicted_values <= (1.5 * all_true_values)), axis=0), axis=0) 36 | 37 | # print(all_true_values) 38 | # add min for log 39 | log_true_values = np.log(1 + all_true_values) 40 | log_predicted_values = np.log(1 + np.where(all_predicted_values >= 0, all_predicted_values, 0)) 41 | male = mean_absolute_error(log_true_values, log_predicted_values) 42 | msle = mean_squared_error(log_true_values, log_predicted_values) 43 | rmsle = math.sqrt(msle) 44 | log_r2 = r2_score(log_true_values, log_predicted_values) 45 | smape = np.mean(np.abs(log_true_values - log_predicted_values) / 46 | ((np.abs(log_true_values) + np.abs(log_predicted_values) + 0.1) / 2), axis=0) 47 | mape = np.mean(np.abs(log_true_values - log_predicted_values) / (np.abs(log_true_values) + 0.1), axis=0) 48 | # print('true', all_true_values[:10]) 49 | # print('pred', all_predicted_values[:10]) 50 | return [acc, mae, r2, mse, rmse, Mrse, mrse, male, log_r2, msle, rmsle, smape, mape] 51 | 52 | 53 | def eval_aux_results(all_true_label, all_predicted_result): 54 | all_predicted_label = np.argmax(all_predicted_result, axis=1) 55 | # all 56 | acc = accuracy_score(all_true_label, all_predicted_label) 57 | one_hot = np.eye(all_predicted_result.shape[-1])[all_true_label] 58 | try: 59 | roc_auc = roc_auc_score(one_hot, all_predicted_result) 60 | log_loss_value = log_loss(all_true_label, all_predicted_result) 61 | except Exception as e: 62 | print(e) 63 | roc_auc = 0 64 | log_loss_value = 0 65 | # print(all_predicted_result) 66 | # # pos 67 | # prec = precision_score(all_true_label, all_predicted_label) 68 | # recall = recall_score(all_true_label, all_predicted_label) 69 | # f1 = f1_score(all_true_label, all_predicted_label, average='binary') 70 | # pr_p, pr_r, _ = precision_recall_curve(all_true_label, all_predicted_result[:, 1]) 71 | # pr_auc = auc(pr_r, pr_p) 72 | # # neg 73 | # prec_neg = precision_score(all_true_label, all_predicted_label, pos_label=0) 74 | # recall_neg = recall_score(all_true_label, all_predicted_label, pos_label=0) 75 | # f1_neg = f1_score(all_true_label, all_predicted_label, average='binary', pos_label=0) 76 | # pr_np, pr_nr, _ = precision_recall_curve(all_true_label, all_predicted_result[:, 0], pos_label=0) 77 | # pr_auc_neg = auc(pr_nr, pr_np) 78 | # for all 79 | micro_prec = precision_score(all_true_label, all_predicted_label, average='micro') 80 | micro_recall = recall_score(all_true_label, all_predicted_label, average='micro') 81 | micro_f1 = f1_score(all_true_label, all_predicted_label, average='micro') 82 | macro_prec = precision_score(all_true_label, all_predicted_label, average='macro') 83 | macro_recall = recall_score(all_true_label, all_predicted_label, average='macro') 84 | macro_f1 = f1_score(all_true_label, all_predicted_label, average='macro') 85 | 86 | results = [acc, roc_auc, log_loss_value, micro_prec, micro_recall, micro_f1, macro_prec, macro_recall, macro_f1] 87 | return results 88 | 89 | 90 | def eval_aux_labels(all_true_label, all_predicted_label): 91 | # acc, mi_prec, mi_recall, mi_f1, ma_prec, ma_recall, ma_f1 92 | # all_predicted_label = np.argmax(all_predicted_result, axis=1) 93 | # all 94 | acc = accuracy_score(all_true_label, all_predicted_label) 95 | # one_hot = np.eye(all_predicted_result.shape[-1])[all_true_label] 96 | # try: 97 | # roc_auc = roc_auc_score(one_hot, all_predicted_result) 98 | # log_loss_value = log_loss(all_true_label, all_predicted_result) 99 | # except Exception as e: 100 | # print(e) 101 | # roc_auc = 0 102 | # log_loss_value = 0 103 | # print(all_predicted_result) 104 | # # pos 105 | # prec = precision_score(all_true_label, all_predicted_label) 106 | # recall = recall_score(all_true_label, all_predicted_label) 107 | # f1 = f1_score(all_true_label, all_predicted_label, average='binary') 108 | # pr_p, pr_r, _ = precision_recall_curve(all_true_label, all_predicted_result[:, 1]) 109 | # pr_auc = auc(pr_r, pr_p) 110 | # # neg 111 | # prec_neg = precision_score(all_true_label, all_predicted_label, pos_label=0) 112 | # recall_neg = recall_score(all_true_label, all_predicted_label, pos_label=0) 113 | # f1_neg = f1_score(all_true_label, all_predicted_label, average='binary', pos_label=0) 114 | # pr_np, pr_nr, _ = precision_recall_curve(all_true_label, all_predicted_result[:, 0], pos_label=0) 115 | # pr_auc_neg = auc(pr_nr, pr_np) 116 | # for all 117 | micro_prec = precision_score(all_true_label, all_predicted_label, average='micro') 118 | micro_recall = recall_score(all_true_label, all_predicted_label, average='micro') 119 | micro_f1 = f1_score(all_true_label, all_predicted_label, average='micro') 120 | macro_prec = precision_score(all_true_label, all_predicted_label, average='macro') 121 | macro_recall = recall_score(all_true_label, all_predicted_label, average='macro') 122 | macro_f1 = f1_score(all_true_label, all_predicted_label, average='macro') 123 | 124 | results = [acc, micro_prec, micro_recall, micro_f1, macro_prec, macro_recall, macro_f1] 125 | return results 126 | 127 | 128 | def result_format(results): 129 | loss, acc, mae, r2, mse, rmse, Mrse, mrse, male, log_r2, msle, rmsle, smape, mape = results 130 | format_str = '| loss {:8.4f} | acc {:9.4f} |\n' \ 131 | '| MAE {:9.4f} | R2 {:10.4f} |\n' \ 132 | '| MSE {:9.4f} | RMSE {:8.4f} |\n'\ 133 | '| MRSE {:8.4f} | mRSE {:8.4f} |\n'\ 134 | '| MALE {:8.4f} | LR2 {:9.4f} |\n'\ 135 | '| MSLE {:8.4f} | RMSLE {:7.4f} |\n'\ 136 | '| SMAPE {:8.4f} | MAPE {:7.4f} |\n'.format(loss, acc, 137 | mae, r2, 138 | mse, rmse, 139 | Mrse, mrse, 140 | male, log_r2, 141 | msle, rmsle, 142 | smape, mape) + \ 143 | '-' * 59 144 | print(format_str) 145 | return format_str 146 | 147 | def aux_result_format(results, phase): 148 | avg_loss, acc, roc_auc, log_loss_value, micro_prec, micro_recall, micro_f1, macro_prec, macro_recall, macro_f1 = results 149 | 150 | format_str = '| aux loss {:8.3f} | {:9} {:7.3f} |\n' \ 151 | '| roc_auc {:9.3f} | log_loss {:8.3f} |\n' \ 152 | 'micro:\n' \ 153 | '| precision {:7.3f} | recall {:10.3f} |\n' \ 154 | '| f1 {:14.3f} |\n' \ 155 | 'macro:\n' \ 156 | '| precision {:7.3f} | recall {:10.3f} |\n' \ 157 | '| f1 {:14.3f} |\n'.format(avg_loss, phase + ' acc', acc, roc_auc, log_loss_value, 158 | micro_prec, micro_recall, micro_f1, 159 | macro_prec, macro_recall, macro_f1) 160 | print(format_str) 161 | return format_str 162 | 163 | 164 | def get_configs(data_source, model_list, replace_dict=None): 165 | fr = open('./configs/{}.json'.format(data_source)) 166 | configs = json.load(fr) 167 | full_configs = {'default': configs['default']} 168 | if replace_dict: 169 | print('>>>>>>>>>>replacing here<<<<<<<<<<<<<<<') 170 | print(replace_dict) 171 | for key in replace_dict: 172 | full_configs['default'][key] = replace_dict[key] 173 | for model in model_list: 174 | full_configs[model] = configs['default'].copy() 175 | if model in configs.keys(): 176 | for key in configs[model].keys(): 177 | full_configs[model][key] = configs[model][key] 178 | return full_configs 179 | 180 | 181 | def setup_seed(seed): 182 | torch.manual_seed(seed) 183 | torch.cuda.manual_seed(seed) 184 | torch.cuda.manual_seed_all(seed) 185 | np.random.seed(seed) 186 | random.seed(seed) 187 | 188 | try: 189 | dgl.seed(seed) 190 | except Exception as E: 191 | print(E) 192 | # torch.backends.cudnn.deterministic = True 193 | # os.environ["OMP_NUM_THREADS"] = '1' 194 | 195 | 196 | class IndexDict(dict): 197 | def __init__(self): 198 | super(IndexDict, self).__init__() 199 | self.count = 0 200 | 201 | def __getitem__(self, item): 202 | if item not in self.keys(): 203 | super().__setitem__(item, self.count) 204 | self.count += 1 205 | return super().__getitem__(item) 206 | 207 | 208 | def add_new_elements(graph, nodes={}, ndata={}, edges={}, edata={}): 209 | ''' 210 | :param graph: Dgl 211 | :param nodes: {ntype: []} 212 | :param ndata: {ntype: {attr: []} 213 | :param edges: {(src_type, etype, dst_type): [(src, dst)] 214 | :param ndata: {etype: {attr: []} 215 | :return: 216 | ''' 217 | # etypes = graph.canonical_etypes 218 | num_nodes_dict = {ntype: graph.num_nodes(ntype) for ntype in graph.ntypes} 219 | for ntype in nodes: 220 | num_nodes_dict[ntype] = nodes[ntype] 221 | # etypes.extend(list(edges.keys())) 222 | 223 | relations = {} 224 | for etype in graph.canonical_etypes: 225 | src, dst = graph.edges(etype=etype[1]) 226 | relations[etype] = (src, dst) 227 | for etype in edges: 228 | relations[etype] = edges[etype] 229 | 230 | # print(relations) 231 | # print(num_nodes_dict) 232 | new_g = dgl.heterograph(relations, num_nodes_dict=num_nodes_dict) 233 | 234 | for ntype in graph.ntypes: 235 | for k, v in graph.nodes[ntype].data.items(): 236 | new_g.nodes[ntype].data[k] = v.detach().clone() 237 | for etype in graph.etypes: 238 | for k, v in graph.edges[etype].data.items(): 239 | new_g.edges[etype].data[k] = v.detach().clone() 240 | 241 | for ntype in ndata: 242 | for attr in ndata[ntype]: 243 | new_g.nodes[ntype].data[attr] = ndata[ntype][attr] 244 | 245 | for etype in edata: 246 | for attr in edata[etype]: 247 | # try: 248 | new_g.edges[etype].data[attr] = edata[etype][attr] 249 | # except Exception as e: 250 | # print(e) 251 | # print(etype, attr) 252 | # print(new_g.edges(etype=etype)) 253 | # print(edata[etype][attr]) 254 | # raise Exception 255 | # print(new_g) 256 | return new_g 257 | 258 | 259 | def get_norm_adj(sparse_adj): 260 | I = torch.diag(torch.ones(sparse_adj.shape[0])).to_sparse() 261 | sparse_adj = sparse_adj + I 262 | D = torch.diag(torch.sparse.sum(sparse_adj, dim=-1).pow(-1).to_dense()) 263 | # print(I) 264 | # print(sparse_adj.to_dense()) 265 | return torch.matmul(D, sparse_adj.to_dense()).to_sparse() 266 | 267 | 268 | def custom_node_unbatch(g): 269 | # just unbatch the node attrs for cl, not including etypes for efficiency 270 | node_split = {ntype: g.batch_num_nodes(ntype) for ntype in g.ntypes} 271 | node_split = {k: dgl.backend.asnumpy(split).tolist() for k, split in node_split.items()} 272 | print(node_split) 273 | node_attr_dict = {ntype: {} for ntype in g.ntypes} 274 | for ntype in g.ntypes: 275 | for key, feat in g.nodes[ntype].data.items(): 276 | subfeats = dgl.backend.split(feat, node_split[ntype], 0) 277 | # print(subfeats) 278 | node_attr_dict[ntype][key] = subfeats 279 | return node_attr_dict 280 | 281 | def get_label(citation): 282 | if citation < 10: 283 | return 0 284 | elif citation >= 100: 285 | return 2 286 | else: 287 | return 1 288 | 289 | 290 | def get_sampled_index(data): 291 | print('wip') 292 | df = pd.DataFrame(data['test'][1]) 293 | df['label'] = df[0].apply(get_label) 294 | sampled_count = 1000 295 | print(df[df['label'] == 0]) 296 | cls_0 = df[df['label'] == 0].sample(sampled_count, random_state=123).index 297 | cls_1 = df[df['label'] == 1].sample(sampled_count, random_state=123).index 298 | print(cls_0[:5], cls_1[:5]) 299 | cls_2 = df[df['label'] == 2].sample(sampled_count, random_state=123).index 300 | all_index = list(cls_0) + list(cls_1) + list(cls_2) 301 | return all_index 302 | 303 | 304 | def get_selected_data(ids, dataProcessor): 305 | data = dataProcessor.get_data() 306 | --------------------------------------------------------------------------------