├── History ├── pretrn_gen0.his ├── pretrn_gen1.his └── pretrn_gen2.his ├── LICENSE ├── Models └── readme ├── README.md ├── datasets ├── amazon-book │ ├── fewshot_mat_1.pkl │ ├── fewshot_mat_5.pkl │ ├── trn_mat.pkl.zip │ └── tst_mat.pkl ├── citeseer │ ├── adj_-1.pkl │ ├── adj_1.pkl │ ├── adj_5.pkl │ ├── feats.pkl.zip │ ├── label.pkl │ ├── mask_-1.pkl │ ├── mask_1.pkl │ └── mask_5.pkl ├── collab │ ├── fewshot_mat_1.pkl │ ├── fewshot_mat_5.pkl │ ├── trn_mat.pkl.zip │ ├── tst_mat.pkl │ └── val_mat.pkl ├── cora │ ├── adj_-1.pkl │ ├── adj_1.pkl │ ├── adj_5.pkl │ ├── feats.pkl │ ├── label.pkl │ ├── mask_-1.pkl │ ├── mask_1.pkl │ └── mask_5.pkl ├── ddi │ ├── fewshot_mat_1.pkl │ ├── fewshot_mat_5.pkl │ ├── trn_mat.pkl.zip │ ├── tst_mat.pkl │ └── val_mat.pkl ├── gen0 │ ├── trn_mat.pkl │ ├── tst_mat.pkl │ └── val_mat.pkl ├── gen1 │ ├── trn_mat.pkl │ ├── tst_mat.pkl │ └── val_mat.pkl ├── gen2 │ ├── trn_mat.pkl │ ├── tst_mat.pkl │ └── val_mat.pkl ├── ml10m │ ├── fewshot_mat_1.pkl │ ├── fewshot_mat_5.pkl │ ├── trn_mat.pkl.zip │ └── tst_mat.pkl.zip ├── ml1m │ ├── fewshot_mat_1.pkl │ ├── fewshot_mat_5.pkl │ ├── trn_mat.pkl │ └── tst_mat.pkl └── pubmed │ ├── adj_-1.pkl │ ├── adj_1.pkl │ ├── adj_5.pkl │ ├── feats.pkl.zip │ ├── label.pkl │ ├── mask_-1.pkl │ ├── mask_1.pkl │ └── mask_5.pkl ├── graph_generation ├── Exp_Utils │ ├── Emailer.py │ └── TimeLogger.py ├── Utils.py ├── embedding_generation.py ├── gen_results │ ├── datasets │ │ └── gen_data_ecommerce │ │ │ ├── embedding_dict.pkl │ │ │ ├── interaction_base-0_iter-0.pkl │ │ │ ├── item_list.pkl │ │ │ └── res │ │ │ ├── interaction_fuse_iter-0.pkl │ │ │ ├── iter-0_imap.pkl │ │ │ ├── iter-0_test.pkl │ │ │ ├── iter-0_train.pkl │ │ │ └── iter-0_valid.pkl │ ├── products_e-commerce platform like Amazon.txt │ ├── tem │ │ ├── e-commerce platform like Amazon_depth1_products │ │ ├── e-commerce platform like Amazon_depth2_products, Automotive │ │ ├── e-commerce platform like Amazon_depth2_products, Baby │ │ ├── e-commerce platform like Amazon_depth2_products, Beauty │ │ ├── e-commerce platform like Amazon_depth2_products, Books │ │ ├── e-commerce platform like Amazon_depth2_products, Clothing │ │ ├── e-commerce platform like Amazon_depth2_products, Electronics │ │ ├── e-commerce platform like Amazon_depth2_products, Handmade │ │ ├── e-commerce platform like Amazon_depth2_products, Health and Personal Care │ │ ├── e-commerce platform like Amazon_depth2_products, Home Improvement │ │ ├── e-commerce platform like Amazon_depth2_products, Industrial and Scientific │ │ ├── e-commerce platform like Amazon_depth2_products, Jewelry │ │ ├── e-commerce platform like Amazon_depth2_products, Musical Instruments │ │ ├── e-commerce platform like Amazon_depth2_products, Office Products │ │ ├── e-commerce platform like Amazon_depth2_products, Pet Supplies │ │ ├── e-commerce platform like Amazon_depth2_products, Sports and Outdoors │ │ ├── e-commerce platform like Amazon_depth2_products, Tools and Home Improvement │ │ └── e-commerce platform like Amazon_depth2_products, Toys │ └── tree_wInstanceNum_products_e-commerce platform like Amazon.pkl ├── human_item_generation_gibbsSampling_embedEstimation.py ├── instance_number_estimation_hierarchical.py ├── itemCollecting_dfsIterator.py └── make_adjs.py ├── imgs ├── article cover.jpg ├── framework.png ├── graph tokenizer.png ├── impact of datasets.png ├── intro.png ├── opengraph_article_cover_full.png ├── performance.png ├── prompt.png └── sampling.png ├── link_prediction ├── Utils │ └── TimeLogger.py ├── data_handler.py ├── main.py ├── model.py └── params.py └── node_classification ├── Utils └── TimeLogger.py ├── data_handler.py ├── main.py ├── model.py └── params.py /History/pretrn_gen0.his: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/History/pretrn_gen0.his -------------------------------------------------------------------------------- /History/pretrn_gen1.his: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/History/pretrn_gen1.his -------------------------------------------------------------------------------- /History/pretrn_gen2.his: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/History/pretrn_gen2.his -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Models/readme: -------------------------------------------------------------------------------- 1 | Download the pre-trained model via the link: https://drive.google.com/drive/folders/1d-Jn7LHJ2ZSmndKveU4_qS70-Z6j9utB?usp=drive_link -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenGraph: Towards Open Graph Foundation Models 2 | 3 |
4 | 5 | 6 | 7 | 8 | 9 | 10 | Lianghao Xia, Ben Kao, and Chao Huang* (*Correspondence) 11 | 12 | 13 | 14 | 15 | 16 | Presenting OpenGraph, a foundation graph model distilling zero-shot graph generalizability from LLMs. 17 | 18 | 19 | 20 |
21 | 22 | To achieve this goal, OpenGraph addresses several key technical challenges: 23 | - We propose a unified graph tokenizer to adapt our graph model to generalize well on unseen graph data, even when the underlying graph properties differ significantly from those encountered during training. 24 | - We develop a scalable graph transformer as the foundational encoder, which effectively and efficiently captures node-wise dependencies within the global topological context. 25 | - We introduce a data augmentation mechanism enhanced by a large language model (LLM) to alleviate the limitations of data scarcity in real-world scenarios. 26 | 27 | 28 | 29 | Extensive experiments validate the effectiveness of our framework. By adapting OpenGraph to new graph characteristics and comprehending the nuances of diverse graphs, our approach achieves remarkable zero-shot graph learning performance across various settings and domains. 30 | 31 | ## Environment Setup 32 | You need to unzip some of the data files in `datasets/`. Download the pre-trained models using the link in `Models/readme`. Our experiments were conducted with the following package versions: 33 | * python==3.10.13 34 | * torch==1.13.0 35 | * numpy==1.23.4 36 | * scipy==1.9.3 37 | 38 | ## Brief Code Structure 39 | Here is a brief overview of the code structures. The explanations for each directory are enclosed in quotes (##...##). For a more detailed version, please refer to the full version listed at the end of this readme. 40 | ``` 41 | ./ 42 | │ └── README.md 43 | │ ├── History/ ## Training history of pre-trained models ## 44 | │ ├── Models/ ## Pre-trained models ## 45 | │ ├── datasets/ 46 | │ ├── graph_generation/ ## Code and examples for graph generation ## 47 | │ ├── imgs/ ## Images used in readme ## 48 | │ ├── link_prediction/ ## code for link prediction and pre-training ## 49 | │ │ ├── data_handler.py 50 | │ │ ├── main.py 51 | │ │ ├── model.py 52 | │ │ └── params.py 53 | │ │ ├── Utils/ 54 | │ │ │ └── TimeLogger.py 55 | │ ├── node_classification/ ## code for testing on node classification ## 56 | │ │ ├── data_handler.py 57 | │ │ ├── main.py 58 | │ │ ├── model.py 59 | │ │ └── params.py 60 | │ │ ├── Utils/ 61 | │ │ │ └── TimeLogger.py 62 | ``` 63 | 64 | ## Usage 65 | #### To reproduce the test performance reported in the paper, run the following command lines: 66 | ``` 67 | cd link_prediction/ 68 | python main.py --load pretrn_gen1 --epoch 0 # test on OGBL-Collab, ML-1M, ML-10M 69 | python main.py --load pretrn_gen0 --tstdata amazon-book --epoch 0 # test on Amazon-Book 70 | python main.py --load pretrn_gen2 --tstdata ddi --epoch 0 # test on OGBL-ddi 71 | cd ../node_classification/ 72 | python main.py --load pretrn_gen1 --tstdata cora # test on Cora 73 | python main.py --load pretrn_gen1 --tstdata citeseer # test on Citeseer 74 | python main.py --load pretrn_gen1 --tstdata pubmed # test on Pubmed 75 | ``` 76 | 77 | #### To re-pretrain OpenGraph by yourself, run the following command lines: 78 | ``` 79 | cd ../link_prediction/ 80 | python main.py --save pretrn_gen1 81 | python main.py --trndata gen0 --tstdata amazon-book --save pretrn_gen0 82 | python main.py --trndata gen2 --tstdata ddi --save pretrn_gen2 83 | ``` 84 | 85 | #### To explore pretraining with multiple different pre-training and testing datasets, modify `trn_datasets` and `tst_datasets` in line 241 of `link_prediction/main.py`. 86 | 87 | ## Graph Data Generation 88 | The graph generation code is in `graph_generation/`. A toy dataset of small size is given. You need to fill in your OpenAI key in `Utils.py` and `itemCollecting_dfsIterator.py` first. To generate your dataset, modify the `descs` and `hyperparams` dicts, and follow the following procedure: 89 | ``` 90 | cd graph_generation/ 91 | python itemCollecting_dfsIterator.py 92 | python instance_number_estimation_hierarchical.py 93 | python embedding_generation.py 94 | python human_item_generation_gibbsSampling_embedEstimation.py 95 | python make_adjs.py 96 | ``` 97 | 98 | Below shows our prompt template, as well as examples for prompt configurations and generated nodes. 99 | 100 | 101 | 102 | ## Evaluation Results 103 | 104 | ### Overall Generalization Performance 105 | OpenGraph achives best performance under the 0-shot setting, compared to baselines trained/tuned with 1-shot and 5-shot data. 106 | 107 | 108 | ### Pre-training Dataset Study 109 | We studied the influence of using different pre-training datasets. Results below indicate that: 110 | - The generation techniques (Norm, Loc, Topo) have positive effects on performance. 111 | - Real-world datasets (Yelp2018, Gowalla) may yield worse results compared to our generated ones. 112 | - A relevant pre-training dataset (ML-10M for test data ML-1M and ML-10M) results in superior performance. 113 | 114 | 115 | 116 | ### Graph Tokenizer Study 117 | We tuned configurations of our unified graph tokenizer, by adjusting the order of graph smoothing, and replacing our topology-aware projection with alternatives. Our findings include: 118 | - **Adjacency smoothing is important**, as OpenGraph with 0-order smoothing yields inferior performance. 119 | - **Topology-aware projection is superior in performance**. Alternatives include *One-hot* which learns a big and unified representation table for all datasets, *Random* which holds no assumption for the node-wise relations and distributes them uniformly, *Degree* which is a widely-used method for non-attributed graphs and seems applicable for cross-graph scenario. 120 | 121 | 122 | 123 | ### Sampling Techniques Study 124 | We ablated the two sampling techniques in the graph transformer, and show their positive effects on both memory and time costs below. Suprisingly, token sequence sampling shows a positive effect over the model performance. 125 | 126 | 127 | 128 | ## Citation 129 | If you find this work useful for your research, please consider citing our paper: 130 | ``` 131 | @inproceedings{xia2024opengraph, 132 | title={OpenGraph: Towards Open Graph Foundation Models}, 133 | author={Xia, Lianghao and Kao, Ben and Huang, Chao}, 134 | booktitle={EMNLP}, 135 | year={2024} 136 | } 137 | ``` 138 | 139 | ## Detailed Code Structures 140 | ``` 141 | ./ 142 | │ └── README.md 143 | │ ├── History/ ## Training history of pre-trained models ## 144 | │ │ ├── pretrn_gen0.his 145 | │ │ ├── pretrn_gen2.his 146 | │ │ └── pretrn_gen1.his 147 | │ ├── Models/ ## Pre-trained models ## 148 | │ │ └── readme ## Download pre-trained models using the link inside ## 149 | │ ├── datasets/ 150 | │ │ ├── amazon-book/ 151 | │ │ │ ├── fewshot_mat_1.pkl 152 | │ │ │ ├── trn_mat.pkl.zip ## Unzip it manually ## 153 | │ │ │ ├── tst_mat.pkl 154 | │ │ │ └── fewshot_mat_5.pkl 155 | │ │ ├── citeseer/ 156 | │ │ │ ├── adj_-1.pkl 157 | │ │ │ ├── adj_1.pkl 158 | │ │ │ ├── adj_5.pkl 159 | │ │ │ ├── feats.pkl.zip ## Unzip it manually ## 160 | │ │ │ ├── label.pkl 161 | │ │ │ ├── mask_-1.pkl 162 | │ │ │ ├── mask_1.pkl 163 | │ │ │ └── mask_5.pkl 164 | │ │ ├── collab/ 165 | │ │ │ ├── fewshot_mat_5.pkl 166 | │ │ │ ├── trn_mat.pkl.zip ## Unzip it manually ## 167 | │ │ │ ├── tst_mat.pkl 168 | │ │ │ ├── val_mat.pkl 169 | │ │ │ └── fewshot_mat_1.pkl 170 | │ │ ├── cora/ 171 | │ │ │ ├── adj_-1.pkl 172 | │ │ │ ├── adj_1.pkl 173 | │ │ │ ├── adj_5.pkl 174 | │ │ │ ├── feats.pkl 175 | │ │ │ ├── label.pkl 176 | │ │ │ ├── mask_-1.pkl 177 | │ │ │ ├── mask_1.pkl 178 | │ │ │ └── mask_5.pkl 179 | │ │ ├── ddi/ 180 | │ │ │ ├── fewshot_mat_1.pkl 181 | │ │ │ ├── trn_mat.pkl.zip ## Unzip it manually ## 182 | │ │ │ ├── tst_mat.pkl 183 | │ │ │ ├── val_mat.pkl 184 | │ │ │ └── fewshot_mat_5.pkl 185 | │ │ ├── gen0/ 186 | │ │ │ ├── trn_mat.pkl 187 | │ │ │ ├── val_mat.pkl 188 | │ │ │ └── tst_mat.pkl 189 | │ │ ├── gen1/ 190 | │ │ │ ├── trn_mat.pkl 191 | │ │ │ ├── tst_mat.pkl 192 | │ │ │ └── val_mat.pkl 193 | │ │ ├── gen2/ 194 | │ │ │ ├── trn_mat.pkl 195 | │ │ │ ├── val_mat.pkl 196 | │ │ │ └── tst_mat.pkl 197 | │ │ ├── ml10m/ 198 | │ │ │ ├── fewshot_mat_1.pkl 199 | │ │ │ ├── trn_mat.pkl.zip ## Unzip it manually ## 200 | │ │ │ ├── tst_mat.pkl.zip ## Unzip it manually ## 201 | │ │ │ └── fewshot_mat_5.pkl 202 | │ │ ├── ml1m/ 203 | │ │ │ ├── fewshot_mat_5.pkl 204 | │ │ │ ├── trn_mat.pkl 205 | │ │ │ ├── tst_mat.pkl 206 | │ │ │ └── fewshot_mat_1.pkl 207 | │ │ ├── pubmed/ 208 | │ │ │ ├── adj_-1.pkl 209 | │ │ │ ├── adj_1.pkl 210 | │ │ │ ├── feats.pkl.zip ## Unzip it manually ## 211 | │ │ │ ├── label.pkl 212 | │ │ │ ├── mask_-1.pkl 213 | │ │ │ ├── mask_1.pkl 214 | │ │ │ ├── mask_5.pkl 215 | │ │ │ └── adj_5.pkl 216 | │ ├── graph_generation/ ## Code and examples for graph generation ## 217 | │ │ ├── embedding_generation.py ## Node embedding generation ## 218 | │ │ ├── human_item_generation_gibbsSampling_embedEstimation.py ## Edge generation ## 219 | │ │ ├── instance_number_estimation_hierarchical.py ## Estimate amount for each node. Not mentioned in the paper. ## 220 | │ │ ├── itemCollecting_dfsIterator.py ## Node generation ## 221 | │ │ ├── make_adjs.py ## Making datasets for generated gaphs ## 222 | │ │ └── Utils.py 223 | │ │ ├── Exp_Utils/ 224 | │ │ │ ├── Emailer.py ## A tool to send warning email for experiments ## 225 | │ │ │ └── TimeLogger.py 226 | │ │ ├── gen_results/ 227 | │ │ │ ├── tree_wInstanceNum_products_e-commerce platform like Amazon.pkl ## Tree data structure ## 228 | │ │ │ └── products_e-commerce platform like Amazon.txt ## Node list ## 229 | │ │ │ ├── datasets/ 230 | │ │ │ │ ├── gen_data_ecommerce/ 231 | │ │ │ │ │ ├── embedding_dict.pkl 232 | │ │ │ │ │ ├── item_list.pkl 233 | │ │ │ │ │ └── interaction_base-0_iter-0.pkl ## generated edges ## 234 | │ │ │ │ │ ├── res/ 235 | │ │ │ │ │ │ ├── iter-0_imap.pkl ## Id map for nodes ## 236 | │ │ │ │ │ │ ├── iter-0_test.pkl 237 | │ │ │ │ │ │ ├── iter-0_train.pkl 238 | │ │ │ │ │ │ ├── iter-0_valid.pkl 239 | │ │ │ │ │ │ └── interaction_fuse_iter-0.pkl 240 | │ │ │ ├── tem/ ## Temporary files for node generation ## 241 | │ │ │ │ ├── e-commerce platform like Amazon_depth1_products 242 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Automotive 243 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Baby 244 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Beauty 245 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Books 246 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Clothing 247 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Electronics 248 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Handmade 249 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Health and Personal Care 250 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Home Improvement 251 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Industrial and Scientific 252 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Jewelry 253 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Musical Instruments 254 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Office Products 255 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Pet Supplies 256 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Tools and Home Improvement 257 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Toys 258 | │ │ │ │ └── e-commerce platform like Amazon_depth2_products, Sports and Outdoors 259 | │ ├── imgs/ ## Images used in readme ## 260 | │ │ ├── framework.png 261 | │ │ ├── intro.png 262 | │ │ ├── performance.png 263 | │ │ └── article cover.jpg 264 | │ ├── link_prediction/ ## code for link prediction and pre-training ## 265 | │ │ ├── data_handler.py 266 | │ │ ├── main.py 267 | │ │ ├── model.py 268 | │ │ └── params.py 269 | │ │ ├── Utils/ 270 | │ │ │ └── TimeLogger.py 271 | │ ├── node_classification/ ## code for testing on node classification ## 272 | │ │ ├── data_handler.py 273 | │ │ ├── main.py 274 | │ │ ├── model.py 275 | │ │ └── params.py 276 | │ │ ├── Utils/ 277 | │ │ │ └── TimeLogger.py 278 | ``` 279 | -------------------------------------------------------------------------------- /datasets/amazon-book/fewshot_mat_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/amazon-book/fewshot_mat_1.pkl -------------------------------------------------------------------------------- /datasets/amazon-book/fewshot_mat_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/amazon-book/fewshot_mat_5.pkl -------------------------------------------------------------------------------- /datasets/amazon-book/trn_mat.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/amazon-book/trn_mat.pkl.zip -------------------------------------------------------------------------------- /datasets/amazon-book/tst_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/amazon-book/tst_mat.pkl -------------------------------------------------------------------------------- /datasets/citeseer/adj_-1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/adj_-1.pkl -------------------------------------------------------------------------------- /datasets/citeseer/adj_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/adj_1.pkl -------------------------------------------------------------------------------- /datasets/citeseer/adj_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/adj_5.pkl -------------------------------------------------------------------------------- /datasets/citeseer/feats.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/feats.pkl.zip -------------------------------------------------------------------------------- /datasets/citeseer/label.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/label.pkl -------------------------------------------------------------------------------- /datasets/citeseer/mask_-1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/mask_-1.pkl -------------------------------------------------------------------------------- /datasets/citeseer/mask_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/mask_1.pkl -------------------------------------------------------------------------------- /datasets/citeseer/mask_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/mask_5.pkl -------------------------------------------------------------------------------- /datasets/collab/fewshot_mat_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/collab/fewshot_mat_1.pkl -------------------------------------------------------------------------------- /datasets/collab/fewshot_mat_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/collab/fewshot_mat_5.pkl -------------------------------------------------------------------------------- /datasets/collab/trn_mat.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/collab/trn_mat.pkl.zip -------------------------------------------------------------------------------- /datasets/collab/tst_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/collab/tst_mat.pkl -------------------------------------------------------------------------------- /datasets/collab/val_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/collab/val_mat.pkl -------------------------------------------------------------------------------- /datasets/cora/adj_-1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/adj_-1.pkl -------------------------------------------------------------------------------- /datasets/cora/adj_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/adj_1.pkl -------------------------------------------------------------------------------- /datasets/cora/adj_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/adj_5.pkl -------------------------------------------------------------------------------- /datasets/cora/feats.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/feats.pkl -------------------------------------------------------------------------------- /datasets/cora/label.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/label.pkl -------------------------------------------------------------------------------- /datasets/cora/mask_-1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/mask_-1.pkl -------------------------------------------------------------------------------- /datasets/cora/mask_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/mask_1.pkl -------------------------------------------------------------------------------- /datasets/cora/mask_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/mask_5.pkl -------------------------------------------------------------------------------- /datasets/ddi/fewshot_mat_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ddi/fewshot_mat_1.pkl -------------------------------------------------------------------------------- /datasets/ddi/fewshot_mat_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ddi/fewshot_mat_5.pkl -------------------------------------------------------------------------------- /datasets/ddi/trn_mat.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ddi/trn_mat.pkl.zip -------------------------------------------------------------------------------- /datasets/ddi/tst_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ddi/tst_mat.pkl -------------------------------------------------------------------------------- /datasets/ddi/val_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ddi/val_mat.pkl -------------------------------------------------------------------------------- /datasets/gen0/trn_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen0/trn_mat.pkl -------------------------------------------------------------------------------- /datasets/gen0/tst_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen0/tst_mat.pkl -------------------------------------------------------------------------------- /datasets/gen0/val_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen0/val_mat.pkl -------------------------------------------------------------------------------- /datasets/gen1/trn_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen1/trn_mat.pkl -------------------------------------------------------------------------------- /datasets/gen1/tst_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen1/tst_mat.pkl -------------------------------------------------------------------------------- /datasets/gen1/val_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen1/val_mat.pkl -------------------------------------------------------------------------------- /datasets/gen2/trn_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen2/trn_mat.pkl -------------------------------------------------------------------------------- /datasets/gen2/tst_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen2/tst_mat.pkl -------------------------------------------------------------------------------- /datasets/gen2/val_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen2/val_mat.pkl -------------------------------------------------------------------------------- /datasets/ml10m/fewshot_mat_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml10m/fewshot_mat_1.pkl -------------------------------------------------------------------------------- /datasets/ml10m/fewshot_mat_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml10m/fewshot_mat_5.pkl -------------------------------------------------------------------------------- /datasets/ml10m/trn_mat.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml10m/trn_mat.pkl.zip -------------------------------------------------------------------------------- /datasets/ml10m/tst_mat.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml10m/tst_mat.pkl.zip -------------------------------------------------------------------------------- /datasets/ml1m/fewshot_mat_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml1m/fewshot_mat_1.pkl -------------------------------------------------------------------------------- /datasets/ml1m/fewshot_mat_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml1m/fewshot_mat_5.pkl -------------------------------------------------------------------------------- /datasets/ml1m/trn_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml1m/trn_mat.pkl -------------------------------------------------------------------------------- /datasets/ml1m/tst_mat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml1m/tst_mat.pkl -------------------------------------------------------------------------------- /datasets/pubmed/adj_-1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/adj_-1.pkl -------------------------------------------------------------------------------- /datasets/pubmed/adj_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/adj_1.pkl -------------------------------------------------------------------------------- /datasets/pubmed/adj_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/adj_5.pkl -------------------------------------------------------------------------------- /datasets/pubmed/feats.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/feats.pkl.zip -------------------------------------------------------------------------------- /datasets/pubmed/label.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/label.pkl -------------------------------------------------------------------------------- /datasets/pubmed/mask_-1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/mask_-1.pkl -------------------------------------------------------------------------------- /datasets/pubmed/mask_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/mask_1.pkl -------------------------------------------------------------------------------- /datasets/pubmed/mask_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/mask_5.pkl -------------------------------------------------------------------------------- /graph_generation/Exp_Utils/Emailer.py: -------------------------------------------------------------------------------- 1 | import smtplib 2 | from email.mime.text import MIMEText 3 | from email.header import Header 4 | 5 | def SendMail(message, subject=None): 6 | return 7 | default_subject = 'Experimental Anomaly Report' 8 | if subject != None: 9 | default_subject += ':' + subject 10 | subject = default_subject 11 | 12 | sender = 'xxx@xx.com' 13 | receivers = ['xxx@xx.com'] 14 | message = 'Dear Artificial Anomaly Investigator,\n\n' + 'I am writing to bring to your attention that an anomaly occurs and cannot be solved by your exception handler in the recent experiment. Please refer to the following message for details.\n\n' + message + '\n\nBest regards,\nIntelligent Experiment Assistant' 15 | 16 | message = MIMEText(message, 'plain', 'utf-8') 17 | message['Subject'] = Header(subject, 'utf-8') 18 | message['From'] = Header('intel_assistant') 19 | message['To'] = Header('investigator') 20 | 21 | mail_host = 'smtp.xxx.com' 22 | mail_user = 'xxx@xxx.com' 23 | mail_pass = 'xxx' 24 | 25 | smtpObj = smtplib.SMTP() 26 | smtpObj.connect(mail_host, 25) 27 | smtpObj.login(mail_user, mail_pass) 28 | smtpObj.sendmail(sender, receivers, message.as_string()) -------------------------------------------------------------------------------- /graph_generation/Exp_Utils/TimeLogger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | logmsg = '' 4 | timemark = dict() 5 | saveDefault = False 6 | def log(msg, save=None, oneline=False): 7 | global logmsg 8 | global saveDefault 9 | time = datetime.datetime.now() 10 | tem = '%s: %s' % (time, msg) 11 | if save != None: 12 | if save: 13 | logmsg += tem + '\n' 14 | elif saveDefault: 15 | logmsg += tem + '\n' 16 | if oneline: 17 | print(tem, end='\r') 18 | else: 19 | print(tem) 20 | 21 | def marktime(marker): 22 | global timemark 23 | timemark[marker] = datetime.datetime.now() 24 | 25 | 26 | if __name__ == '__main__': 27 | log('') -------------------------------------------------------------------------------- /graph_generation/Utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import openai 3 | import json 4 | import tiktoken 5 | import numpy as np 6 | import Exp_Utils.TimeLogger as logger 7 | from Exp_Utils.TimeLogger import log 8 | from Exp_Utils.Emailer import SendMail 9 | import time 10 | 11 | openai.api_key = "xx-xxxxxx" 12 | 13 | class DataGenAgent: 14 | def __init__(self): 15 | super(DataGenAgent, self).__init__() 16 | self.token_num = 0 17 | self.encoding = tiktoken.encoding_for_model('gpt-3.5-turbo') 18 | 19 | def openai_embedding(self, message): 20 | try: 21 | embedding = openai.Embedding.create( 22 | model='text-embedding-ada-002', 23 | input = message 24 | )['data'][0]['embedding'] 25 | # time.sleep() 26 | return np.array(embedding) 27 | except Exception as e: 28 | print('OpenAI request error: {err_msg}. Retry in 10 seconds.'.format(err_msg=e)) 29 | time.sleep(10) 30 | return self.openai_embedding(message) 31 | 32 | def openai(self, message): 33 | try: 34 | completion = openai.ChatCompletion.create( 35 | model='gpt-3.5-turbo-1106', 36 | # model='gpt-4', 37 | messages=[ 38 | {"role": "user", "content": message}, 39 | ] 40 | ) 41 | response = completion.choices[0].message.content 42 | time.sleep(1) 43 | self.token_num += len(self.encoding.encode(json.dumps(message))) 44 | return response 45 | except Exception as e: 46 | print('OpenAI request error: {err_msg}. Retry in 10 seconds.'.format(err_msg=e)) 47 | time.sleep(10) 48 | return self.openai(message) 49 | 50 | def handling_llm_exceptions(self, message, interpret_func, interpret_args, failure_tolerance): 51 | try: 52 | answers_text = self.openai(message) 53 | print('Answers text:') 54 | print(answers_text) 55 | print('----------\n') 56 | return 0, interpret_func(answers_text, *interpret_args) 57 | except Exception as e: 58 | self.failure += 1 59 | log('\n**********\nERROR\n') 60 | log('Exception occurs when interpreting. Exception message: {exception}'.format(exception=e), save=True) 61 | log('Failure times: {failure}'.format(failure=self.failure, save=True)) 62 | log('Prompt text:\n{prompt}'.format(prompt=message), save=True) 63 | log('Response text:\n{response}'.format(response=answers_text), save=True) 64 | if self.failure < failure_tolerance: 65 | log('Retry in 10 seconds.', save=True) 66 | time.sleep(10) 67 | log('\n**********\n') 68 | return 1, None 69 | else: 70 | log('Reaching maximum failure tolerance. CANNOT HANDLE!'.format(failure=self.failure), save=True) 71 | log('Sending report email.', save=True) 72 | SendMail(logger.logmsg) 73 | logger.logmsg = '' 74 | log('\n**********\n') 75 | return 2, None 76 | 77 | class EntityTreeNode: 78 | def __init__(self, entity_name, depth, parent=None): 79 | self.entity_name = entity_name 80 | self.frequency = [] 81 | self.quantity = -1 82 | self.children = dict() 83 | self.parent = parent 84 | self.depth = depth 85 | 86 | def is_child(self, entity_name): 87 | return entity_name in self.children 88 | 89 | def to_child(self, entity_name): 90 | return self.children[entity_name] 91 | 92 | def add_child(self, entity_name): 93 | child = EntityTreeNode(entity_name, self.depth+1, self) 94 | self.children[entity_name] = child 95 | 96 | def iterate_children(self): 97 | for key, node in self.children.items(): 98 | yield key, node 99 | 100 | def allocate_number(self, quantity): 101 | print('Allocating depth {depth} {entity_name}, quantity: {quantity}'.format(depth=self.depth, entity_name=self.entity_name, quantity=quantity)) 102 | self.quantity = quantity 103 | if len(self.children) == 0: 104 | return 105 | child_list = list(self.children.values()) 106 | child_freq = list(map(lambda x: x.frequency, child_list)) 107 | child_freq = np.array(child_freq) # N * T 108 | if child_freq.shape[1] == 0: 109 | raise Exception('No estimated frequency for children.') 110 | summ = np.sum(child_freq, axis=0, keepdims=True) # 1 * T 111 | child_freq = child_freq / summ # N * T 112 | child_num = np.mean(child_freq, axis=1) * self.quantity # N 113 | for i, child in enumerate(child_list): 114 | child.allocate_number(child_num[i]) 115 | 116 | def get_list_of_leaves(self, entity_name, with_branches=False): 117 | if len(self.children) == 0: 118 | num = max(1, int(self.quantity)) 119 | entity_list = list() 120 | cur_entity_name = entity_name + ', ' + self.entity_name 121 | for i in range(num): 122 | # entity_list.append(cur_entity_name + ' #{idx}'.format(idx=i)) 123 | entity_list.append(self.entity_name + ' #{idx}'.format(idx=i)) 124 | return entity_list 125 | entity_list = list() 126 | if with_branches: 127 | if self.depth <= 2: 128 | tem_entity_name = self.entity_name 129 | else: 130 | tem_entity_name = entity_name + ', ' + self.entity_name 131 | entity_list.append(tem_entity_name) 132 | for _, child in self.iterate_children(): 133 | nxt_entity_name = self.entity_name if self.depth <= 2 else (entity_name + ', ' + self.entity_name) 134 | entities = child.get_list_of_leaves(nxt_entity_name, with_branches) 135 | entity_list = entity_list + entities 136 | return entity_list 137 | 138 | class EntityTreeConstructer: 139 | def __init__(self, entity_lines): 140 | super(EntityTreeConstructer, self).__init__() 141 | 142 | root_name = self.line_process(entity_lines[0])[0] 143 | self.root = EntityTreeNode(root_name, depth=1) 144 | self.root.frequency.append(1.0) 145 | self.construct_tree(entity_lines) 146 | 147 | def add_node(self, cur_node, descriptions, cur): 148 | parent_entity_name = descriptions[cur-1] 149 | if cur_node.entity_name != parent_entity_name: 150 | print(cur_node.entity_name, parent_entity_name) 151 | print(descriptions) 152 | assert cur_node.entity_name == parent_entity_name 153 | cur_entity_name = descriptions[cur] 154 | if not cur_node.is_child(cur_entity_name): 155 | cur_node.add_child(cur_entity_name) 156 | if cur + 1 < len(descriptions): 157 | self.add_node(cur_node.to_child(cur_entity_name), descriptions, cur+1) 158 | 159 | def line_process(self, entity_line, check=False): 160 | entity_line = entity_line.strip() 161 | descriptions = list(map(lambda x: x.strip(), entity_line.split(','))) 162 | if not check: 163 | return descriptions 164 | if descriptions[0] != self.root.entity_name: 165 | raise Exception('Cannot find root') 166 | if len(descriptions) <= 1: 167 | raise Exception('Fail to split') 168 | return descriptions 169 | 170 | def construct_tree(self, entity_lines): 171 | for entity_line in entity_lines: 172 | try: 173 | descriptions = self.line_process(entity_line, check=True) 174 | except Exception as e: 175 | print(str(e), ':', entity_line) 176 | continue 177 | self.add_node(self.root, descriptions, cur=1) -------------------------------------------------------------------------------- /graph_generation/embedding_generation.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | from Utils import DataGenAgent 4 | from Exp_Utils.TimeLogger import log 5 | 6 | def load_item_list(item_file, entity_file, item_num): 7 | if not os.path.exists(item_file): 8 | with open(entity_file, 'rb') as fs: 9 | entity_tree_root = pickle.load(fs) 10 | entity_tree_root.allocate_number(item_num) 11 | item_list = entity_tree_root.get_list_of_leaves('') 12 | with open(item_file, 'wb+') as fs: 13 | pickle.dump(item_list, fs) 14 | else: 15 | with open(item_file, 'rb') as fs: 16 | item_list = pickle.load(fs) 17 | return item_list 18 | 19 | descs = { 20 | 'data_name': 'gen_data_venues', 21 | 'scenario_desc': 'venue rating platform like yelp', 22 | 'human_role': 'user', 23 | 'interaction_verb': 'interact', 24 | 'initial_entity': 'business venues', 25 | } 26 | descs = { 27 | 'data_name': 'gen_data_books', 28 | 'scenario_desc': 'book rating platform', 29 | 'human_role': 'user', 30 | 'interaction_verb': 'interact', 31 | 'initial_entity': 'books', 32 | } 33 | descs = { 34 | 'data_name': 'gen_data_ai_papers', 35 | 'scenario_desc': 'published paper list of top AI conferences', 36 | 'human_role': 'user', 37 | 'interaction_verb': 'interact', 38 | 'initial_entity': 'deep learning papers', 39 | } 40 | descs = { 41 | 'data_name': 'gen_data_ecommerce', 42 | 'scenario_desc': 'e-commerce platform like Amazon', 43 | 'human_role': 'user', 44 | 'interaction_verb': 'interact', 45 | 'initial_entity': 'products', 46 | } 47 | file_root = 'gen_results/datasets/{data_name}/'.format(data_name=descs['data_name']) 48 | entity_file = 'gen_results/tree_wInstanceNum_{initial_entity}_{scenario}.pkl'.format(initial_entity=descs['initial_entity'], scenario=descs['scenario_desc']) 49 | embed_file = file_root + 'embedding_dict.pkl' 50 | 51 | with open(entity_file, 'rb') as fs: 52 | entity_tree_root = pickle.load(fs) 53 | entity_tree_root.allocate_number(1) 54 | item_list = entity_tree_root.get_list_of_leaves('', with_branches=True) 55 | item_list = list(map(lambda item_name: item_name if ' #' not in item_name else item_name[:item_name.index(' #')], item_list)) 56 | print(item_list) 57 | print('Num of items', len(item_list)) 58 | agent = DataGenAgent() 59 | embedding_dict = dict() 60 | for i, item in enumerate(item_list): 61 | log('{idx} / {tot}'.format(idx=i, tot=len(item_list))) 62 | embedding = agent.openai_embedding(item) 63 | embedding_dict[item] = embedding 64 | with open(embed_file, 'wb') as fs: 65 | pickle.dump(embedding_dict, fs) -------------------------------------------------------------------------------- /graph_generation/gen_results/datasets/gen_data_ecommerce/embedding_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/embedding_dict.pkl -------------------------------------------------------------------------------- /graph_generation/gen_results/datasets/gen_data_ecommerce/interaction_base-0_iter-0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/interaction_base-0_iter-0.pkl -------------------------------------------------------------------------------- /graph_generation/gen_results/datasets/gen_data_ecommerce/item_list.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/item_list.pkl -------------------------------------------------------------------------------- /graph_generation/gen_results/datasets/gen_data_ecommerce/res/interaction_fuse_iter-0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/res/interaction_fuse_iter-0.pkl -------------------------------------------------------------------------------- /graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_imap.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_imap.pkl -------------------------------------------------------------------------------- /graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_test.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_test.pkl -------------------------------------------------------------------------------- /graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_train.pkl -------------------------------------------------------------------------------- /graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_valid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_valid.pkl -------------------------------------------------------------------------------- /graph_generation/gen_results/products_e-commerce platform like Amazon.txt: -------------------------------------------------------------------------------- 1 | products, Electronics, Headphones 2 | products, Electronics, Speakers 3 | products, Electronics, Cameras 4 | products, Electronics, Smartphones 5 | products, Electronics, Smartwatches 6 | products, Electronics, Tablets 7 | products, Electronics, Laptops 8 | products, Electronics, Desktop Computers 9 | products, Electronics, Home Security Systems 10 | products, Electronics, Drones 11 | products, Electronics, Gaming Consoles 12 | products, Electronics, Virtual Reality Headsets 13 | products, Electronics, GPS Devices 14 | products, Electronics, E-book Readers 15 | products, Clothing, Women's clothing 16 | products, Clothing, Men's clothing 17 | products, Clothing, Kid's clothing 18 | products, Clothing, Shoes 19 | products, Clothing, Accessories 20 | products, Books, Fiction 21 | products, Books, Non-fiction 22 | products, Books, Mystery 23 | products, Books, Science Fiction 24 | products, Books, Romance 25 | products, Books, Fantasy 26 | products, Books, Thriller 27 | products, Books, Biography 28 | products, Books, History 29 | products, Books, Children's Books 30 | products, Books, Self-help 31 | products, Books, Cookbook 32 | products, Books, Art 33 | products, Books, Travel 34 | products, Books, Science 35 | products, Books, Religion 36 | products, Books, Reference 37 | products, Books, Poetry 38 | products, Books, Graphic Novels 39 | products, Books, Horror 40 | products, Books, Young Adult 41 | products, Beauty, Makeup 42 | products, Beauty, Skincare 43 | products, Beauty, Haircare 44 | products, Beauty, Fragrance 45 | products, Beauty, Bath & Body 46 | products, Beauty, Tools & Accessories 47 | products, Beauty, Men's Grooming 48 | products, Beauty, Beauty Supplements 49 | products, Home Improvement, Power tools 50 | products, Home Improvement, Hand tools 51 | products, Home Improvement, Electrical 52 | products, Home Improvement, Hardware 53 | products, Home Improvement, Plumbing 54 | products, Home Improvement, Building supplies 55 | products, Home Improvement, Paint 56 | products, Home Improvement, Heating 57 | products, Home Improvement, Cooling 58 | products, Home Improvement, Home security 59 | products, Home Improvement, Lighting 60 | products, Home Improvement, Storage 61 | products, Home Improvement, Cleaning supplies 62 | products, Home Improvement, Appliances 63 | products, Toys, Action Figures 64 | products, Toys, Dolls 65 | products, Toys, Stuffed Animals 66 | products, Toys, Building Sets 67 | products, Toys, Arts & Crafts 68 | products, Toys, Vehicles 69 | products, Toys, Puzzles 70 | products, Toys, Outdoor Play 71 | products, Toys, Learning & Education 72 | products, Toys, Games 73 | products, Toys, Play Food & Grocery 74 | products, Toys, Electronic Toys 75 | products, Toys, Musical Instruments 76 | products, Automotive, Car parts 77 | products, Automotive, Car accessories 78 | products, Automotive, Car care products 79 | products, Automotive, Car electronics 80 | products, Automotive, Car interior accessories 81 | products, Automotive, Car exterior accessories 82 | products, Automotive, Car performance parts 83 | products, Automotive, Motorcycle parts 84 | products, Automotive, Motorcycle accessories 85 | products, Sports and Outdoors, Camping and Hiking 86 | products, Sports and Outdoors, Exercise and Fitness 87 | products, Sports and Outdoors, Golf 88 | products, Sports and Outdoors, Hunting and Fishing 89 | products, Sports and Outdoors, Outdoor Recreation 90 | products, Sports and Outdoors, Team Sports 91 | products, Sports and Outdoors, Water Sports 92 | products, Sports and Outdoors, Winter Sports 93 | products, Health and Personal Care, Vitamins 94 | products, Health and Personal Care, Supplements 95 | products, Health and Personal Care, Health Care 96 | products, Health and Personal Care, Baby & Child Care 97 | products, Health and Personal Care, Sports Nutrition 98 | products, Health and Personal Care, Household Supplies 99 | products, Health and Personal Care, Medical Supplies & Equipment 100 | products, Health and Personal Care, Sexual Wellness 101 | products, Health and Personal Care, Health & Wellness 102 | products, Health and Personal Care, Oral Care 103 | products, Health and Personal Care, Beauty & Grooming 104 | products, Health and Personal Care, Pet Supplies 105 | products, Health and Personal Care, Health & Household 106 | products, Health and Personal Care, Baby & Child Care 107 | products, Health and Personal Care, Health Supplies 108 | products, Jewelry, Rings 109 | products, Jewelry, Necklaces 110 | products, Jewelry, Bracelets 111 | products, Jewelry, Earrings 112 | products, Jewelry, Pendants 113 | products, Jewelry, Brooches 114 | products, Jewelry, Charms 115 | products, Jewelry, Anklets 116 | products, Jewelry, Jewelry Sets 117 | products, Jewelry, Body Jewelry 118 | products, Jewelry, Hair Jewelry 119 | products, Jewelry, Toe Rings 120 | products, Jewelry, Cufflinks 121 | products, Jewelry, Tie Clips 122 | products, Pet Supplies, Dog food 123 | products, Pet Supplies, Cat food 124 | products, Pet Supplies, Dog treats 125 | products, Pet Supplies, Cat treats 126 | products, Pet Supplies, Cat litter 127 | products, Pet Supplies, Dog grooming supplies 128 | products, Pet Supplies, Cat grooming supplies 129 | products, Pet Supplies, Dog toys 130 | products, Pet Supplies, Cat toys 131 | products, Pet Supplies, Dog collars and leashes 132 | products, Pet Supplies, Cat collars and leashes 133 | products, Pet Supplies, Fish supplies 134 | products, Pet Supplies, Small animal supplies 135 | products, Pet Supplies, Bird supplies 136 | products, Baby, Baby girls' clothing 137 | products, Baby, baby boys' clothing 138 | products, Baby, baby shoes 139 | products, Baby, baby accessories 140 | products, Baby, baby gear 141 | products, Baby, baby feeding 142 | products, Baby, baby bathing 143 | products, Baby, baby nursery 144 | products, Baby, baby toys 145 | products, Office Products, Office furniture 146 | products, Office Products, Office lighting 147 | products, Office Products, Desk accessories 148 | products, Office Products, Filing products 149 | products, Office Products, Office electronics 150 | products, Office Products, Paper products 151 | products, Office Products, Writing instruments 152 | products, Office Products, Presentation supplies 153 | products, Office Products, Office storage 154 | products, Office Products, Office maintenance 155 | products, Office Products, Office organization 156 | products, Tools and Home Improvement, Power Tools 157 | products, Tools and Home Improvement, Hand Tools 158 | products, Tools and Home Improvement, Tool Organizers 159 | products, Tools and Home Improvement, Tool Sets 160 | products, Tools and Home Improvement, Measuring and Layout Tools 161 | products, Tools and Home Improvement, Welding and Soldering 162 | products, Tools and Home Improvement, Tool Storage 163 | products, Tools and Home Improvement, Material Handling 164 | products, Tools and Home Improvement, Jobsite Safety 165 | products, Tools and Home Improvement, Flashlights 166 | products, Tools and Home Improvement, Jobsite Radios 167 | products, Tools and Home Improvement, Welding Equipment 168 | products, Tools and Home Improvement, Woodworking 169 | products, Tools and Home Improvement, Workshop Equipment 170 | products, Tools and Home Improvement, Air Tools 171 | products, Tools and Home Improvement, Welding and Soldering 172 | products, Tools and Home Improvement, Power Finishing Tools 173 | products, Tools and Home Improvement, Fireplace and Stove Accessories 174 | products, Tools and Home Improvement, Abrasive and Finishing Products 175 | products, Tools and Home Improvement, Building Supplies 176 | products, Tools and Home Improvement, Cleaning Supplies 177 | products, Tools and Home Improvement, Electrical 178 | products, Tools and Home Improvement, Hardware 179 | products, Tools and Home Improvement, Kitchen and Bath Fixtures 180 | products, Tools and Home Improvement, Light Bulbs 181 | products, Tools and Home Improvement, Lighting and Ceiling Fans 182 | products, Tools and Home Improvement, Measuring and Layout Tools 183 | products, Tools and Home Improvement, Paint 184 | products, Tools and Home Improvement, Wall Treatments and Supplies 185 | products, Tools and Home Improvement, Power and Hand Tools 186 | products, Tools and Home Improvement, Rough Plumbing 187 | products, Tools and Home Improvement, Safety and Security 188 | products, Tools and Home Improvement, Storage and Home Organization 189 | products, Tools and Home Improvement, Winterization products 190 | Outdoor Power Tools 191 | products, Tools and Home Improvement, Hand Tools 192 | products, Tools and Home Improvement, Power Tools 193 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments 194 | products, Tools and Home Improvement, Electrical 195 | products, Tools and Home Improvement, Hardware 196 | products, Tools and Home Improvement, Kitchen and Bath Fixtures 197 | products, Tools and Home Improvement, Light Bulbs 198 | products, Tools and Home Improvement, Light Fixtures 199 | products, Tools and Home Improvement, Measuring and Layout Tools 200 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments 201 | products, Tools and Home Improvement, Power and Hand Tools 202 | products, Tools and Home Improvement, Rough Plumbing 203 | products, Tools and Home Improvement, Safety and Security 204 | products, Tools and Home Improvement, Storage and Home Organization 205 | products, Tools and Home Improvement, Building Supplies 206 | products, Tools and Home Improvement, Rough Plumbing 207 | products, Tools and Home Improvement, Building Materials 208 | products, Tools and Home Improvement, Material Handling 209 | products, Tools and Home Improvement, Measuring and Layout Tools 210 | products, Tools and Home Improvement, Jobsite Safety 211 | products, Tools and Home Improvement, Work Wear and Safety Gear 212 | products, Tools and Home Improvement, Air Tools & Compressors 213 | products, Musical Instruments, Amplifiers & Effects 214 | products, Musical Instruments, Band & Orchestra 215 | products, Musical Instruments, Drums & Percussion 216 | products, Musical Instruments, Folk & World Instruments 217 | products, Musical Instruments, Guitars 218 | products, Musical Instruments, Keyboards 219 | products, Musical Instruments, Live Sound & Stage 220 | products, Musical Instruments, Microphones & Accessories 221 | products, Musical Instruments, Recording Equipment 222 | products, Industrial and Scientific, Industrial Electrical 223 | products, Industrial and Scientific, Occupational Health & Safety Products 224 | products, Industrial and Scientific, Test 225 | products, Industrial and Scientific, Measure & Inspect 226 | products, Industrial and Scientific, Abrasive & Finishing Products 227 | products, Industrial and Scientific, Janitorial & Sanitation Supplies 228 | products, Industrial and Scientific, Industrial Hardware 229 | products, Handmade, Jewelry 230 | products, Handmade, Clothing 231 | products, Handmade, Home decor 232 | products, Handmade, Art 233 | products, Handmade, Stationery 234 | products, Handmade, Toys 235 | products, Handmade, Beauty products 236 | products, Handmade, Accessories 237 | products, Handmade, Pet supplies 238 | products, Handmade, Craft supplies 239 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth1_products: -------------------------------------------------------------------------------- 1 | products, Electronics, Headphones 2 | products, Electronics, Speakers 3 | products, Electronics, Cameras 4 | products, Electronics, Smartphones 5 | products, Electronics, Smartwatches 6 | products, Electronics, Tablets 7 | products, Electronics, Laptops 8 | products, Electronics, Desktop Computers 9 | products, Electronics, Home Security Systems 10 | products, Electronics, Drones 11 | products, Electronics, Gaming Consoles 12 | products, Electronics, Virtual Reality Headsets 13 | products, Electronics, GPS Devices 14 | products, Electronics, E-book Readers 15 | products, Clothing, Women's clothing 16 | products, Clothing, Men's clothing 17 | products, Clothing, Kid's clothing 18 | products, Clothing, Shoes 19 | products, Clothing, Accessories 20 | products, Books, Fiction 21 | products, Books, Non-fiction 22 | products, Books, Mystery 23 | products, Books, Science Fiction 24 | products, Books, Romance 25 | products, Books, Fantasy 26 | products, Books, Thriller 27 | products, Books, Biography 28 | products, Books, History 29 | products, Books, Children's Books 30 | products, Books, Self-help 31 | products, Books, Cookbook 32 | products, Books, Art 33 | products, Books, Travel 34 | products, Books, Science 35 | products, Books, Religion 36 | products, Books, Reference 37 | products, Books, Poetry 38 | products, Books, Graphic Novels 39 | products, Books, Horror 40 | products, Books, Young Adult 41 | products, Beauty, Makeup 42 | products, Beauty, Skincare 43 | products, Beauty, Haircare 44 | products, Beauty, Fragrance 45 | products, Beauty, Bath & Body 46 | products, Beauty, Tools & Accessories 47 | products, Beauty, Men's Grooming 48 | products, Beauty, Beauty Supplements 49 | products, Home Improvement, Power tools 50 | products, Home Improvement, Hand tools 51 | products, Home Improvement, Electrical 52 | products, Home Improvement, Hardware 53 | products, Home Improvement, Plumbing 54 | products, Home Improvement, Building supplies 55 | products, Home Improvement, Paint 56 | products, Home Improvement, Heating 57 | products, Home Improvement, Cooling 58 | products, Home Improvement, Home security 59 | products, Home Improvement, Lighting 60 | products, Home Improvement, Storage 61 | products, Home Improvement, Cleaning supplies 62 | products, Home Improvement, Appliances 63 | products, Toys, Action Figures 64 | products, Toys, Dolls 65 | products, Toys, Stuffed Animals 66 | products, Toys, Building Sets 67 | products, Toys, Arts & Crafts 68 | products, Toys, Vehicles 69 | products, Toys, Puzzles 70 | products, Toys, Outdoor Play 71 | products, Toys, Learning & Education 72 | products, Toys, Games 73 | products, Toys, Play Food & Grocery 74 | products, Toys, Electronic Toys 75 | products, Toys, Musical Instruments 76 | products, Automotive, Car parts 77 | products, Automotive, Car accessories 78 | products, Automotive, Car care products 79 | products, Automotive, Car electronics 80 | products, Automotive, Car interior accessories 81 | products, Automotive, Car exterior accessories 82 | products, Automotive, Car performance parts 83 | products, Automotive, Motorcycle parts 84 | products, Automotive, Motorcycle accessories 85 | products, Sports and Outdoors, Camping and Hiking 86 | products, Sports and Outdoors, Exercise and Fitness 87 | products, Sports and Outdoors, Golf 88 | products, Sports and Outdoors, Hunting and Fishing 89 | products, Sports and Outdoors, Outdoor Recreation 90 | products, Sports and Outdoors, Team Sports 91 | products, Sports and Outdoors, Water Sports 92 | products, Sports and Outdoors, Winter Sports 93 | products, Health and Personal Care, Vitamins 94 | products, Health and Personal Care, Supplements 95 | products, Health and Personal Care, Health Care 96 | products, Health and Personal Care, Baby & Child Care 97 | products, Health and Personal Care, Sports Nutrition 98 | products, Health and Personal Care, Household Supplies 99 | products, Health and Personal Care, Medical Supplies & Equipment 100 | products, Health and Personal Care, Sexual Wellness 101 | products, Health and Personal Care, Health & Wellness 102 | products, Health and Personal Care, Oral Care 103 | products, Health and Personal Care, Beauty & Grooming 104 | products, Health and Personal Care, Pet Supplies 105 | products, Health and Personal Care, Health & Household 106 | products, Health and Personal Care, Baby & Child Care 107 | products, Health and Personal Care, Health Supplies 108 | products, Jewelry, Rings 109 | products, Jewelry, Necklaces 110 | products, Jewelry, Bracelets 111 | products, Jewelry, Earrings 112 | products, Jewelry, Pendants 113 | products, Jewelry, Brooches 114 | products, Jewelry, Charms 115 | products, Jewelry, Anklets 116 | products, Jewelry, Jewelry Sets 117 | products, Jewelry, Body Jewelry 118 | products, Jewelry, Hair Jewelry 119 | products, Jewelry, Toe Rings 120 | products, Jewelry, Cufflinks 121 | products, Jewelry, Tie Clips 122 | products, Pet Supplies, Dog food 123 | products, Pet Supplies, Cat food 124 | products, Pet Supplies, Dog treats 125 | products, Pet Supplies, Cat treats 126 | products, Pet Supplies, Cat litter 127 | products, Pet Supplies, Dog grooming supplies 128 | products, Pet Supplies, Cat grooming supplies 129 | products, Pet Supplies, Dog toys 130 | products, Pet Supplies, Cat toys 131 | products, Pet Supplies, Dog collars and leashes 132 | products, Pet Supplies, Cat collars and leashes 133 | products, Pet Supplies, Fish supplies 134 | products, Pet Supplies, Small animal supplies 135 | products, Pet Supplies, Bird supplies 136 | products, Baby, Baby girls' clothing 137 | products, Baby, baby boys' clothing 138 | products, Baby, baby shoes 139 | products, Baby, baby accessories 140 | products, Baby, baby gear 141 | products, Baby, baby feeding 142 | products, Baby, baby bathing 143 | products, Baby, baby nursery 144 | products, Baby, baby toys 145 | products, Office Products, Office furniture 146 | products, Office Products, Office lighting 147 | products, Office Products, Desk accessories 148 | products, Office Products, Filing products 149 | products, Office Products, Office electronics 150 | products, Office Products, Paper products 151 | products, Office Products, Writing instruments 152 | products, Office Products, Presentation supplies 153 | products, Office Products, Office storage 154 | products, Office Products, Office maintenance 155 | products, Office Products, Office organization 156 | products, Tools and Home Improvement, Power Tools 157 | products, Tools and Home Improvement, Hand Tools 158 | products, Tools and Home Improvement, Tool Organizers 159 | products, Tools and Home Improvement, Tool Sets 160 | products, Tools and Home Improvement, Measuring and Layout Tools 161 | products, Tools and Home Improvement, Welding and Soldering 162 | products, Tools and Home Improvement, Tool Storage 163 | products, Tools and Home Improvement, Material Handling 164 | products, Tools and Home Improvement, Jobsite Safety 165 | products, Tools and Home Improvement, Flashlights 166 | products, Tools and Home Improvement, Jobsite Radios 167 | products, Tools and Home Improvement, Welding Equipment 168 | products, Tools and Home Improvement, Woodworking 169 | products, Tools and Home Improvement, Workshop Equipment 170 | products, Tools and Home Improvement, Air Tools 171 | products, Tools and Home Improvement, Welding and Soldering 172 | products, Tools and Home Improvement, Power Finishing Tools 173 | products, Tools and Home Improvement, Fireplace and Stove Accessories 174 | products, Tools and Home Improvement, Abrasive and Finishing Products 175 | products, Tools and Home Improvement, Building Supplies 176 | products, Tools and Home Improvement, Cleaning Supplies 177 | products, Tools and Home Improvement, Electrical 178 | products, Tools and Home Improvement, Hardware 179 | products, Tools and Home Improvement, Kitchen and Bath Fixtures 180 | products, Tools and Home Improvement, Light Bulbs 181 | products, Tools and Home Improvement, Lighting and Ceiling Fans 182 | products, Tools and Home Improvement, Measuring and Layout Tools 183 | products, Tools and Home Improvement, Paint 184 | products, Tools and Home Improvement, Wall Treatments and Supplies 185 | products, Tools and Home Improvement, Power and Hand Tools 186 | products, Tools and Home Improvement, Rough Plumbing 187 | products, Tools and Home Improvement, Safety and Security 188 | products, Tools and Home Improvement, Storage and Home Organization 189 | products, Tools and Home Improvement, Winterization products 190 | Outdoor Power Tools 191 | products, Tools and Home Improvement, Hand Tools 192 | products, Tools and Home Improvement, Power Tools 193 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments 194 | products, Tools and Home Improvement, Electrical 195 | products, Tools and Home Improvement, Hardware 196 | products, Tools and Home Improvement, Kitchen and Bath Fixtures 197 | products, Tools and Home Improvement, Light Bulbs 198 | products, Tools and Home Improvement, Light Fixtures 199 | products, Tools and Home Improvement, Measuring and Layout Tools 200 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments 201 | products, Tools and Home Improvement, Power and Hand Tools 202 | products, Tools and Home Improvement, Rough Plumbing 203 | products, Tools and Home Improvement, Safety and Security 204 | products, Tools and Home Improvement, Storage and Home Organization 205 | products, Tools and Home Improvement, Building Supplies 206 | products, Tools and Home Improvement, Rough Plumbing 207 | products, Tools and Home Improvement, Building Materials 208 | products, Tools and Home Improvement, Material Handling 209 | products, Tools and Home Improvement, Measuring and Layout Tools 210 | products, Tools and Home Improvement, Jobsite Safety 211 | products, Tools and Home Improvement, Work Wear and Safety Gear 212 | products, Tools and Home Improvement, Air Tools & Compressors 213 | products, Musical Instruments, Amplifiers & Effects 214 | products, Musical Instruments, Band & Orchestra 215 | products, Musical Instruments, Drums & Percussion 216 | products, Musical Instruments, Folk & World Instruments 217 | products, Musical Instruments, Guitars 218 | products, Musical Instruments, Keyboards 219 | products, Musical Instruments, Live Sound & Stage 220 | products, Musical Instruments, Microphones & Accessories 221 | products, Musical Instruments, Recording Equipment 222 | products, Industrial and Scientific, Industrial Electrical 223 | products, Industrial and Scientific, Occupational Health & Safety Products 224 | products, Industrial and Scientific, Test 225 | products, Industrial and Scientific, Measure & Inspect 226 | products, Industrial and Scientific, Abrasive & Finishing Products 227 | products, Industrial and Scientific, Janitorial & Sanitation Supplies 228 | products, Industrial and Scientific, Industrial Hardware 229 | products, Handmade, Jewelry 230 | products, Handmade, Clothing 231 | products, Handmade, Home decor 232 | products, Handmade, Art 233 | products, Handmade, Stationery 234 | products, Handmade, Toys 235 | products, Handmade, Beauty products 236 | products, Handmade, Accessories 237 | products, Handmade, Pet supplies 238 | products, Handmade, Craft supplies 239 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Automotive: -------------------------------------------------------------------------------- 1 | products, Automotive, Car parts 2 | products, Automotive, Car accessories 3 | products, Automotive, Car care products 4 | products, Automotive, Car electronics 5 | products, Automotive, Car interior accessories 6 | products, Automotive, Car exterior accessories 7 | products, Automotive, Car performance parts 8 | products, Automotive, Motorcycle parts 9 | products, Automotive, Motorcycle accessories 10 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Baby: -------------------------------------------------------------------------------- 1 | products, Baby, Baby girls' clothing 2 | products, Baby, baby boys' clothing 3 | products, Baby, baby shoes 4 | products, Baby, baby accessories 5 | products, Baby, baby gear 6 | products, Baby, baby feeding 7 | products, Baby, baby bathing 8 | products, Baby, baby nursery 9 | products, Baby, baby toys 10 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Beauty: -------------------------------------------------------------------------------- 1 | products, Beauty, Makeup 2 | products, Beauty, Skincare 3 | products, Beauty, Haircare 4 | products, Beauty, Fragrance 5 | products, Beauty, Bath & Body 6 | products, Beauty, Tools & Accessories 7 | products, Beauty, Men's Grooming 8 | products, Beauty, Beauty Supplements 9 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Books: -------------------------------------------------------------------------------- 1 | products, Books, Fiction 2 | products, Books, Non-fiction 3 | products, Books, Mystery 4 | products, Books, Science Fiction 5 | products, Books, Romance 6 | products, Books, Fantasy 7 | products, Books, Thriller 8 | products, Books, Biography 9 | products, Books, History 10 | products, Books, Children's Books 11 | products, Books, Self-help 12 | products, Books, Cookbook 13 | products, Books, Art 14 | products, Books, Travel 15 | products, Books, Science 16 | products, Books, Religion 17 | products, Books, Reference 18 | products, Books, Poetry 19 | products, Books, Graphic Novels 20 | products, Books, Horror 21 | products, Books, Young Adult 22 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Clothing: -------------------------------------------------------------------------------- 1 | products, Clothing, Women's clothing 2 | products, Clothing, Men's clothing 3 | products, Clothing, Kid's clothing 4 | products, Clothing, Shoes 5 | products, Clothing, Accessories 6 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Electronics: -------------------------------------------------------------------------------- 1 | products, Electronics, Headphones 2 | products, Electronics, Speakers 3 | products, Electronics, Cameras 4 | products, Electronics, Smartphones 5 | products, Electronics, Smartwatches 6 | products, Electronics, Tablets 7 | products, Electronics, Laptops 8 | products, Electronics, Desktop Computers 9 | products, Electronics, Home Security Systems 10 | products, Electronics, Drones 11 | products, Electronics, Gaming Consoles 12 | products, Electronics, Virtual Reality Headsets 13 | products, Electronics, GPS Devices 14 | products, Electronics, E-book Readers 15 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Handmade: -------------------------------------------------------------------------------- 1 | products, Handmade, Jewelry 2 | products, Handmade, Clothing 3 | products, Handmade, Home decor 4 | products, Handmade, Art 5 | products, Handmade, Stationery 6 | products, Handmade, Toys 7 | products, Handmade, Beauty products 8 | products, Handmade, Accessories 9 | products, Handmade, Pet supplies 10 | products, Handmade, Craft supplies 11 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Health and Personal Care: -------------------------------------------------------------------------------- 1 | products, Health and Personal Care, Vitamins 2 | products, Health and Personal Care, Supplements 3 | products, Health and Personal Care, Health Care 4 | products, Health and Personal Care, Baby & Child Care 5 | products, Health and Personal Care, Sports Nutrition 6 | products, Health and Personal Care, Household Supplies 7 | products, Health and Personal Care, Medical Supplies & Equipment 8 | products, Health and Personal Care, Sexual Wellness 9 | products, Health and Personal Care, Health & Wellness 10 | products, Health and Personal Care, Oral Care 11 | products, Health and Personal Care, Beauty & Grooming 12 | products, Health and Personal Care, Pet Supplies 13 | products, Health and Personal Care, Health & Household 14 | products, Health and Personal Care, Baby & Child Care 15 | products, Health and Personal Care, Health Supplies 16 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Home Improvement: -------------------------------------------------------------------------------- 1 | products, Home Improvement, Power tools 2 | products, Home Improvement, Hand tools 3 | products, Home Improvement, Electrical 4 | products, Home Improvement, Hardware 5 | products, Home Improvement, Plumbing 6 | products, Home Improvement, Building supplies 7 | products, Home Improvement, Paint 8 | products, Home Improvement, Heating 9 | products, Home Improvement, Cooling 10 | products, Home Improvement, Home security 11 | products, Home Improvement, Lighting 12 | products, Home Improvement, Storage 13 | products, Home Improvement, Cleaning supplies 14 | products, Home Improvement, Appliances 15 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Industrial and Scientific: -------------------------------------------------------------------------------- 1 | products, Industrial and Scientific, Industrial Electrical 2 | products, Industrial and Scientific, Occupational Health & Safety Products 3 | products, Industrial and Scientific, Test 4 | products, Industrial and Scientific, Measure & Inspect 5 | products, Industrial and Scientific, Abrasive & Finishing Products 6 | products, Industrial and Scientific, Janitorial & Sanitation Supplies 7 | products, Industrial and Scientific, Industrial Hardware 8 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Jewelry: -------------------------------------------------------------------------------- 1 | products, Jewelry, Rings 2 | products, Jewelry, Necklaces 3 | products, Jewelry, Bracelets 4 | products, Jewelry, Earrings 5 | products, Jewelry, Pendants 6 | products, Jewelry, Brooches 7 | products, Jewelry, Charms 8 | products, Jewelry, Anklets 9 | products, Jewelry, Jewelry Sets 10 | products, Jewelry, Body Jewelry 11 | products, Jewelry, Hair Jewelry 12 | products, Jewelry, Toe Rings 13 | products, Jewelry, Cufflinks 14 | products, Jewelry, Tie Clips 15 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Musical Instruments: -------------------------------------------------------------------------------- 1 | products, Musical Instruments, Amplifiers & Effects 2 | products, Musical Instruments, Band & Orchestra 3 | products, Musical Instruments, Drums & Percussion 4 | products, Musical Instruments, Folk & World Instruments 5 | products, Musical Instruments, Guitars 6 | products, Musical Instruments, Keyboards 7 | products, Musical Instruments, Live Sound & Stage 8 | products, Musical Instruments, Microphones & Accessories 9 | products, Musical Instruments, Recording Equipment 10 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Office Products: -------------------------------------------------------------------------------- 1 | products, Office Products, Office furniture 2 | products, Office Products, Office lighting 3 | products, Office Products, Desk accessories 4 | products, Office Products, Filing products 5 | products, Office Products, Office electronics 6 | products, Office Products, Paper products 7 | products, Office Products, Writing instruments 8 | products, Office Products, Presentation supplies 9 | products, Office Products, Office storage 10 | products, Office Products, Office maintenance 11 | products, Office Products, Office organization 12 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Pet Supplies: -------------------------------------------------------------------------------- 1 | products, Pet Supplies, Dog food 2 | products, Pet Supplies, Cat food 3 | products, Pet Supplies, Dog treats 4 | products, Pet Supplies, Cat treats 5 | products, Pet Supplies, Cat litter 6 | products, Pet Supplies, Dog grooming supplies 7 | products, Pet Supplies, Cat grooming supplies 8 | products, Pet Supplies, Dog toys 9 | products, Pet Supplies, Cat toys 10 | products, Pet Supplies, Dog collars and leashes 11 | products, Pet Supplies, Cat collars and leashes 12 | products, Pet Supplies, Fish supplies 13 | products, Pet Supplies, Small animal supplies 14 | products, Pet Supplies, Bird supplies 15 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Sports and Outdoors: -------------------------------------------------------------------------------- 1 | products, Sports and Outdoors, Camping and Hiking 2 | products, Sports and Outdoors, Exercise and Fitness 3 | products, Sports and Outdoors, Golf 4 | products, Sports and Outdoors, Hunting and Fishing 5 | products, Sports and Outdoors, Outdoor Recreation 6 | products, Sports and Outdoors, Team Sports 7 | products, Sports and Outdoors, Water Sports 8 | products, Sports and Outdoors, Winter Sports 9 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Tools and Home Improvement: -------------------------------------------------------------------------------- 1 | products, Tools and Home Improvement, Power Tools 2 | products, Tools and Home Improvement, Hand Tools 3 | products, Tools and Home Improvement, Tool Organizers 4 | products, Tools and Home Improvement, Tool Sets 5 | products, Tools and Home Improvement, Measuring and Layout Tools 6 | products, Tools and Home Improvement, Welding and Soldering 7 | products, Tools and Home Improvement, Tool Storage 8 | products, Tools and Home Improvement, Material Handling 9 | products, Tools and Home Improvement, Jobsite Safety 10 | products, Tools and Home Improvement, Flashlights 11 | products, Tools and Home Improvement, Jobsite Radios 12 | products, Tools and Home Improvement, Welding Equipment 13 | products, Tools and Home Improvement, Woodworking 14 | products, Tools and Home Improvement, Workshop Equipment 15 | products, Tools and Home Improvement, Air Tools 16 | products, Tools and Home Improvement, Welding and Soldering 17 | products, Tools and Home Improvement, Power Finishing Tools 18 | products, Tools and Home Improvement, Fireplace and Stove Accessories 19 | products, Tools and Home Improvement, Abrasive and Finishing Products 20 | products, Tools and Home Improvement, Building Supplies 21 | products, Tools and Home Improvement, Cleaning Supplies 22 | products, Tools and Home Improvement, Electrical 23 | products, Tools and Home Improvement, Hardware 24 | products, Tools and Home Improvement, Kitchen and Bath Fixtures 25 | products, Tools and Home Improvement, Light Bulbs 26 | products, Tools and Home Improvement, Lighting and Ceiling Fans 27 | products, Tools and Home Improvement, Measuring and Layout Tools 28 | products, Tools and Home Improvement, Paint 29 | products, Tools and Home Improvement, Wall Treatments and Supplies 30 | products, Tools and Home Improvement, Power and Hand Tools 31 | products, Tools and Home Improvement, Rough Plumbing 32 | products, Tools and Home Improvement, Safety and Security 33 | products, Tools and Home Improvement, Storage and Home Organization 34 | products, Tools and Home Improvement, Winterization products 35 | Outdoor Power Tools 36 | products, Tools and Home Improvement, Hand Tools 37 | products, Tools and Home Improvement, Power Tools 38 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments 39 | products, Tools and Home Improvement, Electrical 40 | products, Tools and Home Improvement, Hardware 41 | products, Tools and Home Improvement, Kitchen and Bath Fixtures 42 | products, Tools and Home Improvement, Light Bulbs 43 | products, Tools and Home Improvement, Light Fixtures 44 | products, Tools and Home Improvement, Measuring and Layout Tools 45 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments 46 | products, Tools and Home Improvement, Power and Hand Tools 47 | products, Tools and Home Improvement, Rough Plumbing 48 | products, Tools and Home Improvement, Safety and Security 49 | products, Tools and Home Improvement, Storage and Home Organization 50 | products, Tools and Home Improvement, Building Supplies 51 | products, Tools and Home Improvement, Rough Plumbing 52 | products, Tools and Home Improvement, Building Materials 53 | products, Tools and Home Improvement, Material Handling 54 | products, Tools and Home Improvement, Measuring and Layout Tools 55 | products, Tools and Home Improvement, Jobsite Safety 56 | products, Tools and Home Improvement, Work Wear and Safety Gear 57 | products, Tools and Home Improvement, Air Tools & Compressors 58 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Toys: -------------------------------------------------------------------------------- 1 | products, Toys, Action Figures 2 | products, Toys, Dolls 3 | products, Toys, Stuffed Animals 4 | products, Toys, Building Sets 5 | products, Toys, Arts & Crafts 6 | products, Toys, Vehicles 7 | products, Toys, Puzzles 8 | products, Toys, Outdoor Play 9 | products, Toys, Learning & Education 10 | products, Toys, Games 11 | products, Toys, Play Food & Grocery 12 | products, Toys, Electronic Toys 13 | products, Toys, Musical Instruments 14 | -------------------------------------------------------------------------------- /graph_generation/gen_results/tree_wInstanceNum_products_e-commerce platform like Amazon.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/tree_wInstanceNum_products_e-commerce platform like Amazon.pkl -------------------------------------------------------------------------------- /graph_generation/human_item_generation_gibbsSampling_embedEstimation.py: -------------------------------------------------------------------------------- 1 | import random 2 | from Utils import DataGenAgent 3 | import pickle 4 | import os 5 | import Exp_Utils.TimeLogger as logger 6 | from Exp_Utils.TimeLogger import log 7 | from Exp_Utils.Emailer import SendMail 8 | import numpy as np 9 | from scipy.stats import norm 10 | import copy 11 | 12 | class HumanItemRelationGeneration(DataGenAgent): 13 | def __init__(self, item_list, length_sampler, descs, hyperparams, text_embedding_dict=None): 14 | super(HumanItemRelationGeneration, self).__init__() 15 | 16 | self.item_list = item_list 17 | self.length_sampler = length_sampler 18 | self.descs = descs 19 | self.hyperparams = hyperparams 20 | self.hyperparams['item_num'] = len(self.item_list) 21 | self.item_to_id = dict() 22 | for iid, item in enumerate(item_list): 23 | self.item_to_id[item] = iid 24 | self.reject_cnt = 0 25 | self.item_perm = np.random.permutation(len(self.item_list)) 26 | self.score_history = [] 27 | self.text_embedding_dict = dict() if text_embedding_dict is None else text_embedding_dict 28 | 29 | # fst_text = 'Clothing, Plus size clothing, Plus size formal wear, Dresses' 30 | # scd_text = 'Health & Household, Household Supplies, Air fresheners, Scented beads & charms' 31 | # print(fst_text) 32 | # print(scd_text) 33 | # print(self.similarity(self.text_embedding(fst_text), self.text_embedding(scd_text))) 34 | # exit() 35 | 36 | def binvec2list(self, bin_sample_vec): 37 | idxs = np.reshape(np.argwhere(bin_sample_vec != 0), [-1]) 38 | return list(map(lambda x: self.item_list[x], idxs)) 39 | 40 | def list_text(self, item_list, nums): 41 | item_num_list = list(zip(item_list, nums)) 42 | item_num_list.sort(key=lambda x: x[1], reverse=True) 43 | ret = '' 44 | for i, pair in enumerate(item_num_list): 45 | item, num = pair[0], pair[1] 46 | ret += '{idx}. {item}. Frequency: {num}\n'.format(idx=i, item=item, num=num) 47 | return ret 48 | 49 | def summarize(self, item_list): 50 | def fuse(item_list, prefix): 51 | for i in range(len(item_list)): 52 | item = item_list[i] 53 | if item.startswith(prefix): 54 | item_list[i] = prefix 55 | return item_list 56 | def count_and_shrink(item_list): 57 | dic = dict() 58 | for item in item_list: 59 | if item not in dic: 60 | dic[item] = 0 61 | dic[item] += 1 62 | ret_item, ret_cnt = [], [] 63 | for key, cnt in dic.items(): 64 | ret_item.append(key) 65 | ret_cnt.append(cnt) 66 | return ret_item, ret_cnt 67 | def count_prefixes_of_different_depth(item_list, max_depth): 68 | ret_item_list = [] 69 | prefix_dicts = [dict() for i in range(max_depth + 1)] 70 | for item in item_list: 71 | num_idx = item.index(' #') 72 | tem_item = item[:num_idx] 73 | ret_item_list.append(tem_item) 74 | entities = tem_item.split(', ') 75 | entities = list(map(lambda entity: entity.strip(), entities)) 76 | for depth in range(max_depth + 1): 77 | if depth + 1 >= len(entities): 78 | break 79 | tem_prefix = ', '.join(entities[:depth + 1]) 80 | if tem_prefix not in prefix_dicts[depth]: 81 | prefix_dicts[depth][tem_prefix] = 0 82 | prefix_dicts[depth][tem_prefix] += 1 83 | return prefix_dicts, ret_item_list 84 | max_depth = len(item_list[0].split(', ')) 85 | prefix_dicts, ret_item_list = count_prefixes_of_different_depth(item_list, max_depth) 86 | if len(ret_item_list) < self.hyperparams['context_limit']: 87 | return ret_item_list, [1] * len(ret_item_list) 88 | # greedy search 89 | flag = False 90 | for depth in range(max_depth, -1, -1): 91 | prefix_list = [(prefix, cnt) for prefix, cnt in prefix_dicts[depth].items()] 92 | prefix_list.sort(key=lambda x: x[1], reverse=True) 93 | for prefix, cnt in prefix_list: 94 | if cnt == 1: 95 | break 96 | ret_item_list = fuse(ret_item_list, prefix) 97 | if depth != 0: 98 | # adjust the counts of shallow entities 99 | shrinked_prefix = ', '.join(prefix.split(', ')[:-1]) 100 | prefix_dicts[depth - 1][shrinked_prefix] -= cnt - 1 101 | if len(set(ret_item_list)) <= self.hyperparams['context_limit']: 102 | flag=True 103 | break 104 | if flag: 105 | return count_and_shrink(ret_item_list) 106 | return count_and_shrink(ret_item_list) 107 | 108 | def text_embedding(self, text): 109 | if text in self.text_embedding_dict: 110 | embeds = self.text_embedding_dict[text] 111 | return embeds 112 | print('Embedding not found!') 113 | print(text) 114 | exit() 115 | embedding = self.openai_embedding(text) 116 | self.text_embedding_dict[text] = embedding 117 | return embedding 118 | 119 | def similarity(self, fst_embed, scd_embed): 120 | # fst_embed = fst_embed / np.sqrt(np.sum(np.square(fst_embed))) 121 | # scd_embed = scd_embed / np.sqrt(np.sum(np.square(scd_embed))) 122 | return np.sum(fst_embed * scd_embed) 123 | 124 | def dynamic_normalize(self, score): 125 | self.score_history.append(score) 126 | if len(self.score_history) > 5000: 127 | self.score_history = self.score_history[-5000:] 128 | if len(self.score_history) < 5: 129 | return max(min(1.0, score), 0.0) 130 | score_samples = np.array(self.score_history) 131 | mean = np.mean(score_samples) 132 | std = np.sqrt(np.mean((score_samples - mean) ** 2)) 133 | minn = mean - 1.96 * std 134 | maxx = mean + 1.96 * std#2.96 * std 135 | # minn = np.min(self.score_history) 136 | # maxx = np.max(self.score_history) 137 | ret = (score - minn) / (maxx - minn) 138 | ret = max(min(ret, 1.0), 0.0) 139 | # add margin 140 | # ret = (ret + 0.1) / 1.2 141 | # print('minn', minn) 142 | # print('maxx', maxx) 143 | # print('score', score) 144 | # print((score - minn) / (maxx - minn)) 145 | 146 | # if len(self.score_history) > 10000: 147 | # import matplotlib.pyplot as plt 148 | # plt.hist(self.score_history, 100) 149 | # plt.savefig('tem.pdf') 150 | # exit() 151 | 152 | return ret 153 | 154 | def estimate_probability(self, bin_sample_vec, cur_dim, is_deleting=False): 155 | item_list = self.binvec2list(bin_sample_vec) 156 | new_item = self.item_list[cur_dim] 157 | candidate_embedding = self.text_embedding(new_item[:new_item.index(' #')]) 158 | 159 | # summary similarity 160 | # if 'real_data' not in self.descs['data_name']: 161 | # summarized_item_list, nums = self.summarize(item_list) 162 | # summarized_score = 0 163 | # for i, summarized_item in enumerate(summarized_item_list): 164 | # sim = self.similarity(self.text_embedding(summarized_item), candidate_embedding) 165 | # summarized_score += sim * nums[i] / sum(nums) 166 | 167 | # top instance similarity 168 | # sims = list(map(lambda item: (self.similarity(self.text_embedding(item[:item.index(' #')]), candidate_embedding), item[:item.index(' #')]), item_list)) 169 | # sims.sort(reverse=True, key=lambda x: x[0]) 170 | # sim_scores = list(map(lambda x: x[0], sims[:self.hyperparams['context_limit']])) 171 | # instance_score = sum(sim_scores) / len(sim_scores) 172 | # most_sim_items = list(map(lambda x: x[1], sims[:self.hyperparams['context_limit']])) 173 | 174 | embed_list = list(map(lambda item: self.text_embedding(item[:item.index(' #')]), item_list)) 175 | avg_embed = sum(embed_list) / len(embed_list) 176 | instance_score = np.sum(avg_embed * candidate_embedding) 177 | 178 | 179 | # print('## Candidate item:', new_item) 180 | # print('## Summary context:') 181 | # print(self.list_text(summarized_item_list, nums)) 182 | # print('## Most similar items:') 183 | # print(self.list_text(most_sim_items, [1] * len(most_sim_items))) 184 | 185 | # if 'real_data' in self.descs['data_name']: 186 | score = instance_score 187 | # else: 188 | # score = (summarized_score + instance_score) / 2 189 | # log('Original score: summary-{sum_score}, instance-{ins_score}, sum-{tot_score}'.format(sum_score=summarized_score, ins_score=instance_score, tot_score=score)) 190 | # score = self.dynamic_normalize(score) 191 | # log('Dynamic normalized probability: {prob}'.format(prob=score)) 192 | interaction_num = np.sum(bin_sample_vec != 0) 193 | interaction_prob = 1.0 / (1.0 + np.exp((interaction_num - self.hyperparams['length_center'])/(self.hyperparams['length_center']//2)))#(100, 50), (150, 75) 194 | score = score * interaction_prob# * 1.1 195 | # log('Interaction num normalized probability: {prob}'.format(prob=score)) 196 | return score 197 | 198 | def update_sample(self, last_sample, cur_dim, should_include): 199 | if last_sample[cur_dim] == 0.0 and should_include or last_sample[cur_dim] > 0.0 and not should_include: 200 | new_sample = copy.deepcopy(last_sample) 201 | new_sample[cur_dim] = 1.0 - last_sample[cur_dim] 202 | return new_sample, True 203 | else: 204 | return last_sample, False 205 | 206 | def Gibbs_Sampling(self): 207 | samples = [] 208 | idx = 0 209 | update_cnt = 0 210 | cur_community = 0 211 | for step in range(self.hyperparams['sample_num']): 212 | if step % self.hyperparams['restart_step'] == 0: 213 | samples.append(self.random_sample()) 214 | cur_community = (cur_community + 1) % self.hyperparams['community_num'] 215 | last_sample = samples[-1] 216 | update_flag = False 217 | for small_step in range(self.hyperparams['gibbs_step']): 218 | cur_dim = self.item_perm[idx] 219 | nnz = np.sum(last_sample != 0) 220 | delete_dice = random.uniform(0, 1) 221 | if nnz > self.hyperparams['delete_nnz'] and delete_dice < 0.5:# 0.5, 0.4, 0.75, 2 222 | cur_dim = np.random.choice(np.reshape(np.argwhere(last_sample > 0.0), [-1])) 223 | tem_delete_flag = False 224 | if last_sample[cur_dim] > 0.0: 225 | # log('Deleting') 226 | tem_delete_flag = True 227 | last_sample[cur_dim] = 0.0 228 | idx = (idx + 1) % len(self.item_list) 229 | self.failure = 0 230 | prob = self.estimate_probability(last_sample, cur_dim, tem_delete_flag) 231 | 232 | # community modifier 233 | diff = abs(cur_community - cur_dim % self.hyperparams['community_num']) 234 | prob *= self.hyperparams['com_decay'] ** diff 235 | 236 | dice = random.uniform(0, 1) - self.hyperparams['dice_shift'] 237 | # if dice < prob: 238 | # print('Edge should be included') 239 | # else: 240 | # print('Edge should not be included') 241 | last_sample, change_flag = self.update_sample(last_sample, cur_dim, dice < prob) 242 | if tem_delete_flag: 243 | change_flag = not change_flag 244 | if change_flag: 245 | if small_step == 0: 246 | log('Sample Updated! Step {step}_{small_step}, update cnt {update_cnt}, interaction num {int_num}, sample num {samp_num}'.format(step=step, small_step=small_step, update_cnt=update_cnt, int_num=np.sum(last_sample!=0.0), samp_num=len(samples)), oneline=True) 247 | self.reject_cnt = 0 248 | update_cnt += 1 249 | update_flag = True 250 | else: 251 | if small_step == 0: 252 | log('Sample UNCHANGED! Step {step}_{small_step}, update cnt {update_cnt}, interaction num {int_num}, sample num {samp_num}'.format(step=step, small_step=small_step, update_cnt=update_cnt, int_num=np.sum(last_sample!=0.0), samp_num=len(samples)), oneline=True) 253 | self.reject_cnt += 1 254 | # print('*******\n') 255 | if self.reject_cnt > 50: 256 | log('Consecutive rejection {rej_cnt} when sampling!'.format(rej_cnt=self.reject_cnt), save=True) 257 | log('Last sample: {last_sample}'.format(last_sample=self.binvec2list(samples[-1]))) 258 | log('New sample: {new_sample}'.format(new_sample=self.binvec2list(self.update_sample(last_sample, cur_dim, dice >= prob)))) 259 | log('Sending report email.', save=True) 260 | # SendMail(logger.logmsg) 261 | self.reject_cnt = 0 262 | break 263 | if update_flag: 264 | if step % self.hyperparams['restart_step'] < self.hyperparams['gibbs_skip_step']: # original 50 265 | samples[-1] = last_sample 266 | else: 267 | samples.append(last_sample) 268 | return samples 269 | 270 | def random_sample(self): 271 | picked_idxs = random.sample(list(range(len(self.item_list))), self.hyperparams['seed_num']) 272 | last_interaction = np.zeros(len(self.item_list)) 273 | last_interaction[picked_idxs] = 1.0 274 | return last_interaction 275 | 276 | def run(self): 277 | samples_binvec = self.Gibbs_Sampling() 278 | picked_items_list = [] 279 | for vec in samples_binvec: 280 | picked_items = self.binvec2list(vec) 281 | picked_items_list.append(picked_items) 282 | return picked_items_list 283 | 284 | def load_item_list(item_file, entity_file, item_num): 285 | if not os.path.exists(item_file): 286 | with open(entity_file, 'rb') as fs: 287 | entity_tree_root = pickle.load(fs) 288 | entity_tree_root.allocate_number(item_num) 289 | item_list = entity_tree_root.get_list_of_leaves('') 290 | with open(item_file, 'wb+') as fs: 291 | pickle.dump(item_list, fs) 292 | else: 293 | with open(item_file, 'rb') as fs: 294 | item_list = pickle.load(fs) 295 | return item_list 296 | 297 | def get_gen_iter(file_root, interaction_file_prefix): 298 | max_existing_iter = -1 299 | for filename in os.listdir(file_root): 300 | cur_filename = file_root + filename 301 | if interaction_file_prefix in cur_filename: 302 | st_idx = len(interaction_file_prefix) 303 | ed_idx = cur_filename.index('_iter-0.pkl') 304 | cur_iter = int(cur_filename[st_idx: ed_idx]) 305 | max_existing_iter = max(max_existing_iter, cur_iter) 306 | return max_existing_iter 307 | 308 | def load_interactions(prev_interaction_file): 309 | if 'iter--1' in prev_interaction_file: 310 | return None 311 | with open(prev_interaction_file, 'rb') as fs: 312 | prev_interactions = pickle.load(fs) 313 | return prev_interactions 314 | 315 | def load_embedding_dict(embed_file): 316 | if not os.path.exists(embed_file): 317 | return None 318 | with open(embed_file, 'rb') as fs: 319 | ret = pickle.load(fs) 320 | return ret 321 | 322 | if __name__ == '__main__': 323 | # parameter specification 324 | descs = { 325 | 'data_name': 'gen_data_ecommerce', 326 | 'scenario_desc': 'e-commerce platform like Amazon', 327 | 'human_role': 'user', 328 | 'interaction_verb': 'interact', 329 | 'initial_entity': 'products', 330 | } 331 | hyperparams = { 332 | 'seed_num': 6, 333 | 'item_num': 1,#200000, 334 | 'sample_num': 1400,#20000 335 | 'context_limit': 15, 336 | 'gibbs_step': 1000, 337 | 'gen_base': 0, 338 | 'restart_step': 100,# shift community when restart 339 | 'gibbs_skip_step': 1,# 100 340 | 'delete_nnz': 1, # 5 341 | 'length_center': 400, # 60, 100, 150 342 | 'community_num': 7, 343 | 'itmfuse': True, 344 | 'com_decay': 0.95, 345 | # 'dice_shift': 0.1, 346 | 'dice_shift': 0.1, 347 | } 348 | 349 | # file name definition 350 | file_root = 'gen_results/datasets/{data_name}/'.format(data_name=descs['data_name']) 351 | entity_file = 'gen_results/tree_wInstanceNum_{initial_entity}_{scenario}.pkl'.format(initial_entity=descs['initial_entity'], scenario=descs['scenario_desc']) 352 | item_file = file_root + 'item_list.pkl' 353 | embed_file = file_root + 'embedding_dict.pkl' 354 | 355 | # load data 356 | item_list = load_item_list(item_file, entity_file, hyperparams['item_num']) 357 | if hyperparams['itmfuse']: 358 | item_list = list(map(lambda x: x[:x.index(' #')] + ' #1', item_list)) 359 | embedding_dict = load_embedding_dict(embed_file) 360 | 361 | def length_sampler(): 362 | min_len, max_len = hyperparams['min_len'], hyperparams['max_len'] 363 | return random.randint(min_len, max_len) 364 | 365 | # generate new interactions 366 | generator = HumanItemRelationGeneration(item_list, length_sampler, descs, hyperparams, embedding_dict) 367 | sampled_interactions = generator.run() 368 | 369 | # store to disk 370 | interaction_file_prefix = file_root + 'interaction_base-' 371 | if 'gen_base' in hyperparams: 372 | gen_base = hyperparams['gen_base'] 373 | next_interaction_file = interaction_file_prefix + str(gen_base) + '_iter-0.pkl' 374 | if os.path.exists(next_interaction_file): 375 | gen_base = get_gen_iter(file_root, interaction_file_prefix) + 1 376 | next_interaction_file = interaction_file_prefix + str(gen_base) + '_iter-0.pkl' 377 | with open(next_interaction_file, 'wb+') as fs: 378 | pickle.dump(sampled_interactions, fs) -------------------------------------------------------------------------------- /graph_generation/instance_number_estimation_hierarchical.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from Utils import DataGenAgent, EntityTreeConstructer, EntityTreeNode 3 | import os 4 | import time 5 | import pickle 6 | import Exp_Utils.TimeLogger as logger 7 | from Exp_Utils.TimeLogger import log 8 | from Exp_Utils.Emailer import SendMail 9 | 10 | class HierarchicalInstanceNumberEstimator(DataGenAgent): 11 | def __init__(self, entity_tree_root, total_num, depth, initial_entity, scenario_desc): 12 | super(HierarchicalInstanceNumberEstimator, self).__init__() 13 | 14 | self.entity_tree_root = entity_tree_root 15 | self.total_num = total_num 16 | self.initial_entity = initial_entity 17 | self.scenario_desc = scenario_desc 18 | self.depth = depth 19 | self.failure = 0 20 | 21 | def _entity_list_to_text(self, entity_list): 22 | ret = '' 23 | for i, entity in enumerate(entity_list): 24 | ret += '{idx}. {entity}\n'.format(idx=i+1, entity=entity) 25 | return ret 26 | 27 | def interpret_one_answer(self, answer_text, subcategory): 28 | answer_lower = answer_text.lower() 29 | if subcategory.lower() not in answer_lower: 30 | log('ERROR: Entity name not found.', save=True) 31 | log('subcategory: {subcategory}'.format(subcategory=subcategory), save=True) 32 | log('answer_lower: {answer_lower}'.format(answer_lower=answer_lower), save=True) 33 | raise Exception('Entity name not found.') 34 | estimation_choices = ['average frequency', '1.2 times more frequent', '1.2 times less frequent', '1.5 times more frequent', '1.5 times less frequent', '2 times more frequent', '2 times less frequent', '4 times more frequent', '4 times less frequent', '8 times more frequent', '8 times less frequent'] 35 | estimation_scores = [1.0, 1.2, 1/1.2, 1.5, 1/1.5, 2.0, 1/2.0, 4.0, 1/4.0, 8.0, 1/8.0] 36 | for i, choice in enumerate(estimation_choices): 37 | if choice in answer_lower: 38 | return estimation_scores[i] 39 | raise Exception('Estimation not found.') 40 | 41 | def interpret(self, answers_text, subcategories): 42 | answer_list = answers_text.strip().split('\n') 43 | assert len(answer_list) == len(subcategories), 'Length does not match.' 44 | answers = [] 45 | for i in range(len(answer_list)): 46 | answers.append(self.interpret_one_answer(answer_list[i], subcategories[i])) 47 | return answers 48 | 49 | def estimate_subcategories(self, subcategories, category): 50 | subcategories_text = self._entity_list_to_text(subcategories) 51 | if category != self.initial_entity: 52 | text = '''In the context of {scenario_desc}, you are given a list of sub-categories below, which belong to the {category} category of {initial_entity}. Using your intuition and common sense, your goal is to identify the frequency of these sub-categories compared to the average frequency of all possible {category} {initial_entity}.'''.format(scenario_desc=self.scenario_desc, initial_entity=self.initial_entity, category=category) 53 | else: 54 | text = '''In the context of {scenario_desc}, you are given a list of sub-categories below, which belong to {initial_entity}. Using your intuition and common sense, your goal is to identify the frequency of these sub-categories compared to the average frequency of all possible {initial_entity}.'''.format(scenario_desc=self.scenario_desc, initial_entity=self.initial_entity) 55 | text += '''Your answer should contain one line for each of the sub-categories, EXACTLY following the following format: "[serial number]. [sub-category name same as in the input]; [your frequency estimation]; [one-sentence explaination for your estimation]". The frequency estimation should be one of the following choices: [average frequency, 1.2 times more/less frequent, 1.5 times more/less frequent, 2 times more/less frequent, 4 times more/less frequent, 8 times more/less frequent]. No other words should be included in your response. The sub-categories list is as follows:\n\n''' + subcategories_text 56 | # print('input') 57 | # print(text) 58 | try: 59 | answers_text = self.openai(text) 60 | print('Answers text:') 61 | print(answers_text) 62 | return self.interpret(answers_text, subcategories) 63 | except Exception as e: 64 | self.failure += 1 65 | if self.failure < 5: 66 | log('Exception occurs when interpreting. Retry in 10 seconds.', save=True) 67 | log('Exception message: {exception}'.format(exception=e), save=True) 68 | log('Failure times: {failure}'.format(failure=self.failure), save=True) 69 | log('Prompt text:\n{prompt}'.format(prompt=text), save=True) 70 | log('Response text:\n{response}'.format(response=answers_text), save=True) 71 | time.sleep(10) 72 | return self.estimate_subcategories(subcategories, category) 73 | else: 74 | log('Exception occurs {failure} times when interpreting. CANNOT HANDLE.'.format(failure=str(self.failure)), save=True) 75 | log('Exception message: {exception}'.format(exception=e), save=True) 76 | log('Prompt text:\n{prompt}'.format(prompt=text), save=True) 77 | log('Response text:\n{response}'.format(response=answers_text), save=True) 78 | log('Sending report email.', save=True) 79 | SendMail(logger.logmsg) 80 | logger.logmsg = '' 81 | return [1.0] * len(subcategories) 82 | 83 | def run(self): 84 | que = [self.entity_tree_root] 85 | while len(que) > 0: 86 | cur_entity = que[0] 87 | que = que[1:] 88 | if len(cur_entity.children) == 0: 89 | continue 90 | cur_children_entities = list(cur_entity.children.values()) 91 | que = que + cur_children_entities 92 | cur_children_names = list(map(lambda x: x.entity_name, cur_children_entities)) 93 | assert self.depth - cur_entity.depth > 0 and self.depth - cur_entity.depth < self.depth 94 | for _ in range(self.depth - cur_entity.depth + 1): 95 | self.failure = 0 96 | # answers = self.estimate_subcategories(cur_children_names, cur_entity.entity_name) 97 | # print(answers) 98 | # print('-----------------') 99 | # print() 100 | for j, entity in enumerate(cur_children_entities): 101 | # entity.frequency.append(answers[j]) 102 | entity.frequency.append(1.0) 103 | self.entity_tree_root.allocate_number(self.total_num) 104 | with open('gen_results/tree_wInstanceNum_{initial_entity}_{scenario}.pkl'.format(initial_entity=self.initial_entity, scenario=self.scenario_desc), 'wb') as fs: 105 | pickle.dump(self.entity_tree_root, fs) 106 | 107 | 108 | scenario = 'e-commerce platform like Amazon' 109 | initial_entity = 'products' 110 | total_num = 200000 111 | depth = 5 112 | 113 | # scenario = 'published paper list of top AI conferences' 114 | # initial_entity = 'deep learning papers' 115 | # total_num = 1000000 116 | # depth = 6 117 | 118 | # scenario = 'venue rating platform like yelp' 119 | # initial_entity = 'business venues' 120 | # total_num = 30000 121 | # depth = 5 122 | 123 | # scenario = 'book rating platform' 124 | # initial_entity = 'books' 125 | # total_num = 30000 126 | # depth = 5 127 | 128 | # load entities 129 | file = os.path.join('gen_results/', '{entity}_{scenario}.txt'.format(entity=initial_entity, scenario=scenario)) 130 | entity_lines = [] 131 | with open(file, 'r') as fs: 132 | for line in fs: 133 | entity_lines.append(line) 134 | entity_tree_constructer = EntityTreeConstructer(entity_lines) 135 | entity_tree_root = entity_tree_constructer.root 136 | 137 | estimator = HierarchicalInstanceNumberEstimator(entity_tree_root, total_num=total_num, depth=depth, initial_entity=initial_entity, scenario_desc=scenario) 138 | estimator.run() -------------------------------------------------------------------------------- /graph_generation/itemCollecting_dfsIterator.py: -------------------------------------------------------------------------------- 1 | # from langchain.prompts import PromptTemplate 2 | # from langchain.llms import OpenAI 3 | import os 4 | import openai 5 | import time 6 | import tiktoken 7 | import json 8 | 9 | openai.api_key = "xxxxxx" 10 | class DataGenAgent: 11 | def __init__(self, initial_entity, scenario_desc, depth): 12 | super(DataGenAgent, self).__init__() 13 | 14 | # self.openai = OpenAI(temperature=0) 15 | self.initial_entity = initial_entity 16 | self.scenario_desc = scenario_desc 17 | self.encoding = tiktoken.encoding_for_model('gpt-3.5-turbo') 18 | self.token_num = 0 19 | self.total_num = 0 20 | self.depth = depth 21 | 22 | def openai(self, message): 23 | try: 24 | completion = openai.ChatCompletion.create( 25 | model='gpt-3.5-turbo-1106', 26 | messages=[ 27 | {"role": "user", "content": message}, 28 | ] 29 | ) 30 | response = completion.choices[0].message.content 31 | time.sleep(1) 32 | self.token_num += len(self.encoding.encode(json.dumps(message))) 33 | return response 34 | except Exception as e: 35 | print('Exception occurs. Retry in 10 seconds.') 36 | time.sleep(10) 37 | return self.openai(message) 38 | 39 | def check_if_concrete(self, entity_stack): 40 | entity_name = ', '.join(entity_stack) 41 | text = 'In the context of {scenario_desc}, is {entity_name} a concrete instance or category that can hardly be divided into sub-categories with prominent differences? Response should starts with "True" or "False".'.format(scenario_desc=self.scenario_desc, entity_name=entity_name) 42 | answer = self.openai(text) 43 | # print('answer to {entity_name}: {answer}'.format(entity_name=entity_name, answer=answer)) 44 | if answer.startswith('True'): 45 | print('Concrete Check True') 46 | return True 47 | return False 48 | 49 | def category_enum(self, prefix, entity_name): 50 | if prefix == '': 51 | text = 'List all distinct sub-categories of {entity_name} in the context of {scenario_desc}, ensuring a finer level of granularity. The sub-categories should not overlap with each other. And a sub-category should be a smaller subset of {entity_name}. Directly present the list EXACTLY following the form: "sub-category a, sub-category b, sub-category c, ..." without other words, format symbols, new lines, serial numbers.'.format(entity_name=entity_name, prefix=prefix, scenario_desc=self.scenario_desc) 52 | else: 53 | text = 'List all distinct sub-categories of {entity_name} within the {prefix} category in the context of {scenario_desc}, ensuring a finer level of granularity. The sub-categories should not overlap with each other. And a sub-category should be a smaller subset of {entity_name}. Directly present the list EXACTLY following the form: "sub-category a, sub-category b, sub-category c, ..." without other words, format symbols, new lines, serial numbers.'.format(entity_name=entity_name, prefix=prefix, scenario_desc=self.scenario_desc) 54 | # text = 'List all distinct sub-categories of {entity_name} within the {prefix} category in the context of {scenario_desc}, ensuring a finer level of granularity. The sub-categories should not overlap with each other. Present the list exactly following the form: "sub-category a, sub-category b, sub-category c, ...". There should be no serial number, new lines or other format symbols. Separate each pair of sub-categories with a comma.'.format(entity_name=entity_name, prefix=prefix, scenario_desc=self.scenario_desc) 55 | answer = self.openai(text) 56 | return list(map(lambda x: x.strip().strip(',').strip('.'), answer.split(','))) 57 | 58 | def decompose_category(self, entity_stack, depth): 59 | entity_name = ', '.join(entity_stack) 60 | prefix = '' if depth == 1 else ', '.join(entity_stack[:-1]) 61 | if depth >= self.depth:# or self.check_if_concrete(entity_stack) is True: 62 | # print('{entity_name} is considered a concrete instance.'.format(entity_name=entity_name)) 63 | self.total_num += 1 64 | return [entity_name] 65 | print('\nCurrent entity: {entity_name}'.format(entity_name=entity_name)) 66 | concrete_entities = [] 67 | sub_entities = self.category_enum(prefix, entity_stack[-1]) 68 | print('sub-categories of {entity_name} includes:'.format(entity_name=entity_name), sub_entities) 69 | for sub_entity in sub_entities: 70 | if sub_entity in entity_name: 71 | continue 72 | new_concrete_entities = self.decompose_category(entity_stack + [sub_entity], depth+1) 73 | concrete_entities += new_concrete_entities 74 | if depth <= 4: 75 | print('Depth {depth}, current num of nodes {num}, total num of nodes {total_num}, num of tokens {token}'.format(depth=depth, num=len(concrete_entities), total_num=self.total_num, token=self.token_num)) 76 | if depth <= 2: 77 | print('Storing...') 78 | tem_file = 'gen_results/tem/{scenario}_depth{depth}_{cur_entity}'.format(scenario=self.scenario_desc, depth=str(depth), cur_entity=entity_name) 79 | with open(tem_file, 'w+') as fs: 80 | for node in concrete_entities: 81 | fs.write(node + '\n') 82 | return concrete_entities 83 | 84 | def run(self): 85 | return self.decompose_category([self.initial_entity], 1) 86 | 87 | entity = 'products' 88 | scenario = 'e-commerce platform like Amazon' 89 | depth = 3 90 | 91 | # entity = 'movies' 92 | # scenario = 'movie rating platform' 93 | 94 | # entity = 'books' 95 | # scenario = 'book rating platform' 96 | 97 | # entity = 'business venues' 98 | # scenario = 'venue rating platform like yelp' 99 | 100 | # entity = 'movies' 101 | # scenario = 'movie rating platform' 102 | # depth = 5 103 | 104 | # entity = 'deep learning papers' 105 | # scenario = 'published paper list of top AI conferences' 106 | # depth = 6 107 | 108 | # entity = 'ideology' 109 | # scenario = "people's political ideologies" 110 | # depth = 4 111 | 112 | # entity = 'jobs' 113 | # scenario = "people's occupations and professions" 114 | # depth = 5 115 | 116 | agent = DataGenAgent(entity, scenario, depth) 117 | nodes = agent.run() 118 | with open('gen_results/{entity}_{scenario}.txt'.format(entity=entity, scenario=scenario), 'w+') as fs: 119 | for node in nodes: 120 | fs.write(node+'\n') -------------------------------------------------------------------------------- /graph_generation/make_adjs.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import argparse 4 | from scipy.sparse import coo_matrix 5 | import numpy as np 6 | import random 7 | import networkx as nx 8 | 9 | descs = { 10 | 'data_name': 'gen_data_ecommerce', 11 | } 12 | params = { 13 | 'itmfusion': True, 14 | 'kcore': 0, 15 | 'sep': 1, 16 | 'min_base': 0, 17 | 'max_base': 100, 18 | } 19 | parser = argparse.ArgumentParser(description='Dataset information') 20 | parser.add_argument('--gen_iter', default=0, type=int, help='maximum generation iteration') 21 | args = parser.parse_args() 22 | 23 | 24 | 25 | file_root = 'gen_results/datasets/{data_name}/'.format(data_name=descs['data_name']) 26 | fuse_file_path = file_root + 'res/interaction_fuse_iter-{iter}.pkl'.format(iter=args.gen_iter) 27 | 28 | def get_all_bases(): 29 | bases = set() 30 | prefix = 'interaction_base-' 31 | suffix = '_iter-' 32 | for filename in os.listdir(file_root): 33 | if prefix in filename and suffix in filename: 34 | prefix_idx = len(prefix) 35 | suffix_idx = filename.index(suffix) 36 | base = int(filename[prefix_idx: suffix_idx]) 37 | if base >= params['min_base'] and base <= params['max_base']: 38 | bases.add(base) 39 | bases = list(bases) 40 | bases.sort() 41 | return bases 42 | 43 | def fuse_bases(): 44 | if os.path.exists(fuse_file_path): 45 | print('Fused interaction file exists! REUSING!') 46 | print('This may happen inadverdently!') 47 | exit() 48 | with open(fuse_file_path, 'rb') as fs: 49 | interactions = pickle.load(fs) 50 | return interactions 51 | all_bases = get_all_bases() 52 | interactions = [] 53 | for gen_base in all_bases: 54 | file_path = None 55 | for iter in range(args.gen_iter, -1, -1): 56 | file_path = file_root + 'interaction_base-{gen_base}_iter-{gen_iter}.pkl'.format(gen_base=gen_base, gen_iter=iter) 57 | if os.path.exists(file_path): 58 | break 59 | with open(file_path, 'rb') as fs: 60 | tem_cur_base_interactions = pickle.load(fs) 61 | cur_base_interactions = [] 62 | for i in range(len(tem_cur_base_interactions)): 63 | cur_base_interactions.append(tem_cur_base_interactions[i]) 64 | interactions += cur_base_interactions 65 | 66 | new_interactions = [] 67 | for i in range(len(interactions)): 68 | if i % params['sep'] == 0: 69 | # interactions[i] = interactions[i][:len(interactions[i]) // 3] 70 | new_interactions.append(interactions[i]) 71 | interactions = new_interactions 72 | # interactions = new_interactions[:20000] 73 | 74 | with open(fuse_file_path, 'wb+') as fs: 75 | pickle.dump(interactions, fs) 76 | return interactions 77 | 78 | def make_id_map(interactions, itm_criteria=None): 79 | u_num = len(interactions) 80 | i_set = set() 81 | i_cnt = dict() 82 | for interaction in interactions: 83 | for item in interaction: 84 | num_idx = item.index(' #') 85 | tem_item = item if not params['itmfusion'] else item[:num_idx] 86 | i_set.add(tem_item) 87 | if tem_item not in i_cnt: 88 | i_cnt[tem_item] = 0 89 | i_cnt[tem_item] += 1 90 | i_list = list(i_set) 91 | if itm_criteria is not None: 92 | tem_i_list = list() 93 | for item in i_list: 94 | if itm_criteria(i_cnt[item]): 95 | tem_i_list.append(item) 96 | print('Filtering {new_num} / {old_num}'.format(new_num=len(tem_i_list), old_num=len(i_list))) 97 | i_list = tem_i_list 98 | i_num = len(i_list) 99 | i_mapp = dict() 100 | for i, item in enumerate(i_list): 101 | i_mapp[item] = i 102 | rows = [] 103 | cols = [] 104 | for uid, interaction in enumerate(interactions): 105 | for item in interaction: 106 | num_idx = item.index(' #') 107 | tem_item = item if not params['itmfusion'] else item[:num_idx] 108 | if tem_item not in i_mapp: 109 | continue 110 | iid = i_mapp[tem_item] 111 | rows.append(uid) 112 | cols.append(iid) 113 | return rows, cols, i_mapp, u_num, i_num 114 | 115 | def id_map(nodes): 116 | uniq_nodes = list(set(nodes)) 117 | dic = dict() 118 | for i, node in enumerate(uniq_nodes): 119 | dic[node] = i 120 | return dic 121 | 122 | def k_core(rows, cols, i_mapp, i_num, k): 123 | edge_list = list(map(lambda idx: (rows[idx] + i_num, cols[idx]), range(len(rows)))) 124 | G = nx.Graph(edge_list) 125 | edge_list = list(nx.k_core(G, k=k).edges()) 126 | rows = [None] * len(edge_list) 127 | cols = [None] * len(edge_list) 128 | for i, edge in enumerate(edge_list): 129 | if edge[0] < i_num: 130 | rows[i] = edge[1] - i_num 131 | cols[i] = edge[0] 132 | else: 133 | rows[i] = edge[0] - i_num 134 | cols[i] = edge[1] 135 | row_map = id_map(rows) 136 | col_map = id_map(cols) 137 | new_rows = list(map(lambda x: row_map[x], rows)) 138 | new_cols = list(map(lambda x: col_map[x], cols)) 139 | new_i_mapp = dict() 140 | for key in i_mapp: 141 | tem_item = i_mapp[key] 142 | if tem_item not in col_map: 143 | continue 144 | new_i_mapp[key] = col_map[tem_item] 145 | return new_rows, new_cols, new_i_mapp, len(row_map), len(col_map) 146 | 147 | def make_mat(rows, cols, st, ed, u_num, i_num, perm, decrease=False): 148 | rows = np.array(rows)[perm] 149 | cols = np.array(cols)[perm] 150 | rows = rows[st: ed] 151 | cols = cols[st: ed] 152 | if decrease: 153 | rows = rows[:len(rows)//3] 154 | cols = cols[:len(cols)//3] 155 | vals = np.ones_like(rows) 156 | return coo_matrix((vals, (rows, cols)), shape=[u_num, i_num]) 157 | 158 | def data_split(rows, cols, u_num, i_num): 159 | leng = len(rows) 160 | perm = np.random.permutation(leng) 161 | trn_split = int(leng * 0.7) 162 | val_split = int(leng * 0.75) 163 | trn_mat = make_mat(rows, cols, 0, trn_split, u_num, i_num, perm) 164 | val_mat = make_mat(rows, cols, trn_split, val_split, u_num, i_num, perm) 165 | tst_mat = make_mat(rows, cols, val_split, leng, u_num, i_num, perm, decrease=True) 166 | return trn_mat, val_mat, tst_mat 167 | 168 | interactions = fuse_bases() 169 | rows, cols, i_mapp, u_num, i_num = make_id_map(interactions)#, lambda x: x>20) 170 | if params['kcore'] != 1: 171 | rows, cols, i_mapp, u_num, i_num = k_core(rows, cols, i_mapp, i_num, params['kcore']) 172 | print('U NUM', u_num, 'I Num', i_num, 'E Num', len(rows)) 173 | with open(file_root + 'res/iter-{gen_iter}_imap.pkl'.format(gen_iter=args.gen_iter), 'wb+') as fs: 174 | pickle.dump(i_mapp, fs) 175 | trn_mat, val_mat, tst_mat = data_split(rows, cols, u_num, i_num) 176 | with open(file_root + 'res/iter-{gen_iter}_train.pkl'.format(gen_iter=args.gen_iter), 'wb+') as fs: 177 | pickle.dump(trn_mat, fs) 178 | with open(file_root + 'res/iter-{gen_iter}_valid.pkl'.format(gen_iter=args.gen_iter), 'wb+') as fs: 179 | pickle.dump(val_mat, fs) 180 | with open(file_root + 'res/iter-{gen_iter}_test.pkl'.format(gen_iter=args.gen_iter), 'wb+') as fs: 181 | pickle.dump(tst_mat, fs) -------------------------------------------------------------------------------- /imgs/article cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/article cover.jpg -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/framework.png -------------------------------------------------------------------------------- /imgs/graph tokenizer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/graph tokenizer.png -------------------------------------------------------------------------------- /imgs/impact of datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/impact of datasets.png -------------------------------------------------------------------------------- /imgs/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/intro.png -------------------------------------------------------------------------------- /imgs/opengraph_article_cover_full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/opengraph_article_cover_full.png -------------------------------------------------------------------------------- /imgs/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/performance.png -------------------------------------------------------------------------------- /imgs/prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/prompt.png -------------------------------------------------------------------------------- /imgs/sampling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/sampling.png -------------------------------------------------------------------------------- /link_prediction/Utils/TimeLogger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | logmsg = '' 4 | timemark = dict() 5 | saveDefault = False 6 | def log(msg, save=None, oneline=False): 7 | global logmsg 8 | global saveDefault 9 | time = datetime.datetime.now() 10 | tem = '%s: %s' % (time, msg) 11 | if save != None: 12 | if save: 13 | logmsg += tem + '\n' 14 | elif saveDefault: 15 | logmsg += tem + '\n' 16 | if oneline: 17 | print(tem, end='\r') 18 | else: 19 | print(tem) 20 | 21 | def marktime(marker): 22 | global timemark 23 | timemark[marker] = datetime.datetime.now() 24 | 25 | 26 | if __name__ == '__main__': 27 | log('') -------------------------------------------------------------------------------- /link_prediction/data_handler.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from scipy.sparse import csr_matrix, coo_matrix, dok_matrix 4 | from params import args 5 | import scipy.sparse as sp 6 | from Utils.TimeLogger import log 7 | import torch as t 8 | import torch.utils.data as data 9 | import torch_geometric.transforms as T 10 | from model import InitialProjector 11 | import os 12 | 13 | class MultiDataHandler: 14 | def __init__(self, trn_datasets, tst_datasets): 15 | all_datasets = list(set(trn_datasets + tst_datasets)) 16 | self.trn_handlers = [] 17 | self.tst_handlers = [] 18 | for data_name in all_datasets: 19 | trn_flag = data_name in trn_datasets 20 | tst_flag = data_name in tst_datasets 21 | handler = DataHandler(data_name, trn_flag, tst_flag) 22 | if trn_flag: 23 | self.trn_handlers.append(handler) 24 | if tst_flag: 25 | self.tst_handlers.append(handler) 26 | self.make_joint_trn_loader() 27 | 28 | def make_joint_trn_loader(self): 29 | trn_data = TrnData(self.trn_handlers) 30 | self.trn_loader = data.DataLoader(trn_data, batch_size=1, shuffle=True, num_workers=0) 31 | 32 | def remake_initial_projections(self): 33 | for i in range(len(self.trn_handlers)): 34 | self.remake_one_initial_projection(i) 35 | 36 | def remake_one_initial_projection(self, idx): 37 | trn_handler = self.trn_handlers[idx] 38 | trn_handler.initial_projector = InitialProjector(trn_handler.asym_adj) 39 | 40 | class DataHandler: 41 | def __init__(self, data_name, trn_flag, tst_flag): 42 | self.data_name = data_name 43 | self.trn_flag = trn_flag 44 | self.tst_flag = tst_flag 45 | self.get_data_files() 46 | self.load_data() 47 | 48 | def get_data_files(self): 49 | predir = os.path.join(args.data_dir, self.data_name) 50 | self.trnfile = os.path.join(predir, 'trn_mat.pkl') 51 | self.tstfile = os.path.join(predir, 'tst_mat.pkl') 52 | self.valfile = os.path.join(predir, 'val_mat.pkl') 53 | if not os.path.exists(self.valfile): 54 | self.valfile = self.tstfile 55 | 56 | def load_one_file(self, filename): 57 | with open(filename, 'rb') as fs: 58 | ret = (pickle.load(fs)).astype(np.float32) 59 | if type(ret) != coo_matrix: 60 | ret = sp.coo_matrix(ret) 61 | return ret 62 | 63 | def normalize_adj(self, mat): 64 | degree = np.array(mat.sum(axis=-1)) 65 | d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1]) 66 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 67 | d_inv_sqrt_mat = sp.diags(d_inv_sqrt) 68 | if mat.shape[0] == mat.shape[1]: 69 | return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo() 70 | else: 71 | tem = d_inv_sqrt_mat.dot(mat) 72 | col_degree = np.array(mat.sum(axis=0)) 73 | d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1]) 74 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 75 | d_inv_sqrt_mat = sp.diags(d_inv_sqrt) 76 | return tem.dot(d_inv_sqrt_mat).tocoo() 77 | 78 | def unique_numpy(self, row, col): 79 | hash_vals = row * self.node_num + col 80 | hash_vals = np.unique(hash_vals).astype(np.int64) 81 | col = hash_vals % self.node_num 82 | row = (hash_vals - col).astype(np.int64) // self.node_num 83 | return row, col 84 | 85 | def make_torch_adj(self, mat): 86 | if mat.shape[0] == mat.shape[1]: 87 | # to symmetric 88 | if self.data_name in ['ddi']: 89 | _row = mat.row 90 | _col = mat.col 91 | row = np.concatenate([_row, _col]).astype(np.int64) 92 | col = np.concatenate([_col, _row]).astype(np.int64) 93 | # row, col = self.unique_numpy(row, col) 94 | data = mat.data 95 | data = np.concatenate([data, data]).astype(np.float32) 96 | else: 97 | row, col = mat.row, mat.col 98 | data = mat.data 99 | # data = np.ones_like(row) 100 | mat = coo_matrix((data, (row, col)), mat.shape) 101 | if args.selfloop == 1: 102 | mat = (mat + sp.eye(mat.shape[0])) * 1.0 103 | normed_asym_mat = self.normalize_adj(mat) 104 | row = t.from_numpy(normed_asym_mat.row).long() 105 | col = t.from_numpy(normed_asym_mat.col).long() 106 | idxs = t.stack([row, col], dim=0) 107 | vals = t.from_numpy(normed_asym_mat.data).float() 108 | shape = t.Size(normed_asym_mat.shape) 109 | asym_adj = t.sparse.FloatTensor(idxs, vals, shape) 110 | if mat.shape[0] == mat.shape[1]: 111 | return asym_adj, asym_adj 112 | else: 113 | # make ui adj 114 | a = sp.csr_matrix((self.user_num, self.user_num)) 115 | b = sp.csr_matrix((self.item_num, self.item_num)) 116 | mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])]) 117 | mat = (mat != 0) * 1.0 118 | if args.selfloop == 1: 119 | mat = (mat + sp.eye(mat.shape[0])) * 1.0 120 | mat = self.normalize_adj(mat) 121 | 122 | # make cuda tensor 123 | idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64)) 124 | vals = t.from_numpy(mat.data.astype(np.float32)) 125 | shape = t.Size(mat.shape) 126 | return t.sparse.FloatTensor(idxs, vals, shape), asym_adj 127 | 128 | def load_data(self): 129 | trn_mat = self.load_one_file(self.trnfile) 130 | # if self.trn_flag: 131 | self.trn_mat = trn_mat 132 | if trn_mat.shape[0] != trn_mat.shape[1]: 133 | self.user_num, self.item_num = trn_mat.shape 134 | self.node_num = self.user_num + self.item_num 135 | print('Dataset: {data_name}, User num: {user_num}, Item num: {item_num}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, user_num=self.user_num, item_num=self.item_num, node_num=self.node_num, edge_num=trn_mat.nnz)) 136 | else: 137 | self.node_num = trn_mat.shape[0] 138 | print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=self.node_num, edge_num=trn_mat.nnz)) 139 | self.torch_adj, self.asym_adj = self.make_torch_adj(trn_mat) 140 | if args.cache_proj: 141 | self.asym_adj = self.asym_adj.to(args.devices[0]) 142 | if args.cache_adj: 143 | self.torch_adj = self.torch_adj.to(args.devices[0]) 144 | 145 | self.initial_projector = InitialProjector(self.asym_adj) 146 | 147 | if self.tst_flag: 148 | val_mat = self.load_one_file(self.valfile) 149 | val_data = TstData(val_mat, trn_mat) 150 | self.val_loader = data.DataLoader(val_data, batch_size=args.tst_batch, shuffle=False, num_workers=0) 151 | tst_mat = self.load_one_file(self.tstfile) 152 | tst_data = TstData(tst_mat, trn_mat) 153 | self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0) 154 | 155 | class TstData(data.Dataset): 156 | def __init__(self, coomat, trn_mat): 157 | self.csrmat = (trn_mat.tocsr() != 0) * 1.0 158 | 159 | tstLocs = [None] * coomat.shape[0] 160 | tst_nodes = set() 161 | for i in range(len(coomat.data)): 162 | row = coomat.row[i] 163 | col = coomat.col[i] 164 | if tstLocs[row] is None: 165 | tstLocs[row] = list() 166 | tstLocs[row].append(col) 167 | tst_nodes.add(row) 168 | tst_nodes = np.array(list(tst_nodes)) 169 | self.tst_nodes = tst_nodes 170 | self.tstLocs = tstLocs 171 | 172 | def __len__(self): 173 | return len(self.tst_nodes) 174 | 175 | def __getitem__(self, idx): 176 | return self.tst_nodes[idx] 177 | 178 | class TrnData(data.Dataset): 179 | def __init__(self, trn_handlers): 180 | self.dataset_num = len(trn_handlers) 181 | self.trn_handlers = trn_handlers 182 | self.ancs_list = [None] * self.dataset_num 183 | self.poss_list = [None] * self.dataset_num 184 | self.negs_list = [None] * self.dataset_num 185 | self.edge_nums = [None] * self.dataset_num 186 | self.sample_nums = [None] * self.dataset_num 187 | for i, handler in enumerate(self.trn_handlers): 188 | trn_mat = handler.trn_mat 189 | ancs = np.array(trn_mat.row) 190 | poss = np.array(trn_mat.col) 191 | self.ancs_list[i] = ancs 192 | self.poss_list[i] = poss 193 | self.edge_nums[i] = len(ancs) 194 | self.sample_nums[i] = self.edge_nums[i] // args.batch + (1 if self.edge_nums[i] % args.batch != 0 else 0) 195 | self.total_sample_num = sum(self.sample_nums) 196 | self.samples = [None] * self.total_sample_num 197 | 198 | def data_shuffling(self): 199 | sample_idx = 0 200 | for i in range(self.dataset_num): 201 | edge_num = self.edge_nums[i] 202 | perms = np.random.permutation(edge_num) 203 | handler = self.trn_handlers[i] 204 | asym_flag = handler.trn_mat.shape[0] != handler.trn_mat.shape[1] 205 | cand_num = handler.item_num if asym_flag else handler.node_num 206 | self.negs_list[i] = self.neg_sampling(self.ancs_list[i], handler.trn_mat.todok(), cand_num) 207 | # self.negs_list[i] = np.random.randint(cand_num, size=edge_num) 208 | for j in range(self.sample_nums[i]): 209 | st_idx = j * args.batch 210 | ed_idx = min((j + 1) * args.batch, edge_num) 211 | pick_idxs = perms[st_idx: ed_idx] 212 | ancs = self.ancs_list[i][pick_idxs] 213 | poss = self.poss_list[i][pick_idxs] 214 | negs = self.negs_list[i][pick_idxs] 215 | if asym_flag: 216 | poss += handler.user_num 217 | negs += handler.user_num 218 | self.samples[sample_idx] = (ancs, poss, negs, i) 219 | sample_idx += 1 220 | assert sample_idx == self.total_sample_num 221 | 222 | def neg_sampling(self, ancs, dokmat, cand_num): 223 | negs = np.zeros_like(ancs) 224 | for i in range(len(ancs)): 225 | u = ancs[i] 226 | while True: 227 | i_neg = np.random.randint(cand_num) 228 | if (u, i_neg) not in dokmat: 229 | break 230 | negs[i] = i_neg 231 | return negs 232 | 233 | def __len__(self): 234 | return self.total_sample_num 235 | 236 | def __getitem__(self, idx): 237 | ancs, poss, negs, adj_id = self.samples[idx] 238 | return ancs, poss, negs, adj_id 239 | -------------------------------------------------------------------------------- /link_prediction/main.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | from torch import nn 3 | import Utils.TimeLogger as logger 4 | from Utils.TimeLogger import log 5 | from params import args 6 | from model import OpenGraph, ALRS 7 | from data_handler import DataHandler, MultiDataHandler 8 | import numpy as np 9 | import pickle 10 | import os 11 | import setproctitle 12 | import time 13 | 14 | class Exp: 15 | def __init__(self, multi_handler): 16 | self.multi_handler = multi_handler 17 | self.metrics = dict() 18 | trn_mets = ['Loss', 'preLoss'] 19 | tst_mets = ['Recall', 'NDCG'] 20 | mets = trn_mets + tst_mets 21 | for met in mets: 22 | if met in trn_mets: 23 | self.metrics['Train' + met] = list() 24 | else: 25 | for handler in self.multi_handler.tst_handlers: 26 | self.metrics['Test' + handler.data_name + met] = list() 27 | 28 | def make_print(self, name, ep, reses, save, data_name=None): 29 | if data_name is None: 30 | ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name) 31 | else: 32 | ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name) 33 | for metric in reses: 34 | val = reses[metric] 35 | ret += '%s = %.4f, ' % (metric, val) 36 | tem = name + metric if data_name is None else name + data_name + metric 37 | if save and tem in self.metrics: 38 | self.metrics[tem].append(val) 39 | ret = ret[:-2] + ' ' 40 | return ret 41 | 42 | def run(self): 43 | self.prepare_model() 44 | log('Model Prepared') 45 | stloc = 0 46 | if args.load_model != None: 47 | self.load_model() 48 | stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1) 49 | for ep in range(stloc, args.epoch): 50 | tst_flag = (ep % args.tst_epoch == 0) 51 | reses = self.train_epoch() 52 | log(self.make_print('Train', ep, reses, tst_flag)) 53 | if ep % 1 == 0: 54 | self.multi_handler.remake_initial_projections() 55 | if tst_flag: 56 | for handler in self.multi_handler.tst_handlers: 57 | reses = self.test_epoch(handler.val_loader, handler) 58 | # Note that this is the validation performance 59 | log(self.make_print('Test', ep, reses, tst_flag, handler.data_name)) 60 | self.save_history() 61 | print() 62 | 63 | for handler in self.multi_handler.tst_handlers: 64 | res_summary = dict() 65 | times = 10 66 | st = time.time() 67 | for i in range(times): 68 | reses = self.test_epoch(handler.tst_loader, handler) 69 | log(self.make_print('Test', args.epoch, reses, False, handler.data_name)) 70 | self.add_res_to_summary(res_summary, reses) 71 | self.multi_handler.remake_initial_projections() 72 | for key in res_summary: 73 | res_summary[key] /= times 74 | log(self.make_print('AVG', args.epoch, res_summary, False, handler.data_name)) 75 | print(time.time() - st) 76 | self.save_history() 77 | 78 | def add_res_to_summary(self, summary, res): 79 | for key in res: 80 | if key not in summary: 81 | summary[key] = 0 82 | summary[key] += res[key] 83 | 84 | def print_model_size(self): 85 | total_params = 0 86 | trainable_params = 0 87 | non_trainable_params = 0 88 | for param in self.model.parameters(): 89 | tem = np.prod(param.size()) 90 | total_params += tem 91 | if param.requires_grad: 92 | trainable_params += tem 93 | else: 94 | non_trainable_params += tem 95 | print(f'Total params: {total_params/1e6}') 96 | print(f'Trainable params: {trainable_params/1e6}') 97 | print(f'Non-trainable params: {non_trainable_params/1e6}') 98 | 99 | def prepare_model(self): 100 | self.model = OpenGraph() 101 | t.cuda.empty_cache() 102 | self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0) 103 | self.lr_scheduler = ALRS(self.opt) 104 | self.print_model_size() 105 | 106 | def train_epoch(self): 107 | self.model.train() 108 | trn_loader = self.multi_handler.trn_loader 109 | trn_loader.dataset.data_shuffling() 110 | ep_loss, ep_preloss, ep_regloss = 0, 0, 0 111 | steps = len(trn_loader) 112 | tot_samp_num = 0 113 | counter = [0] * len(self.multi_handler.trn_handlers) 114 | for i, batch_data in enumerate(trn_loader): 115 | if args.epoch_max_step > 0 and i >= args.epoch_max_step: 116 | break 117 | ancs, poss, negs, adj_idx = batch_data 118 | adj_idx = adj_idx[0] 119 | ancs = ancs[0].long() 120 | poss = poss[0].long() 121 | negs = negs[0].long() 122 | adj = self.multi_handler.trn_handlers[adj_idx].torch_adj 123 | if args.cache_adj == 0: 124 | adj = adj.to(args.devices[0]) 125 | initial_projector = self.multi_handler.trn_handlers[adj_idx].initial_projector 126 | if args.cache_proj == 0: 127 | initial_projector = initial_projector.to(args.devices[0]) 128 | loss, loss_dict = self.model.cal_loss((ancs, poss, negs), adj, initial_projector) 129 | self.opt.zero_grad() 130 | loss.backward() 131 | # nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2) 132 | self.opt.step() 133 | 134 | sample_num = ancs.shape[0] 135 | tot_samp_num += sample_num 136 | ep_loss += loss.item() * sample_num 137 | ep_preloss += loss_dict['preloss'].item() * sample_num 138 | ep_regloss += loss_dict['regloss'].item() 139 | log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True) 140 | 141 | counter[adj_idx] += 1 142 | if args.proj_trn_steps > 0 and counter[adj_idx] >= args.proj_trn_steps: 143 | counter[adj_idx] = 0 144 | dice = np.random.uniform() 145 | if dice < 999: 146 | self.multi_handler.remake_one_initial_projection(adj_idx) 147 | else: 148 | self.multi_handler.make_one_self_initialization(self.model, adj_idx) 149 | ret = dict() 150 | ret['Loss'] = ep_loss / tot_samp_num 151 | ret['preLoss'] = ep_preloss / tot_samp_num 152 | ret['regLoss'] = ep_regloss / steps 153 | t.cuda.empty_cache() 154 | self.lr_scheduler.step(ret['Loss']) 155 | return ret 156 | 157 | def test_epoch(self, tst_loader, tst_handler): 158 | with t.no_grad(): 159 | self.model.eval() 160 | ep_recall, ep_ndcg = 0, 0 161 | ep_tstnum = len(tst_loader.dataset) 162 | steps = max(ep_tstnum // args.tst_batch, 1) 163 | for i, batch_data in enumerate(tst_loader): 164 | usrs = batch_data 165 | numpy_usrs = usrs.numpy() 166 | usrs = usrs.long().to(args.devices[1]) 167 | trn_masks = tst_loader.dataset.csrmat[numpy_usrs].tocoo() 168 | cand_size = trn_masks.shape[1] 169 | trn_masks = t.from_numpy(np.stack([trn_masks.row, trn_masks.col], axis=0)).long().cuda() 170 | adj = tst_handler.torch_adj 171 | if args.cache_adj == 0: 172 | adj = adj.to(args.devices[0]) 173 | initial_projector = tst_handler.initial_projector#.cuda() 174 | if args.cache_proj == 0: 175 | initial_projector = initial_projector.to(args.devices[0]) 176 | all_preds = self.model.pred_for_test((usrs, trn_masks), adj, initial_projector, cand_size, rerun_embed=False if i!=0 else True) 177 | _, top_locs = t.topk(all_preds, args.topk) 178 | top_locs = top_locs.cpu().numpy() 179 | recall, ndcg = self.calc_recall_ndcg(top_locs, tst_loader.dataset.tstLocs, usrs) 180 | ep_recall += recall 181 | ep_ndcg += ndcg 182 | log('Steps %d/%d: recall = %.2f, ndcg = %.2f ' % (i, steps, recall, ndcg), save=False, oneline=True) 183 | # t.cuda.empty_cache() 184 | ret = dict() 185 | ret['Recall'] = ep_recall / ep_tstnum 186 | ret['NDCG'] = ep_ndcg / ep_tstnum 187 | t.cuda.empty_cache() 188 | return ret 189 | 190 | def calc_recall_ndcg(self, topLocs, tstLocs, batIds): 191 | assert topLocs.shape[0] == len(batIds) 192 | allRecall = allNdcg = 0 193 | for i in range(len(batIds)): 194 | temTopLocs = list(topLocs[i]) 195 | temTstLocs = tstLocs[batIds[i]] 196 | tstNum = len(temTstLocs) 197 | maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))]) 198 | recall = dcg = 0 199 | for val in temTstLocs: 200 | if val in temTopLocs: 201 | recall += 1 202 | dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2)) 203 | recall = recall / tstNum 204 | ndcg = dcg / maxDcg 205 | allRecall += recall 206 | allNdcg += ndcg 207 | return allRecall, allNdcg 208 | 209 | def save_history(self): 210 | if args.epoch == 0: 211 | return 212 | with open('../History/' + args.save_path + '.his', 'wb') as fs: 213 | pickle.dump(self.metrics, fs) 214 | 215 | content = { 216 | 'model': self.model, 217 | } 218 | t.save(content, '../Models/' + args.save_path + '.mod') 219 | log('Model Saved: %s' % args.save_path) 220 | 221 | def load_model(self): 222 | ckp = t.load('../Models/' + args.load_model + '.mod') 223 | self.model = ckp['model'] 224 | self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0) 225 | 226 | with open('../History/' + args.load_model + '.his', 'rb') as fs: 227 | self.metrics = pickle.load(fs) 228 | log('Model Loaded') 229 | 230 | if __name__ == '__main__': 231 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 232 | if len(args.gpu.split(',')) > 1: 233 | args.devices = ['cuda:0', 'cuda:1'] 234 | else: 235 | args.devices = ['cuda:0', 'cuda:0'] 236 | args.devices = list(map(lambda x: t.device(x), args.devices)) 237 | logger.saveDefault = True 238 | setproctitle.setproctitle('OpenGraph') 239 | 240 | log('Start') 241 | trn_datasets = ['gen1'] 242 | tst_datasets = ['ml1m', 'ml10m', 'collab'] 243 | 244 | # trn_datasets = ['gen2'] 245 | # tst_datasets = ['ddi'] 246 | 247 | # trn_datasets = ['gen0'] 248 | # tst_datasets = ['amazon-book'] 249 | 250 | if len(args.tstdata) != 0: 251 | tst_datasets = [args.tstdata] 252 | if len(args.trndata) != 0: 253 | trn_datasets = [args.trndata] 254 | trn_datasets = list(set(trn_datasets)) 255 | tst_datasets = list(set(tst_datasets)) 256 | multi_handler = MultiDataHandler(trn_datasets, tst_datasets) 257 | log('Load Data') 258 | 259 | exp = Exp(multi_handler) 260 | exp.run() 261 | -------------------------------------------------------------------------------- /link_prediction/model.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from params import args 5 | import numpy as np 6 | from Utils.TimeLogger import log 7 | from torch.nn import MultiheadAttention 8 | from time import time 9 | 10 | init = nn.init.xavier_uniform_ 11 | uniformInit = nn.init.uniform_ 12 | 13 | class InitialProjector(nn.Module): 14 | def __init__(self, adj, input_is_embeds=False): 15 | super(InitialProjector, self).__init__() 16 | 17 | if input_is_embeds: 18 | projection = adj 19 | if args.cache_proj: 20 | projection = projection.to(args.devices[0]) 21 | else: 22 | projection = projection.cpu() 23 | self.proj_embeds = nn.Parameter(projection) 24 | t.cuda.empty_cache() 25 | return 26 | if args.proj_method == 'uniform': 27 | self.proj_embeds = nn.Parameter(self.uniform_proj(adj)) 28 | elif args.proj_method == 'lowrank_uniform': 29 | self.proj_embeds = nn.Parameter(self.lowrank_uniform_proj(adj)) 30 | elif args.proj_method == 'svd': 31 | self.proj_embeds = nn.Parameter(self.svd_proj(adj)) 32 | elif args.proj_method == 'both': 33 | self.proj_embeds = nn.Parameter(self.uniform_proj(adj) + self.svd_proj(adj)) 34 | elif args.proj_method == 'id': 35 | self.proj_embeds = nn.Parameter(self.id_proj(adj)) 36 | else: 37 | raise Exception('Unrecognized Initial Embedding') 38 | t.cuda.empty_cache() 39 | 40 | def uniform_proj(self, adj): 41 | node_num = adj.shape[0] if adj.shape[0] == adj.shape[1] else adj.shape[0] + adj.shape[1] 42 | projection = init(t.empty(node_num, args.latdim)) 43 | if args.cache_proj: 44 | projection = projection.to(args.devices[0]) 45 | return projection 46 | 47 | def id_proj(self, adj): 48 | node_num = adj.shape[0] if adj.shape[0] == adj.shape[1] else adj.shape[0] + adj.shape[1] 49 | return t.eye(node_num) 50 | 51 | def lowrank_uniform_proj(self, adj): 52 | node_num = adj.shape[0] + adj.shape[1] 53 | rank = 16 54 | projection1 = init(t.empty(node_num, rank)) 55 | projection2 = init(t.empty(rank, args.latdim)) 56 | projection = projection1 @ projection2 57 | if args.cache_proj: 58 | projection = projection.to(args.devices[0]) 59 | return projection 60 | 61 | def svd_proj(self, adj): 62 | if not args.cache_proj: 63 | adj = adj.to(args.devices[0]) 64 | q = args.latdim 65 | if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]: 66 | dim = min(adj.shape[0], adj.shape[1]) 67 | svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter) 68 | svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1) 69 | svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1) 70 | s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])]) 71 | else: 72 | svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter) 73 | svd_u = svd_u @ t.diag(t.sqrt(s)) 74 | svd_v = svd_v @ t.diag(t.sqrt(s)) 75 | if adj.shape[0] != adj.shape[1]: 76 | projection = t.concat([svd_u, svd_v], dim=0) 77 | else: 78 | projection = svd_u + svd_v 79 | if not args.cache_proj: 80 | projection = projection.cpu() 81 | return projection 82 | 83 | def forward(self): 84 | return self.proj_embeds 85 | 86 | class TopoEncoder(nn.Module): 87 | def __init__(self): 88 | super(TopoEncoder, self).__init__() 89 | 90 | self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False)#, dtype=t.bfloat16) 91 | 92 | 93 | def forward(self, adj, embeds): 94 | embeds = self.layer_norm(embeds) 95 | embeds_list = [] 96 | if args.gnn_layer == 0: 97 | embeds_list.append(embeds) 98 | for i in range(args.gnn_layer): 99 | embeds = t.spmm(adj, embeds) 100 | embeds_list.append(embeds) 101 | embeds = sum(embeds_list) 102 | # embeds = t.concat([embeds_list[-1][:user_num], embeds_list[-2][user_num:]], dim=0) 103 | embeds = embeds#.to(t.bfloat16) 104 | return embeds 105 | 106 | class GraphTransformer(nn.Module): 107 | def __init__(self): 108 | super(GraphTransformer, self).__init__() 109 | self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)]) 110 | 111 | def forward(self, embeds): 112 | for i, layer in enumerate(self.gt_layers): 113 | embeds = layer(embeds) / args.scale_layer 114 | return embeds 115 | 116 | class GTLayer(nn.Module): 117 | def __init__(self): 118 | super(GTLayer, self).__init__() 119 | self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16) 120 | self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False 121 | self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16) 122 | self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16) 123 | self.fc_dropout = nn.Dropout(p=args.drop_rate) 124 | 125 | def _attention(self, anchor_embeds, embeds): 126 | q_embeds = t.einsum('ne,ehd->nhd', anchor_embeds, self.Q) 127 | k_embeds = t.einsum('ne,ehd->nhd', embeds, self.K) 128 | v_embeds = t.einsum('ne,ehd->nhd', embeds, self.V) 129 | att = t.einsum('khd,nhd->knh', q_embeds, k_embeds) / np.sqrt(args.latdim / args.head) 130 | att = t.softmax(att, dim=1) 131 | res = t.einsum('knh,nhd->khd', att, v_embeds).reshape([-1, args.latdim]) 132 | res = self.att_linear(res) 133 | return res 134 | 135 | def _pick_anchors(self, embeds): 136 | perm = t.randperm(embeds.shape[0]) 137 | anchors = perm[:args.anchor] 138 | return embeds[anchors] 139 | 140 | def forward(self, embeds): 141 | anchor_embeds = self._pick_anchors(embeds) 142 | _anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds) 143 | anchor_embeds = _anchor_embeds + anchor_embeds 144 | _embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False) 145 | embeds = self.layer_norm1(_embeds + embeds) 146 | _embeds = self.fc_dropout(self.dense_layers(embeds)) 147 | embeds = (self.layer_norm2(_embeds + embeds)) 148 | return embeds 149 | 150 | class FeedForwardLayer(nn.Module): 151 | def __init__(self, in_feat, out_feat, bias=True, act=None): 152 | super(FeedForwardLayer, self).__init__() 153 | self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16) 154 | if act == 'identity' or act is None: 155 | self.act = None 156 | elif act == 'leaky': 157 | self.act = nn.LeakyReLU(negative_slope=args.leaky) 158 | elif act == 'relu': 159 | self.act = nn.ReLU() 160 | elif act == 'relu6': 161 | self.act = nn.ReLU6() 162 | else: 163 | raise Exception('Error') 164 | 165 | def forward(self, embeds): 166 | if self.act is None: 167 | return self.linear(embeds) 168 | return self.act(self.linear(embeds)) 169 | 170 | class Masker(nn.Module): 171 | def __init__(self): 172 | super(Masker, self).__init__() 173 | 174 | def forward(self, adj, edges): 175 | if args.mask_method is None or args.mask_method == 'none': 176 | return adj 177 | elif args.mask_method == 'trn': 178 | node_num = adj.shape[0] + adj.shape[1] 179 | rows = adj._indices()[0, :] 180 | cols = adj._indices()[1, :] 181 | pck_rows, pck_cols = edges 182 | 183 | hashvals = rows * node_num + cols 184 | pck_hashvals1 = pck_rows * node_num + pck_cols 185 | pck_hashvals2 = pck_cols * node_num + pck_rows 186 | pck_hashvals = t.concat([pck_hashvals1, pck_hashvals2]) 187 | 188 | if args.mask_alg == 'cross': 189 | masked_hashvals = self._mask_by_cross(hashvals, pck_hashvals) 190 | elif args.mask_alg == 'linear': 191 | masked_hashvals = self._mask_by_linear(hashvals, pck_hashvals) 192 | 193 | cols = masked_hashvals % node_num 194 | rows = t.div((masked_hashvals - cols).long(), node_num, rounding_mode='trunc').long() 195 | 196 | adj = t.sparse.FloatTensor(t.stack([rows, cols], dim=0), t.ones_like(rows, dtype=t.float32).to(args.devices[0]), adj.shape) 197 | return self._normalize_adj(adj) 198 | elif args.mask_method == 'random': 199 | return self._random_mask_edge(adj) 200 | 201 | def _mask_by_cross(self, hashvals, pck_hashvals): 202 | for i in range(args.batch * 2 // args.mask_bat): 203 | bat_pck_hashvals = pck_hashvals[i * args.mask_bat: (i+1) * args.mask_bat] 204 | idct = (hashvals.view([-1, 1]) - bat_pck_hashvals.view([1, -1]) == 0).sum(-1).bool() 205 | hashvals = hashvals[t.logical_not(idct)] 206 | return hashvals 207 | 208 | def _mask_by_linear(self, hashvals, pck_hashvals): 209 | hashvals = t.unique(hashvals) 210 | pck_hashvals = t.unique(pck_hashvals) 211 | hashvals = t.concat([hashvals, pck_hashvals]) 212 | hashvals, counts = t.unique(hashvals, return_counts=True) 213 | hashvals = hashvals[counts==1] 214 | return hashvals 215 | 216 | def _random_mask_edge(self, adj): 217 | if args.random_mask_rate == 0.0: 218 | return adj 219 | vals = adj._values() 220 | idxs = adj._indices() 221 | edgeNum = vals.size() 222 | mask = ((t.rand(edgeNum) + 1.0 - args.random_mask_rate).floor()).type(t.bool) 223 | newIdxs = idxs[:, mask] 224 | newVals = t.ones(newIdxs.shape[1]).to(args.devices[0]).float() 225 | return self._normalize_adj(t.sparse.FloatTensor(newIdxs, newVals, adj.shape)) 226 | 227 | def _normalize_adj(self, adj): 228 | row_degree = t.pow(t.sparse.sum(adj, dim=1).to_dense(), 0.5) 229 | col_degree = t.pow(t.sparse.sum(adj, dim=0).to_dense(), 0.5) 230 | newRows, newCols = adj._indices()[0, :], adj._indices()[1, :] 231 | rowNorm, colNorm = row_degree[newRows], col_degree[newCols] 232 | newVals = adj._values() / rowNorm / colNorm 233 | return t.sparse.FloatTensor(adj._indices(), newVals, adj.shape) 234 | 235 | class OpenGraph(nn.Module): 236 | def __init__(self): 237 | super(OpenGraph, self).__init__() 238 | self.topoEncoder = TopoEncoder().to(args.devices[0]) 239 | self.graphTransformer = GraphTransformer().to(args.devices[1]) 240 | self.masker = Masker().to(args.devices[0]) 241 | 242 | def forward(self, adj, initial_projector, user_num): 243 | topo_embeds = self.topoEncoder(adj, initial_projector(), user_num).to(args.devices[1]) 244 | final_embeds = self.graphTransformer(topo_embeds) 245 | return final_embeds 246 | 247 | def pred_norm(self, pos_preds, neg_preds): 248 | pos_preds_num = pos_preds.shape[0] 249 | neg_preds_shape = neg_preds.shape 250 | preds = t.concat([pos_preds, neg_preds.view(-1)]) 251 | preds = preds - preds.max() 252 | pos_preds = preds[:pos_preds_num] 253 | neg_preds = preds[pos_preds_num:].view(neg_preds_shape) 254 | return pos_preds, neg_preds 255 | 256 | def cal_loss(self, batch_data, adj, initial_projector): 257 | ancs, poss, negs = batch_data 258 | with t.no_grad(): 259 | masked_adj = self.masker(adj, (ancs.to(args.devices[0]), (poss.to(args.devices[0])))) 260 | initial_embeds = initial_projector() 261 | topo_embeds = self.topoEncoder(masked_adj, initial_embeds).to(args.devices[1]) 262 | ancs, poss, negs = ancs.to(args.devices[1]), poss.to(args.devices[1]), negs.to(args.devices[1]) 263 | input_seq = t.concat([ancs, poss, negs]) 264 | input_seq = topo_embeds[input_seq] 265 | final_embeds = self.graphTransformer(input_seq) 266 | anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds[:ancs.shape[0] * 3], [ancs.shape[0]] * 3) 267 | # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs] 268 | if final_embeds.isinf().any() or final_embeds.isnan().any(): 269 | raise Exception('Final embedding fails') 270 | 271 | pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T) 272 | if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any(): 273 | raise Exception('Preds fails') 274 | pos_loss = pos_preds 275 | neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log() 276 | pre_loss = -(pos_loss - neg_loss).mean() 277 | 278 | if t.isinf(pre_loss).any() or t.isnan(pre_loss).any(): 279 | raise Exception('NaN or Inf') 280 | 281 | reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters()))) 282 | loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()} 283 | return pre_loss + reg_loss, loss_dict 284 | 285 | def pred_for_test(self, batch_data, adj, initial_projector, cand_size, rerun_embed=True): 286 | ancs, trn_mask = batch_data 287 | if rerun_embed: 288 | final_embeds = self.graphTransformer(self.topoEncoder(adj, initial_projector()).to(args.devices[1])) 289 | self.final_embeds = final_embeds 290 | final_embeds = self.final_embeds 291 | anc_embeds = final_embeds[ancs] 292 | cand_embeds = final_embeds[-cand_size:] 293 | 294 | mask_mat = t.sparse.FloatTensor(trn_mask, t.ones(trn_mask.shape[1]).cuda(), t.Size([ancs.shape[0], cand_size])) 295 | dense_mat = mask_mat.to_dense() 296 | 297 | all_preds = anc_embeds @ cand_embeds.T * (1 - dense_mat) - dense_mat * 1e8 298 | return all_preds 299 | 300 | class ALRS: 301 | def __init__(self, optimizer, loss_threshold=0.01, loss_ratio_threshold=0.01, decay_rate=0.97): 302 | self.optimizer = optimizer 303 | self.loss_threshold = loss_threshold 304 | self.decay_rate = decay_rate 305 | self.loss_ratio_threshold = loss_ratio_threshold 306 | self.last_loss = 1e9 307 | 308 | def step(self, loss): 309 | delta = self.last_loss - loss 310 | if delta < self.loss_threshold and delta / self.last_loss < self.loss_ratio_threshold: 311 | for group in self.optimizer.param_groups: 312 | group['lr'] *= self.decay_rate 313 | self.last_loss = loss 314 | -------------------------------------------------------------------------------- /link_prediction/params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description='Model Parameters') 5 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 6 | parser.add_argument('--batch', default=1024, type=int, help='training batch size') 7 | parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)') 8 | parser.add_argument('--epoch', default=50, type=int, help='number of epochs') 9 | parser.add_argument('--save_path', default='tem', help='file name to save model and training record') 10 | parser.add_argument('--load_model', default=None, help='model name to load') 11 | parser.add_argument('--data', default='', type=str, help='name of dataset') 12 | parser.add_argument('--trndata', default='', type=str, help='name of train dataset') 13 | parser.add_argument('--tstdata', default='', type=str, help='name of test dataset') 14 | parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training') 15 | parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use') 16 | parser.add_argument('--topk', default=20, type=int, help='topk in evaluation') 17 | parser.add_argument('--cache_adj', default=0, type=int, help='indicates wheter cache bidirectional adjs') 18 | parser.add_argument('--cache_proj', default=1, type=int, help='indicates wheter cache projector and matrices') 19 | parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps') 20 | parser.add_argument('--data_dir', default='../datasets', type=str, help='dataset directory') 21 | 22 | parser.add_argument('--niter', default=2, type=int, help='number of iteration in svd') 23 | parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer') 24 | parser.add_argument('--drop_rate', default=0.1, type=float, help='dropout rate') 25 | parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor') 26 | parser.add_argument('--clamp', default=-1, type=float, help='absolute value for the limit of prediction scores while training') 27 | parser.add_argument('--mask_method', default='trn', type=str, help='which graph masking method to use') 28 | parser.add_argument('--mask_alg', default='linear', type=str, help='which graph masking algorithm in trn mask_method') 29 | parser.add_argument('--random_mask_rate', default=0.5, type=float, help='mask ratio in random mask_method') 30 | parser.add_argument('--act', default='leaky', type=str, help='activation function') 31 | parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation') 32 | parser.add_argument('--latdim', default=1024, type=int, help='latent dimensionality') 33 | parser.add_argument('--head', default=4, type=int, help='number of attention heads') 34 | parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not') 35 | parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn iterations') 36 | parser.add_argument('--gt_layer', default=4, type=int, help='number of graph transformer layers') 37 | parser.add_argument('--proj_method', default='svd', type=str, help='initial projection method') 38 | parser.add_argument('--mask', default='trn', type=str, help='indicating which mask strategy to apply') 39 | parser.add_argument('--loss', default='ce', type=str, help='loss function') 40 | parser.add_argument('--mask_bat', default=512, type=int, help='batch size for masking') 41 | parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer') 42 | parser.add_argument('--pred_iter', default=1, type=int, help='number of prediction iterations') 43 | parser.add_argument('--proj_trn_steps', default=10, type=int, help='number of training steps for one initial projection') 44 | return parser.parse_args() 45 | args = parse_args() -------------------------------------------------------------------------------- /node_classification/Utils/TimeLogger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | logmsg = '' 4 | timemark = dict() 5 | saveDefault = False 6 | def log(msg, save=None, oneline=False): 7 | global logmsg 8 | global saveDefault 9 | time = datetime.datetime.now() 10 | tem = '%s: %s' % (time, msg) 11 | if save != None: 12 | if save: 13 | logmsg += tem + '\n' 14 | elif saveDefault: 15 | logmsg += tem + '\n' 16 | if oneline: 17 | print(tem, end='\r') 18 | else: 19 | print(tem) 20 | 21 | def marktime(marker): 22 | global timemark 23 | timemark[marker] = datetime.datetime.now() 24 | 25 | 26 | if __name__ == '__main__': 27 | log('') -------------------------------------------------------------------------------- /node_classification/data_handler.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from scipy.sparse import csr_matrix, coo_matrix, dok_matrix 4 | from params import args 5 | import scipy.sparse as sp 6 | from Utils.TimeLogger import log 7 | import torch as t 8 | import torch.utils.data as data 9 | import torch_geometric.transforms as T 10 | from model import InitialProjector 11 | import os 12 | 13 | class MultiDataHandler: 14 | def __init__(self, trn_datasets, tst_datasets): 15 | all_datasets = list(set(trn_datasets + tst_datasets)) 16 | self.trn_handlers = [] 17 | self.tst_handlers = [] 18 | for data_name in all_datasets: 19 | trn_flag = data_name in trn_datasets 20 | tst_flag = data_name in tst_datasets 21 | handler = DataHandler(data_name, trn_flag, tst_flag) 22 | if trn_flag: 23 | self.trn_handlers.append(handler) 24 | if tst_flag: 25 | self.tst_handlers.append(handler) 26 | 27 | def make_joint_trn_loader(self): 28 | trn_data = TrnData(self.trn_handlers) 29 | self.trn_loader = data.DataLoader(trn_data, batch_size=1, shuffle=True, num_workers=0) 30 | 31 | def remake_initial_projections(self): 32 | for i in range(len(self.trn_handlers)): 33 | self.remake_one_initial_projection(i) 34 | 35 | def remake_one_initial_projection(self, idx): 36 | trn_handler = self.trn_handlers[idx] 37 | trn_handler.initial_projector = InitialProjector(trn_handler.asym_adj) 38 | 39 | class DataHandler: 40 | def __init__(self, data_name, trn_flag, tst_flag): 41 | self.data_name = data_name 42 | self.trn_flag = trn_flag 43 | self.tst_flag = tst_flag 44 | self.get_data_files() 45 | self.load_data() 46 | 47 | def get_data_files(self): 48 | predir = os.path.join(args.data_dir, self.data_name) 49 | self.adj_file = os.path.join(predir, 'adj_-1.pkl') 50 | self.label_file = os.path.join(predir, 'label.pkl') 51 | self.mask_file = os.path.join(predir, 'mask_-1.pkl') 52 | 53 | def load_one_file(self, filename): 54 | with open(filename, 'rb') as fs: 55 | ret = (pickle.load(fs)).astype(np.float32) 56 | if type(ret) != coo_matrix: 57 | ret = sp.coo_matrix(ret) 58 | return ret 59 | 60 | def normalize_adj(self, mat): 61 | degree = np.array(mat.sum(axis=-1)) 62 | d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1]) 63 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 64 | d_inv_sqrt_mat = sp.diags(d_inv_sqrt) 65 | if mat.shape[0] == mat.shape[1]: 66 | return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo() 67 | else: 68 | tem = d_inv_sqrt_mat.dot(mat) 69 | col_degree = np.array(mat.sum(axis=0)) 70 | d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1]) 71 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 72 | d_inv_sqrt_mat = sp.diags(d_inv_sqrt) 73 | return tem.dot(d_inv_sqrt_mat).tocoo() 74 | 75 | def load_feats(self, filename): 76 | try: 77 | with open(filename, 'rb') as fs: 78 | feats = pickle.load(fs) 79 | except Exception as e: 80 | print(filename + str(e)) 81 | exit() 82 | return feats 83 | 84 | def unique_numpy(self, row, col): 85 | hash_vals = row * self.node_num + col 86 | hash_vals = np.unique(hash_vals).astype(np.int64) 87 | col = hash_vals % self.node_num 88 | row = (hash_vals - col).astype(np.int64) // self.node_num 89 | return row, col 90 | 91 | def make_torch_adj(self, mat): 92 | if mat.shape[0] == mat.shape[1]: 93 | # to symmetric 94 | if self.data_name in ['ddi']: 95 | _row = mat.row 96 | _col = mat.col 97 | row = np.concatenate([_row, _col]).astype(np.int64) 98 | col = np.concatenate([_col, _row]).astype(np.int64) 99 | # row, col = self.unique_numpy(row, col) 100 | data = mat.data 101 | data = np.concatenate([data, data]).astype(np.float32) 102 | else: 103 | row, col = mat.row, mat.col 104 | data = mat.data 105 | # data = np.ones_like(row) 106 | mat = coo_matrix((data, (row, col)), mat.shape) 107 | if args.selfloop == 1: 108 | mat = (mat + sp.eye(mat.shape[0])) * 1.0 109 | normed_asym_mat = self.normalize_adj(mat) 110 | row = t.from_numpy(normed_asym_mat.row).long() 111 | col = t.from_numpy(normed_asym_mat.col).long() 112 | idxs = t.stack([row, col], dim=0) 113 | vals = t.from_numpy(normed_asym_mat.data).float() 114 | shape = t.Size(normed_asym_mat.shape) 115 | asym_adj = t.sparse.FloatTensor(idxs, vals, shape) 116 | if mat.shape[0] == mat.shape[1]: 117 | return asym_adj, asym_adj 118 | else: 119 | # make ui adj 120 | a = sp.csr_matrix((self.user_num, self.user_num)) 121 | b = sp.csr_matrix((self.item_num, self.item_num)) 122 | mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])]) 123 | mat = (mat != 0) * 1.0 124 | if args.selfloop == 1: 125 | mat = (mat + sp.eye(mat.shape[0])) * 1.0 126 | mat = self.normalize_adj(mat) 127 | 128 | # make cuda tensor 129 | idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64)) 130 | vals = t.from_numpy(mat.data.astype(np.float32)) 131 | shape = t.Size(mat.shape) 132 | return t.sparse.FloatTensor(idxs, vals, shape), asym_adj 133 | 134 | def load_data(self): 135 | self.adj = self.load_one_file(self.adj_file) 136 | self.labels = self.load_feats(self.label_file) 137 | if np.min(self.labels) != 0: 138 | log(f'Class label starts from {np.min(self.labels)}') 139 | self.labels -= np.min(self.labels) 140 | args.class_num = np.max(self.labels) + 1 141 | masks = self.load_feats(self.mask_file) 142 | self.trn_mask, self.val_mask, self.tst_mask = masks['train'], masks['valid'], masks['test'] 143 | 144 | self.node_num = self.adj.shape[0] 145 | print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=self.node_num, edge_num=self.adj.nnz)) 146 | 147 | self.torch_adj, self.asym_adj = self.make_torch_adj(self.adj) 148 | if args.cache_proj: 149 | self.asym_adj = self.asym_adj.to(args.devices[0]) 150 | if args.cache_adj: 151 | self.torch_adj = self.torch_adj.to(args.devices[0]) 152 | 153 | self.initial_projector = InitialProjector(self.asym_adj) 154 | 155 | if self.tst_flag: 156 | tst_data = NodeData(self.labels, self.tst_mask) 157 | self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0) 158 | 159 | val_data = NodeData(self.labels, self.val_mask) 160 | self.val_loader = data.DataLoader(val_data, batch_size=args.tst_batch, shuffle=False, num_workers=0) 161 | 162 | trn_data = NodeData(self.labels, self.trn_mask) 163 | self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0) 164 | 165 | 166 | class NodeData(data.Dataset): 167 | def __init__(self, labels, mask): 168 | self.iter_nodes = np.reshape(np.argwhere(np.array(mask) == True), -1) 169 | self.labels = labels[self.iter_nodes] 170 | 171 | def __len__(self): 172 | return len(self.iter_nodes) 173 | 174 | def __getitem__(self, idx): 175 | return self.iter_nodes[idx], self.labels[idx]# + args.node_num - args.class_num 176 | 177 | class TrnData(data.Dataset): 178 | def __init__(self, trn_handlers): 179 | self.dataset_num = len(trn_handlers) 180 | self.trn_handlers = trn_handlers 181 | self.ancs_list = [None] * self.dataset_num 182 | self.poss_list = [None] * self.dataset_num 183 | self.negs_list = [None] * self.dataset_num 184 | self.edge_nums = [None] * self.dataset_num 185 | self.sample_nums = [None] * self.dataset_num 186 | for i, handler in enumerate(self.trn_handlers): 187 | trn_mat = handler.trn_mat 188 | ancs = np.array(trn_mat.row) 189 | poss = np.array(trn_mat.col) 190 | self.ancs_list[i] = ancs 191 | self.poss_list[i] = poss 192 | self.edge_nums[i] = len(ancs) 193 | self.sample_nums[i] = self.edge_nums[i] // args.batch + (1 if self.edge_nums[i] % args.batch != 0 else 0) 194 | self.total_sample_num = sum(self.sample_nums) 195 | self.samples = [None] * self.total_sample_num 196 | 197 | def data_shuffling(self): 198 | sample_idx = 0 199 | for i in range(self.dataset_num): 200 | edge_num = self.edge_nums[i] 201 | perms = np.random.permutation(edge_num) 202 | handler = self.trn_handlers[i] 203 | asym_flag = handler.trn_mat.shape[0] != handler.trn_mat.shape[1] 204 | cand_num = handler.item_num if asym_flag else handler.node_num 205 | self.negs_list[i] = self.neg_sampling(self.ancs_list[i], handler.trn_mat.todok(), cand_num) 206 | # self.negs_list[i] = np.random.randint(cand_num, size=edge_num) 207 | for j in range(self.sample_nums[i]): 208 | st_idx = j * args.batch 209 | ed_idx = min((j + 1) * args.batch, edge_num) 210 | pick_idxs = perms[st_idx: ed_idx] 211 | ancs = self.ancs_list[i][pick_idxs] 212 | poss = self.poss_list[i][pick_idxs] 213 | negs = self.negs_list[i][pick_idxs] 214 | if asym_flag: 215 | poss += handler.user_num 216 | negs += handler.user_num 217 | self.samples[sample_idx] = (ancs, poss, negs, i) 218 | sample_idx += 1 219 | assert sample_idx == self.total_sample_num 220 | 221 | def neg_sampling(self, ancs, dokmat, cand_num): 222 | negs = np.zeros_like(ancs) 223 | for i in range(len(ancs)): 224 | u = ancs[i] 225 | while True: 226 | i_neg = np.random.randint(cand_num) 227 | if (u, i_neg) not in dokmat: 228 | break 229 | negs[i] = i_neg 230 | return negs 231 | 232 | def __len__(self): 233 | return self.total_sample_num 234 | 235 | def __getitem__(self, idx): 236 | ancs, poss, negs, adj_id = self.samples[idx] 237 | return ancs, poss, negs, adj_id 238 | -------------------------------------------------------------------------------- /node_classification/main.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | from torch import nn 3 | import Utils.TimeLogger as logger 4 | from Utils.TimeLogger import log 5 | from params import args 6 | from model import OpenGraph, ALRS 7 | from data_handler import DataHandler, MultiDataHandler 8 | import numpy as np 9 | import pickle 10 | import os 11 | import setproctitle 12 | import time 13 | from sklearn.metrics import f1_score 14 | 15 | class Exp: 16 | def __init__(self, multi_handler): 17 | self.multi_handler = multi_handler 18 | self.metrics = dict() 19 | trn_mets = ['Loss', 'preLoss', 'Acc'] 20 | tst_mets = ['Acc'] 21 | mets = trn_mets + tst_mets 22 | for met in mets: 23 | if met in trn_mets: 24 | self.metrics['Train' + met] = list() 25 | else: 26 | for handler in self.multi_handler.tst_handlers: 27 | self.metrics['Test' + handler.data_name + met] = list() 28 | 29 | def make_print(self, name, ep, reses, save, data_name=None): 30 | if data_name is None: 31 | ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name) 32 | else: 33 | ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name) 34 | for metric in reses: 35 | val = reses[metric] 36 | ret += '%s = %.4f, ' % (metric, val) 37 | tem = name + metric if data_name is None else name + data_name + metric 38 | if save and tem in self.metrics: 39 | self.metrics[tem].append(val) 40 | ret = ret[:-2] + ' ' 41 | return ret 42 | 43 | def run(self): 44 | self.prepare_model() 45 | log('Model Prepared') 46 | stloc = 0 47 | if args.load_model != None: 48 | self.load_model() 49 | stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1) 50 | 51 | for handler in self.multi_handler.tst_handlers: 52 | # reses = self.test_epoch(handler.val_loader, handler) 53 | # log(self.make_print('Valid', args.epoch, reses, True, handler.data_name)) 54 | res_summary = dict() 55 | times = 10 56 | for i in range(times): 57 | reses = self.test_epoch(handler.tst_loader, handler) 58 | log(self.make_print('Test', args.epoch, reses, False, handler.data_name)) 59 | self.add_res_to_summary(res_summary, reses) 60 | self.multi_handler.remake_initial_projections() 61 | for key in res_summary: 62 | res_summary[key] /= times 63 | log(self.make_print('AVG', args.epoch, res_summary, False, handler.data_name)) 64 | self.save_history() 65 | 66 | def add_res_to_summary(self, summary, res): 67 | for key in res: 68 | if key not in summary: 69 | summary[key] = 0 70 | summary[key] += res[key] 71 | 72 | def print_model_size(self): 73 | total_params = 0 74 | trainable_params = 0 75 | non_trainable_params = 0 76 | for param in self.model.parameters(): 77 | tem = np.prod(param.size()) 78 | total_params += tem 79 | if param.requires_grad: 80 | trainable_params += tem 81 | else: 82 | non_trainable_params += tem 83 | print(f'Total params: {total_params/1e6}') 84 | print(f'Trainable params: {trainable_params/1e6}') 85 | print(f'Non-trainable params: {non_trainable_params/1e6}') 86 | 87 | def prepare_model(self): 88 | self.model = OpenGraph()#.to(args.devices[1])#.cuda() 89 | t.cuda.empty_cache() 90 | self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0) 91 | self.lr_scheduler = ALRS(self.opt) 92 | self.print_model_size() 93 | 94 | def test_epoch(self, tst_loader, tst_handler): 95 | with t.no_grad(): 96 | self.model.eval() 97 | ep_acc, ep_tot = 0, 0 98 | ep_tstnum = len(tst_loader.dataset) 99 | steps = max(ep_tstnum // args.tst_batch, 1) 100 | for i, batch_data in enumerate(tst_loader): 101 | nodes, labels = list(map(lambda x: x.long().to(args.devices[1]), batch_data)) 102 | adj = tst_handler.torch_adj 103 | if args.cache_adj == 0: 104 | adj = adj.to(args.devices[0]) 105 | initial_projector = tst_handler.initial_projector 106 | if args.cache_proj == 0: 107 | initial_projector = initial_projector.to(args.devices[0]) 108 | preds = self.model.pred_for_node_test(nodes, adj, initial_projector, rerun_embed=False if i!=0 else True) 109 | if i == 0: 110 | all_preds, all_labels = preds, labels 111 | else: 112 | all_preds = t.concatenate([all_preds, preds]) 113 | all_labels = t.concatenate([all_labels, labels]) 114 | hit = (labels == preds).float().sum().item() 115 | ep_acc += hit 116 | ep_tot += labels.shape[0] 117 | log('Steps %d/%d: hit = %d, tot = %d ' % (i, steps, ep_acc, ep_tot), save=False, oneline=True) 118 | # t.cuda.empty_cache() 119 | ret = dict() 120 | ret['Acc'] = ep_acc / ep_tot 121 | ret['F1'] = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro') 122 | t.cuda.empty_cache() 123 | return ret 124 | 125 | def calc_recall_ndcg(self, topLocs, tstLocs, batIds): 126 | assert topLocs.shape[0] == len(batIds) 127 | allRecall = allNdcg = 0 128 | for i in range(len(batIds)): 129 | temTopLocs = list(topLocs[i]) 130 | temTstLocs = tstLocs[batIds[i]] 131 | tstNum = len(temTstLocs) 132 | maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))]) 133 | recall = dcg = 0 134 | for val in temTstLocs: 135 | if val in temTopLocs: 136 | recall += 1 137 | dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2)) 138 | recall = recall / tstNum 139 | ndcg = dcg / maxDcg 140 | allRecall += recall 141 | allNdcg += ndcg 142 | return allRecall, allNdcg 143 | 144 | def save_history(self): 145 | if args.epoch == 0: 146 | return 147 | with open('../History/' + args.save_path + '.his', 'wb') as fs: 148 | pickle.dump(self.metrics, fs) 149 | 150 | content = { 151 | 'model': self.model, 152 | } 153 | t.save(content, '../Models/' + args.save_path + '.mod') 154 | log('Model Saved: %s' % args.save_path) 155 | 156 | def load_model(self): 157 | ckp = t.load('../Models/' + args.load_model + '.mod') 158 | self.model = ckp['model'].to(args.devices[1]) 159 | self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0) 160 | 161 | with open('../History/' + args.load_model + '.his', 'rb') as fs: 162 | self.metrics = pickle.load(fs) 163 | log('Model Loaded') 164 | 165 | if __name__ == '__main__': 166 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 167 | if len(args.gpu.split(',')) > 1: 168 | args.devices = ['cuda:0', 'cuda:1'] 169 | else: 170 | args.devices = ['cuda:0', 'cuda:0'] 171 | args.devices = list(map(lambda x: t.device(x), args.devices)) 172 | logger.saveDefault = True 173 | setproctitle.setproctitle('OpenGraph') 174 | 175 | log('Start') 176 | trn_datasets = tst_datasets = [args.tstdata] 177 | trn_datasets = list(set(trn_datasets)) 178 | tst_datasets = list(set(tst_datasets)) 179 | multi_handler = MultiDataHandler(trn_datasets, tst_datasets) 180 | log('Load Data') 181 | 182 | exp = Exp(multi_handler) 183 | exp.run() 184 | -------------------------------------------------------------------------------- /node_classification/model.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from params import args 5 | import numpy as np 6 | from Utils.TimeLogger import log 7 | from torch.nn import MultiheadAttention 8 | from time import time 9 | 10 | init = nn.init.xavier_uniform_ 11 | uniformInit = nn.init.uniform_ 12 | 13 | class InitialProjector(nn.Module): 14 | def __init__(self, adj, input_is_embeds=False): 15 | super(InitialProjector, self).__init__() 16 | 17 | if input_is_embeds: 18 | projection = adj 19 | if args.cache_proj: 20 | projection = projection.to(args.devices[0]) 21 | else: 22 | projection = projection.cpu() 23 | self.proj_embeds = nn.Parameter(projection) 24 | t.cuda.empty_cache() 25 | return 26 | if args.proj_method == 'uniform': 27 | self.proj_embeds = nn.Parameter(self.uniform_proj(adj)) 28 | elif args.proj_method == 'lowrank_uniform': 29 | self.proj_embeds = nn.Parameter(self.lowrank_uniform_proj(adj)) 30 | elif args.proj_method == 'svd': 31 | self.proj_embeds = nn.Parameter(self.svd_proj(adj)) 32 | elif args.proj_method == 'both': 33 | self.proj_embeds = nn.Parameter(self.uniform_proj(adj) + self.svd_proj(adj)) 34 | elif args.proj_method == 'id': 35 | self.proj_embeds = nn.Parameter(self.id_proj(adj)) 36 | else: 37 | raise Exception('Unrecognized Initial Embedding') 38 | t.cuda.empty_cache() 39 | 40 | def uniform_proj(self, adj): 41 | node_num = adj.shape[0] if adj.shape[0] == adj.shape[1] else adj.shape[0] + adj.shape[1] 42 | projection = init(t.empty(node_num, args.latdim)) 43 | if args.cache_proj: 44 | projection = projection.to(args.devices[0]) 45 | return projection 46 | 47 | def id_proj(self, adj): 48 | node_num = adj.shape[0] if adj.shape[0] == adj.shape[1] else adj.shape[0] + adj.shape[1] 49 | return t.eye(node_num) 50 | 51 | def lowrank_uniform_proj(self, adj): 52 | node_num = adj.shape[0] + adj.shape[1] 53 | rank = 16 54 | projection1 = init(t.empty(node_num, rank)) 55 | projection2 = init(t.empty(rank, args.latdim)) 56 | projection = projection1 @ projection2 57 | if args.cache_proj: 58 | projection = projection.to(args.devices[0]) 59 | return projection 60 | 61 | def svd_proj(self, adj): 62 | if not args.cache_proj: 63 | adj = adj.to(args.devices[0]) 64 | q = args.latdim 65 | if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]: 66 | dim = min(adj.shape[0], adj.shape[1]) 67 | svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter) 68 | svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1) 69 | svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1) 70 | s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])]) 71 | else: 72 | svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter) 73 | svd_u = svd_u @ t.diag(t.sqrt(s)) 74 | svd_v = svd_v @ t.diag(t.sqrt(s)) 75 | if adj.shape[0] != adj.shape[1]: 76 | projection = t.concat([svd_u, svd_v], dim=0) 77 | else: 78 | projection = svd_u + svd_v 79 | if not args.cache_proj: 80 | projection = projection.cpu() 81 | return projection 82 | 83 | def forward(self): 84 | return ((self.proj_embeds))#.cuda()#[perms, :] 85 | 86 | class TopoEncoder(nn.Module): 87 | def __init__(self): 88 | super(TopoEncoder, self).__init__() 89 | 90 | self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False)#, dtype=t.bfloat16) 91 | 92 | def forward(self, adj, embeds): 93 | embeds = self.layer_norm(embeds) 94 | embeds_list = [] 95 | if args.gnn_layer == 0: 96 | embeds_list.append(embeds) 97 | for i in range(args.gnn_layer): 98 | embeds = t.spmm(adj, embeds) 99 | embeds_list.append(embeds) 100 | embeds = sum(embeds_list) 101 | # embeds = t.concat([embeds_list[-1][:user_num], embeds_list[-2][user_num:]], dim=0) 102 | embeds = embeds#.to(t.bfloat16) 103 | return embeds 104 | 105 | class GraphTransformer(nn.Module): 106 | def __init__(self): 107 | super(GraphTransformer, self).__init__() 108 | self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)]) 109 | 110 | def forward(self, embeds): 111 | for i, layer in enumerate(self.gt_layers): 112 | embeds = layer(embeds) / args.scale_layer 113 | return embeds 114 | 115 | class GTLayer(nn.Module): 116 | def __init__(self): 117 | super(GTLayer, self).__init__() 118 | self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16) 119 | self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False 120 | self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16) 121 | self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16) 122 | self.fc_dropout = nn.Dropout(p=args.drop_rate) 123 | 124 | def _attention(self, anchor_embeds, embeds): 125 | q_embeds = t.einsum('ne,ehd->nhd', anchor_embeds, self.Q) 126 | k_embeds = t.einsum('ne,ehd->nhd', embeds, self.K) 127 | v_embeds = t.einsum('ne,ehd->nhd', embeds, self.V) 128 | att = t.einsum('khd,nhd->knh', q_embeds, k_embeds) / np.sqrt(args.latdim / args.head) 129 | att = t.softmax(att, dim=1) 130 | res = t.einsum('knh,nhd->khd', att, v_embeds).reshape([-1, args.latdim]) 131 | res = self.att_linear(res) 132 | return res 133 | 134 | def _pick_anchors(self, embeds): 135 | perm = t.randperm(embeds.shape[0]) 136 | anchors = perm[:args.anchor] 137 | return embeds[anchors] 138 | 139 | def print_nodewise_std(self, embeds): 140 | mean = embeds.mean(0) 141 | std = (embeds - mean).square().mean(0).sqrt().mean() 142 | print(embeds) 143 | print(std.item()) 144 | 145 | def forward(self, embeds): 146 | anchor_embeds = self._pick_anchors(embeds) 147 | _anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds) 148 | anchor_embeds = _anchor_embeds + anchor_embeds 149 | _embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False) 150 | embeds = self.layer_norm1(_embeds + embeds) 151 | _embeds = self.fc_dropout(self.dense_layers(embeds)) 152 | embeds = (self.layer_norm2(_embeds + embeds)) 153 | return embeds 154 | 155 | class FeedForwardLayer(nn.Module): 156 | def __init__(self, in_feat, out_feat, bias=True, act=None): 157 | super(FeedForwardLayer, self).__init__() 158 | self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16) 159 | if act == 'identity' or act is None: 160 | self.act = None 161 | elif act == 'leaky': 162 | self.act = nn.LeakyReLU(negative_slope=args.leaky) 163 | elif act == 'relu': 164 | self.act = nn.ReLU() 165 | elif act == 'relu6': 166 | self.act = nn.ReLU6() 167 | else: 168 | raise Exception('Error') 169 | 170 | def forward(self, embeds): 171 | if self.act is None: 172 | return self.linear(embeds) 173 | return self.act(self.linear(embeds)) 174 | 175 | class Masker(nn.Module): 176 | def __init__(self): 177 | super(Masker, self).__init__() 178 | 179 | def forward(self, adj, edges): 180 | if args.mask_method is None or args.mask_method == 'none': 181 | return adj 182 | elif args.mask_method == 'trn': 183 | node_num = adj.shape[0] + adj.shape[1] 184 | rows = adj._indices()[0, :] 185 | cols = adj._indices()[1, :] 186 | pck_rows, pck_cols = edges 187 | 188 | hashvals = rows * node_num + cols 189 | pck_hashvals1 = pck_rows * node_num + pck_cols 190 | pck_hashvals2 = pck_cols * node_num + pck_rows 191 | pck_hashvals = t.concat([pck_hashvals1, pck_hashvals2]) 192 | 193 | if args.mask_alg == 'cross': 194 | masked_hashvals = self._mask_by_cross(hashvals, pck_hashvals) 195 | elif args.mask_alg == 'linear': 196 | masked_hashvals = self._mask_by_linear(hashvals, pck_hashvals) 197 | 198 | cols = masked_hashvals % node_num 199 | rows = t.div((masked_hashvals - cols).long(), node_num, rounding_mode='trunc').long() 200 | 201 | adj = t.sparse.FloatTensor(t.stack([rows, cols], dim=0), t.ones_like(rows, dtype=t.float32).to(args.devices[0]), adj.shape) 202 | return self._normalize_adj(adj) 203 | elif args.mask_method == 'random': 204 | return self._random_mask_edge(adj) 205 | 206 | def _mask_by_cross(self, hashvals, pck_hashvals): 207 | for i in range(args.batch * 2 // args.mask_bat): 208 | bat_pck_hashvals = pck_hashvals[i * args.mask_bat: (i+1) * args.mask_bat] 209 | idct = (hashvals.view([-1, 1]) - bat_pck_hashvals.view([1, -1]) == 0).sum(-1).bool() 210 | hashvals = hashvals[t.logical_not(idct)] 211 | return hashvals 212 | 213 | def _mask_by_linear(self, hashvals, pck_hashvals): 214 | hashvals = t.unique(hashvals) 215 | pck_hashvals = t.unique(pck_hashvals) 216 | hashvals = t.concat([hashvals, pck_hashvals]) 217 | hashvals, counts = t.unique(hashvals, return_counts=True) 218 | hashvals = hashvals[counts==1] 219 | return hashvals 220 | 221 | def _random_mask_edge(self, adj): 222 | if args.random_mask_rate == 0.0: 223 | return adj 224 | vals = adj._values() 225 | idxs = adj._indices() 226 | edgeNum = vals.size() 227 | mask = ((t.rand(edgeNum) + 1.0 - args.random_mask_rate).floor()).type(t.bool) 228 | newIdxs = idxs[:, mask] 229 | newVals = t.ones(newIdxs.shape[1]).to(args.devices[0]).float() 230 | return self._normalize_adj(t.sparse.FloatTensor(newIdxs, newVals, adj.shape)) 231 | 232 | def _normalize_adj(self, adj): 233 | row_degree = t.pow(t.sparse.sum(adj, dim=1).to_dense(), 0.5) 234 | col_degree = t.pow(t.sparse.sum(adj, dim=0).to_dense(), 0.5) 235 | newRows, newCols = adj._indices()[0, :], adj._indices()[1, :] 236 | rowNorm, colNorm = row_degree[newRows], col_degree[newCols] 237 | newVals = adj._values() / rowNorm / colNorm 238 | return t.sparse.FloatTensor(adj._indices(), newVals, adj.shape) 239 | 240 | class OpenGraph(nn.Module): 241 | def __init__(self): 242 | super(OpenGraph, self).__init__() 243 | self.topoEncoder = TopoEncoder().to(args.devices[0]) 244 | self.graphTransformer = GraphTransformer().to(args.devices[1]) 245 | self.masker = Masker().to(args.devices[0]) 246 | 247 | def forward(self, adj, initial_projector, user_num): 248 | topo_embeds = self.topoEncoder(adj, initial_projector(), user_num).to(args.devices[1]) 249 | final_embeds = self.graphTransformer(topo_embeds) 250 | return final_embeds 251 | 252 | def pred_norm(self, pos_preds, neg_preds): 253 | pos_preds_num = pos_preds.shape[0] 254 | neg_preds_shape = neg_preds.shape 255 | preds = t.concat([pos_preds, neg_preds.view(-1)]) 256 | preds = preds - preds.max() 257 | pos_preds = preds[:pos_preds_num] 258 | neg_preds = preds[pos_preds_num:].view(neg_preds_shape) 259 | return pos_preds, neg_preds 260 | 261 | def cal_loss(self, batch_data, adj, initial_projector): 262 | ancs, poss, negs = batch_data 263 | with t.no_grad(): 264 | masked_adj = self.masker(adj, (ancs.to(args.devices[0]), (poss.to(args.devices[0])))) 265 | initial_embeds = initial_projector() 266 | topo_embeds = self.topoEncoder(masked_adj, initial_embeds).to(args.devices[1]) 267 | ancs, poss, negs = ancs.to(args.devices[1]), poss.to(args.devices[1]), negs.to(args.devices[1]) 268 | input_seq = t.concat([ancs, poss, negs]) 269 | input_seq = topo_embeds[input_seq] 270 | final_embeds = self.graphTransformer(input_seq) 271 | anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds[:ancs.shape[0] * 3], [ancs.shape[0]] * 3) 272 | # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs] 273 | if final_embeds.isinf().any() or final_embeds.isnan().any(): 274 | raise Exception('Final embedding fails') 275 | 276 | pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T) 277 | if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any(): 278 | raise Exception('Preds fails') 279 | pos_loss = pos_preds 280 | neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log() 281 | pre_loss = -(pos_loss - neg_loss).mean() 282 | 283 | if t.isinf(pre_loss).any() or t.isnan(pre_loss).any(): 284 | raise Exception('NaN or Inf') 285 | 286 | reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters()))) 287 | loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()} 288 | return pre_loss + reg_loss, loss_dict 289 | 290 | def cal_loss_node(self, batch_data, adj, initial_projector): 291 | ancs, labels = batch_data 292 | poss = labels + adj.shape[0] - args.class_num 293 | negs = t.from_numpy(np.array(list(range(args.class_num)))).to(t.int64).cuda() + adj.shape[0] - args.class_num 294 | with t.no_grad(): 295 | masked_adj = self.masker(adj, (ancs.to(args.devices[0]), (poss.to(args.devices[0])))) 296 | initial_embeds = initial_projector() 297 | topo_embeds = self.topoEncoder(masked_adj, initial_embeds).to(args.devices[1]) 298 | ancs, poss, negs = ancs.to(args.devices[1]), poss.to(args.devices[1]), negs.to(args.devices[1]) 299 | input_seq = t.concat([ancs, poss, negs]) 300 | input_seq = topo_embeds[input_seq] 301 | final_embeds = self.graphTransformer(input_seq) 302 | # anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds[:ancs.shape[0] * 3], [ancs.shape[0]] * 3) 303 | anc_embeds = final_embeds[:ancs.shape[0]] 304 | pos_embeds = final_embeds[ancs.shape[0]:ancs.shape[0]+poss.shape[0]] 305 | neg_embeds = final_embeds[-negs.shape[0]:] 306 | # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs] 307 | if final_embeds.isinf().any() or final_embeds.isnan().any(): 308 | raise Exception('Final embedding fails') 309 | 310 | pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T) 311 | if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any(): 312 | raise Exception('Preds fails') 313 | pos_loss = pos_preds 314 | neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log() 315 | pre_loss = -(pos_loss - neg_loss).mean() 316 | 317 | if t.isinf(pre_loss).any() or t.isnan(pre_loss).any(): 318 | raise Exception('NaN or Inf') 319 | 320 | reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters()))) 321 | preds = anc_embeds @ neg_embeds.T 322 | loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean(), 'preds': t.argmax(preds, dim=-1)} 323 | return pre_loss + reg_loss, loss_dict 324 | 325 | def pred_for_node_test(self, nodes, adj, initial_projector, rerun_embed=True): 326 | if rerun_embed: 327 | final_embeds = self.graphTransformer(self.topoEncoder(adj, initial_projector()).to(args.devices[1])) 328 | self.final_embeds = final_embeds 329 | final_embeds = self.final_embeds 330 | pck_embeds = final_embeds[nodes] 331 | class_embeds = final_embeds[-args.class_num:] 332 | preds = pck_embeds @ class_embeds.T 333 | return t.argmax(preds, dim=-1) 334 | 335 | class ALRS: 336 | def __init__(self, optimizer, loss_threshold=0.01, loss_ratio_threshold=0.01, decay_rate=0.97): 337 | self.optimizer = optimizer 338 | self.loss_threshold = loss_threshold 339 | self.decay_rate = decay_rate 340 | self.loss_ratio_threshold = loss_ratio_threshold 341 | self.last_loss = 1e9 342 | 343 | def step(self, loss): 344 | delta = self.last_loss - loss 345 | if delta < self.loss_threshold and delta / self.last_loss < self.loss_ratio_threshold: 346 | for group in self.optimizer.param_groups: 347 | group['lr'] *= self.decay_rate 348 | self.last_loss = loss 349 | -------------------------------------------------------------------------------- /node_classification/params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description='Model Parameters') 5 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 6 | parser.add_argument('--batch', default=1024, type=int, help='training batch size') 7 | parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)') 8 | parser.add_argument('--shot', default=-1, type=int, help='number of shots for each node') 9 | parser.add_argument('--epoch', default=50, type=int, help='number of epochs') 10 | parser.add_argument('--tune_step', default=100, type=int, help='number of projection tuning steps') 11 | parser.add_argument('--save_path', default='tem', help='file name to save model and training record') 12 | parser.add_argument('--load_model', default=None, help='model name to load') 13 | parser.add_argument('--tstdata', default='', type=str, help='name of test dataset') 14 | parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training') 15 | parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use') 16 | parser.add_argument('--topk', default=20, type=int, help='topk in evaluation') 17 | parser.add_argument('--cache_adj', default=0, type=int, help='indicates wheter cache bidirectional adjs') 18 | parser.add_argument('--cache_proj', default=1, type=int, help='indicates wheter cache projector and matrices') 19 | parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps') 20 | parser.add_argument('--data_dir', default='../datasets', type=str, help='dataset directory') 21 | 22 | parser.add_argument('--niter', default=2, type=int, help='number of iteration in svd') 23 | parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer') 24 | parser.add_argument('--drop_rate', default=0.1, type=float, help='dropout rate') 25 | parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor') 26 | parser.add_argument('--clamp', default=-1, type=float, help='absolute value for the limit of prediction scores while training') 27 | parser.add_argument('--mask_method', default='trn', type=str, help='which graph masking method to use') 28 | parser.add_argument('--mask_alg', default='linear', type=str, help='which graph masking algorithm in trn mask_method') 29 | parser.add_argument('--random_mask_rate', default=0.5, type=float, help='mask ratio in random mask_method') 30 | parser.add_argument('--act', default='leaky', type=str, help='activation function') 31 | parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation') 32 | parser.add_argument('--latdim', default=1024, type=int, help='latent dimensionality') 33 | parser.add_argument('--head', default=4, type=int, help='number of attention heads') 34 | parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not') 35 | parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn iterations') 36 | parser.add_argument('--gt_layer', default=4, type=int, help='number of graph transformer layers') 37 | parser.add_argument('--proj_method', default='svd', type=str, help='initial projection method') 38 | parser.add_argument('--mask', default='trn', type=str, help='indicating which mask strategy to apply') 39 | parser.add_argument('--loss', default='ce', type=str, help='loss function') 40 | parser.add_argument('--mask_bat', default=512, type=int, help='batch size for masking') 41 | parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer') 42 | parser.add_argument('--pred_iter', default=1, type=int, help='number of prediction iterations') 43 | parser.add_argument('--proj_trn_steps', default=10, type=int, help='number of training steps for one initial projection') 44 | return parser.parse_args() 45 | args = parse_args() --------------------------------------------------------------------------------