├── History
├── pretrn_gen0.his
├── pretrn_gen1.his
└── pretrn_gen2.his
├── LICENSE
├── Models
└── readme
├── README.md
├── datasets
├── amazon-book
│ ├── fewshot_mat_1.pkl
│ ├── fewshot_mat_5.pkl
│ ├── trn_mat.pkl.zip
│ └── tst_mat.pkl
├── citeseer
│ ├── adj_-1.pkl
│ ├── adj_1.pkl
│ ├── adj_5.pkl
│ ├── feats.pkl.zip
│ ├── label.pkl
│ ├── mask_-1.pkl
│ ├── mask_1.pkl
│ └── mask_5.pkl
├── collab
│ ├── fewshot_mat_1.pkl
│ ├── fewshot_mat_5.pkl
│ ├── trn_mat.pkl.zip
│ ├── tst_mat.pkl
│ └── val_mat.pkl
├── cora
│ ├── adj_-1.pkl
│ ├── adj_1.pkl
│ ├── adj_5.pkl
│ ├── feats.pkl
│ ├── label.pkl
│ ├── mask_-1.pkl
│ ├── mask_1.pkl
│ └── mask_5.pkl
├── ddi
│ ├── fewshot_mat_1.pkl
│ ├── fewshot_mat_5.pkl
│ ├── trn_mat.pkl.zip
│ ├── tst_mat.pkl
│ └── val_mat.pkl
├── gen0
│ ├── trn_mat.pkl
│ ├── tst_mat.pkl
│ └── val_mat.pkl
├── gen1
│ ├── trn_mat.pkl
│ ├── tst_mat.pkl
│ └── val_mat.pkl
├── gen2
│ ├── trn_mat.pkl
│ ├── tst_mat.pkl
│ └── val_mat.pkl
├── ml10m
│ ├── fewshot_mat_1.pkl
│ ├── fewshot_mat_5.pkl
│ ├── trn_mat.pkl.zip
│ └── tst_mat.pkl.zip
├── ml1m
│ ├── fewshot_mat_1.pkl
│ ├── fewshot_mat_5.pkl
│ ├── trn_mat.pkl
│ └── tst_mat.pkl
└── pubmed
│ ├── adj_-1.pkl
│ ├── adj_1.pkl
│ ├── adj_5.pkl
│ ├── feats.pkl.zip
│ ├── label.pkl
│ ├── mask_-1.pkl
│ ├── mask_1.pkl
│ └── mask_5.pkl
├── graph_generation
├── Exp_Utils
│ ├── Emailer.py
│ └── TimeLogger.py
├── Utils.py
├── embedding_generation.py
├── gen_results
│ ├── datasets
│ │ └── gen_data_ecommerce
│ │ │ ├── embedding_dict.pkl
│ │ │ ├── interaction_base-0_iter-0.pkl
│ │ │ ├── item_list.pkl
│ │ │ └── res
│ │ │ ├── interaction_fuse_iter-0.pkl
│ │ │ ├── iter-0_imap.pkl
│ │ │ ├── iter-0_test.pkl
│ │ │ ├── iter-0_train.pkl
│ │ │ └── iter-0_valid.pkl
│ ├── products_e-commerce platform like Amazon.txt
│ ├── tem
│ │ ├── e-commerce platform like Amazon_depth1_products
│ │ ├── e-commerce platform like Amazon_depth2_products, Automotive
│ │ ├── e-commerce platform like Amazon_depth2_products, Baby
│ │ ├── e-commerce platform like Amazon_depth2_products, Beauty
│ │ ├── e-commerce platform like Amazon_depth2_products, Books
│ │ ├── e-commerce platform like Amazon_depth2_products, Clothing
│ │ ├── e-commerce platform like Amazon_depth2_products, Electronics
│ │ ├── e-commerce platform like Amazon_depth2_products, Handmade
│ │ ├── e-commerce platform like Amazon_depth2_products, Health and Personal Care
│ │ ├── e-commerce platform like Amazon_depth2_products, Home Improvement
│ │ ├── e-commerce platform like Amazon_depth2_products, Industrial and Scientific
│ │ ├── e-commerce platform like Amazon_depth2_products, Jewelry
│ │ ├── e-commerce platform like Amazon_depth2_products, Musical Instruments
│ │ ├── e-commerce platform like Amazon_depth2_products, Office Products
│ │ ├── e-commerce platform like Amazon_depth2_products, Pet Supplies
│ │ ├── e-commerce platform like Amazon_depth2_products, Sports and Outdoors
│ │ ├── e-commerce platform like Amazon_depth2_products, Tools and Home Improvement
│ │ └── e-commerce platform like Amazon_depth2_products, Toys
│ └── tree_wInstanceNum_products_e-commerce platform like Amazon.pkl
├── human_item_generation_gibbsSampling_embedEstimation.py
├── instance_number_estimation_hierarchical.py
├── itemCollecting_dfsIterator.py
└── make_adjs.py
├── imgs
├── article cover.jpg
├── framework.png
├── graph tokenizer.png
├── impact of datasets.png
├── intro.png
├── opengraph_article_cover_full.png
├── performance.png
├── prompt.png
└── sampling.png
├── link_prediction
├── Utils
│ └── TimeLogger.py
├── data_handler.py
├── main.py
├── model.py
└── params.py
└── node_classification
├── Utils
└── TimeLogger.py
├── data_handler.py
├── main.py
├── model.py
└── params.py
/History/pretrn_gen0.his:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/History/pretrn_gen0.his
--------------------------------------------------------------------------------
/History/pretrn_gen1.his:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/History/pretrn_gen1.his
--------------------------------------------------------------------------------
/History/pretrn_gen2.his:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/History/pretrn_gen2.his
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Models/readme:
--------------------------------------------------------------------------------
1 | Download the pre-trained model via the link: https://drive.google.com/drive/folders/1d-Jn7LHJ2ZSmndKveU4_qS70-Z6j9utB?usp=drive_link
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OpenGraph: Towards Open Graph Foundation Models
2 |
3 |
4 |

5 |

6 |

7 |

8 |

9 |
10 |
Lianghao Xia,
Ben Kao, and
Chao Huang* (*Correspondence)
11 |
12 |

13 |
14 |
15 |
16 | Presenting OpenGraph, a foundation graph model
distilling zero-shot graph generalizability from LLMs.
17 |
18 |

19 |
20 |
21 |
22 | To achieve this goal, OpenGraph addresses several key technical challenges:
23 | - We propose a unified graph tokenizer to adapt our graph model to generalize well on unseen graph data, even when the underlying graph properties differ significantly from those encountered during training.
24 | - We develop a scalable graph transformer as the foundational encoder, which effectively and efficiently captures node-wise dependencies within the global topological context.
25 | - We introduce a data augmentation mechanism enhanced by a large language model (LLM) to alleviate the limitations of data scarcity in real-world scenarios.
26 |
27 |
28 |
29 | Extensive experiments validate the effectiveness of our framework. By adapting OpenGraph to new graph characteristics and comprehending the nuances of diverse graphs, our approach achieves remarkable zero-shot graph learning performance across various settings and domains.
30 |
31 | ## Environment Setup
32 | You need to unzip some of the data files in `datasets/`. Download the pre-trained models using the link in `Models/readme`. Our experiments were conducted with the following package versions:
33 | * python==3.10.13
34 | * torch==1.13.0
35 | * numpy==1.23.4
36 | * scipy==1.9.3
37 |
38 | ## Brief Code Structure
39 | Here is a brief overview of the code structures. The explanations for each directory are enclosed in quotes (##...##). For a more detailed version, please refer to the full version listed at the end of this readme.
40 | ```
41 | ./
42 | │ └── README.md
43 | │ ├── History/ ## Training history of pre-trained models ##
44 | │ ├── Models/ ## Pre-trained models ##
45 | │ ├── datasets/
46 | │ ├── graph_generation/ ## Code and examples for graph generation ##
47 | │ ├── imgs/ ## Images used in readme ##
48 | │ ├── link_prediction/ ## code for link prediction and pre-training ##
49 | │ │ ├── data_handler.py
50 | │ │ ├── main.py
51 | │ │ ├── model.py
52 | │ │ └── params.py
53 | │ │ ├── Utils/
54 | │ │ │ └── TimeLogger.py
55 | │ ├── node_classification/ ## code for testing on node classification ##
56 | │ │ ├── data_handler.py
57 | │ │ ├── main.py
58 | │ │ ├── model.py
59 | │ │ └── params.py
60 | │ │ ├── Utils/
61 | │ │ │ └── TimeLogger.py
62 | ```
63 |
64 | ## Usage
65 | #### To reproduce the test performance reported in the paper, run the following command lines:
66 | ```
67 | cd link_prediction/
68 | python main.py --load pretrn_gen1 --epoch 0 # test on OGBL-Collab, ML-1M, ML-10M
69 | python main.py --load pretrn_gen0 --tstdata amazon-book --epoch 0 # test on Amazon-Book
70 | python main.py --load pretrn_gen2 --tstdata ddi --epoch 0 # test on OGBL-ddi
71 | cd ../node_classification/
72 | python main.py --load pretrn_gen1 --tstdata cora # test on Cora
73 | python main.py --load pretrn_gen1 --tstdata citeseer # test on Citeseer
74 | python main.py --load pretrn_gen1 --tstdata pubmed # test on Pubmed
75 | ```
76 |
77 | #### To re-pretrain OpenGraph by yourself, run the following command lines:
78 | ```
79 | cd ../link_prediction/
80 | python main.py --save pretrn_gen1
81 | python main.py --trndata gen0 --tstdata amazon-book --save pretrn_gen0
82 | python main.py --trndata gen2 --tstdata ddi --save pretrn_gen2
83 | ```
84 |
85 | #### To explore pretraining with multiple different pre-training and testing datasets, modify `trn_datasets` and `tst_datasets` in line 241 of `link_prediction/main.py`.
86 |
87 | ## Graph Data Generation
88 | The graph generation code is in `graph_generation/`. A toy dataset of small size is given. You need to fill in your OpenAI key in `Utils.py` and `itemCollecting_dfsIterator.py` first. To generate your dataset, modify the `descs` and `hyperparams` dicts, and follow the following procedure:
89 | ```
90 | cd graph_generation/
91 | python itemCollecting_dfsIterator.py
92 | python instance_number_estimation_hierarchical.py
93 | python embedding_generation.py
94 | python human_item_generation_gibbsSampling_embedEstimation.py
95 | python make_adjs.py
96 | ```
97 |
98 | Below shows our prompt template, as well as examples for prompt configurations and generated nodes.
99 |
100 |
101 |
102 | ## Evaluation Results
103 |
104 | ### Overall Generalization Performance
105 | OpenGraph achives best performance under the 0-shot setting, compared to baselines trained/tuned with 1-shot and 5-shot data.
106 |
107 |
108 | ### Pre-training Dataset Study
109 | We studied the influence of using different pre-training datasets. Results below indicate that:
110 | - The generation techniques (Norm, Loc, Topo) have positive effects on performance.
111 | - Real-world datasets (Yelp2018, Gowalla) may yield worse results compared to our generated ones.
112 | - A relevant pre-training dataset (ML-10M for test data ML-1M and ML-10M) results in superior performance.
113 |
114 |
115 |
116 | ### Graph Tokenizer Study
117 | We tuned configurations of our unified graph tokenizer, by adjusting the order of graph smoothing, and replacing our topology-aware projection with alternatives. Our findings include:
118 | - **Adjacency smoothing is important**, as OpenGraph with 0-order smoothing yields inferior performance.
119 | - **Topology-aware projection is superior in performance**. Alternatives include *One-hot* which learns a big and unified representation table for all datasets, *Random* which holds no assumption for the node-wise relations and distributes them uniformly, *Degree* which is a widely-used method for non-attributed graphs and seems applicable for cross-graph scenario.
120 |
121 |
122 |
123 | ### Sampling Techniques Study
124 | We ablated the two sampling techniques in the graph transformer, and show their positive effects on both memory and time costs below. Suprisingly, token sequence sampling shows a positive effect over the model performance.
125 |
126 |
127 |
128 | ## Citation
129 | If you find this work useful for your research, please consider citing our paper:
130 | ```
131 | @inproceedings{xia2024opengraph,
132 | title={OpenGraph: Towards Open Graph Foundation Models},
133 | author={Xia, Lianghao and Kao, Ben and Huang, Chao},
134 | booktitle={EMNLP},
135 | year={2024}
136 | }
137 | ```
138 |
139 | ## Detailed Code Structures
140 | ```
141 | ./
142 | │ └── README.md
143 | │ ├── History/ ## Training history of pre-trained models ##
144 | │ │ ├── pretrn_gen0.his
145 | │ │ ├── pretrn_gen2.his
146 | │ │ └── pretrn_gen1.his
147 | │ ├── Models/ ## Pre-trained models ##
148 | │ │ └── readme ## Download pre-trained models using the link inside ##
149 | │ ├── datasets/
150 | │ │ ├── amazon-book/
151 | │ │ │ ├── fewshot_mat_1.pkl
152 | │ │ │ ├── trn_mat.pkl.zip ## Unzip it manually ##
153 | │ │ │ ├── tst_mat.pkl
154 | │ │ │ └── fewshot_mat_5.pkl
155 | │ │ ├── citeseer/
156 | │ │ │ ├── adj_-1.pkl
157 | │ │ │ ├── adj_1.pkl
158 | │ │ │ ├── adj_5.pkl
159 | │ │ │ ├── feats.pkl.zip ## Unzip it manually ##
160 | │ │ │ ├── label.pkl
161 | │ │ │ ├── mask_-1.pkl
162 | │ │ │ ├── mask_1.pkl
163 | │ │ │ └── mask_5.pkl
164 | │ │ ├── collab/
165 | │ │ │ ├── fewshot_mat_5.pkl
166 | │ │ │ ├── trn_mat.pkl.zip ## Unzip it manually ##
167 | │ │ │ ├── tst_mat.pkl
168 | │ │ │ ├── val_mat.pkl
169 | │ │ │ └── fewshot_mat_1.pkl
170 | │ │ ├── cora/
171 | │ │ │ ├── adj_-1.pkl
172 | │ │ │ ├── adj_1.pkl
173 | │ │ │ ├── adj_5.pkl
174 | │ │ │ ├── feats.pkl
175 | │ │ │ ├── label.pkl
176 | │ │ │ ├── mask_-1.pkl
177 | │ │ │ ├── mask_1.pkl
178 | │ │ │ └── mask_5.pkl
179 | │ │ ├── ddi/
180 | │ │ │ ├── fewshot_mat_1.pkl
181 | │ │ │ ├── trn_mat.pkl.zip ## Unzip it manually ##
182 | │ │ │ ├── tst_mat.pkl
183 | │ │ │ ├── val_mat.pkl
184 | │ │ │ └── fewshot_mat_5.pkl
185 | │ │ ├── gen0/
186 | │ │ │ ├── trn_mat.pkl
187 | │ │ │ ├── val_mat.pkl
188 | │ │ │ └── tst_mat.pkl
189 | │ │ ├── gen1/
190 | │ │ │ ├── trn_mat.pkl
191 | │ │ │ ├── tst_mat.pkl
192 | │ │ │ └── val_mat.pkl
193 | │ │ ├── gen2/
194 | │ │ │ ├── trn_mat.pkl
195 | │ │ │ ├── val_mat.pkl
196 | │ │ │ └── tst_mat.pkl
197 | │ │ ├── ml10m/
198 | │ │ │ ├── fewshot_mat_1.pkl
199 | │ │ │ ├── trn_mat.pkl.zip ## Unzip it manually ##
200 | │ │ │ ├── tst_mat.pkl.zip ## Unzip it manually ##
201 | │ │ │ └── fewshot_mat_5.pkl
202 | │ │ ├── ml1m/
203 | │ │ │ ├── fewshot_mat_5.pkl
204 | │ │ │ ├── trn_mat.pkl
205 | │ │ │ ├── tst_mat.pkl
206 | │ │ │ └── fewshot_mat_1.pkl
207 | │ │ ├── pubmed/
208 | │ │ │ ├── adj_-1.pkl
209 | │ │ │ ├── adj_1.pkl
210 | │ │ │ ├── feats.pkl.zip ## Unzip it manually ##
211 | │ │ │ ├── label.pkl
212 | │ │ │ ├── mask_-1.pkl
213 | │ │ │ ├── mask_1.pkl
214 | │ │ │ ├── mask_5.pkl
215 | │ │ │ └── adj_5.pkl
216 | │ ├── graph_generation/ ## Code and examples for graph generation ##
217 | │ │ ├── embedding_generation.py ## Node embedding generation ##
218 | │ │ ├── human_item_generation_gibbsSampling_embedEstimation.py ## Edge generation ##
219 | │ │ ├── instance_number_estimation_hierarchical.py ## Estimate amount for each node. Not mentioned in the paper. ##
220 | │ │ ├── itemCollecting_dfsIterator.py ## Node generation ##
221 | │ │ ├── make_adjs.py ## Making datasets for generated gaphs ##
222 | │ │ └── Utils.py
223 | │ │ ├── Exp_Utils/
224 | │ │ │ ├── Emailer.py ## A tool to send warning email for experiments ##
225 | │ │ │ └── TimeLogger.py
226 | │ │ ├── gen_results/
227 | │ │ │ ├── tree_wInstanceNum_products_e-commerce platform like Amazon.pkl ## Tree data structure ##
228 | │ │ │ └── products_e-commerce platform like Amazon.txt ## Node list ##
229 | │ │ │ ├── datasets/
230 | │ │ │ │ ├── gen_data_ecommerce/
231 | │ │ │ │ │ ├── embedding_dict.pkl
232 | │ │ │ │ │ ├── item_list.pkl
233 | │ │ │ │ │ └── interaction_base-0_iter-0.pkl ## generated edges ##
234 | │ │ │ │ │ ├── res/
235 | │ │ │ │ │ │ ├── iter-0_imap.pkl ## Id map for nodes ##
236 | │ │ │ │ │ │ ├── iter-0_test.pkl
237 | │ │ │ │ │ │ ├── iter-0_train.pkl
238 | │ │ │ │ │ │ ├── iter-0_valid.pkl
239 | │ │ │ │ │ │ └── interaction_fuse_iter-0.pkl
240 | │ │ │ ├── tem/ ## Temporary files for node generation ##
241 | │ │ │ │ ├── e-commerce platform like Amazon_depth1_products
242 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Automotive
243 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Baby
244 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Beauty
245 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Books
246 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Clothing
247 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Electronics
248 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Handmade
249 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Health and Personal Care
250 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Home Improvement
251 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Industrial and Scientific
252 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Jewelry
253 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Musical Instruments
254 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Office Products
255 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Pet Supplies
256 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Tools and Home Improvement
257 | │ │ │ │ ├── e-commerce platform like Amazon_depth2_products, Toys
258 | │ │ │ │ └── e-commerce platform like Amazon_depth2_products, Sports and Outdoors
259 | │ ├── imgs/ ## Images used in readme ##
260 | │ │ ├── framework.png
261 | │ │ ├── intro.png
262 | │ │ ├── performance.png
263 | │ │ └── article cover.jpg
264 | │ ├── link_prediction/ ## code for link prediction and pre-training ##
265 | │ │ ├── data_handler.py
266 | │ │ ├── main.py
267 | │ │ ├── model.py
268 | │ │ └── params.py
269 | │ │ ├── Utils/
270 | │ │ │ └── TimeLogger.py
271 | │ ├── node_classification/ ## code for testing on node classification ##
272 | │ │ ├── data_handler.py
273 | │ │ ├── main.py
274 | │ │ ├── model.py
275 | │ │ └── params.py
276 | │ │ ├── Utils/
277 | │ │ │ └── TimeLogger.py
278 | ```
279 |
--------------------------------------------------------------------------------
/datasets/amazon-book/fewshot_mat_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/amazon-book/fewshot_mat_1.pkl
--------------------------------------------------------------------------------
/datasets/amazon-book/fewshot_mat_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/amazon-book/fewshot_mat_5.pkl
--------------------------------------------------------------------------------
/datasets/amazon-book/trn_mat.pkl.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/amazon-book/trn_mat.pkl.zip
--------------------------------------------------------------------------------
/datasets/amazon-book/tst_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/amazon-book/tst_mat.pkl
--------------------------------------------------------------------------------
/datasets/citeseer/adj_-1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/adj_-1.pkl
--------------------------------------------------------------------------------
/datasets/citeseer/adj_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/adj_1.pkl
--------------------------------------------------------------------------------
/datasets/citeseer/adj_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/adj_5.pkl
--------------------------------------------------------------------------------
/datasets/citeseer/feats.pkl.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/feats.pkl.zip
--------------------------------------------------------------------------------
/datasets/citeseer/label.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/label.pkl
--------------------------------------------------------------------------------
/datasets/citeseer/mask_-1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/mask_-1.pkl
--------------------------------------------------------------------------------
/datasets/citeseer/mask_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/mask_1.pkl
--------------------------------------------------------------------------------
/datasets/citeseer/mask_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/citeseer/mask_5.pkl
--------------------------------------------------------------------------------
/datasets/collab/fewshot_mat_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/collab/fewshot_mat_1.pkl
--------------------------------------------------------------------------------
/datasets/collab/fewshot_mat_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/collab/fewshot_mat_5.pkl
--------------------------------------------------------------------------------
/datasets/collab/trn_mat.pkl.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/collab/trn_mat.pkl.zip
--------------------------------------------------------------------------------
/datasets/collab/tst_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/collab/tst_mat.pkl
--------------------------------------------------------------------------------
/datasets/collab/val_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/collab/val_mat.pkl
--------------------------------------------------------------------------------
/datasets/cora/adj_-1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/adj_-1.pkl
--------------------------------------------------------------------------------
/datasets/cora/adj_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/adj_1.pkl
--------------------------------------------------------------------------------
/datasets/cora/adj_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/adj_5.pkl
--------------------------------------------------------------------------------
/datasets/cora/feats.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/feats.pkl
--------------------------------------------------------------------------------
/datasets/cora/label.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/label.pkl
--------------------------------------------------------------------------------
/datasets/cora/mask_-1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/mask_-1.pkl
--------------------------------------------------------------------------------
/datasets/cora/mask_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/mask_1.pkl
--------------------------------------------------------------------------------
/datasets/cora/mask_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/cora/mask_5.pkl
--------------------------------------------------------------------------------
/datasets/ddi/fewshot_mat_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ddi/fewshot_mat_1.pkl
--------------------------------------------------------------------------------
/datasets/ddi/fewshot_mat_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ddi/fewshot_mat_5.pkl
--------------------------------------------------------------------------------
/datasets/ddi/trn_mat.pkl.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ddi/trn_mat.pkl.zip
--------------------------------------------------------------------------------
/datasets/ddi/tst_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ddi/tst_mat.pkl
--------------------------------------------------------------------------------
/datasets/ddi/val_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ddi/val_mat.pkl
--------------------------------------------------------------------------------
/datasets/gen0/trn_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen0/trn_mat.pkl
--------------------------------------------------------------------------------
/datasets/gen0/tst_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen0/tst_mat.pkl
--------------------------------------------------------------------------------
/datasets/gen0/val_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen0/val_mat.pkl
--------------------------------------------------------------------------------
/datasets/gen1/trn_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen1/trn_mat.pkl
--------------------------------------------------------------------------------
/datasets/gen1/tst_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen1/tst_mat.pkl
--------------------------------------------------------------------------------
/datasets/gen1/val_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen1/val_mat.pkl
--------------------------------------------------------------------------------
/datasets/gen2/trn_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen2/trn_mat.pkl
--------------------------------------------------------------------------------
/datasets/gen2/tst_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen2/tst_mat.pkl
--------------------------------------------------------------------------------
/datasets/gen2/val_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/gen2/val_mat.pkl
--------------------------------------------------------------------------------
/datasets/ml10m/fewshot_mat_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml10m/fewshot_mat_1.pkl
--------------------------------------------------------------------------------
/datasets/ml10m/fewshot_mat_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml10m/fewshot_mat_5.pkl
--------------------------------------------------------------------------------
/datasets/ml10m/trn_mat.pkl.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml10m/trn_mat.pkl.zip
--------------------------------------------------------------------------------
/datasets/ml10m/tst_mat.pkl.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml10m/tst_mat.pkl.zip
--------------------------------------------------------------------------------
/datasets/ml1m/fewshot_mat_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml1m/fewshot_mat_1.pkl
--------------------------------------------------------------------------------
/datasets/ml1m/fewshot_mat_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml1m/fewshot_mat_5.pkl
--------------------------------------------------------------------------------
/datasets/ml1m/trn_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml1m/trn_mat.pkl
--------------------------------------------------------------------------------
/datasets/ml1m/tst_mat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/ml1m/tst_mat.pkl
--------------------------------------------------------------------------------
/datasets/pubmed/adj_-1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/adj_-1.pkl
--------------------------------------------------------------------------------
/datasets/pubmed/adj_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/adj_1.pkl
--------------------------------------------------------------------------------
/datasets/pubmed/adj_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/adj_5.pkl
--------------------------------------------------------------------------------
/datasets/pubmed/feats.pkl.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/feats.pkl.zip
--------------------------------------------------------------------------------
/datasets/pubmed/label.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/label.pkl
--------------------------------------------------------------------------------
/datasets/pubmed/mask_-1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/mask_-1.pkl
--------------------------------------------------------------------------------
/datasets/pubmed/mask_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/mask_1.pkl
--------------------------------------------------------------------------------
/datasets/pubmed/mask_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/datasets/pubmed/mask_5.pkl
--------------------------------------------------------------------------------
/graph_generation/Exp_Utils/Emailer.py:
--------------------------------------------------------------------------------
1 | import smtplib
2 | from email.mime.text import MIMEText
3 | from email.header import Header
4 |
5 | def SendMail(message, subject=None):
6 | return
7 | default_subject = 'Experimental Anomaly Report'
8 | if subject != None:
9 | default_subject += ':' + subject
10 | subject = default_subject
11 |
12 | sender = 'xxx@xx.com'
13 | receivers = ['xxx@xx.com']
14 | message = 'Dear Artificial Anomaly Investigator,\n\n' + 'I am writing to bring to your attention that an anomaly occurs and cannot be solved by your exception handler in the recent experiment. Please refer to the following message for details.\n\n' + message + '\n\nBest regards,\nIntelligent Experiment Assistant'
15 |
16 | message = MIMEText(message, 'plain', 'utf-8')
17 | message['Subject'] = Header(subject, 'utf-8')
18 | message['From'] = Header('intel_assistant')
19 | message['To'] = Header('investigator')
20 |
21 | mail_host = 'smtp.xxx.com'
22 | mail_user = 'xxx@xxx.com'
23 | mail_pass = 'xxx'
24 |
25 | smtpObj = smtplib.SMTP()
26 | smtpObj.connect(mail_host, 25)
27 | smtpObj.login(mail_user, mail_pass)
28 | smtpObj.sendmail(sender, receivers, message.as_string())
--------------------------------------------------------------------------------
/graph_generation/Exp_Utils/TimeLogger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | logmsg = ''
4 | timemark = dict()
5 | saveDefault = False
6 | def log(msg, save=None, oneline=False):
7 | global logmsg
8 | global saveDefault
9 | time = datetime.datetime.now()
10 | tem = '%s: %s' % (time, msg)
11 | if save != None:
12 | if save:
13 | logmsg += tem + '\n'
14 | elif saveDefault:
15 | logmsg += tem + '\n'
16 | if oneline:
17 | print(tem, end='\r')
18 | else:
19 | print(tem)
20 |
21 | def marktime(marker):
22 | global timemark
23 | timemark[marker] = datetime.datetime.now()
24 |
25 |
26 | if __name__ == '__main__':
27 | log('')
--------------------------------------------------------------------------------
/graph_generation/Utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import openai
3 | import json
4 | import tiktoken
5 | import numpy as np
6 | import Exp_Utils.TimeLogger as logger
7 | from Exp_Utils.TimeLogger import log
8 | from Exp_Utils.Emailer import SendMail
9 | import time
10 |
11 | openai.api_key = "xx-xxxxxx"
12 |
13 | class DataGenAgent:
14 | def __init__(self):
15 | super(DataGenAgent, self).__init__()
16 | self.token_num = 0
17 | self.encoding = tiktoken.encoding_for_model('gpt-3.5-turbo')
18 |
19 | def openai_embedding(self, message):
20 | try:
21 | embedding = openai.Embedding.create(
22 | model='text-embedding-ada-002',
23 | input = message
24 | )['data'][0]['embedding']
25 | # time.sleep()
26 | return np.array(embedding)
27 | except Exception as e:
28 | print('OpenAI request error: {err_msg}. Retry in 10 seconds.'.format(err_msg=e))
29 | time.sleep(10)
30 | return self.openai_embedding(message)
31 |
32 | def openai(self, message):
33 | try:
34 | completion = openai.ChatCompletion.create(
35 | model='gpt-3.5-turbo-1106',
36 | # model='gpt-4',
37 | messages=[
38 | {"role": "user", "content": message},
39 | ]
40 | )
41 | response = completion.choices[0].message.content
42 | time.sleep(1)
43 | self.token_num += len(self.encoding.encode(json.dumps(message)))
44 | return response
45 | except Exception as e:
46 | print('OpenAI request error: {err_msg}. Retry in 10 seconds.'.format(err_msg=e))
47 | time.sleep(10)
48 | return self.openai(message)
49 |
50 | def handling_llm_exceptions(self, message, interpret_func, interpret_args, failure_tolerance):
51 | try:
52 | answers_text = self.openai(message)
53 | print('Answers text:')
54 | print(answers_text)
55 | print('----------\n')
56 | return 0, interpret_func(answers_text, *interpret_args)
57 | except Exception as e:
58 | self.failure += 1
59 | log('\n**********\nERROR\n')
60 | log('Exception occurs when interpreting. Exception message: {exception}'.format(exception=e), save=True)
61 | log('Failure times: {failure}'.format(failure=self.failure, save=True))
62 | log('Prompt text:\n{prompt}'.format(prompt=message), save=True)
63 | log('Response text:\n{response}'.format(response=answers_text), save=True)
64 | if self.failure < failure_tolerance:
65 | log('Retry in 10 seconds.', save=True)
66 | time.sleep(10)
67 | log('\n**********\n')
68 | return 1, None
69 | else:
70 | log('Reaching maximum failure tolerance. CANNOT HANDLE!'.format(failure=self.failure), save=True)
71 | log('Sending report email.', save=True)
72 | SendMail(logger.logmsg)
73 | logger.logmsg = ''
74 | log('\n**********\n')
75 | return 2, None
76 |
77 | class EntityTreeNode:
78 | def __init__(self, entity_name, depth, parent=None):
79 | self.entity_name = entity_name
80 | self.frequency = []
81 | self.quantity = -1
82 | self.children = dict()
83 | self.parent = parent
84 | self.depth = depth
85 |
86 | def is_child(self, entity_name):
87 | return entity_name in self.children
88 |
89 | def to_child(self, entity_name):
90 | return self.children[entity_name]
91 |
92 | def add_child(self, entity_name):
93 | child = EntityTreeNode(entity_name, self.depth+1, self)
94 | self.children[entity_name] = child
95 |
96 | def iterate_children(self):
97 | for key, node in self.children.items():
98 | yield key, node
99 |
100 | def allocate_number(self, quantity):
101 | print('Allocating depth {depth} {entity_name}, quantity: {quantity}'.format(depth=self.depth, entity_name=self.entity_name, quantity=quantity))
102 | self.quantity = quantity
103 | if len(self.children) == 0:
104 | return
105 | child_list = list(self.children.values())
106 | child_freq = list(map(lambda x: x.frequency, child_list))
107 | child_freq = np.array(child_freq) # N * T
108 | if child_freq.shape[1] == 0:
109 | raise Exception('No estimated frequency for children.')
110 | summ = np.sum(child_freq, axis=0, keepdims=True) # 1 * T
111 | child_freq = child_freq / summ # N * T
112 | child_num = np.mean(child_freq, axis=1) * self.quantity # N
113 | for i, child in enumerate(child_list):
114 | child.allocate_number(child_num[i])
115 |
116 | def get_list_of_leaves(self, entity_name, with_branches=False):
117 | if len(self.children) == 0:
118 | num = max(1, int(self.quantity))
119 | entity_list = list()
120 | cur_entity_name = entity_name + ', ' + self.entity_name
121 | for i in range(num):
122 | # entity_list.append(cur_entity_name + ' #{idx}'.format(idx=i))
123 | entity_list.append(self.entity_name + ' #{idx}'.format(idx=i))
124 | return entity_list
125 | entity_list = list()
126 | if with_branches:
127 | if self.depth <= 2:
128 | tem_entity_name = self.entity_name
129 | else:
130 | tem_entity_name = entity_name + ', ' + self.entity_name
131 | entity_list.append(tem_entity_name)
132 | for _, child in self.iterate_children():
133 | nxt_entity_name = self.entity_name if self.depth <= 2 else (entity_name + ', ' + self.entity_name)
134 | entities = child.get_list_of_leaves(nxt_entity_name, with_branches)
135 | entity_list = entity_list + entities
136 | return entity_list
137 |
138 | class EntityTreeConstructer:
139 | def __init__(self, entity_lines):
140 | super(EntityTreeConstructer, self).__init__()
141 |
142 | root_name = self.line_process(entity_lines[0])[0]
143 | self.root = EntityTreeNode(root_name, depth=1)
144 | self.root.frequency.append(1.0)
145 | self.construct_tree(entity_lines)
146 |
147 | def add_node(self, cur_node, descriptions, cur):
148 | parent_entity_name = descriptions[cur-1]
149 | if cur_node.entity_name != parent_entity_name:
150 | print(cur_node.entity_name, parent_entity_name)
151 | print(descriptions)
152 | assert cur_node.entity_name == parent_entity_name
153 | cur_entity_name = descriptions[cur]
154 | if not cur_node.is_child(cur_entity_name):
155 | cur_node.add_child(cur_entity_name)
156 | if cur + 1 < len(descriptions):
157 | self.add_node(cur_node.to_child(cur_entity_name), descriptions, cur+1)
158 |
159 | def line_process(self, entity_line, check=False):
160 | entity_line = entity_line.strip()
161 | descriptions = list(map(lambda x: x.strip(), entity_line.split(',')))
162 | if not check:
163 | return descriptions
164 | if descriptions[0] != self.root.entity_name:
165 | raise Exception('Cannot find root')
166 | if len(descriptions) <= 1:
167 | raise Exception('Fail to split')
168 | return descriptions
169 |
170 | def construct_tree(self, entity_lines):
171 | for entity_line in entity_lines:
172 | try:
173 | descriptions = self.line_process(entity_line, check=True)
174 | except Exception as e:
175 | print(str(e), ':', entity_line)
176 | continue
177 | self.add_node(self.root, descriptions, cur=1)
--------------------------------------------------------------------------------
/graph_generation/embedding_generation.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | from Utils import DataGenAgent
4 | from Exp_Utils.TimeLogger import log
5 |
6 | def load_item_list(item_file, entity_file, item_num):
7 | if not os.path.exists(item_file):
8 | with open(entity_file, 'rb') as fs:
9 | entity_tree_root = pickle.load(fs)
10 | entity_tree_root.allocate_number(item_num)
11 | item_list = entity_tree_root.get_list_of_leaves('')
12 | with open(item_file, 'wb+') as fs:
13 | pickle.dump(item_list, fs)
14 | else:
15 | with open(item_file, 'rb') as fs:
16 | item_list = pickle.load(fs)
17 | return item_list
18 |
19 | descs = {
20 | 'data_name': 'gen_data_venues',
21 | 'scenario_desc': 'venue rating platform like yelp',
22 | 'human_role': 'user',
23 | 'interaction_verb': 'interact',
24 | 'initial_entity': 'business venues',
25 | }
26 | descs = {
27 | 'data_name': 'gen_data_books',
28 | 'scenario_desc': 'book rating platform',
29 | 'human_role': 'user',
30 | 'interaction_verb': 'interact',
31 | 'initial_entity': 'books',
32 | }
33 | descs = {
34 | 'data_name': 'gen_data_ai_papers',
35 | 'scenario_desc': 'published paper list of top AI conferences',
36 | 'human_role': 'user',
37 | 'interaction_verb': 'interact',
38 | 'initial_entity': 'deep learning papers',
39 | }
40 | descs = {
41 | 'data_name': 'gen_data_ecommerce',
42 | 'scenario_desc': 'e-commerce platform like Amazon',
43 | 'human_role': 'user',
44 | 'interaction_verb': 'interact',
45 | 'initial_entity': 'products',
46 | }
47 | file_root = 'gen_results/datasets/{data_name}/'.format(data_name=descs['data_name'])
48 | entity_file = 'gen_results/tree_wInstanceNum_{initial_entity}_{scenario}.pkl'.format(initial_entity=descs['initial_entity'], scenario=descs['scenario_desc'])
49 | embed_file = file_root + 'embedding_dict.pkl'
50 |
51 | with open(entity_file, 'rb') as fs:
52 | entity_tree_root = pickle.load(fs)
53 | entity_tree_root.allocate_number(1)
54 | item_list = entity_tree_root.get_list_of_leaves('', with_branches=True)
55 | item_list = list(map(lambda item_name: item_name if ' #' not in item_name else item_name[:item_name.index(' #')], item_list))
56 | print(item_list)
57 | print('Num of items', len(item_list))
58 | agent = DataGenAgent()
59 | embedding_dict = dict()
60 | for i, item in enumerate(item_list):
61 | log('{idx} / {tot}'.format(idx=i, tot=len(item_list)))
62 | embedding = agent.openai_embedding(item)
63 | embedding_dict[item] = embedding
64 | with open(embed_file, 'wb') as fs:
65 | pickle.dump(embedding_dict, fs)
--------------------------------------------------------------------------------
/graph_generation/gen_results/datasets/gen_data_ecommerce/embedding_dict.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/embedding_dict.pkl
--------------------------------------------------------------------------------
/graph_generation/gen_results/datasets/gen_data_ecommerce/interaction_base-0_iter-0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/interaction_base-0_iter-0.pkl
--------------------------------------------------------------------------------
/graph_generation/gen_results/datasets/gen_data_ecommerce/item_list.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/item_list.pkl
--------------------------------------------------------------------------------
/graph_generation/gen_results/datasets/gen_data_ecommerce/res/interaction_fuse_iter-0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/res/interaction_fuse_iter-0.pkl
--------------------------------------------------------------------------------
/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_imap.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_imap.pkl
--------------------------------------------------------------------------------
/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_test.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_test.pkl
--------------------------------------------------------------------------------
/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_train.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_train.pkl
--------------------------------------------------------------------------------
/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_valid.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/datasets/gen_data_ecommerce/res/iter-0_valid.pkl
--------------------------------------------------------------------------------
/graph_generation/gen_results/products_e-commerce platform like Amazon.txt:
--------------------------------------------------------------------------------
1 | products, Electronics, Headphones
2 | products, Electronics, Speakers
3 | products, Electronics, Cameras
4 | products, Electronics, Smartphones
5 | products, Electronics, Smartwatches
6 | products, Electronics, Tablets
7 | products, Electronics, Laptops
8 | products, Electronics, Desktop Computers
9 | products, Electronics, Home Security Systems
10 | products, Electronics, Drones
11 | products, Electronics, Gaming Consoles
12 | products, Electronics, Virtual Reality Headsets
13 | products, Electronics, GPS Devices
14 | products, Electronics, E-book Readers
15 | products, Clothing, Women's clothing
16 | products, Clothing, Men's clothing
17 | products, Clothing, Kid's clothing
18 | products, Clothing, Shoes
19 | products, Clothing, Accessories
20 | products, Books, Fiction
21 | products, Books, Non-fiction
22 | products, Books, Mystery
23 | products, Books, Science Fiction
24 | products, Books, Romance
25 | products, Books, Fantasy
26 | products, Books, Thriller
27 | products, Books, Biography
28 | products, Books, History
29 | products, Books, Children's Books
30 | products, Books, Self-help
31 | products, Books, Cookbook
32 | products, Books, Art
33 | products, Books, Travel
34 | products, Books, Science
35 | products, Books, Religion
36 | products, Books, Reference
37 | products, Books, Poetry
38 | products, Books, Graphic Novels
39 | products, Books, Horror
40 | products, Books, Young Adult
41 | products, Beauty, Makeup
42 | products, Beauty, Skincare
43 | products, Beauty, Haircare
44 | products, Beauty, Fragrance
45 | products, Beauty, Bath & Body
46 | products, Beauty, Tools & Accessories
47 | products, Beauty, Men's Grooming
48 | products, Beauty, Beauty Supplements
49 | products, Home Improvement, Power tools
50 | products, Home Improvement, Hand tools
51 | products, Home Improvement, Electrical
52 | products, Home Improvement, Hardware
53 | products, Home Improvement, Plumbing
54 | products, Home Improvement, Building supplies
55 | products, Home Improvement, Paint
56 | products, Home Improvement, Heating
57 | products, Home Improvement, Cooling
58 | products, Home Improvement, Home security
59 | products, Home Improvement, Lighting
60 | products, Home Improvement, Storage
61 | products, Home Improvement, Cleaning supplies
62 | products, Home Improvement, Appliances
63 | products, Toys, Action Figures
64 | products, Toys, Dolls
65 | products, Toys, Stuffed Animals
66 | products, Toys, Building Sets
67 | products, Toys, Arts & Crafts
68 | products, Toys, Vehicles
69 | products, Toys, Puzzles
70 | products, Toys, Outdoor Play
71 | products, Toys, Learning & Education
72 | products, Toys, Games
73 | products, Toys, Play Food & Grocery
74 | products, Toys, Electronic Toys
75 | products, Toys, Musical Instruments
76 | products, Automotive, Car parts
77 | products, Automotive, Car accessories
78 | products, Automotive, Car care products
79 | products, Automotive, Car electronics
80 | products, Automotive, Car interior accessories
81 | products, Automotive, Car exterior accessories
82 | products, Automotive, Car performance parts
83 | products, Automotive, Motorcycle parts
84 | products, Automotive, Motorcycle accessories
85 | products, Sports and Outdoors, Camping and Hiking
86 | products, Sports and Outdoors, Exercise and Fitness
87 | products, Sports and Outdoors, Golf
88 | products, Sports and Outdoors, Hunting and Fishing
89 | products, Sports and Outdoors, Outdoor Recreation
90 | products, Sports and Outdoors, Team Sports
91 | products, Sports and Outdoors, Water Sports
92 | products, Sports and Outdoors, Winter Sports
93 | products, Health and Personal Care, Vitamins
94 | products, Health and Personal Care, Supplements
95 | products, Health and Personal Care, Health Care
96 | products, Health and Personal Care, Baby & Child Care
97 | products, Health and Personal Care, Sports Nutrition
98 | products, Health and Personal Care, Household Supplies
99 | products, Health and Personal Care, Medical Supplies & Equipment
100 | products, Health and Personal Care, Sexual Wellness
101 | products, Health and Personal Care, Health & Wellness
102 | products, Health and Personal Care, Oral Care
103 | products, Health and Personal Care, Beauty & Grooming
104 | products, Health and Personal Care, Pet Supplies
105 | products, Health and Personal Care, Health & Household
106 | products, Health and Personal Care, Baby & Child Care
107 | products, Health and Personal Care, Health Supplies
108 | products, Jewelry, Rings
109 | products, Jewelry, Necklaces
110 | products, Jewelry, Bracelets
111 | products, Jewelry, Earrings
112 | products, Jewelry, Pendants
113 | products, Jewelry, Brooches
114 | products, Jewelry, Charms
115 | products, Jewelry, Anklets
116 | products, Jewelry, Jewelry Sets
117 | products, Jewelry, Body Jewelry
118 | products, Jewelry, Hair Jewelry
119 | products, Jewelry, Toe Rings
120 | products, Jewelry, Cufflinks
121 | products, Jewelry, Tie Clips
122 | products, Pet Supplies, Dog food
123 | products, Pet Supplies, Cat food
124 | products, Pet Supplies, Dog treats
125 | products, Pet Supplies, Cat treats
126 | products, Pet Supplies, Cat litter
127 | products, Pet Supplies, Dog grooming supplies
128 | products, Pet Supplies, Cat grooming supplies
129 | products, Pet Supplies, Dog toys
130 | products, Pet Supplies, Cat toys
131 | products, Pet Supplies, Dog collars and leashes
132 | products, Pet Supplies, Cat collars and leashes
133 | products, Pet Supplies, Fish supplies
134 | products, Pet Supplies, Small animal supplies
135 | products, Pet Supplies, Bird supplies
136 | products, Baby, Baby girls' clothing
137 | products, Baby, baby boys' clothing
138 | products, Baby, baby shoes
139 | products, Baby, baby accessories
140 | products, Baby, baby gear
141 | products, Baby, baby feeding
142 | products, Baby, baby bathing
143 | products, Baby, baby nursery
144 | products, Baby, baby toys
145 | products, Office Products, Office furniture
146 | products, Office Products, Office lighting
147 | products, Office Products, Desk accessories
148 | products, Office Products, Filing products
149 | products, Office Products, Office electronics
150 | products, Office Products, Paper products
151 | products, Office Products, Writing instruments
152 | products, Office Products, Presentation supplies
153 | products, Office Products, Office storage
154 | products, Office Products, Office maintenance
155 | products, Office Products, Office organization
156 | products, Tools and Home Improvement, Power Tools
157 | products, Tools and Home Improvement, Hand Tools
158 | products, Tools and Home Improvement, Tool Organizers
159 | products, Tools and Home Improvement, Tool Sets
160 | products, Tools and Home Improvement, Measuring and Layout Tools
161 | products, Tools and Home Improvement, Welding and Soldering
162 | products, Tools and Home Improvement, Tool Storage
163 | products, Tools and Home Improvement, Material Handling
164 | products, Tools and Home Improvement, Jobsite Safety
165 | products, Tools and Home Improvement, Flashlights
166 | products, Tools and Home Improvement, Jobsite Radios
167 | products, Tools and Home Improvement, Welding Equipment
168 | products, Tools and Home Improvement, Woodworking
169 | products, Tools and Home Improvement, Workshop Equipment
170 | products, Tools and Home Improvement, Air Tools
171 | products, Tools and Home Improvement, Welding and Soldering
172 | products, Tools and Home Improvement, Power Finishing Tools
173 | products, Tools and Home Improvement, Fireplace and Stove Accessories
174 | products, Tools and Home Improvement, Abrasive and Finishing Products
175 | products, Tools and Home Improvement, Building Supplies
176 | products, Tools and Home Improvement, Cleaning Supplies
177 | products, Tools and Home Improvement, Electrical
178 | products, Tools and Home Improvement, Hardware
179 | products, Tools and Home Improvement, Kitchen and Bath Fixtures
180 | products, Tools and Home Improvement, Light Bulbs
181 | products, Tools and Home Improvement, Lighting and Ceiling Fans
182 | products, Tools and Home Improvement, Measuring and Layout Tools
183 | products, Tools and Home Improvement, Paint
184 | products, Tools and Home Improvement, Wall Treatments and Supplies
185 | products, Tools and Home Improvement, Power and Hand Tools
186 | products, Tools and Home Improvement, Rough Plumbing
187 | products, Tools and Home Improvement, Safety and Security
188 | products, Tools and Home Improvement, Storage and Home Organization
189 | products, Tools and Home Improvement, Winterization products
190 | Outdoor Power Tools
191 | products, Tools and Home Improvement, Hand Tools
192 | products, Tools and Home Improvement, Power Tools
193 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments
194 | products, Tools and Home Improvement, Electrical
195 | products, Tools and Home Improvement, Hardware
196 | products, Tools and Home Improvement, Kitchen and Bath Fixtures
197 | products, Tools and Home Improvement, Light Bulbs
198 | products, Tools and Home Improvement, Light Fixtures
199 | products, Tools and Home Improvement, Measuring and Layout Tools
200 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments
201 | products, Tools and Home Improvement, Power and Hand Tools
202 | products, Tools and Home Improvement, Rough Plumbing
203 | products, Tools and Home Improvement, Safety and Security
204 | products, Tools and Home Improvement, Storage and Home Organization
205 | products, Tools and Home Improvement, Building Supplies
206 | products, Tools and Home Improvement, Rough Plumbing
207 | products, Tools and Home Improvement, Building Materials
208 | products, Tools and Home Improvement, Material Handling
209 | products, Tools and Home Improvement, Measuring and Layout Tools
210 | products, Tools and Home Improvement, Jobsite Safety
211 | products, Tools and Home Improvement, Work Wear and Safety Gear
212 | products, Tools and Home Improvement, Air Tools & Compressors
213 | products, Musical Instruments, Amplifiers & Effects
214 | products, Musical Instruments, Band & Orchestra
215 | products, Musical Instruments, Drums & Percussion
216 | products, Musical Instruments, Folk & World Instruments
217 | products, Musical Instruments, Guitars
218 | products, Musical Instruments, Keyboards
219 | products, Musical Instruments, Live Sound & Stage
220 | products, Musical Instruments, Microphones & Accessories
221 | products, Musical Instruments, Recording Equipment
222 | products, Industrial and Scientific, Industrial Electrical
223 | products, Industrial and Scientific, Occupational Health & Safety Products
224 | products, Industrial and Scientific, Test
225 | products, Industrial and Scientific, Measure & Inspect
226 | products, Industrial and Scientific, Abrasive & Finishing Products
227 | products, Industrial and Scientific, Janitorial & Sanitation Supplies
228 | products, Industrial and Scientific, Industrial Hardware
229 | products, Handmade, Jewelry
230 | products, Handmade, Clothing
231 | products, Handmade, Home decor
232 | products, Handmade, Art
233 | products, Handmade, Stationery
234 | products, Handmade, Toys
235 | products, Handmade, Beauty products
236 | products, Handmade, Accessories
237 | products, Handmade, Pet supplies
238 | products, Handmade, Craft supplies
239 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth1_products:
--------------------------------------------------------------------------------
1 | products, Electronics, Headphones
2 | products, Electronics, Speakers
3 | products, Electronics, Cameras
4 | products, Electronics, Smartphones
5 | products, Electronics, Smartwatches
6 | products, Electronics, Tablets
7 | products, Electronics, Laptops
8 | products, Electronics, Desktop Computers
9 | products, Electronics, Home Security Systems
10 | products, Electronics, Drones
11 | products, Electronics, Gaming Consoles
12 | products, Electronics, Virtual Reality Headsets
13 | products, Electronics, GPS Devices
14 | products, Electronics, E-book Readers
15 | products, Clothing, Women's clothing
16 | products, Clothing, Men's clothing
17 | products, Clothing, Kid's clothing
18 | products, Clothing, Shoes
19 | products, Clothing, Accessories
20 | products, Books, Fiction
21 | products, Books, Non-fiction
22 | products, Books, Mystery
23 | products, Books, Science Fiction
24 | products, Books, Romance
25 | products, Books, Fantasy
26 | products, Books, Thriller
27 | products, Books, Biography
28 | products, Books, History
29 | products, Books, Children's Books
30 | products, Books, Self-help
31 | products, Books, Cookbook
32 | products, Books, Art
33 | products, Books, Travel
34 | products, Books, Science
35 | products, Books, Religion
36 | products, Books, Reference
37 | products, Books, Poetry
38 | products, Books, Graphic Novels
39 | products, Books, Horror
40 | products, Books, Young Adult
41 | products, Beauty, Makeup
42 | products, Beauty, Skincare
43 | products, Beauty, Haircare
44 | products, Beauty, Fragrance
45 | products, Beauty, Bath & Body
46 | products, Beauty, Tools & Accessories
47 | products, Beauty, Men's Grooming
48 | products, Beauty, Beauty Supplements
49 | products, Home Improvement, Power tools
50 | products, Home Improvement, Hand tools
51 | products, Home Improvement, Electrical
52 | products, Home Improvement, Hardware
53 | products, Home Improvement, Plumbing
54 | products, Home Improvement, Building supplies
55 | products, Home Improvement, Paint
56 | products, Home Improvement, Heating
57 | products, Home Improvement, Cooling
58 | products, Home Improvement, Home security
59 | products, Home Improvement, Lighting
60 | products, Home Improvement, Storage
61 | products, Home Improvement, Cleaning supplies
62 | products, Home Improvement, Appliances
63 | products, Toys, Action Figures
64 | products, Toys, Dolls
65 | products, Toys, Stuffed Animals
66 | products, Toys, Building Sets
67 | products, Toys, Arts & Crafts
68 | products, Toys, Vehicles
69 | products, Toys, Puzzles
70 | products, Toys, Outdoor Play
71 | products, Toys, Learning & Education
72 | products, Toys, Games
73 | products, Toys, Play Food & Grocery
74 | products, Toys, Electronic Toys
75 | products, Toys, Musical Instruments
76 | products, Automotive, Car parts
77 | products, Automotive, Car accessories
78 | products, Automotive, Car care products
79 | products, Automotive, Car electronics
80 | products, Automotive, Car interior accessories
81 | products, Automotive, Car exterior accessories
82 | products, Automotive, Car performance parts
83 | products, Automotive, Motorcycle parts
84 | products, Automotive, Motorcycle accessories
85 | products, Sports and Outdoors, Camping and Hiking
86 | products, Sports and Outdoors, Exercise and Fitness
87 | products, Sports and Outdoors, Golf
88 | products, Sports and Outdoors, Hunting and Fishing
89 | products, Sports and Outdoors, Outdoor Recreation
90 | products, Sports and Outdoors, Team Sports
91 | products, Sports and Outdoors, Water Sports
92 | products, Sports and Outdoors, Winter Sports
93 | products, Health and Personal Care, Vitamins
94 | products, Health and Personal Care, Supplements
95 | products, Health and Personal Care, Health Care
96 | products, Health and Personal Care, Baby & Child Care
97 | products, Health and Personal Care, Sports Nutrition
98 | products, Health and Personal Care, Household Supplies
99 | products, Health and Personal Care, Medical Supplies & Equipment
100 | products, Health and Personal Care, Sexual Wellness
101 | products, Health and Personal Care, Health & Wellness
102 | products, Health and Personal Care, Oral Care
103 | products, Health and Personal Care, Beauty & Grooming
104 | products, Health and Personal Care, Pet Supplies
105 | products, Health and Personal Care, Health & Household
106 | products, Health and Personal Care, Baby & Child Care
107 | products, Health and Personal Care, Health Supplies
108 | products, Jewelry, Rings
109 | products, Jewelry, Necklaces
110 | products, Jewelry, Bracelets
111 | products, Jewelry, Earrings
112 | products, Jewelry, Pendants
113 | products, Jewelry, Brooches
114 | products, Jewelry, Charms
115 | products, Jewelry, Anklets
116 | products, Jewelry, Jewelry Sets
117 | products, Jewelry, Body Jewelry
118 | products, Jewelry, Hair Jewelry
119 | products, Jewelry, Toe Rings
120 | products, Jewelry, Cufflinks
121 | products, Jewelry, Tie Clips
122 | products, Pet Supplies, Dog food
123 | products, Pet Supplies, Cat food
124 | products, Pet Supplies, Dog treats
125 | products, Pet Supplies, Cat treats
126 | products, Pet Supplies, Cat litter
127 | products, Pet Supplies, Dog grooming supplies
128 | products, Pet Supplies, Cat grooming supplies
129 | products, Pet Supplies, Dog toys
130 | products, Pet Supplies, Cat toys
131 | products, Pet Supplies, Dog collars and leashes
132 | products, Pet Supplies, Cat collars and leashes
133 | products, Pet Supplies, Fish supplies
134 | products, Pet Supplies, Small animal supplies
135 | products, Pet Supplies, Bird supplies
136 | products, Baby, Baby girls' clothing
137 | products, Baby, baby boys' clothing
138 | products, Baby, baby shoes
139 | products, Baby, baby accessories
140 | products, Baby, baby gear
141 | products, Baby, baby feeding
142 | products, Baby, baby bathing
143 | products, Baby, baby nursery
144 | products, Baby, baby toys
145 | products, Office Products, Office furniture
146 | products, Office Products, Office lighting
147 | products, Office Products, Desk accessories
148 | products, Office Products, Filing products
149 | products, Office Products, Office electronics
150 | products, Office Products, Paper products
151 | products, Office Products, Writing instruments
152 | products, Office Products, Presentation supplies
153 | products, Office Products, Office storage
154 | products, Office Products, Office maintenance
155 | products, Office Products, Office organization
156 | products, Tools and Home Improvement, Power Tools
157 | products, Tools and Home Improvement, Hand Tools
158 | products, Tools and Home Improvement, Tool Organizers
159 | products, Tools and Home Improvement, Tool Sets
160 | products, Tools and Home Improvement, Measuring and Layout Tools
161 | products, Tools and Home Improvement, Welding and Soldering
162 | products, Tools and Home Improvement, Tool Storage
163 | products, Tools and Home Improvement, Material Handling
164 | products, Tools and Home Improvement, Jobsite Safety
165 | products, Tools and Home Improvement, Flashlights
166 | products, Tools and Home Improvement, Jobsite Radios
167 | products, Tools and Home Improvement, Welding Equipment
168 | products, Tools and Home Improvement, Woodworking
169 | products, Tools and Home Improvement, Workshop Equipment
170 | products, Tools and Home Improvement, Air Tools
171 | products, Tools and Home Improvement, Welding and Soldering
172 | products, Tools and Home Improvement, Power Finishing Tools
173 | products, Tools and Home Improvement, Fireplace and Stove Accessories
174 | products, Tools and Home Improvement, Abrasive and Finishing Products
175 | products, Tools and Home Improvement, Building Supplies
176 | products, Tools and Home Improvement, Cleaning Supplies
177 | products, Tools and Home Improvement, Electrical
178 | products, Tools and Home Improvement, Hardware
179 | products, Tools and Home Improvement, Kitchen and Bath Fixtures
180 | products, Tools and Home Improvement, Light Bulbs
181 | products, Tools and Home Improvement, Lighting and Ceiling Fans
182 | products, Tools and Home Improvement, Measuring and Layout Tools
183 | products, Tools and Home Improvement, Paint
184 | products, Tools and Home Improvement, Wall Treatments and Supplies
185 | products, Tools and Home Improvement, Power and Hand Tools
186 | products, Tools and Home Improvement, Rough Plumbing
187 | products, Tools and Home Improvement, Safety and Security
188 | products, Tools and Home Improvement, Storage and Home Organization
189 | products, Tools and Home Improvement, Winterization products
190 | Outdoor Power Tools
191 | products, Tools and Home Improvement, Hand Tools
192 | products, Tools and Home Improvement, Power Tools
193 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments
194 | products, Tools and Home Improvement, Electrical
195 | products, Tools and Home Improvement, Hardware
196 | products, Tools and Home Improvement, Kitchen and Bath Fixtures
197 | products, Tools and Home Improvement, Light Bulbs
198 | products, Tools and Home Improvement, Light Fixtures
199 | products, Tools and Home Improvement, Measuring and Layout Tools
200 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments
201 | products, Tools and Home Improvement, Power and Hand Tools
202 | products, Tools and Home Improvement, Rough Plumbing
203 | products, Tools and Home Improvement, Safety and Security
204 | products, Tools and Home Improvement, Storage and Home Organization
205 | products, Tools and Home Improvement, Building Supplies
206 | products, Tools and Home Improvement, Rough Plumbing
207 | products, Tools and Home Improvement, Building Materials
208 | products, Tools and Home Improvement, Material Handling
209 | products, Tools and Home Improvement, Measuring and Layout Tools
210 | products, Tools and Home Improvement, Jobsite Safety
211 | products, Tools and Home Improvement, Work Wear and Safety Gear
212 | products, Tools and Home Improvement, Air Tools & Compressors
213 | products, Musical Instruments, Amplifiers & Effects
214 | products, Musical Instruments, Band & Orchestra
215 | products, Musical Instruments, Drums & Percussion
216 | products, Musical Instruments, Folk & World Instruments
217 | products, Musical Instruments, Guitars
218 | products, Musical Instruments, Keyboards
219 | products, Musical Instruments, Live Sound & Stage
220 | products, Musical Instruments, Microphones & Accessories
221 | products, Musical Instruments, Recording Equipment
222 | products, Industrial and Scientific, Industrial Electrical
223 | products, Industrial and Scientific, Occupational Health & Safety Products
224 | products, Industrial and Scientific, Test
225 | products, Industrial and Scientific, Measure & Inspect
226 | products, Industrial and Scientific, Abrasive & Finishing Products
227 | products, Industrial and Scientific, Janitorial & Sanitation Supplies
228 | products, Industrial and Scientific, Industrial Hardware
229 | products, Handmade, Jewelry
230 | products, Handmade, Clothing
231 | products, Handmade, Home decor
232 | products, Handmade, Art
233 | products, Handmade, Stationery
234 | products, Handmade, Toys
235 | products, Handmade, Beauty products
236 | products, Handmade, Accessories
237 | products, Handmade, Pet supplies
238 | products, Handmade, Craft supplies
239 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Automotive:
--------------------------------------------------------------------------------
1 | products, Automotive, Car parts
2 | products, Automotive, Car accessories
3 | products, Automotive, Car care products
4 | products, Automotive, Car electronics
5 | products, Automotive, Car interior accessories
6 | products, Automotive, Car exterior accessories
7 | products, Automotive, Car performance parts
8 | products, Automotive, Motorcycle parts
9 | products, Automotive, Motorcycle accessories
10 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Baby:
--------------------------------------------------------------------------------
1 | products, Baby, Baby girls' clothing
2 | products, Baby, baby boys' clothing
3 | products, Baby, baby shoes
4 | products, Baby, baby accessories
5 | products, Baby, baby gear
6 | products, Baby, baby feeding
7 | products, Baby, baby bathing
8 | products, Baby, baby nursery
9 | products, Baby, baby toys
10 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Beauty:
--------------------------------------------------------------------------------
1 | products, Beauty, Makeup
2 | products, Beauty, Skincare
3 | products, Beauty, Haircare
4 | products, Beauty, Fragrance
5 | products, Beauty, Bath & Body
6 | products, Beauty, Tools & Accessories
7 | products, Beauty, Men's Grooming
8 | products, Beauty, Beauty Supplements
9 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Books:
--------------------------------------------------------------------------------
1 | products, Books, Fiction
2 | products, Books, Non-fiction
3 | products, Books, Mystery
4 | products, Books, Science Fiction
5 | products, Books, Romance
6 | products, Books, Fantasy
7 | products, Books, Thriller
8 | products, Books, Biography
9 | products, Books, History
10 | products, Books, Children's Books
11 | products, Books, Self-help
12 | products, Books, Cookbook
13 | products, Books, Art
14 | products, Books, Travel
15 | products, Books, Science
16 | products, Books, Religion
17 | products, Books, Reference
18 | products, Books, Poetry
19 | products, Books, Graphic Novels
20 | products, Books, Horror
21 | products, Books, Young Adult
22 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Clothing:
--------------------------------------------------------------------------------
1 | products, Clothing, Women's clothing
2 | products, Clothing, Men's clothing
3 | products, Clothing, Kid's clothing
4 | products, Clothing, Shoes
5 | products, Clothing, Accessories
6 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Electronics:
--------------------------------------------------------------------------------
1 | products, Electronics, Headphones
2 | products, Electronics, Speakers
3 | products, Electronics, Cameras
4 | products, Electronics, Smartphones
5 | products, Electronics, Smartwatches
6 | products, Electronics, Tablets
7 | products, Electronics, Laptops
8 | products, Electronics, Desktop Computers
9 | products, Electronics, Home Security Systems
10 | products, Electronics, Drones
11 | products, Electronics, Gaming Consoles
12 | products, Electronics, Virtual Reality Headsets
13 | products, Electronics, GPS Devices
14 | products, Electronics, E-book Readers
15 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Handmade:
--------------------------------------------------------------------------------
1 | products, Handmade, Jewelry
2 | products, Handmade, Clothing
3 | products, Handmade, Home decor
4 | products, Handmade, Art
5 | products, Handmade, Stationery
6 | products, Handmade, Toys
7 | products, Handmade, Beauty products
8 | products, Handmade, Accessories
9 | products, Handmade, Pet supplies
10 | products, Handmade, Craft supplies
11 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Health and Personal Care:
--------------------------------------------------------------------------------
1 | products, Health and Personal Care, Vitamins
2 | products, Health and Personal Care, Supplements
3 | products, Health and Personal Care, Health Care
4 | products, Health and Personal Care, Baby & Child Care
5 | products, Health and Personal Care, Sports Nutrition
6 | products, Health and Personal Care, Household Supplies
7 | products, Health and Personal Care, Medical Supplies & Equipment
8 | products, Health and Personal Care, Sexual Wellness
9 | products, Health and Personal Care, Health & Wellness
10 | products, Health and Personal Care, Oral Care
11 | products, Health and Personal Care, Beauty & Grooming
12 | products, Health and Personal Care, Pet Supplies
13 | products, Health and Personal Care, Health & Household
14 | products, Health and Personal Care, Baby & Child Care
15 | products, Health and Personal Care, Health Supplies
16 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Home Improvement:
--------------------------------------------------------------------------------
1 | products, Home Improvement, Power tools
2 | products, Home Improvement, Hand tools
3 | products, Home Improvement, Electrical
4 | products, Home Improvement, Hardware
5 | products, Home Improvement, Plumbing
6 | products, Home Improvement, Building supplies
7 | products, Home Improvement, Paint
8 | products, Home Improvement, Heating
9 | products, Home Improvement, Cooling
10 | products, Home Improvement, Home security
11 | products, Home Improvement, Lighting
12 | products, Home Improvement, Storage
13 | products, Home Improvement, Cleaning supplies
14 | products, Home Improvement, Appliances
15 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Industrial and Scientific:
--------------------------------------------------------------------------------
1 | products, Industrial and Scientific, Industrial Electrical
2 | products, Industrial and Scientific, Occupational Health & Safety Products
3 | products, Industrial and Scientific, Test
4 | products, Industrial and Scientific, Measure & Inspect
5 | products, Industrial and Scientific, Abrasive & Finishing Products
6 | products, Industrial and Scientific, Janitorial & Sanitation Supplies
7 | products, Industrial and Scientific, Industrial Hardware
8 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Jewelry:
--------------------------------------------------------------------------------
1 | products, Jewelry, Rings
2 | products, Jewelry, Necklaces
3 | products, Jewelry, Bracelets
4 | products, Jewelry, Earrings
5 | products, Jewelry, Pendants
6 | products, Jewelry, Brooches
7 | products, Jewelry, Charms
8 | products, Jewelry, Anklets
9 | products, Jewelry, Jewelry Sets
10 | products, Jewelry, Body Jewelry
11 | products, Jewelry, Hair Jewelry
12 | products, Jewelry, Toe Rings
13 | products, Jewelry, Cufflinks
14 | products, Jewelry, Tie Clips
15 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Musical Instruments:
--------------------------------------------------------------------------------
1 | products, Musical Instruments, Amplifiers & Effects
2 | products, Musical Instruments, Band & Orchestra
3 | products, Musical Instruments, Drums & Percussion
4 | products, Musical Instruments, Folk & World Instruments
5 | products, Musical Instruments, Guitars
6 | products, Musical Instruments, Keyboards
7 | products, Musical Instruments, Live Sound & Stage
8 | products, Musical Instruments, Microphones & Accessories
9 | products, Musical Instruments, Recording Equipment
10 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Office Products:
--------------------------------------------------------------------------------
1 | products, Office Products, Office furniture
2 | products, Office Products, Office lighting
3 | products, Office Products, Desk accessories
4 | products, Office Products, Filing products
5 | products, Office Products, Office electronics
6 | products, Office Products, Paper products
7 | products, Office Products, Writing instruments
8 | products, Office Products, Presentation supplies
9 | products, Office Products, Office storage
10 | products, Office Products, Office maintenance
11 | products, Office Products, Office organization
12 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Pet Supplies:
--------------------------------------------------------------------------------
1 | products, Pet Supplies, Dog food
2 | products, Pet Supplies, Cat food
3 | products, Pet Supplies, Dog treats
4 | products, Pet Supplies, Cat treats
5 | products, Pet Supplies, Cat litter
6 | products, Pet Supplies, Dog grooming supplies
7 | products, Pet Supplies, Cat grooming supplies
8 | products, Pet Supplies, Dog toys
9 | products, Pet Supplies, Cat toys
10 | products, Pet Supplies, Dog collars and leashes
11 | products, Pet Supplies, Cat collars and leashes
12 | products, Pet Supplies, Fish supplies
13 | products, Pet Supplies, Small animal supplies
14 | products, Pet Supplies, Bird supplies
15 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Sports and Outdoors:
--------------------------------------------------------------------------------
1 | products, Sports and Outdoors, Camping and Hiking
2 | products, Sports and Outdoors, Exercise and Fitness
3 | products, Sports and Outdoors, Golf
4 | products, Sports and Outdoors, Hunting and Fishing
5 | products, Sports and Outdoors, Outdoor Recreation
6 | products, Sports and Outdoors, Team Sports
7 | products, Sports and Outdoors, Water Sports
8 | products, Sports and Outdoors, Winter Sports
9 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Tools and Home Improvement:
--------------------------------------------------------------------------------
1 | products, Tools and Home Improvement, Power Tools
2 | products, Tools and Home Improvement, Hand Tools
3 | products, Tools and Home Improvement, Tool Organizers
4 | products, Tools and Home Improvement, Tool Sets
5 | products, Tools and Home Improvement, Measuring and Layout Tools
6 | products, Tools and Home Improvement, Welding and Soldering
7 | products, Tools and Home Improvement, Tool Storage
8 | products, Tools and Home Improvement, Material Handling
9 | products, Tools and Home Improvement, Jobsite Safety
10 | products, Tools and Home Improvement, Flashlights
11 | products, Tools and Home Improvement, Jobsite Radios
12 | products, Tools and Home Improvement, Welding Equipment
13 | products, Tools and Home Improvement, Woodworking
14 | products, Tools and Home Improvement, Workshop Equipment
15 | products, Tools and Home Improvement, Air Tools
16 | products, Tools and Home Improvement, Welding and Soldering
17 | products, Tools and Home Improvement, Power Finishing Tools
18 | products, Tools and Home Improvement, Fireplace and Stove Accessories
19 | products, Tools and Home Improvement, Abrasive and Finishing Products
20 | products, Tools and Home Improvement, Building Supplies
21 | products, Tools and Home Improvement, Cleaning Supplies
22 | products, Tools and Home Improvement, Electrical
23 | products, Tools and Home Improvement, Hardware
24 | products, Tools and Home Improvement, Kitchen and Bath Fixtures
25 | products, Tools and Home Improvement, Light Bulbs
26 | products, Tools and Home Improvement, Lighting and Ceiling Fans
27 | products, Tools and Home Improvement, Measuring and Layout Tools
28 | products, Tools and Home Improvement, Paint
29 | products, Tools and Home Improvement, Wall Treatments and Supplies
30 | products, Tools and Home Improvement, Power and Hand Tools
31 | products, Tools and Home Improvement, Rough Plumbing
32 | products, Tools and Home Improvement, Safety and Security
33 | products, Tools and Home Improvement, Storage and Home Organization
34 | products, Tools and Home Improvement, Winterization products
35 | Outdoor Power Tools
36 | products, Tools and Home Improvement, Hand Tools
37 | products, Tools and Home Improvement, Power Tools
38 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments
39 | products, Tools and Home Improvement, Electrical
40 | products, Tools and Home Improvement, Hardware
41 | products, Tools and Home Improvement, Kitchen and Bath Fixtures
42 | products, Tools and Home Improvement, Light Bulbs
43 | products, Tools and Home Improvement, Light Fixtures
44 | products, Tools and Home Improvement, Measuring and Layout Tools
45 | products, Tools and Home Improvement, Painting Supplies and Wall Treatments
46 | products, Tools and Home Improvement, Power and Hand Tools
47 | products, Tools and Home Improvement, Rough Plumbing
48 | products, Tools and Home Improvement, Safety and Security
49 | products, Tools and Home Improvement, Storage and Home Organization
50 | products, Tools and Home Improvement, Building Supplies
51 | products, Tools and Home Improvement, Rough Plumbing
52 | products, Tools and Home Improvement, Building Materials
53 | products, Tools and Home Improvement, Material Handling
54 | products, Tools and Home Improvement, Measuring and Layout Tools
55 | products, Tools and Home Improvement, Jobsite Safety
56 | products, Tools and Home Improvement, Work Wear and Safety Gear
57 | products, Tools and Home Improvement, Air Tools & Compressors
58 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tem/e-commerce platform like Amazon_depth2_products, Toys:
--------------------------------------------------------------------------------
1 | products, Toys, Action Figures
2 | products, Toys, Dolls
3 | products, Toys, Stuffed Animals
4 | products, Toys, Building Sets
5 | products, Toys, Arts & Crafts
6 | products, Toys, Vehicles
7 | products, Toys, Puzzles
8 | products, Toys, Outdoor Play
9 | products, Toys, Learning & Education
10 | products, Toys, Games
11 | products, Toys, Play Food & Grocery
12 | products, Toys, Electronic Toys
13 | products, Toys, Musical Instruments
14 |
--------------------------------------------------------------------------------
/graph_generation/gen_results/tree_wInstanceNum_products_e-commerce platform like Amazon.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/graph_generation/gen_results/tree_wInstanceNum_products_e-commerce platform like Amazon.pkl
--------------------------------------------------------------------------------
/graph_generation/human_item_generation_gibbsSampling_embedEstimation.py:
--------------------------------------------------------------------------------
1 | import random
2 | from Utils import DataGenAgent
3 | import pickle
4 | import os
5 | import Exp_Utils.TimeLogger as logger
6 | from Exp_Utils.TimeLogger import log
7 | from Exp_Utils.Emailer import SendMail
8 | import numpy as np
9 | from scipy.stats import norm
10 | import copy
11 |
12 | class HumanItemRelationGeneration(DataGenAgent):
13 | def __init__(self, item_list, length_sampler, descs, hyperparams, text_embedding_dict=None):
14 | super(HumanItemRelationGeneration, self).__init__()
15 |
16 | self.item_list = item_list
17 | self.length_sampler = length_sampler
18 | self.descs = descs
19 | self.hyperparams = hyperparams
20 | self.hyperparams['item_num'] = len(self.item_list)
21 | self.item_to_id = dict()
22 | for iid, item in enumerate(item_list):
23 | self.item_to_id[item] = iid
24 | self.reject_cnt = 0
25 | self.item_perm = np.random.permutation(len(self.item_list))
26 | self.score_history = []
27 | self.text_embedding_dict = dict() if text_embedding_dict is None else text_embedding_dict
28 |
29 | # fst_text = 'Clothing, Plus size clothing, Plus size formal wear, Dresses'
30 | # scd_text = 'Health & Household, Household Supplies, Air fresheners, Scented beads & charms'
31 | # print(fst_text)
32 | # print(scd_text)
33 | # print(self.similarity(self.text_embedding(fst_text), self.text_embedding(scd_text)))
34 | # exit()
35 |
36 | def binvec2list(self, bin_sample_vec):
37 | idxs = np.reshape(np.argwhere(bin_sample_vec != 0), [-1])
38 | return list(map(lambda x: self.item_list[x], idxs))
39 |
40 | def list_text(self, item_list, nums):
41 | item_num_list = list(zip(item_list, nums))
42 | item_num_list.sort(key=lambda x: x[1], reverse=True)
43 | ret = ''
44 | for i, pair in enumerate(item_num_list):
45 | item, num = pair[0], pair[1]
46 | ret += '{idx}. {item}. Frequency: {num}\n'.format(idx=i, item=item, num=num)
47 | return ret
48 |
49 | def summarize(self, item_list):
50 | def fuse(item_list, prefix):
51 | for i in range(len(item_list)):
52 | item = item_list[i]
53 | if item.startswith(prefix):
54 | item_list[i] = prefix
55 | return item_list
56 | def count_and_shrink(item_list):
57 | dic = dict()
58 | for item in item_list:
59 | if item not in dic:
60 | dic[item] = 0
61 | dic[item] += 1
62 | ret_item, ret_cnt = [], []
63 | for key, cnt in dic.items():
64 | ret_item.append(key)
65 | ret_cnt.append(cnt)
66 | return ret_item, ret_cnt
67 | def count_prefixes_of_different_depth(item_list, max_depth):
68 | ret_item_list = []
69 | prefix_dicts = [dict() for i in range(max_depth + 1)]
70 | for item in item_list:
71 | num_idx = item.index(' #')
72 | tem_item = item[:num_idx]
73 | ret_item_list.append(tem_item)
74 | entities = tem_item.split(', ')
75 | entities = list(map(lambda entity: entity.strip(), entities))
76 | for depth in range(max_depth + 1):
77 | if depth + 1 >= len(entities):
78 | break
79 | tem_prefix = ', '.join(entities[:depth + 1])
80 | if tem_prefix not in prefix_dicts[depth]:
81 | prefix_dicts[depth][tem_prefix] = 0
82 | prefix_dicts[depth][tem_prefix] += 1
83 | return prefix_dicts, ret_item_list
84 | max_depth = len(item_list[0].split(', '))
85 | prefix_dicts, ret_item_list = count_prefixes_of_different_depth(item_list, max_depth)
86 | if len(ret_item_list) < self.hyperparams['context_limit']:
87 | return ret_item_list, [1] * len(ret_item_list)
88 | # greedy search
89 | flag = False
90 | for depth in range(max_depth, -1, -1):
91 | prefix_list = [(prefix, cnt) for prefix, cnt in prefix_dicts[depth].items()]
92 | prefix_list.sort(key=lambda x: x[1], reverse=True)
93 | for prefix, cnt in prefix_list:
94 | if cnt == 1:
95 | break
96 | ret_item_list = fuse(ret_item_list, prefix)
97 | if depth != 0:
98 | # adjust the counts of shallow entities
99 | shrinked_prefix = ', '.join(prefix.split(', ')[:-1])
100 | prefix_dicts[depth - 1][shrinked_prefix] -= cnt - 1
101 | if len(set(ret_item_list)) <= self.hyperparams['context_limit']:
102 | flag=True
103 | break
104 | if flag:
105 | return count_and_shrink(ret_item_list)
106 | return count_and_shrink(ret_item_list)
107 |
108 | def text_embedding(self, text):
109 | if text in self.text_embedding_dict:
110 | embeds = self.text_embedding_dict[text]
111 | return embeds
112 | print('Embedding not found!')
113 | print(text)
114 | exit()
115 | embedding = self.openai_embedding(text)
116 | self.text_embedding_dict[text] = embedding
117 | return embedding
118 |
119 | def similarity(self, fst_embed, scd_embed):
120 | # fst_embed = fst_embed / np.sqrt(np.sum(np.square(fst_embed)))
121 | # scd_embed = scd_embed / np.sqrt(np.sum(np.square(scd_embed)))
122 | return np.sum(fst_embed * scd_embed)
123 |
124 | def dynamic_normalize(self, score):
125 | self.score_history.append(score)
126 | if len(self.score_history) > 5000:
127 | self.score_history = self.score_history[-5000:]
128 | if len(self.score_history) < 5:
129 | return max(min(1.0, score), 0.0)
130 | score_samples = np.array(self.score_history)
131 | mean = np.mean(score_samples)
132 | std = np.sqrt(np.mean((score_samples - mean) ** 2))
133 | minn = mean - 1.96 * std
134 | maxx = mean + 1.96 * std#2.96 * std
135 | # minn = np.min(self.score_history)
136 | # maxx = np.max(self.score_history)
137 | ret = (score - minn) / (maxx - minn)
138 | ret = max(min(ret, 1.0), 0.0)
139 | # add margin
140 | # ret = (ret + 0.1) / 1.2
141 | # print('minn', minn)
142 | # print('maxx', maxx)
143 | # print('score', score)
144 | # print((score - minn) / (maxx - minn))
145 |
146 | # if len(self.score_history) > 10000:
147 | # import matplotlib.pyplot as plt
148 | # plt.hist(self.score_history, 100)
149 | # plt.savefig('tem.pdf')
150 | # exit()
151 |
152 | return ret
153 |
154 | def estimate_probability(self, bin_sample_vec, cur_dim, is_deleting=False):
155 | item_list = self.binvec2list(bin_sample_vec)
156 | new_item = self.item_list[cur_dim]
157 | candidate_embedding = self.text_embedding(new_item[:new_item.index(' #')])
158 |
159 | # summary similarity
160 | # if 'real_data' not in self.descs['data_name']:
161 | # summarized_item_list, nums = self.summarize(item_list)
162 | # summarized_score = 0
163 | # for i, summarized_item in enumerate(summarized_item_list):
164 | # sim = self.similarity(self.text_embedding(summarized_item), candidate_embedding)
165 | # summarized_score += sim * nums[i] / sum(nums)
166 |
167 | # top instance similarity
168 | # sims = list(map(lambda item: (self.similarity(self.text_embedding(item[:item.index(' #')]), candidate_embedding), item[:item.index(' #')]), item_list))
169 | # sims.sort(reverse=True, key=lambda x: x[0])
170 | # sim_scores = list(map(lambda x: x[0], sims[:self.hyperparams['context_limit']]))
171 | # instance_score = sum(sim_scores) / len(sim_scores)
172 | # most_sim_items = list(map(lambda x: x[1], sims[:self.hyperparams['context_limit']]))
173 |
174 | embed_list = list(map(lambda item: self.text_embedding(item[:item.index(' #')]), item_list))
175 | avg_embed = sum(embed_list) / len(embed_list)
176 | instance_score = np.sum(avg_embed * candidate_embedding)
177 |
178 |
179 | # print('## Candidate item:', new_item)
180 | # print('## Summary context:')
181 | # print(self.list_text(summarized_item_list, nums))
182 | # print('## Most similar items:')
183 | # print(self.list_text(most_sim_items, [1] * len(most_sim_items)))
184 |
185 | # if 'real_data' in self.descs['data_name']:
186 | score = instance_score
187 | # else:
188 | # score = (summarized_score + instance_score) / 2
189 | # log('Original score: summary-{sum_score}, instance-{ins_score}, sum-{tot_score}'.format(sum_score=summarized_score, ins_score=instance_score, tot_score=score))
190 | # score = self.dynamic_normalize(score)
191 | # log('Dynamic normalized probability: {prob}'.format(prob=score))
192 | interaction_num = np.sum(bin_sample_vec != 0)
193 | interaction_prob = 1.0 / (1.0 + np.exp((interaction_num - self.hyperparams['length_center'])/(self.hyperparams['length_center']//2)))#(100, 50), (150, 75)
194 | score = score * interaction_prob# * 1.1
195 | # log('Interaction num normalized probability: {prob}'.format(prob=score))
196 | return score
197 |
198 | def update_sample(self, last_sample, cur_dim, should_include):
199 | if last_sample[cur_dim] == 0.0 and should_include or last_sample[cur_dim] > 0.0 and not should_include:
200 | new_sample = copy.deepcopy(last_sample)
201 | new_sample[cur_dim] = 1.0 - last_sample[cur_dim]
202 | return new_sample, True
203 | else:
204 | return last_sample, False
205 |
206 | def Gibbs_Sampling(self):
207 | samples = []
208 | idx = 0
209 | update_cnt = 0
210 | cur_community = 0
211 | for step in range(self.hyperparams['sample_num']):
212 | if step % self.hyperparams['restart_step'] == 0:
213 | samples.append(self.random_sample())
214 | cur_community = (cur_community + 1) % self.hyperparams['community_num']
215 | last_sample = samples[-1]
216 | update_flag = False
217 | for small_step in range(self.hyperparams['gibbs_step']):
218 | cur_dim = self.item_perm[idx]
219 | nnz = np.sum(last_sample != 0)
220 | delete_dice = random.uniform(0, 1)
221 | if nnz > self.hyperparams['delete_nnz'] and delete_dice < 0.5:# 0.5, 0.4, 0.75, 2
222 | cur_dim = np.random.choice(np.reshape(np.argwhere(last_sample > 0.0), [-1]))
223 | tem_delete_flag = False
224 | if last_sample[cur_dim] > 0.0:
225 | # log('Deleting')
226 | tem_delete_flag = True
227 | last_sample[cur_dim] = 0.0
228 | idx = (idx + 1) % len(self.item_list)
229 | self.failure = 0
230 | prob = self.estimate_probability(last_sample, cur_dim, tem_delete_flag)
231 |
232 | # community modifier
233 | diff = abs(cur_community - cur_dim % self.hyperparams['community_num'])
234 | prob *= self.hyperparams['com_decay'] ** diff
235 |
236 | dice = random.uniform(0, 1) - self.hyperparams['dice_shift']
237 | # if dice < prob:
238 | # print('Edge should be included')
239 | # else:
240 | # print('Edge should not be included')
241 | last_sample, change_flag = self.update_sample(last_sample, cur_dim, dice < prob)
242 | if tem_delete_flag:
243 | change_flag = not change_flag
244 | if change_flag:
245 | if small_step == 0:
246 | log('Sample Updated! Step {step}_{small_step}, update cnt {update_cnt}, interaction num {int_num}, sample num {samp_num}'.format(step=step, small_step=small_step, update_cnt=update_cnt, int_num=np.sum(last_sample!=0.0), samp_num=len(samples)), oneline=True)
247 | self.reject_cnt = 0
248 | update_cnt += 1
249 | update_flag = True
250 | else:
251 | if small_step == 0:
252 | log('Sample UNCHANGED! Step {step}_{small_step}, update cnt {update_cnt}, interaction num {int_num}, sample num {samp_num}'.format(step=step, small_step=small_step, update_cnt=update_cnt, int_num=np.sum(last_sample!=0.0), samp_num=len(samples)), oneline=True)
253 | self.reject_cnt += 1
254 | # print('*******\n')
255 | if self.reject_cnt > 50:
256 | log('Consecutive rejection {rej_cnt} when sampling!'.format(rej_cnt=self.reject_cnt), save=True)
257 | log('Last sample: {last_sample}'.format(last_sample=self.binvec2list(samples[-1])))
258 | log('New sample: {new_sample}'.format(new_sample=self.binvec2list(self.update_sample(last_sample, cur_dim, dice >= prob))))
259 | log('Sending report email.', save=True)
260 | # SendMail(logger.logmsg)
261 | self.reject_cnt = 0
262 | break
263 | if update_flag:
264 | if step % self.hyperparams['restart_step'] < self.hyperparams['gibbs_skip_step']: # original 50
265 | samples[-1] = last_sample
266 | else:
267 | samples.append(last_sample)
268 | return samples
269 |
270 | def random_sample(self):
271 | picked_idxs = random.sample(list(range(len(self.item_list))), self.hyperparams['seed_num'])
272 | last_interaction = np.zeros(len(self.item_list))
273 | last_interaction[picked_idxs] = 1.0
274 | return last_interaction
275 |
276 | def run(self):
277 | samples_binvec = self.Gibbs_Sampling()
278 | picked_items_list = []
279 | for vec in samples_binvec:
280 | picked_items = self.binvec2list(vec)
281 | picked_items_list.append(picked_items)
282 | return picked_items_list
283 |
284 | def load_item_list(item_file, entity_file, item_num):
285 | if not os.path.exists(item_file):
286 | with open(entity_file, 'rb') as fs:
287 | entity_tree_root = pickle.load(fs)
288 | entity_tree_root.allocate_number(item_num)
289 | item_list = entity_tree_root.get_list_of_leaves('')
290 | with open(item_file, 'wb+') as fs:
291 | pickle.dump(item_list, fs)
292 | else:
293 | with open(item_file, 'rb') as fs:
294 | item_list = pickle.load(fs)
295 | return item_list
296 |
297 | def get_gen_iter(file_root, interaction_file_prefix):
298 | max_existing_iter = -1
299 | for filename in os.listdir(file_root):
300 | cur_filename = file_root + filename
301 | if interaction_file_prefix in cur_filename:
302 | st_idx = len(interaction_file_prefix)
303 | ed_idx = cur_filename.index('_iter-0.pkl')
304 | cur_iter = int(cur_filename[st_idx: ed_idx])
305 | max_existing_iter = max(max_existing_iter, cur_iter)
306 | return max_existing_iter
307 |
308 | def load_interactions(prev_interaction_file):
309 | if 'iter--1' in prev_interaction_file:
310 | return None
311 | with open(prev_interaction_file, 'rb') as fs:
312 | prev_interactions = pickle.load(fs)
313 | return prev_interactions
314 |
315 | def load_embedding_dict(embed_file):
316 | if not os.path.exists(embed_file):
317 | return None
318 | with open(embed_file, 'rb') as fs:
319 | ret = pickle.load(fs)
320 | return ret
321 |
322 | if __name__ == '__main__':
323 | # parameter specification
324 | descs = {
325 | 'data_name': 'gen_data_ecommerce',
326 | 'scenario_desc': 'e-commerce platform like Amazon',
327 | 'human_role': 'user',
328 | 'interaction_verb': 'interact',
329 | 'initial_entity': 'products',
330 | }
331 | hyperparams = {
332 | 'seed_num': 6,
333 | 'item_num': 1,#200000,
334 | 'sample_num': 1400,#20000
335 | 'context_limit': 15,
336 | 'gibbs_step': 1000,
337 | 'gen_base': 0,
338 | 'restart_step': 100,# shift community when restart
339 | 'gibbs_skip_step': 1,# 100
340 | 'delete_nnz': 1, # 5
341 | 'length_center': 400, # 60, 100, 150
342 | 'community_num': 7,
343 | 'itmfuse': True,
344 | 'com_decay': 0.95,
345 | # 'dice_shift': 0.1,
346 | 'dice_shift': 0.1,
347 | }
348 |
349 | # file name definition
350 | file_root = 'gen_results/datasets/{data_name}/'.format(data_name=descs['data_name'])
351 | entity_file = 'gen_results/tree_wInstanceNum_{initial_entity}_{scenario}.pkl'.format(initial_entity=descs['initial_entity'], scenario=descs['scenario_desc'])
352 | item_file = file_root + 'item_list.pkl'
353 | embed_file = file_root + 'embedding_dict.pkl'
354 |
355 | # load data
356 | item_list = load_item_list(item_file, entity_file, hyperparams['item_num'])
357 | if hyperparams['itmfuse']:
358 | item_list = list(map(lambda x: x[:x.index(' #')] + ' #1', item_list))
359 | embedding_dict = load_embedding_dict(embed_file)
360 |
361 | def length_sampler():
362 | min_len, max_len = hyperparams['min_len'], hyperparams['max_len']
363 | return random.randint(min_len, max_len)
364 |
365 | # generate new interactions
366 | generator = HumanItemRelationGeneration(item_list, length_sampler, descs, hyperparams, embedding_dict)
367 | sampled_interactions = generator.run()
368 |
369 | # store to disk
370 | interaction_file_prefix = file_root + 'interaction_base-'
371 | if 'gen_base' in hyperparams:
372 | gen_base = hyperparams['gen_base']
373 | next_interaction_file = interaction_file_prefix + str(gen_base) + '_iter-0.pkl'
374 | if os.path.exists(next_interaction_file):
375 | gen_base = get_gen_iter(file_root, interaction_file_prefix) + 1
376 | next_interaction_file = interaction_file_prefix + str(gen_base) + '_iter-0.pkl'
377 | with open(next_interaction_file, 'wb+') as fs:
378 | pickle.dump(sampled_interactions, fs)
--------------------------------------------------------------------------------
/graph_generation/instance_number_estimation_hierarchical.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from Utils import DataGenAgent, EntityTreeConstructer, EntityTreeNode
3 | import os
4 | import time
5 | import pickle
6 | import Exp_Utils.TimeLogger as logger
7 | from Exp_Utils.TimeLogger import log
8 | from Exp_Utils.Emailer import SendMail
9 |
10 | class HierarchicalInstanceNumberEstimator(DataGenAgent):
11 | def __init__(self, entity_tree_root, total_num, depth, initial_entity, scenario_desc):
12 | super(HierarchicalInstanceNumberEstimator, self).__init__()
13 |
14 | self.entity_tree_root = entity_tree_root
15 | self.total_num = total_num
16 | self.initial_entity = initial_entity
17 | self.scenario_desc = scenario_desc
18 | self.depth = depth
19 | self.failure = 0
20 |
21 | def _entity_list_to_text(self, entity_list):
22 | ret = ''
23 | for i, entity in enumerate(entity_list):
24 | ret += '{idx}. {entity}\n'.format(idx=i+1, entity=entity)
25 | return ret
26 |
27 | def interpret_one_answer(self, answer_text, subcategory):
28 | answer_lower = answer_text.lower()
29 | if subcategory.lower() not in answer_lower:
30 | log('ERROR: Entity name not found.', save=True)
31 | log('subcategory: {subcategory}'.format(subcategory=subcategory), save=True)
32 | log('answer_lower: {answer_lower}'.format(answer_lower=answer_lower), save=True)
33 | raise Exception('Entity name not found.')
34 | estimation_choices = ['average frequency', '1.2 times more frequent', '1.2 times less frequent', '1.5 times more frequent', '1.5 times less frequent', '2 times more frequent', '2 times less frequent', '4 times more frequent', '4 times less frequent', '8 times more frequent', '8 times less frequent']
35 | estimation_scores = [1.0, 1.2, 1/1.2, 1.5, 1/1.5, 2.0, 1/2.0, 4.0, 1/4.0, 8.0, 1/8.0]
36 | for i, choice in enumerate(estimation_choices):
37 | if choice in answer_lower:
38 | return estimation_scores[i]
39 | raise Exception('Estimation not found.')
40 |
41 | def interpret(self, answers_text, subcategories):
42 | answer_list = answers_text.strip().split('\n')
43 | assert len(answer_list) == len(subcategories), 'Length does not match.'
44 | answers = []
45 | for i in range(len(answer_list)):
46 | answers.append(self.interpret_one_answer(answer_list[i], subcategories[i]))
47 | return answers
48 |
49 | def estimate_subcategories(self, subcategories, category):
50 | subcategories_text = self._entity_list_to_text(subcategories)
51 | if category != self.initial_entity:
52 | text = '''In the context of {scenario_desc}, you are given a list of sub-categories below, which belong to the {category} category of {initial_entity}. Using your intuition and common sense, your goal is to identify the frequency of these sub-categories compared to the average frequency of all possible {category} {initial_entity}.'''.format(scenario_desc=self.scenario_desc, initial_entity=self.initial_entity, category=category)
53 | else:
54 | text = '''In the context of {scenario_desc}, you are given a list of sub-categories below, which belong to {initial_entity}. Using your intuition and common sense, your goal is to identify the frequency of these sub-categories compared to the average frequency of all possible {initial_entity}.'''.format(scenario_desc=self.scenario_desc, initial_entity=self.initial_entity)
55 | text += '''Your answer should contain one line for each of the sub-categories, EXACTLY following the following format: "[serial number]. [sub-category name same as in the input]; [your frequency estimation]; [one-sentence explaination for your estimation]". The frequency estimation should be one of the following choices: [average frequency, 1.2 times more/less frequent, 1.5 times more/less frequent, 2 times more/less frequent, 4 times more/less frequent, 8 times more/less frequent]. No other words should be included in your response. The sub-categories list is as follows:\n\n''' + subcategories_text
56 | # print('input')
57 | # print(text)
58 | try:
59 | answers_text = self.openai(text)
60 | print('Answers text:')
61 | print(answers_text)
62 | return self.interpret(answers_text, subcategories)
63 | except Exception as e:
64 | self.failure += 1
65 | if self.failure < 5:
66 | log('Exception occurs when interpreting. Retry in 10 seconds.', save=True)
67 | log('Exception message: {exception}'.format(exception=e), save=True)
68 | log('Failure times: {failure}'.format(failure=self.failure), save=True)
69 | log('Prompt text:\n{prompt}'.format(prompt=text), save=True)
70 | log('Response text:\n{response}'.format(response=answers_text), save=True)
71 | time.sleep(10)
72 | return self.estimate_subcategories(subcategories, category)
73 | else:
74 | log('Exception occurs {failure} times when interpreting. CANNOT HANDLE.'.format(failure=str(self.failure)), save=True)
75 | log('Exception message: {exception}'.format(exception=e), save=True)
76 | log('Prompt text:\n{prompt}'.format(prompt=text), save=True)
77 | log('Response text:\n{response}'.format(response=answers_text), save=True)
78 | log('Sending report email.', save=True)
79 | SendMail(logger.logmsg)
80 | logger.logmsg = ''
81 | return [1.0] * len(subcategories)
82 |
83 | def run(self):
84 | que = [self.entity_tree_root]
85 | while len(que) > 0:
86 | cur_entity = que[0]
87 | que = que[1:]
88 | if len(cur_entity.children) == 0:
89 | continue
90 | cur_children_entities = list(cur_entity.children.values())
91 | que = que + cur_children_entities
92 | cur_children_names = list(map(lambda x: x.entity_name, cur_children_entities))
93 | assert self.depth - cur_entity.depth > 0 and self.depth - cur_entity.depth < self.depth
94 | for _ in range(self.depth - cur_entity.depth + 1):
95 | self.failure = 0
96 | # answers = self.estimate_subcategories(cur_children_names, cur_entity.entity_name)
97 | # print(answers)
98 | # print('-----------------')
99 | # print()
100 | for j, entity in enumerate(cur_children_entities):
101 | # entity.frequency.append(answers[j])
102 | entity.frequency.append(1.0)
103 | self.entity_tree_root.allocate_number(self.total_num)
104 | with open('gen_results/tree_wInstanceNum_{initial_entity}_{scenario}.pkl'.format(initial_entity=self.initial_entity, scenario=self.scenario_desc), 'wb') as fs:
105 | pickle.dump(self.entity_tree_root, fs)
106 |
107 |
108 | scenario = 'e-commerce platform like Amazon'
109 | initial_entity = 'products'
110 | total_num = 200000
111 | depth = 5
112 |
113 | # scenario = 'published paper list of top AI conferences'
114 | # initial_entity = 'deep learning papers'
115 | # total_num = 1000000
116 | # depth = 6
117 |
118 | # scenario = 'venue rating platform like yelp'
119 | # initial_entity = 'business venues'
120 | # total_num = 30000
121 | # depth = 5
122 |
123 | # scenario = 'book rating platform'
124 | # initial_entity = 'books'
125 | # total_num = 30000
126 | # depth = 5
127 |
128 | # load entities
129 | file = os.path.join('gen_results/', '{entity}_{scenario}.txt'.format(entity=initial_entity, scenario=scenario))
130 | entity_lines = []
131 | with open(file, 'r') as fs:
132 | for line in fs:
133 | entity_lines.append(line)
134 | entity_tree_constructer = EntityTreeConstructer(entity_lines)
135 | entity_tree_root = entity_tree_constructer.root
136 |
137 | estimator = HierarchicalInstanceNumberEstimator(entity_tree_root, total_num=total_num, depth=depth, initial_entity=initial_entity, scenario_desc=scenario)
138 | estimator.run()
--------------------------------------------------------------------------------
/graph_generation/itemCollecting_dfsIterator.py:
--------------------------------------------------------------------------------
1 | # from langchain.prompts import PromptTemplate
2 | # from langchain.llms import OpenAI
3 | import os
4 | import openai
5 | import time
6 | import tiktoken
7 | import json
8 |
9 | openai.api_key = "xxxxxx"
10 | class DataGenAgent:
11 | def __init__(self, initial_entity, scenario_desc, depth):
12 | super(DataGenAgent, self).__init__()
13 |
14 | # self.openai = OpenAI(temperature=0)
15 | self.initial_entity = initial_entity
16 | self.scenario_desc = scenario_desc
17 | self.encoding = tiktoken.encoding_for_model('gpt-3.5-turbo')
18 | self.token_num = 0
19 | self.total_num = 0
20 | self.depth = depth
21 |
22 | def openai(self, message):
23 | try:
24 | completion = openai.ChatCompletion.create(
25 | model='gpt-3.5-turbo-1106',
26 | messages=[
27 | {"role": "user", "content": message},
28 | ]
29 | )
30 | response = completion.choices[0].message.content
31 | time.sleep(1)
32 | self.token_num += len(self.encoding.encode(json.dumps(message)))
33 | return response
34 | except Exception as e:
35 | print('Exception occurs. Retry in 10 seconds.')
36 | time.sleep(10)
37 | return self.openai(message)
38 |
39 | def check_if_concrete(self, entity_stack):
40 | entity_name = ', '.join(entity_stack)
41 | text = 'In the context of {scenario_desc}, is {entity_name} a concrete instance or category that can hardly be divided into sub-categories with prominent differences? Response should starts with "True" or "False".'.format(scenario_desc=self.scenario_desc, entity_name=entity_name)
42 | answer = self.openai(text)
43 | # print('answer to {entity_name}: {answer}'.format(entity_name=entity_name, answer=answer))
44 | if answer.startswith('True'):
45 | print('Concrete Check True')
46 | return True
47 | return False
48 |
49 | def category_enum(self, prefix, entity_name):
50 | if prefix == '':
51 | text = 'List all distinct sub-categories of {entity_name} in the context of {scenario_desc}, ensuring a finer level of granularity. The sub-categories should not overlap with each other. And a sub-category should be a smaller subset of {entity_name}. Directly present the list EXACTLY following the form: "sub-category a, sub-category b, sub-category c, ..." without other words, format symbols, new lines, serial numbers.'.format(entity_name=entity_name, prefix=prefix, scenario_desc=self.scenario_desc)
52 | else:
53 | text = 'List all distinct sub-categories of {entity_name} within the {prefix} category in the context of {scenario_desc}, ensuring a finer level of granularity. The sub-categories should not overlap with each other. And a sub-category should be a smaller subset of {entity_name}. Directly present the list EXACTLY following the form: "sub-category a, sub-category b, sub-category c, ..." without other words, format symbols, new lines, serial numbers.'.format(entity_name=entity_name, prefix=prefix, scenario_desc=self.scenario_desc)
54 | # text = 'List all distinct sub-categories of {entity_name} within the {prefix} category in the context of {scenario_desc}, ensuring a finer level of granularity. The sub-categories should not overlap with each other. Present the list exactly following the form: "sub-category a, sub-category b, sub-category c, ...". There should be no serial number, new lines or other format symbols. Separate each pair of sub-categories with a comma.'.format(entity_name=entity_name, prefix=prefix, scenario_desc=self.scenario_desc)
55 | answer = self.openai(text)
56 | return list(map(lambda x: x.strip().strip(',').strip('.'), answer.split(',')))
57 |
58 | def decompose_category(self, entity_stack, depth):
59 | entity_name = ', '.join(entity_stack)
60 | prefix = '' if depth == 1 else ', '.join(entity_stack[:-1])
61 | if depth >= self.depth:# or self.check_if_concrete(entity_stack) is True:
62 | # print('{entity_name} is considered a concrete instance.'.format(entity_name=entity_name))
63 | self.total_num += 1
64 | return [entity_name]
65 | print('\nCurrent entity: {entity_name}'.format(entity_name=entity_name))
66 | concrete_entities = []
67 | sub_entities = self.category_enum(prefix, entity_stack[-1])
68 | print('sub-categories of {entity_name} includes:'.format(entity_name=entity_name), sub_entities)
69 | for sub_entity in sub_entities:
70 | if sub_entity in entity_name:
71 | continue
72 | new_concrete_entities = self.decompose_category(entity_stack + [sub_entity], depth+1)
73 | concrete_entities += new_concrete_entities
74 | if depth <= 4:
75 | print('Depth {depth}, current num of nodes {num}, total num of nodes {total_num}, num of tokens {token}'.format(depth=depth, num=len(concrete_entities), total_num=self.total_num, token=self.token_num))
76 | if depth <= 2:
77 | print('Storing...')
78 | tem_file = 'gen_results/tem/{scenario}_depth{depth}_{cur_entity}'.format(scenario=self.scenario_desc, depth=str(depth), cur_entity=entity_name)
79 | with open(tem_file, 'w+') as fs:
80 | for node in concrete_entities:
81 | fs.write(node + '\n')
82 | return concrete_entities
83 |
84 | def run(self):
85 | return self.decompose_category([self.initial_entity], 1)
86 |
87 | entity = 'products'
88 | scenario = 'e-commerce platform like Amazon'
89 | depth = 3
90 |
91 | # entity = 'movies'
92 | # scenario = 'movie rating platform'
93 |
94 | # entity = 'books'
95 | # scenario = 'book rating platform'
96 |
97 | # entity = 'business venues'
98 | # scenario = 'venue rating platform like yelp'
99 |
100 | # entity = 'movies'
101 | # scenario = 'movie rating platform'
102 | # depth = 5
103 |
104 | # entity = 'deep learning papers'
105 | # scenario = 'published paper list of top AI conferences'
106 | # depth = 6
107 |
108 | # entity = 'ideology'
109 | # scenario = "people's political ideologies"
110 | # depth = 4
111 |
112 | # entity = 'jobs'
113 | # scenario = "people's occupations and professions"
114 | # depth = 5
115 |
116 | agent = DataGenAgent(entity, scenario, depth)
117 | nodes = agent.run()
118 | with open('gen_results/{entity}_{scenario}.txt'.format(entity=entity, scenario=scenario), 'w+') as fs:
119 | for node in nodes:
120 | fs.write(node+'\n')
--------------------------------------------------------------------------------
/graph_generation/make_adjs.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | import argparse
4 | from scipy.sparse import coo_matrix
5 | import numpy as np
6 | import random
7 | import networkx as nx
8 |
9 | descs = {
10 | 'data_name': 'gen_data_ecommerce',
11 | }
12 | params = {
13 | 'itmfusion': True,
14 | 'kcore': 0,
15 | 'sep': 1,
16 | 'min_base': 0,
17 | 'max_base': 100,
18 | }
19 | parser = argparse.ArgumentParser(description='Dataset information')
20 | parser.add_argument('--gen_iter', default=0, type=int, help='maximum generation iteration')
21 | args = parser.parse_args()
22 |
23 |
24 |
25 | file_root = 'gen_results/datasets/{data_name}/'.format(data_name=descs['data_name'])
26 | fuse_file_path = file_root + 'res/interaction_fuse_iter-{iter}.pkl'.format(iter=args.gen_iter)
27 |
28 | def get_all_bases():
29 | bases = set()
30 | prefix = 'interaction_base-'
31 | suffix = '_iter-'
32 | for filename in os.listdir(file_root):
33 | if prefix in filename and suffix in filename:
34 | prefix_idx = len(prefix)
35 | suffix_idx = filename.index(suffix)
36 | base = int(filename[prefix_idx: suffix_idx])
37 | if base >= params['min_base'] and base <= params['max_base']:
38 | bases.add(base)
39 | bases = list(bases)
40 | bases.sort()
41 | return bases
42 |
43 | def fuse_bases():
44 | if os.path.exists(fuse_file_path):
45 | print('Fused interaction file exists! REUSING!')
46 | print('This may happen inadverdently!')
47 | exit()
48 | with open(fuse_file_path, 'rb') as fs:
49 | interactions = pickle.load(fs)
50 | return interactions
51 | all_bases = get_all_bases()
52 | interactions = []
53 | for gen_base in all_bases:
54 | file_path = None
55 | for iter in range(args.gen_iter, -1, -1):
56 | file_path = file_root + 'interaction_base-{gen_base}_iter-{gen_iter}.pkl'.format(gen_base=gen_base, gen_iter=iter)
57 | if os.path.exists(file_path):
58 | break
59 | with open(file_path, 'rb') as fs:
60 | tem_cur_base_interactions = pickle.load(fs)
61 | cur_base_interactions = []
62 | for i in range(len(tem_cur_base_interactions)):
63 | cur_base_interactions.append(tem_cur_base_interactions[i])
64 | interactions += cur_base_interactions
65 |
66 | new_interactions = []
67 | for i in range(len(interactions)):
68 | if i % params['sep'] == 0:
69 | # interactions[i] = interactions[i][:len(interactions[i]) // 3]
70 | new_interactions.append(interactions[i])
71 | interactions = new_interactions
72 | # interactions = new_interactions[:20000]
73 |
74 | with open(fuse_file_path, 'wb+') as fs:
75 | pickle.dump(interactions, fs)
76 | return interactions
77 |
78 | def make_id_map(interactions, itm_criteria=None):
79 | u_num = len(interactions)
80 | i_set = set()
81 | i_cnt = dict()
82 | for interaction in interactions:
83 | for item in interaction:
84 | num_idx = item.index(' #')
85 | tem_item = item if not params['itmfusion'] else item[:num_idx]
86 | i_set.add(tem_item)
87 | if tem_item not in i_cnt:
88 | i_cnt[tem_item] = 0
89 | i_cnt[tem_item] += 1
90 | i_list = list(i_set)
91 | if itm_criteria is not None:
92 | tem_i_list = list()
93 | for item in i_list:
94 | if itm_criteria(i_cnt[item]):
95 | tem_i_list.append(item)
96 | print('Filtering {new_num} / {old_num}'.format(new_num=len(tem_i_list), old_num=len(i_list)))
97 | i_list = tem_i_list
98 | i_num = len(i_list)
99 | i_mapp = dict()
100 | for i, item in enumerate(i_list):
101 | i_mapp[item] = i
102 | rows = []
103 | cols = []
104 | for uid, interaction in enumerate(interactions):
105 | for item in interaction:
106 | num_idx = item.index(' #')
107 | tem_item = item if not params['itmfusion'] else item[:num_idx]
108 | if tem_item not in i_mapp:
109 | continue
110 | iid = i_mapp[tem_item]
111 | rows.append(uid)
112 | cols.append(iid)
113 | return rows, cols, i_mapp, u_num, i_num
114 |
115 | def id_map(nodes):
116 | uniq_nodes = list(set(nodes))
117 | dic = dict()
118 | for i, node in enumerate(uniq_nodes):
119 | dic[node] = i
120 | return dic
121 |
122 | def k_core(rows, cols, i_mapp, i_num, k):
123 | edge_list = list(map(lambda idx: (rows[idx] + i_num, cols[idx]), range(len(rows))))
124 | G = nx.Graph(edge_list)
125 | edge_list = list(nx.k_core(G, k=k).edges())
126 | rows = [None] * len(edge_list)
127 | cols = [None] * len(edge_list)
128 | for i, edge in enumerate(edge_list):
129 | if edge[0] < i_num:
130 | rows[i] = edge[1] - i_num
131 | cols[i] = edge[0]
132 | else:
133 | rows[i] = edge[0] - i_num
134 | cols[i] = edge[1]
135 | row_map = id_map(rows)
136 | col_map = id_map(cols)
137 | new_rows = list(map(lambda x: row_map[x], rows))
138 | new_cols = list(map(lambda x: col_map[x], cols))
139 | new_i_mapp = dict()
140 | for key in i_mapp:
141 | tem_item = i_mapp[key]
142 | if tem_item not in col_map:
143 | continue
144 | new_i_mapp[key] = col_map[tem_item]
145 | return new_rows, new_cols, new_i_mapp, len(row_map), len(col_map)
146 |
147 | def make_mat(rows, cols, st, ed, u_num, i_num, perm, decrease=False):
148 | rows = np.array(rows)[perm]
149 | cols = np.array(cols)[perm]
150 | rows = rows[st: ed]
151 | cols = cols[st: ed]
152 | if decrease:
153 | rows = rows[:len(rows)//3]
154 | cols = cols[:len(cols)//3]
155 | vals = np.ones_like(rows)
156 | return coo_matrix((vals, (rows, cols)), shape=[u_num, i_num])
157 |
158 | def data_split(rows, cols, u_num, i_num):
159 | leng = len(rows)
160 | perm = np.random.permutation(leng)
161 | trn_split = int(leng * 0.7)
162 | val_split = int(leng * 0.75)
163 | trn_mat = make_mat(rows, cols, 0, trn_split, u_num, i_num, perm)
164 | val_mat = make_mat(rows, cols, trn_split, val_split, u_num, i_num, perm)
165 | tst_mat = make_mat(rows, cols, val_split, leng, u_num, i_num, perm, decrease=True)
166 | return trn_mat, val_mat, tst_mat
167 |
168 | interactions = fuse_bases()
169 | rows, cols, i_mapp, u_num, i_num = make_id_map(interactions)#, lambda x: x>20)
170 | if params['kcore'] != 1:
171 | rows, cols, i_mapp, u_num, i_num = k_core(rows, cols, i_mapp, i_num, params['kcore'])
172 | print('U NUM', u_num, 'I Num', i_num, 'E Num', len(rows))
173 | with open(file_root + 'res/iter-{gen_iter}_imap.pkl'.format(gen_iter=args.gen_iter), 'wb+') as fs:
174 | pickle.dump(i_mapp, fs)
175 | trn_mat, val_mat, tst_mat = data_split(rows, cols, u_num, i_num)
176 | with open(file_root + 'res/iter-{gen_iter}_train.pkl'.format(gen_iter=args.gen_iter), 'wb+') as fs:
177 | pickle.dump(trn_mat, fs)
178 | with open(file_root + 'res/iter-{gen_iter}_valid.pkl'.format(gen_iter=args.gen_iter), 'wb+') as fs:
179 | pickle.dump(val_mat, fs)
180 | with open(file_root + 'res/iter-{gen_iter}_test.pkl'.format(gen_iter=args.gen_iter), 'wb+') as fs:
181 | pickle.dump(tst_mat, fs)
--------------------------------------------------------------------------------
/imgs/article cover.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/article cover.jpg
--------------------------------------------------------------------------------
/imgs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/framework.png
--------------------------------------------------------------------------------
/imgs/graph tokenizer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/graph tokenizer.png
--------------------------------------------------------------------------------
/imgs/impact of datasets.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/impact of datasets.png
--------------------------------------------------------------------------------
/imgs/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/intro.png
--------------------------------------------------------------------------------
/imgs/opengraph_article_cover_full.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/opengraph_article_cover_full.png
--------------------------------------------------------------------------------
/imgs/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/performance.png
--------------------------------------------------------------------------------
/imgs/prompt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/prompt.png
--------------------------------------------------------------------------------
/imgs/sampling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUDS/OpenGraph/152b8b6439a281e46bd569a92c69f0c59f98a902/imgs/sampling.png
--------------------------------------------------------------------------------
/link_prediction/Utils/TimeLogger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | logmsg = ''
4 | timemark = dict()
5 | saveDefault = False
6 | def log(msg, save=None, oneline=False):
7 | global logmsg
8 | global saveDefault
9 | time = datetime.datetime.now()
10 | tem = '%s: %s' % (time, msg)
11 | if save != None:
12 | if save:
13 | logmsg += tem + '\n'
14 | elif saveDefault:
15 | logmsg += tem + '\n'
16 | if oneline:
17 | print(tem, end='\r')
18 | else:
19 | print(tem)
20 |
21 | def marktime(marker):
22 | global timemark
23 | timemark[marker] = datetime.datetime.now()
24 |
25 |
26 | if __name__ == '__main__':
27 | log('')
--------------------------------------------------------------------------------
/link_prediction/data_handler.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | from scipy.sparse import csr_matrix, coo_matrix, dok_matrix
4 | from params import args
5 | import scipy.sparse as sp
6 | from Utils.TimeLogger import log
7 | import torch as t
8 | import torch.utils.data as data
9 | import torch_geometric.transforms as T
10 | from model import InitialProjector
11 | import os
12 |
13 | class MultiDataHandler:
14 | def __init__(self, trn_datasets, tst_datasets):
15 | all_datasets = list(set(trn_datasets + tst_datasets))
16 | self.trn_handlers = []
17 | self.tst_handlers = []
18 | for data_name in all_datasets:
19 | trn_flag = data_name in trn_datasets
20 | tst_flag = data_name in tst_datasets
21 | handler = DataHandler(data_name, trn_flag, tst_flag)
22 | if trn_flag:
23 | self.trn_handlers.append(handler)
24 | if tst_flag:
25 | self.tst_handlers.append(handler)
26 | self.make_joint_trn_loader()
27 |
28 | def make_joint_trn_loader(self):
29 | trn_data = TrnData(self.trn_handlers)
30 | self.trn_loader = data.DataLoader(trn_data, batch_size=1, shuffle=True, num_workers=0)
31 |
32 | def remake_initial_projections(self):
33 | for i in range(len(self.trn_handlers)):
34 | self.remake_one_initial_projection(i)
35 |
36 | def remake_one_initial_projection(self, idx):
37 | trn_handler = self.trn_handlers[idx]
38 | trn_handler.initial_projector = InitialProjector(trn_handler.asym_adj)
39 |
40 | class DataHandler:
41 | def __init__(self, data_name, trn_flag, tst_flag):
42 | self.data_name = data_name
43 | self.trn_flag = trn_flag
44 | self.tst_flag = tst_flag
45 | self.get_data_files()
46 | self.load_data()
47 |
48 | def get_data_files(self):
49 | predir = os.path.join(args.data_dir, self.data_name)
50 | self.trnfile = os.path.join(predir, 'trn_mat.pkl')
51 | self.tstfile = os.path.join(predir, 'tst_mat.pkl')
52 | self.valfile = os.path.join(predir, 'val_mat.pkl')
53 | if not os.path.exists(self.valfile):
54 | self.valfile = self.tstfile
55 |
56 | def load_one_file(self, filename):
57 | with open(filename, 'rb') as fs:
58 | ret = (pickle.load(fs)).astype(np.float32)
59 | if type(ret) != coo_matrix:
60 | ret = sp.coo_matrix(ret)
61 | return ret
62 |
63 | def normalize_adj(self, mat):
64 | degree = np.array(mat.sum(axis=-1))
65 | d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
66 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
67 | d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
68 | if mat.shape[0] == mat.shape[1]:
69 | return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo()
70 | else:
71 | tem = d_inv_sqrt_mat.dot(mat)
72 | col_degree = np.array(mat.sum(axis=0))
73 | d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1])
74 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
75 | d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
76 | return tem.dot(d_inv_sqrt_mat).tocoo()
77 |
78 | def unique_numpy(self, row, col):
79 | hash_vals = row * self.node_num + col
80 | hash_vals = np.unique(hash_vals).astype(np.int64)
81 | col = hash_vals % self.node_num
82 | row = (hash_vals - col).astype(np.int64) // self.node_num
83 | return row, col
84 |
85 | def make_torch_adj(self, mat):
86 | if mat.shape[0] == mat.shape[1]:
87 | # to symmetric
88 | if self.data_name in ['ddi']:
89 | _row = mat.row
90 | _col = mat.col
91 | row = np.concatenate([_row, _col]).astype(np.int64)
92 | col = np.concatenate([_col, _row]).astype(np.int64)
93 | # row, col = self.unique_numpy(row, col)
94 | data = mat.data
95 | data = np.concatenate([data, data]).astype(np.float32)
96 | else:
97 | row, col = mat.row, mat.col
98 | data = mat.data
99 | # data = np.ones_like(row)
100 | mat = coo_matrix((data, (row, col)), mat.shape)
101 | if args.selfloop == 1:
102 | mat = (mat + sp.eye(mat.shape[0])) * 1.0
103 | normed_asym_mat = self.normalize_adj(mat)
104 | row = t.from_numpy(normed_asym_mat.row).long()
105 | col = t.from_numpy(normed_asym_mat.col).long()
106 | idxs = t.stack([row, col], dim=0)
107 | vals = t.from_numpy(normed_asym_mat.data).float()
108 | shape = t.Size(normed_asym_mat.shape)
109 | asym_adj = t.sparse.FloatTensor(idxs, vals, shape)
110 | if mat.shape[0] == mat.shape[1]:
111 | return asym_adj, asym_adj
112 | else:
113 | # make ui adj
114 | a = sp.csr_matrix((self.user_num, self.user_num))
115 | b = sp.csr_matrix((self.item_num, self.item_num))
116 | mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
117 | mat = (mat != 0) * 1.0
118 | if args.selfloop == 1:
119 | mat = (mat + sp.eye(mat.shape[0])) * 1.0
120 | mat = self.normalize_adj(mat)
121 |
122 | # make cuda tensor
123 | idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
124 | vals = t.from_numpy(mat.data.astype(np.float32))
125 | shape = t.Size(mat.shape)
126 | return t.sparse.FloatTensor(idxs, vals, shape), asym_adj
127 |
128 | def load_data(self):
129 | trn_mat = self.load_one_file(self.trnfile)
130 | # if self.trn_flag:
131 | self.trn_mat = trn_mat
132 | if trn_mat.shape[0] != trn_mat.shape[1]:
133 | self.user_num, self.item_num = trn_mat.shape
134 | self.node_num = self.user_num + self.item_num
135 | print('Dataset: {data_name}, User num: {user_num}, Item num: {item_num}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, user_num=self.user_num, item_num=self.item_num, node_num=self.node_num, edge_num=trn_mat.nnz))
136 | else:
137 | self.node_num = trn_mat.shape[0]
138 | print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=self.node_num, edge_num=trn_mat.nnz))
139 | self.torch_adj, self.asym_adj = self.make_torch_adj(trn_mat)
140 | if args.cache_proj:
141 | self.asym_adj = self.asym_adj.to(args.devices[0])
142 | if args.cache_adj:
143 | self.torch_adj = self.torch_adj.to(args.devices[0])
144 |
145 | self.initial_projector = InitialProjector(self.asym_adj)
146 |
147 | if self.tst_flag:
148 | val_mat = self.load_one_file(self.valfile)
149 | val_data = TstData(val_mat, trn_mat)
150 | self.val_loader = data.DataLoader(val_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)
151 | tst_mat = self.load_one_file(self.tstfile)
152 | tst_data = TstData(tst_mat, trn_mat)
153 | self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)
154 |
155 | class TstData(data.Dataset):
156 | def __init__(self, coomat, trn_mat):
157 | self.csrmat = (trn_mat.tocsr() != 0) * 1.0
158 |
159 | tstLocs = [None] * coomat.shape[0]
160 | tst_nodes = set()
161 | for i in range(len(coomat.data)):
162 | row = coomat.row[i]
163 | col = coomat.col[i]
164 | if tstLocs[row] is None:
165 | tstLocs[row] = list()
166 | tstLocs[row].append(col)
167 | tst_nodes.add(row)
168 | tst_nodes = np.array(list(tst_nodes))
169 | self.tst_nodes = tst_nodes
170 | self.tstLocs = tstLocs
171 |
172 | def __len__(self):
173 | return len(self.tst_nodes)
174 |
175 | def __getitem__(self, idx):
176 | return self.tst_nodes[idx]
177 |
178 | class TrnData(data.Dataset):
179 | def __init__(self, trn_handlers):
180 | self.dataset_num = len(trn_handlers)
181 | self.trn_handlers = trn_handlers
182 | self.ancs_list = [None] * self.dataset_num
183 | self.poss_list = [None] * self.dataset_num
184 | self.negs_list = [None] * self.dataset_num
185 | self.edge_nums = [None] * self.dataset_num
186 | self.sample_nums = [None] * self.dataset_num
187 | for i, handler in enumerate(self.trn_handlers):
188 | trn_mat = handler.trn_mat
189 | ancs = np.array(trn_mat.row)
190 | poss = np.array(trn_mat.col)
191 | self.ancs_list[i] = ancs
192 | self.poss_list[i] = poss
193 | self.edge_nums[i] = len(ancs)
194 | self.sample_nums[i] = self.edge_nums[i] // args.batch + (1 if self.edge_nums[i] % args.batch != 0 else 0)
195 | self.total_sample_num = sum(self.sample_nums)
196 | self.samples = [None] * self.total_sample_num
197 |
198 | def data_shuffling(self):
199 | sample_idx = 0
200 | for i in range(self.dataset_num):
201 | edge_num = self.edge_nums[i]
202 | perms = np.random.permutation(edge_num)
203 | handler = self.trn_handlers[i]
204 | asym_flag = handler.trn_mat.shape[0] != handler.trn_mat.shape[1]
205 | cand_num = handler.item_num if asym_flag else handler.node_num
206 | self.negs_list[i] = self.neg_sampling(self.ancs_list[i], handler.trn_mat.todok(), cand_num)
207 | # self.negs_list[i] = np.random.randint(cand_num, size=edge_num)
208 | for j in range(self.sample_nums[i]):
209 | st_idx = j * args.batch
210 | ed_idx = min((j + 1) * args.batch, edge_num)
211 | pick_idxs = perms[st_idx: ed_idx]
212 | ancs = self.ancs_list[i][pick_idxs]
213 | poss = self.poss_list[i][pick_idxs]
214 | negs = self.negs_list[i][pick_idxs]
215 | if asym_flag:
216 | poss += handler.user_num
217 | negs += handler.user_num
218 | self.samples[sample_idx] = (ancs, poss, negs, i)
219 | sample_idx += 1
220 | assert sample_idx == self.total_sample_num
221 |
222 | def neg_sampling(self, ancs, dokmat, cand_num):
223 | negs = np.zeros_like(ancs)
224 | for i in range(len(ancs)):
225 | u = ancs[i]
226 | while True:
227 | i_neg = np.random.randint(cand_num)
228 | if (u, i_neg) not in dokmat:
229 | break
230 | negs[i] = i_neg
231 | return negs
232 |
233 | def __len__(self):
234 | return self.total_sample_num
235 |
236 | def __getitem__(self, idx):
237 | ancs, poss, negs, adj_id = self.samples[idx]
238 | return ancs, poss, negs, adj_id
239 |
--------------------------------------------------------------------------------
/link_prediction/main.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | from torch import nn
3 | import Utils.TimeLogger as logger
4 | from Utils.TimeLogger import log
5 | from params import args
6 | from model import OpenGraph, ALRS
7 | from data_handler import DataHandler, MultiDataHandler
8 | import numpy as np
9 | import pickle
10 | import os
11 | import setproctitle
12 | import time
13 |
14 | class Exp:
15 | def __init__(self, multi_handler):
16 | self.multi_handler = multi_handler
17 | self.metrics = dict()
18 | trn_mets = ['Loss', 'preLoss']
19 | tst_mets = ['Recall', 'NDCG']
20 | mets = trn_mets + tst_mets
21 | for met in mets:
22 | if met in trn_mets:
23 | self.metrics['Train' + met] = list()
24 | else:
25 | for handler in self.multi_handler.tst_handlers:
26 | self.metrics['Test' + handler.data_name + met] = list()
27 |
28 | def make_print(self, name, ep, reses, save, data_name=None):
29 | if data_name is None:
30 | ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)
31 | else:
32 | ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name)
33 | for metric in reses:
34 | val = reses[metric]
35 | ret += '%s = %.4f, ' % (metric, val)
36 | tem = name + metric if data_name is None else name + data_name + metric
37 | if save and tem in self.metrics:
38 | self.metrics[tem].append(val)
39 | ret = ret[:-2] + ' '
40 | return ret
41 |
42 | def run(self):
43 | self.prepare_model()
44 | log('Model Prepared')
45 | stloc = 0
46 | if args.load_model != None:
47 | self.load_model()
48 | stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1)
49 | for ep in range(stloc, args.epoch):
50 | tst_flag = (ep % args.tst_epoch == 0)
51 | reses = self.train_epoch()
52 | log(self.make_print('Train', ep, reses, tst_flag))
53 | if ep % 1 == 0:
54 | self.multi_handler.remake_initial_projections()
55 | if tst_flag:
56 | for handler in self.multi_handler.tst_handlers:
57 | reses = self.test_epoch(handler.val_loader, handler)
58 | # Note that this is the validation performance
59 | log(self.make_print('Test', ep, reses, tst_flag, handler.data_name))
60 | self.save_history()
61 | print()
62 |
63 | for handler in self.multi_handler.tst_handlers:
64 | res_summary = dict()
65 | times = 10
66 | st = time.time()
67 | for i in range(times):
68 | reses = self.test_epoch(handler.tst_loader, handler)
69 | log(self.make_print('Test', args.epoch, reses, False, handler.data_name))
70 | self.add_res_to_summary(res_summary, reses)
71 | self.multi_handler.remake_initial_projections()
72 | for key in res_summary:
73 | res_summary[key] /= times
74 | log(self.make_print('AVG', args.epoch, res_summary, False, handler.data_name))
75 | print(time.time() - st)
76 | self.save_history()
77 |
78 | def add_res_to_summary(self, summary, res):
79 | for key in res:
80 | if key not in summary:
81 | summary[key] = 0
82 | summary[key] += res[key]
83 |
84 | def print_model_size(self):
85 | total_params = 0
86 | trainable_params = 0
87 | non_trainable_params = 0
88 | for param in self.model.parameters():
89 | tem = np.prod(param.size())
90 | total_params += tem
91 | if param.requires_grad:
92 | trainable_params += tem
93 | else:
94 | non_trainable_params += tem
95 | print(f'Total params: {total_params/1e6}')
96 | print(f'Trainable params: {trainable_params/1e6}')
97 | print(f'Non-trainable params: {non_trainable_params/1e6}')
98 |
99 | def prepare_model(self):
100 | self.model = OpenGraph()
101 | t.cuda.empty_cache()
102 | self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0)
103 | self.lr_scheduler = ALRS(self.opt)
104 | self.print_model_size()
105 |
106 | def train_epoch(self):
107 | self.model.train()
108 | trn_loader = self.multi_handler.trn_loader
109 | trn_loader.dataset.data_shuffling()
110 | ep_loss, ep_preloss, ep_regloss = 0, 0, 0
111 | steps = len(trn_loader)
112 | tot_samp_num = 0
113 | counter = [0] * len(self.multi_handler.trn_handlers)
114 | for i, batch_data in enumerate(trn_loader):
115 | if args.epoch_max_step > 0 and i >= args.epoch_max_step:
116 | break
117 | ancs, poss, negs, adj_idx = batch_data
118 | adj_idx = adj_idx[0]
119 | ancs = ancs[0].long()
120 | poss = poss[0].long()
121 | negs = negs[0].long()
122 | adj = self.multi_handler.trn_handlers[adj_idx].torch_adj
123 | if args.cache_adj == 0:
124 | adj = adj.to(args.devices[0])
125 | initial_projector = self.multi_handler.trn_handlers[adj_idx].initial_projector
126 | if args.cache_proj == 0:
127 | initial_projector = initial_projector.to(args.devices[0])
128 | loss, loss_dict = self.model.cal_loss((ancs, poss, negs), adj, initial_projector)
129 | self.opt.zero_grad()
130 | loss.backward()
131 | # nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2)
132 | self.opt.step()
133 |
134 | sample_num = ancs.shape[0]
135 | tot_samp_num += sample_num
136 | ep_loss += loss.item() * sample_num
137 | ep_preloss += loss_dict['preloss'].item() * sample_num
138 | ep_regloss += loss_dict['regloss'].item()
139 | log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True)
140 |
141 | counter[adj_idx] += 1
142 | if args.proj_trn_steps > 0 and counter[adj_idx] >= args.proj_trn_steps:
143 | counter[adj_idx] = 0
144 | dice = np.random.uniform()
145 | if dice < 999:
146 | self.multi_handler.remake_one_initial_projection(adj_idx)
147 | else:
148 | self.multi_handler.make_one_self_initialization(self.model, adj_idx)
149 | ret = dict()
150 | ret['Loss'] = ep_loss / tot_samp_num
151 | ret['preLoss'] = ep_preloss / tot_samp_num
152 | ret['regLoss'] = ep_regloss / steps
153 | t.cuda.empty_cache()
154 | self.lr_scheduler.step(ret['Loss'])
155 | return ret
156 |
157 | def test_epoch(self, tst_loader, tst_handler):
158 | with t.no_grad():
159 | self.model.eval()
160 | ep_recall, ep_ndcg = 0, 0
161 | ep_tstnum = len(tst_loader.dataset)
162 | steps = max(ep_tstnum // args.tst_batch, 1)
163 | for i, batch_data in enumerate(tst_loader):
164 | usrs = batch_data
165 | numpy_usrs = usrs.numpy()
166 | usrs = usrs.long().to(args.devices[1])
167 | trn_masks = tst_loader.dataset.csrmat[numpy_usrs].tocoo()
168 | cand_size = trn_masks.shape[1]
169 | trn_masks = t.from_numpy(np.stack([trn_masks.row, trn_masks.col], axis=0)).long().cuda()
170 | adj = tst_handler.torch_adj
171 | if args.cache_adj == 0:
172 | adj = adj.to(args.devices[0])
173 | initial_projector = tst_handler.initial_projector#.cuda()
174 | if args.cache_proj == 0:
175 | initial_projector = initial_projector.to(args.devices[0])
176 | all_preds = self.model.pred_for_test((usrs, trn_masks), adj, initial_projector, cand_size, rerun_embed=False if i!=0 else True)
177 | _, top_locs = t.topk(all_preds, args.topk)
178 | top_locs = top_locs.cpu().numpy()
179 | recall, ndcg = self.calc_recall_ndcg(top_locs, tst_loader.dataset.tstLocs, usrs)
180 | ep_recall += recall
181 | ep_ndcg += ndcg
182 | log('Steps %d/%d: recall = %.2f, ndcg = %.2f ' % (i, steps, recall, ndcg), save=False, oneline=True)
183 | # t.cuda.empty_cache()
184 | ret = dict()
185 | ret['Recall'] = ep_recall / ep_tstnum
186 | ret['NDCG'] = ep_ndcg / ep_tstnum
187 | t.cuda.empty_cache()
188 | return ret
189 |
190 | def calc_recall_ndcg(self, topLocs, tstLocs, batIds):
191 | assert topLocs.shape[0] == len(batIds)
192 | allRecall = allNdcg = 0
193 | for i in range(len(batIds)):
194 | temTopLocs = list(topLocs[i])
195 | temTstLocs = tstLocs[batIds[i]]
196 | tstNum = len(temTstLocs)
197 | maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))])
198 | recall = dcg = 0
199 | for val in temTstLocs:
200 | if val in temTopLocs:
201 | recall += 1
202 | dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2))
203 | recall = recall / tstNum
204 | ndcg = dcg / maxDcg
205 | allRecall += recall
206 | allNdcg += ndcg
207 | return allRecall, allNdcg
208 |
209 | def save_history(self):
210 | if args.epoch == 0:
211 | return
212 | with open('../History/' + args.save_path + '.his', 'wb') as fs:
213 | pickle.dump(self.metrics, fs)
214 |
215 | content = {
216 | 'model': self.model,
217 | }
218 | t.save(content, '../Models/' + args.save_path + '.mod')
219 | log('Model Saved: %s' % args.save_path)
220 |
221 | def load_model(self):
222 | ckp = t.load('../Models/' + args.load_model + '.mod')
223 | self.model = ckp['model']
224 | self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0)
225 |
226 | with open('../History/' + args.load_model + '.his', 'rb') as fs:
227 | self.metrics = pickle.load(fs)
228 | log('Model Loaded')
229 |
230 | if __name__ == '__main__':
231 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
232 | if len(args.gpu.split(',')) > 1:
233 | args.devices = ['cuda:0', 'cuda:1']
234 | else:
235 | args.devices = ['cuda:0', 'cuda:0']
236 | args.devices = list(map(lambda x: t.device(x), args.devices))
237 | logger.saveDefault = True
238 | setproctitle.setproctitle('OpenGraph')
239 |
240 | log('Start')
241 | trn_datasets = ['gen1']
242 | tst_datasets = ['ml1m', 'ml10m', 'collab']
243 |
244 | # trn_datasets = ['gen2']
245 | # tst_datasets = ['ddi']
246 |
247 | # trn_datasets = ['gen0']
248 | # tst_datasets = ['amazon-book']
249 |
250 | if len(args.tstdata) != 0:
251 | tst_datasets = [args.tstdata]
252 | if len(args.trndata) != 0:
253 | trn_datasets = [args.trndata]
254 | trn_datasets = list(set(trn_datasets))
255 | tst_datasets = list(set(tst_datasets))
256 | multi_handler = MultiDataHandler(trn_datasets, tst_datasets)
257 | log('Load Data')
258 |
259 | exp = Exp(multi_handler)
260 | exp.run()
261 |
--------------------------------------------------------------------------------
/link_prediction/model.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from params import args
5 | import numpy as np
6 | from Utils.TimeLogger import log
7 | from torch.nn import MultiheadAttention
8 | from time import time
9 |
10 | init = nn.init.xavier_uniform_
11 | uniformInit = nn.init.uniform_
12 |
13 | class InitialProjector(nn.Module):
14 | def __init__(self, adj, input_is_embeds=False):
15 | super(InitialProjector, self).__init__()
16 |
17 | if input_is_embeds:
18 | projection = adj
19 | if args.cache_proj:
20 | projection = projection.to(args.devices[0])
21 | else:
22 | projection = projection.cpu()
23 | self.proj_embeds = nn.Parameter(projection)
24 | t.cuda.empty_cache()
25 | return
26 | if args.proj_method == 'uniform':
27 | self.proj_embeds = nn.Parameter(self.uniform_proj(adj))
28 | elif args.proj_method == 'lowrank_uniform':
29 | self.proj_embeds = nn.Parameter(self.lowrank_uniform_proj(adj))
30 | elif args.proj_method == 'svd':
31 | self.proj_embeds = nn.Parameter(self.svd_proj(adj))
32 | elif args.proj_method == 'both':
33 | self.proj_embeds = nn.Parameter(self.uniform_proj(adj) + self.svd_proj(adj))
34 | elif args.proj_method == 'id':
35 | self.proj_embeds = nn.Parameter(self.id_proj(adj))
36 | else:
37 | raise Exception('Unrecognized Initial Embedding')
38 | t.cuda.empty_cache()
39 |
40 | def uniform_proj(self, adj):
41 | node_num = adj.shape[0] if adj.shape[0] == adj.shape[1] else adj.shape[0] + adj.shape[1]
42 | projection = init(t.empty(node_num, args.latdim))
43 | if args.cache_proj:
44 | projection = projection.to(args.devices[0])
45 | return projection
46 |
47 | def id_proj(self, adj):
48 | node_num = adj.shape[0] if adj.shape[0] == adj.shape[1] else adj.shape[0] + adj.shape[1]
49 | return t.eye(node_num)
50 |
51 | def lowrank_uniform_proj(self, adj):
52 | node_num = adj.shape[0] + adj.shape[1]
53 | rank = 16
54 | projection1 = init(t.empty(node_num, rank))
55 | projection2 = init(t.empty(rank, args.latdim))
56 | projection = projection1 @ projection2
57 | if args.cache_proj:
58 | projection = projection.to(args.devices[0])
59 | return projection
60 |
61 | def svd_proj(self, adj):
62 | if not args.cache_proj:
63 | adj = adj.to(args.devices[0])
64 | q = args.latdim
65 | if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]:
66 | dim = min(adj.shape[0], adj.shape[1])
67 | svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter)
68 | svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
69 | svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
70 | s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])])
71 | else:
72 | svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter)
73 | svd_u = svd_u @ t.diag(t.sqrt(s))
74 | svd_v = svd_v @ t.diag(t.sqrt(s))
75 | if adj.shape[0] != adj.shape[1]:
76 | projection = t.concat([svd_u, svd_v], dim=0)
77 | else:
78 | projection = svd_u + svd_v
79 | if not args.cache_proj:
80 | projection = projection.cpu()
81 | return projection
82 |
83 | def forward(self):
84 | return self.proj_embeds
85 |
86 | class TopoEncoder(nn.Module):
87 | def __init__(self):
88 | super(TopoEncoder, self).__init__()
89 |
90 | self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False)#, dtype=t.bfloat16)
91 |
92 |
93 | def forward(self, adj, embeds):
94 | embeds = self.layer_norm(embeds)
95 | embeds_list = []
96 | if args.gnn_layer == 0:
97 | embeds_list.append(embeds)
98 | for i in range(args.gnn_layer):
99 | embeds = t.spmm(adj, embeds)
100 | embeds_list.append(embeds)
101 | embeds = sum(embeds_list)
102 | # embeds = t.concat([embeds_list[-1][:user_num], embeds_list[-2][user_num:]], dim=0)
103 | embeds = embeds#.to(t.bfloat16)
104 | return embeds
105 |
106 | class GraphTransformer(nn.Module):
107 | def __init__(self):
108 | super(GraphTransformer, self).__init__()
109 | self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)])
110 |
111 | def forward(self, embeds):
112 | for i, layer in enumerate(self.gt_layers):
113 | embeds = layer(embeds) / args.scale_layer
114 | return embeds
115 |
116 | class GTLayer(nn.Module):
117 | def __init__(self):
118 | super(GTLayer, self).__init__()
119 | self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16)
120 | self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False
121 | self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
122 | self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
123 | self.fc_dropout = nn.Dropout(p=args.drop_rate)
124 |
125 | def _attention(self, anchor_embeds, embeds):
126 | q_embeds = t.einsum('ne,ehd->nhd', anchor_embeds, self.Q)
127 | k_embeds = t.einsum('ne,ehd->nhd', embeds, self.K)
128 | v_embeds = t.einsum('ne,ehd->nhd', embeds, self.V)
129 | att = t.einsum('khd,nhd->knh', q_embeds, k_embeds) / np.sqrt(args.latdim / args.head)
130 | att = t.softmax(att, dim=1)
131 | res = t.einsum('knh,nhd->khd', att, v_embeds).reshape([-1, args.latdim])
132 | res = self.att_linear(res)
133 | return res
134 |
135 | def _pick_anchors(self, embeds):
136 | perm = t.randperm(embeds.shape[0])
137 | anchors = perm[:args.anchor]
138 | return embeds[anchors]
139 |
140 | def forward(self, embeds):
141 | anchor_embeds = self._pick_anchors(embeds)
142 | _anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds)
143 | anchor_embeds = _anchor_embeds + anchor_embeds
144 | _embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False)
145 | embeds = self.layer_norm1(_embeds + embeds)
146 | _embeds = self.fc_dropout(self.dense_layers(embeds))
147 | embeds = (self.layer_norm2(_embeds + embeds))
148 | return embeds
149 |
150 | class FeedForwardLayer(nn.Module):
151 | def __init__(self, in_feat, out_feat, bias=True, act=None):
152 | super(FeedForwardLayer, self).__init__()
153 | self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16)
154 | if act == 'identity' or act is None:
155 | self.act = None
156 | elif act == 'leaky':
157 | self.act = nn.LeakyReLU(negative_slope=args.leaky)
158 | elif act == 'relu':
159 | self.act = nn.ReLU()
160 | elif act == 'relu6':
161 | self.act = nn.ReLU6()
162 | else:
163 | raise Exception('Error')
164 |
165 | def forward(self, embeds):
166 | if self.act is None:
167 | return self.linear(embeds)
168 | return self.act(self.linear(embeds))
169 |
170 | class Masker(nn.Module):
171 | def __init__(self):
172 | super(Masker, self).__init__()
173 |
174 | def forward(self, adj, edges):
175 | if args.mask_method is None or args.mask_method == 'none':
176 | return adj
177 | elif args.mask_method == 'trn':
178 | node_num = adj.shape[0] + adj.shape[1]
179 | rows = adj._indices()[0, :]
180 | cols = adj._indices()[1, :]
181 | pck_rows, pck_cols = edges
182 |
183 | hashvals = rows * node_num + cols
184 | pck_hashvals1 = pck_rows * node_num + pck_cols
185 | pck_hashvals2 = pck_cols * node_num + pck_rows
186 | pck_hashvals = t.concat([pck_hashvals1, pck_hashvals2])
187 |
188 | if args.mask_alg == 'cross':
189 | masked_hashvals = self._mask_by_cross(hashvals, pck_hashvals)
190 | elif args.mask_alg == 'linear':
191 | masked_hashvals = self._mask_by_linear(hashvals, pck_hashvals)
192 |
193 | cols = masked_hashvals % node_num
194 | rows = t.div((masked_hashvals - cols).long(), node_num, rounding_mode='trunc').long()
195 |
196 | adj = t.sparse.FloatTensor(t.stack([rows, cols], dim=0), t.ones_like(rows, dtype=t.float32).to(args.devices[0]), adj.shape)
197 | return self._normalize_adj(adj)
198 | elif args.mask_method == 'random':
199 | return self._random_mask_edge(adj)
200 |
201 | def _mask_by_cross(self, hashvals, pck_hashvals):
202 | for i in range(args.batch * 2 // args.mask_bat):
203 | bat_pck_hashvals = pck_hashvals[i * args.mask_bat: (i+1) * args.mask_bat]
204 | idct = (hashvals.view([-1, 1]) - bat_pck_hashvals.view([1, -1]) == 0).sum(-1).bool()
205 | hashvals = hashvals[t.logical_not(idct)]
206 | return hashvals
207 |
208 | def _mask_by_linear(self, hashvals, pck_hashvals):
209 | hashvals = t.unique(hashvals)
210 | pck_hashvals = t.unique(pck_hashvals)
211 | hashvals = t.concat([hashvals, pck_hashvals])
212 | hashvals, counts = t.unique(hashvals, return_counts=True)
213 | hashvals = hashvals[counts==1]
214 | return hashvals
215 |
216 | def _random_mask_edge(self, adj):
217 | if args.random_mask_rate == 0.0:
218 | return adj
219 | vals = adj._values()
220 | idxs = adj._indices()
221 | edgeNum = vals.size()
222 | mask = ((t.rand(edgeNum) + 1.0 - args.random_mask_rate).floor()).type(t.bool)
223 | newIdxs = idxs[:, mask]
224 | newVals = t.ones(newIdxs.shape[1]).to(args.devices[0]).float()
225 | return self._normalize_adj(t.sparse.FloatTensor(newIdxs, newVals, adj.shape))
226 |
227 | def _normalize_adj(self, adj):
228 | row_degree = t.pow(t.sparse.sum(adj, dim=1).to_dense(), 0.5)
229 | col_degree = t.pow(t.sparse.sum(adj, dim=0).to_dense(), 0.5)
230 | newRows, newCols = adj._indices()[0, :], adj._indices()[1, :]
231 | rowNorm, colNorm = row_degree[newRows], col_degree[newCols]
232 | newVals = adj._values() / rowNorm / colNorm
233 | return t.sparse.FloatTensor(adj._indices(), newVals, adj.shape)
234 |
235 | class OpenGraph(nn.Module):
236 | def __init__(self):
237 | super(OpenGraph, self).__init__()
238 | self.topoEncoder = TopoEncoder().to(args.devices[0])
239 | self.graphTransformer = GraphTransformer().to(args.devices[1])
240 | self.masker = Masker().to(args.devices[0])
241 |
242 | def forward(self, adj, initial_projector, user_num):
243 | topo_embeds = self.topoEncoder(adj, initial_projector(), user_num).to(args.devices[1])
244 | final_embeds = self.graphTransformer(topo_embeds)
245 | return final_embeds
246 |
247 | def pred_norm(self, pos_preds, neg_preds):
248 | pos_preds_num = pos_preds.shape[0]
249 | neg_preds_shape = neg_preds.shape
250 | preds = t.concat([pos_preds, neg_preds.view(-1)])
251 | preds = preds - preds.max()
252 | pos_preds = preds[:pos_preds_num]
253 | neg_preds = preds[pos_preds_num:].view(neg_preds_shape)
254 | return pos_preds, neg_preds
255 |
256 | def cal_loss(self, batch_data, adj, initial_projector):
257 | ancs, poss, negs = batch_data
258 | with t.no_grad():
259 | masked_adj = self.masker(adj, (ancs.to(args.devices[0]), (poss.to(args.devices[0]))))
260 | initial_embeds = initial_projector()
261 | topo_embeds = self.topoEncoder(masked_adj, initial_embeds).to(args.devices[1])
262 | ancs, poss, negs = ancs.to(args.devices[1]), poss.to(args.devices[1]), negs.to(args.devices[1])
263 | input_seq = t.concat([ancs, poss, negs])
264 | input_seq = topo_embeds[input_seq]
265 | final_embeds = self.graphTransformer(input_seq)
266 | anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds[:ancs.shape[0] * 3], [ancs.shape[0]] * 3)
267 | # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs]
268 | if final_embeds.isinf().any() or final_embeds.isnan().any():
269 | raise Exception('Final embedding fails')
270 |
271 | pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T)
272 | if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any():
273 | raise Exception('Preds fails')
274 | pos_loss = pos_preds
275 | neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log()
276 | pre_loss = -(pos_loss - neg_loss).mean()
277 |
278 | if t.isinf(pre_loss).any() or t.isnan(pre_loss).any():
279 | raise Exception('NaN or Inf')
280 |
281 | reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters())))
282 | loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()}
283 | return pre_loss + reg_loss, loss_dict
284 |
285 | def pred_for_test(self, batch_data, adj, initial_projector, cand_size, rerun_embed=True):
286 | ancs, trn_mask = batch_data
287 | if rerun_embed:
288 | final_embeds = self.graphTransformer(self.topoEncoder(adj, initial_projector()).to(args.devices[1]))
289 | self.final_embeds = final_embeds
290 | final_embeds = self.final_embeds
291 | anc_embeds = final_embeds[ancs]
292 | cand_embeds = final_embeds[-cand_size:]
293 |
294 | mask_mat = t.sparse.FloatTensor(trn_mask, t.ones(trn_mask.shape[1]).cuda(), t.Size([ancs.shape[0], cand_size]))
295 | dense_mat = mask_mat.to_dense()
296 |
297 | all_preds = anc_embeds @ cand_embeds.T * (1 - dense_mat) - dense_mat * 1e8
298 | return all_preds
299 |
300 | class ALRS:
301 | def __init__(self, optimizer, loss_threshold=0.01, loss_ratio_threshold=0.01, decay_rate=0.97):
302 | self.optimizer = optimizer
303 | self.loss_threshold = loss_threshold
304 | self.decay_rate = decay_rate
305 | self.loss_ratio_threshold = loss_ratio_threshold
306 | self.last_loss = 1e9
307 |
308 | def step(self, loss):
309 | delta = self.last_loss - loss
310 | if delta < self.loss_threshold and delta / self.last_loss < self.loss_ratio_threshold:
311 | for group in self.optimizer.param_groups:
312 | group['lr'] *= self.decay_rate
313 | self.last_loss = loss
314 |
--------------------------------------------------------------------------------
/link_prediction/params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser(description='Model Parameters')
5 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
6 | parser.add_argument('--batch', default=1024, type=int, help='training batch size')
7 | parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)')
8 | parser.add_argument('--epoch', default=50, type=int, help='number of epochs')
9 | parser.add_argument('--save_path', default='tem', help='file name to save model and training record')
10 | parser.add_argument('--load_model', default=None, help='model name to load')
11 | parser.add_argument('--data', default='', type=str, help='name of dataset')
12 | parser.add_argument('--trndata', default='', type=str, help='name of train dataset')
13 | parser.add_argument('--tstdata', default='', type=str, help='name of test dataset')
14 | parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training')
15 | parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use')
16 | parser.add_argument('--topk', default=20, type=int, help='topk in evaluation')
17 | parser.add_argument('--cache_adj', default=0, type=int, help='indicates wheter cache bidirectional adjs')
18 | parser.add_argument('--cache_proj', default=1, type=int, help='indicates wheter cache projector and matrices')
19 | parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps')
20 | parser.add_argument('--data_dir', default='../datasets', type=str, help='dataset directory')
21 |
22 | parser.add_argument('--niter', default=2, type=int, help='number of iteration in svd')
23 | parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer')
24 | parser.add_argument('--drop_rate', default=0.1, type=float, help='dropout rate')
25 | parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor')
26 | parser.add_argument('--clamp', default=-1, type=float, help='absolute value for the limit of prediction scores while training')
27 | parser.add_argument('--mask_method', default='trn', type=str, help='which graph masking method to use')
28 | parser.add_argument('--mask_alg', default='linear', type=str, help='which graph masking algorithm in trn mask_method')
29 | parser.add_argument('--random_mask_rate', default=0.5, type=float, help='mask ratio in random mask_method')
30 | parser.add_argument('--act', default='leaky', type=str, help='activation function')
31 | parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation')
32 | parser.add_argument('--latdim', default=1024, type=int, help='latent dimensionality')
33 | parser.add_argument('--head', default=4, type=int, help='number of attention heads')
34 | parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not')
35 | parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn iterations')
36 | parser.add_argument('--gt_layer', default=4, type=int, help='number of graph transformer layers')
37 | parser.add_argument('--proj_method', default='svd', type=str, help='initial projection method')
38 | parser.add_argument('--mask', default='trn', type=str, help='indicating which mask strategy to apply')
39 | parser.add_argument('--loss', default='ce', type=str, help='loss function')
40 | parser.add_argument('--mask_bat', default=512, type=int, help='batch size for masking')
41 | parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer')
42 | parser.add_argument('--pred_iter', default=1, type=int, help='number of prediction iterations')
43 | parser.add_argument('--proj_trn_steps', default=10, type=int, help='number of training steps for one initial projection')
44 | return parser.parse_args()
45 | args = parse_args()
--------------------------------------------------------------------------------
/node_classification/Utils/TimeLogger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | logmsg = ''
4 | timemark = dict()
5 | saveDefault = False
6 | def log(msg, save=None, oneline=False):
7 | global logmsg
8 | global saveDefault
9 | time = datetime.datetime.now()
10 | tem = '%s: %s' % (time, msg)
11 | if save != None:
12 | if save:
13 | logmsg += tem + '\n'
14 | elif saveDefault:
15 | logmsg += tem + '\n'
16 | if oneline:
17 | print(tem, end='\r')
18 | else:
19 | print(tem)
20 |
21 | def marktime(marker):
22 | global timemark
23 | timemark[marker] = datetime.datetime.now()
24 |
25 |
26 | if __name__ == '__main__':
27 | log('')
--------------------------------------------------------------------------------
/node_classification/data_handler.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | from scipy.sparse import csr_matrix, coo_matrix, dok_matrix
4 | from params import args
5 | import scipy.sparse as sp
6 | from Utils.TimeLogger import log
7 | import torch as t
8 | import torch.utils.data as data
9 | import torch_geometric.transforms as T
10 | from model import InitialProjector
11 | import os
12 |
13 | class MultiDataHandler:
14 | def __init__(self, trn_datasets, tst_datasets):
15 | all_datasets = list(set(trn_datasets + tst_datasets))
16 | self.trn_handlers = []
17 | self.tst_handlers = []
18 | for data_name in all_datasets:
19 | trn_flag = data_name in trn_datasets
20 | tst_flag = data_name in tst_datasets
21 | handler = DataHandler(data_name, trn_flag, tst_flag)
22 | if trn_flag:
23 | self.trn_handlers.append(handler)
24 | if tst_flag:
25 | self.tst_handlers.append(handler)
26 |
27 | def make_joint_trn_loader(self):
28 | trn_data = TrnData(self.trn_handlers)
29 | self.trn_loader = data.DataLoader(trn_data, batch_size=1, shuffle=True, num_workers=0)
30 |
31 | def remake_initial_projections(self):
32 | for i in range(len(self.trn_handlers)):
33 | self.remake_one_initial_projection(i)
34 |
35 | def remake_one_initial_projection(self, idx):
36 | trn_handler = self.trn_handlers[idx]
37 | trn_handler.initial_projector = InitialProjector(trn_handler.asym_adj)
38 |
39 | class DataHandler:
40 | def __init__(self, data_name, trn_flag, tst_flag):
41 | self.data_name = data_name
42 | self.trn_flag = trn_flag
43 | self.tst_flag = tst_flag
44 | self.get_data_files()
45 | self.load_data()
46 |
47 | def get_data_files(self):
48 | predir = os.path.join(args.data_dir, self.data_name)
49 | self.adj_file = os.path.join(predir, 'adj_-1.pkl')
50 | self.label_file = os.path.join(predir, 'label.pkl')
51 | self.mask_file = os.path.join(predir, 'mask_-1.pkl')
52 |
53 | def load_one_file(self, filename):
54 | with open(filename, 'rb') as fs:
55 | ret = (pickle.load(fs)).astype(np.float32)
56 | if type(ret) != coo_matrix:
57 | ret = sp.coo_matrix(ret)
58 | return ret
59 |
60 | def normalize_adj(self, mat):
61 | degree = np.array(mat.sum(axis=-1))
62 | d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
63 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
64 | d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
65 | if mat.shape[0] == mat.shape[1]:
66 | return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo()
67 | else:
68 | tem = d_inv_sqrt_mat.dot(mat)
69 | col_degree = np.array(mat.sum(axis=0))
70 | d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1])
71 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
72 | d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
73 | return tem.dot(d_inv_sqrt_mat).tocoo()
74 |
75 | def load_feats(self, filename):
76 | try:
77 | with open(filename, 'rb') as fs:
78 | feats = pickle.load(fs)
79 | except Exception as e:
80 | print(filename + str(e))
81 | exit()
82 | return feats
83 |
84 | def unique_numpy(self, row, col):
85 | hash_vals = row * self.node_num + col
86 | hash_vals = np.unique(hash_vals).astype(np.int64)
87 | col = hash_vals % self.node_num
88 | row = (hash_vals - col).astype(np.int64) // self.node_num
89 | return row, col
90 |
91 | def make_torch_adj(self, mat):
92 | if mat.shape[0] == mat.shape[1]:
93 | # to symmetric
94 | if self.data_name in ['ddi']:
95 | _row = mat.row
96 | _col = mat.col
97 | row = np.concatenate([_row, _col]).astype(np.int64)
98 | col = np.concatenate([_col, _row]).astype(np.int64)
99 | # row, col = self.unique_numpy(row, col)
100 | data = mat.data
101 | data = np.concatenate([data, data]).astype(np.float32)
102 | else:
103 | row, col = mat.row, mat.col
104 | data = mat.data
105 | # data = np.ones_like(row)
106 | mat = coo_matrix((data, (row, col)), mat.shape)
107 | if args.selfloop == 1:
108 | mat = (mat + sp.eye(mat.shape[0])) * 1.0
109 | normed_asym_mat = self.normalize_adj(mat)
110 | row = t.from_numpy(normed_asym_mat.row).long()
111 | col = t.from_numpy(normed_asym_mat.col).long()
112 | idxs = t.stack([row, col], dim=0)
113 | vals = t.from_numpy(normed_asym_mat.data).float()
114 | shape = t.Size(normed_asym_mat.shape)
115 | asym_adj = t.sparse.FloatTensor(idxs, vals, shape)
116 | if mat.shape[0] == mat.shape[1]:
117 | return asym_adj, asym_adj
118 | else:
119 | # make ui adj
120 | a = sp.csr_matrix((self.user_num, self.user_num))
121 | b = sp.csr_matrix((self.item_num, self.item_num))
122 | mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
123 | mat = (mat != 0) * 1.0
124 | if args.selfloop == 1:
125 | mat = (mat + sp.eye(mat.shape[0])) * 1.0
126 | mat = self.normalize_adj(mat)
127 |
128 | # make cuda tensor
129 | idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
130 | vals = t.from_numpy(mat.data.astype(np.float32))
131 | shape = t.Size(mat.shape)
132 | return t.sparse.FloatTensor(idxs, vals, shape), asym_adj
133 |
134 | def load_data(self):
135 | self.adj = self.load_one_file(self.adj_file)
136 | self.labels = self.load_feats(self.label_file)
137 | if np.min(self.labels) != 0:
138 | log(f'Class label starts from {np.min(self.labels)}')
139 | self.labels -= np.min(self.labels)
140 | args.class_num = np.max(self.labels) + 1
141 | masks = self.load_feats(self.mask_file)
142 | self.trn_mask, self.val_mask, self.tst_mask = masks['train'], masks['valid'], masks['test']
143 |
144 | self.node_num = self.adj.shape[0]
145 | print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=self.node_num, edge_num=self.adj.nnz))
146 |
147 | self.torch_adj, self.asym_adj = self.make_torch_adj(self.adj)
148 | if args.cache_proj:
149 | self.asym_adj = self.asym_adj.to(args.devices[0])
150 | if args.cache_adj:
151 | self.torch_adj = self.torch_adj.to(args.devices[0])
152 |
153 | self.initial_projector = InitialProjector(self.asym_adj)
154 |
155 | if self.tst_flag:
156 | tst_data = NodeData(self.labels, self.tst_mask)
157 | self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)
158 |
159 | val_data = NodeData(self.labels, self.val_mask)
160 | self.val_loader = data.DataLoader(val_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)
161 |
162 | trn_data = NodeData(self.labels, self.trn_mask)
163 | self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0)
164 |
165 |
166 | class NodeData(data.Dataset):
167 | def __init__(self, labels, mask):
168 | self.iter_nodes = np.reshape(np.argwhere(np.array(mask) == True), -1)
169 | self.labels = labels[self.iter_nodes]
170 |
171 | def __len__(self):
172 | return len(self.iter_nodes)
173 |
174 | def __getitem__(self, idx):
175 | return self.iter_nodes[idx], self.labels[idx]# + args.node_num - args.class_num
176 |
177 | class TrnData(data.Dataset):
178 | def __init__(self, trn_handlers):
179 | self.dataset_num = len(trn_handlers)
180 | self.trn_handlers = trn_handlers
181 | self.ancs_list = [None] * self.dataset_num
182 | self.poss_list = [None] * self.dataset_num
183 | self.negs_list = [None] * self.dataset_num
184 | self.edge_nums = [None] * self.dataset_num
185 | self.sample_nums = [None] * self.dataset_num
186 | for i, handler in enumerate(self.trn_handlers):
187 | trn_mat = handler.trn_mat
188 | ancs = np.array(trn_mat.row)
189 | poss = np.array(trn_mat.col)
190 | self.ancs_list[i] = ancs
191 | self.poss_list[i] = poss
192 | self.edge_nums[i] = len(ancs)
193 | self.sample_nums[i] = self.edge_nums[i] // args.batch + (1 if self.edge_nums[i] % args.batch != 0 else 0)
194 | self.total_sample_num = sum(self.sample_nums)
195 | self.samples = [None] * self.total_sample_num
196 |
197 | def data_shuffling(self):
198 | sample_idx = 0
199 | for i in range(self.dataset_num):
200 | edge_num = self.edge_nums[i]
201 | perms = np.random.permutation(edge_num)
202 | handler = self.trn_handlers[i]
203 | asym_flag = handler.trn_mat.shape[0] != handler.trn_mat.shape[1]
204 | cand_num = handler.item_num if asym_flag else handler.node_num
205 | self.negs_list[i] = self.neg_sampling(self.ancs_list[i], handler.trn_mat.todok(), cand_num)
206 | # self.negs_list[i] = np.random.randint(cand_num, size=edge_num)
207 | for j in range(self.sample_nums[i]):
208 | st_idx = j * args.batch
209 | ed_idx = min((j + 1) * args.batch, edge_num)
210 | pick_idxs = perms[st_idx: ed_idx]
211 | ancs = self.ancs_list[i][pick_idxs]
212 | poss = self.poss_list[i][pick_idxs]
213 | negs = self.negs_list[i][pick_idxs]
214 | if asym_flag:
215 | poss += handler.user_num
216 | negs += handler.user_num
217 | self.samples[sample_idx] = (ancs, poss, negs, i)
218 | sample_idx += 1
219 | assert sample_idx == self.total_sample_num
220 |
221 | def neg_sampling(self, ancs, dokmat, cand_num):
222 | negs = np.zeros_like(ancs)
223 | for i in range(len(ancs)):
224 | u = ancs[i]
225 | while True:
226 | i_neg = np.random.randint(cand_num)
227 | if (u, i_neg) not in dokmat:
228 | break
229 | negs[i] = i_neg
230 | return negs
231 |
232 | def __len__(self):
233 | return self.total_sample_num
234 |
235 | def __getitem__(self, idx):
236 | ancs, poss, negs, adj_id = self.samples[idx]
237 | return ancs, poss, negs, adj_id
238 |
--------------------------------------------------------------------------------
/node_classification/main.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | from torch import nn
3 | import Utils.TimeLogger as logger
4 | from Utils.TimeLogger import log
5 | from params import args
6 | from model import OpenGraph, ALRS
7 | from data_handler import DataHandler, MultiDataHandler
8 | import numpy as np
9 | import pickle
10 | import os
11 | import setproctitle
12 | import time
13 | from sklearn.metrics import f1_score
14 |
15 | class Exp:
16 | def __init__(self, multi_handler):
17 | self.multi_handler = multi_handler
18 | self.metrics = dict()
19 | trn_mets = ['Loss', 'preLoss', 'Acc']
20 | tst_mets = ['Acc']
21 | mets = trn_mets + tst_mets
22 | for met in mets:
23 | if met in trn_mets:
24 | self.metrics['Train' + met] = list()
25 | else:
26 | for handler in self.multi_handler.tst_handlers:
27 | self.metrics['Test' + handler.data_name + met] = list()
28 |
29 | def make_print(self, name, ep, reses, save, data_name=None):
30 | if data_name is None:
31 | ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)
32 | else:
33 | ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name)
34 | for metric in reses:
35 | val = reses[metric]
36 | ret += '%s = %.4f, ' % (metric, val)
37 | tem = name + metric if data_name is None else name + data_name + metric
38 | if save and tem in self.metrics:
39 | self.metrics[tem].append(val)
40 | ret = ret[:-2] + ' '
41 | return ret
42 |
43 | def run(self):
44 | self.prepare_model()
45 | log('Model Prepared')
46 | stloc = 0
47 | if args.load_model != None:
48 | self.load_model()
49 | stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1)
50 |
51 | for handler in self.multi_handler.tst_handlers:
52 | # reses = self.test_epoch(handler.val_loader, handler)
53 | # log(self.make_print('Valid', args.epoch, reses, True, handler.data_name))
54 | res_summary = dict()
55 | times = 10
56 | for i in range(times):
57 | reses = self.test_epoch(handler.tst_loader, handler)
58 | log(self.make_print('Test', args.epoch, reses, False, handler.data_name))
59 | self.add_res_to_summary(res_summary, reses)
60 | self.multi_handler.remake_initial_projections()
61 | for key in res_summary:
62 | res_summary[key] /= times
63 | log(self.make_print('AVG', args.epoch, res_summary, False, handler.data_name))
64 | self.save_history()
65 |
66 | def add_res_to_summary(self, summary, res):
67 | for key in res:
68 | if key not in summary:
69 | summary[key] = 0
70 | summary[key] += res[key]
71 |
72 | def print_model_size(self):
73 | total_params = 0
74 | trainable_params = 0
75 | non_trainable_params = 0
76 | for param in self.model.parameters():
77 | tem = np.prod(param.size())
78 | total_params += tem
79 | if param.requires_grad:
80 | trainable_params += tem
81 | else:
82 | non_trainable_params += tem
83 | print(f'Total params: {total_params/1e6}')
84 | print(f'Trainable params: {trainable_params/1e6}')
85 | print(f'Non-trainable params: {non_trainable_params/1e6}')
86 |
87 | def prepare_model(self):
88 | self.model = OpenGraph()#.to(args.devices[1])#.cuda()
89 | t.cuda.empty_cache()
90 | self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0)
91 | self.lr_scheduler = ALRS(self.opt)
92 | self.print_model_size()
93 |
94 | def test_epoch(self, tst_loader, tst_handler):
95 | with t.no_grad():
96 | self.model.eval()
97 | ep_acc, ep_tot = 0, 0
98 | ep_tstnum = len(tst_loader.dataset)
99 | steps = max(ep_tstnum // args.tst_batch, 1)
100 | for i, batch_data in enumerate(tst_loader):
101 | nodes, labels = list(map(lambda x: x.long().to(args.devices[1]), batch_data))
102 | adj = tst_handler.torch_adj
103 | if args.cache_adj == 0:
104 | adj = adj.to(args.devices[0])
105 | initial_projector = tst_handler.initial_projector
106 | if args.cache_proj == 0:
107 | initial_projector = initial_projector.to(args.devices[0])
108 | preds = self.model.pred_for_node_test(nodes, adj, initial_projector, rerun_embed=False if i!=0 else True)
109 | if i == 0:
110 | all_preds, all_labels = preds, labels
111 | else:
112 | all_preds = t.concatenate([all_preds, preds])
113 | all_labels = t.concatenate([all_labels, labels])
114 | hit = (labels == preds).float().sum().item()
115 | ep_acc += hit
116 | ep_tot += labels.shape[0]
117 | log('Steps %d/%d: hit = %d, tot = %d ' % (i, steps, ep_acc, ep_tot), save=False, oneline=True)
118 | # t.cuda.empty_cache()
119 | ret = dict()
120 | ret['Acc'] = ep_acc / ep_tot
121 | ret['F1'] = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')
122 | t.cuda.empty_cache()
123 | return ret
124 |
125 | def calc_recall_ndcg(self, topLocs, tstLocs, batIds):
126 | assert topLocs.shape[0] == len(batIds)
127 | allRecall = allNdcg = 0
128 | for i in range(len(batIds)):
129 | temTopLocs = list(topLocs[i])
130 | temTstLocs = tstLocs[batIds[i]]
131 | tstNum = len(temTstLocs)
132 | maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))])
133 | recall = dcg = 0
134 | for val in temTstLocs:
135 | if val in temTopLocs:
136 | recall += 1
137 | dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2))
138 | recall = recall / tstNum
139 | ndcg = dcg / maxDcg
140 | allRecall += recall
141 | allNdcg += ndcg
142 | return allRecall, allNdcg
143 |
144 | def save_history(self):
145 | if args.epoch == 0:
146 | return
147 | with open('../History/' + args.save_path + '.his', 'wb') as fs:
148 | pickle.dump(self.metrics, fs)
149 |
150 | content = {
151 | 'model': self.model,
152 | }
153 | t.save(content, '../Models/' + args.save_path + '.mod')
154 | log('Model Saved: %s' % args.save_path)
155 |
156 | def load_model(self):
157 | ckp = t.load('../Models/' + args.load_model + '.mod')
158 | self.model = ckp['model'].to(args.devices[1])
159 | self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0)
160 |
161 | with open('../History/' + args.load_model + '.his', 'rb') as fs:
162 | self.metrics = pickle.load(fs)
163 | log('Model Loaded')
164 |
165 | if __name__ == '__main__':
166 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
167 | if len(args.gpu.split(',')) > 1:
168 | args.devices = ['cuda:0', 'cuda:1']
169 | else:
170 | args.devices = ['cuda:0', 'cuda:0']
171 | args.devices = list(map(lambda x: t.device(x), args.devices))
172 | logger.saveDefault = True
173 | setproctitle.setproctitle('OpenGraph')
174 |
175 | log('Start')
176 | trn_datasets = tst_datasets = [args.tstdata]
177 | trn_datasets = list(set(trn_datasets))
178 | tst_datasets = list(set(tst_datasets))
179 | multi_handler = MultiDataHandler(trn_datasets, tst_datasets)
180 | log('Load Data')
181 |
182 | exp = Exp(multi_handler)
183 | exp.run()
184 |
--------------------------------------------------------------------------------
/node_classification/model.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from params import args
5 | import numpy as np
6 | from Utils.TimeLogger import log
7 | from torch.nn import MultiheadAttention
8 | from time import time
9 |
10 | init = nn.init.xavier_uniform_
11 | uniformInit = nn.init.uniform_
12 |
13 | class InitialProjector(nn.Module):
14 | def __init__(self, adj, input_is_embeds=False):
15 | super(InitialProjector, self).__init__()
16 |
17 | if input_is_embeds:
18 | projection = adj
19 | if args.cache_proj:
20 | projection = projection.to(args.devices[0])
21 | else:
22 | projection = projection.cpu()
23 | self.proj_embeds = nn.Parameter(projection)
24 | t.cuda.empty_cache()
25 | return
26 | if args.proj_method == 'uniform':
27 | self.proj_embeds = nn.Parameter(self.uniform_proj(adj))
28 | elif args.proj_method == 'lowrank_uniform':
29 | self.proj_embeds = nn.Parameter(self.lowrank_uniform_proj(adj))
30 | elif args.proj_method == 'svd':
31 | self.proj_embeds = nn.Parameter(self.svd_proj(adj))
32 | elif args.proj_method == 'both':
33 | self.proj_embeds = nn.Parameter(self.uniform_proj(adj) + self.svd_proj(adj))
34 | elif args.proj_method == 'id':
35 | self.proj_embeds = nn.Parameter(self.id_proj(adj))
36 | else:
37 | raise Exception('Unrecognized Initial Embedding')
38 | t.cuda.empty_cache()
39 |
40 | def uniform_proj(self, adj):
41 | node_num = adj.shape[0] if adj.shape[0] == adj.shape[1] else adj.shape[0] + adj.shape[1]
42 | projection = init(t.empty(node_num, args.latdim))
43 | if args.cache_proj:
44 | projection = projection.to(args.devices[0])
45 | return projection
46 |
47 | def id_proj(self, adj):
48 | node_num = adj.shape[0] if adj.shape[0] == adj.shape[1] else adj.shape[0] + adj.shape[1]
49 | return t.eye(node_num)
50 |
51 | def lowrank_uniform_proj(self, adj):
52 | node_num = adj.shape[0] + adj.shape[1]
53 | rank = 16
54 | projection1 = init(t.empty(node_num, rank))
55 | projection2 = init(t.empty(rank, args.latdim))
56 | projection = projection1 @ projection2
57 | if args.cache_proj:
58 | projection = projection.to(args.devices[0])
59 | return projection
60 |
61 | def svd_proj(self, adj):
62 | if not args.cache_proj:
63 | adj = adj.to(args.devices[0])
64 | q = args.latdim
65 | if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]:
66 | dim = min(adj.shape[0], adj.shape[1])
67 | svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter)
68 | svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
69 | svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
70 | s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])])
71 | else:
72 | svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter)
73 | svd_u = svd_u @ t.diag(t.sqrt(s))
74 | svd_v = svd_v @ t.diag(t.sqrt(s))
75 | if adj.shape[0] != adj.shape[1]:
76 | projection = t.concat([svd_u, svd_v], dim=0)
77 | else:
78 | projection = svd_u + svd_v
79 | if not args.cache_proj:
80 | projection = projection.cpu()
81 | return projection
82 |
83 | def forward(self):
84 | return ((self.proj_embeds))#.cuda()#[perms, :]
85 |
86 | class TopoEncoder(nn.Module):
87 | def __init__(self):
88 | super(TopoEncoder, self).__init__()
89 |
90 | self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False)#, dtype=t.bfloat16)
91 |
92 | def forward(self, adj, embeds):
93 | embeds = self.layer_norm(embeds)
94 | embeds_list = []
95 | if args.gnn_layer == 0:
96 | embeds_list.append(embeds)
97 | for i in range(args.gnn_layer):
98 | embeds = t.spmm(adj, embeds)
99 | embeds_list.append(embeds)
100 | embeds = sum(embeds_list)
101 | # embeds = t.concat([embeds_list[-1][:user_num], embeds_list[-2][user_num:]], dim=0)
102 | embeds = embeds#.to(t.bfloat16)
103 | return embeds
104 |
105 | class GraphTransformer(nn.Module):
106 | def __init__(self):
107 | super(GraphTransformer, self).__init__()
108 | self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)])
109 |
110 | def forward(self, embeds):
111 | for i, layer in enumerate(self.gt_layers):
112 | embeds = layer(embeds) / args.scale_layer
113 | return embeds
114 |
115 | class GTLayer(nn.Module):
116 | def __init__(self):
117 | super(GTLayer, self).__init__()
118 | self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16)
119 | self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False
120 | self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
121 | self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
122 | self.fc_dropout = nn.Dropout(p=args.drop_rate)
123 |
124 | def _attention(self, anchor_embeds, embeds):
125 | q_embeds = t.einsum('ne,ehd->nhd', anchor_embeds, self.Q)
126 | k_embeds = t.einsum('ne,ehd->nhd', embeds, self.K)
127 | v_embeds = t.einsum('ne,ehd->nhd', embeds, self.V)
128 | att = t.einsum('khd,nhd->knh', q_embeds, k_embeds) / np.sqrt(args.latdim / args.head)
129 | att = t.softmax(att, dim=1)
130 | res = t.einsum('knh,nhd->khd', att, v_embeds).reshape([-1, args.latdim])
131 | res = self.att_linear(res)
132 | return res
133 |
134 | def _pick_anchors(self, embeds):
135 | perm = t.randperm(embeds.shape[0])
136 | anchors = perm[:args.anchor]
137 | return embeds[anchors]
138 |
139 | def print_nodewise_std(self, embeds):
140 | mean = embeds.mean(0)
141 | std = (embeds - mean).square().mean(0).sqrt().mean()
142 | print(embeds)
143 | print(std.item())
144 |
145 | def forward(self, embeds):
146 | anchor_embeds = self._pick_anchors(embeds)
147 | _anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds)
148 | anchor_embeds = _anchor_embeds + anchor_embeds
149 | _embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False)
150 | embeds = self.layer_norm1(_embeds + embeds)
151 | _embeds = self.fc_dropout(self.dense_layers(embeds))
152 | embeds = (self.layer_norm2(_embeds + embeds))
153 | return embeds
154 |
155 | class FeedForwardLayer(nn.Module):
156 | def __init__(self, in_feat, out_feat, bias=True, act=None):
157 | super(FeedForwardLayer, self).__init__()
158 | self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16)
159 | if act == 'identity' or act is None:
160 | self.act = None
161 | elif act == 'leaky':
162 | self.act = nn.LeakyReLU(negative_slope=args.leaky)
163 | elif act == 'relu':
164 | self.act = nn.ReLU()
165 | elif act == 'relu6':
166 | self.act = nn.ReLU6()
167 | else:
168 | raise Exception('Error')
169 |
170 | def forward(self, embeds):
171 | if self.act is None:
172 | return self.linear(embeds)
173 | return self.act(self.linear(embeds))
174 |
175 | class Masker(nn.Module):
176 | def __init__(self):
177 | super(Masker, self).__init__()
178 |
179 | def forward(self, adj, edges):
180 | if args.mask_method is None or args.mask_method == 'none':
181 | return adj
182 | elif args.mask_method == 'trn':
183 | node_num = adj.shape[0] + adj.shape[1]
184 | rows = adj._indices()[0, :]
185 | cols = adj._indices()[1, :]
186 | pck_rows, pck_cols = edges
187 |
188 | hashvals = rows * node_num + cols
189 | pck_hashvals1 = pck_rows * node_num + pck_cols
190 | pck_hashvals2 = pck_cols * node_num + pck_rows
191 | pck_hashvals = t.concat([pck_hashvals1, pck_hashvals2])
192 |
193 | if args.mask_alg == 'cross':
194 | masked_hashvals = self._mask_by_cross(hashvals, pck_hashvals)
195 | elif args.mask_alg == 'linear':
196 | masked_hashvals = self._mask_by_linear(hashvals, pck_hashvals)
197 |
198 | cols = masked_hashvals % node_num
199 | rows = t.div((masked_hashvals - cols).long(), node_num, rounding_mode='trunc').long()
200 |
201 | adj = t.sparse.FloatTensor(t.stack([rows, cols], dim=0), t.ones_like(rows, dtype=t.float32).to(args.devices[0]), adj.shape)
202 | return self._normalize_adj(adj)
203 | elif args.mask_method == 'random':
204 | return self._random_mask_edge(adj)
205 |
206 | def _mask_by_cross(self, hashvals, pck_hashvals):
207 | for i in range(args.batch * 2 // args.mask_bat):
208 | bat_pck_hashvals = pck_hashvals[i * args.mask_bat: (i+1) * args.mask_bat]
209 | idct = (hashvals.view([-1, 1]) - bat_pck_hashvals.view([1, -1]) == 0).sum(-1).bool()
210 | hashvals = hashvals[t.logical_not(idct)]
211 | return hashvals
212 |
213 | def _mask_by_linear(self, hashvals, pck_hashvals):
214 | hashvals = t.unique(hashvals)
215 | pck_hashvals = t.unique(pck_hashvals)
216 | hashvals = t.concat([hashvals, pck_hashvals])
217 | hashvals, counts = t.unique(hashvals, return_counts=True)
218 | hashvals = hashvals[counts==1]
219 | return hashvals
220 |
221 | def _random_mask_edge(self, adj):
222 | if args.random_mask_rate == 0.0:
223 | return adj
224 | vals = adj._values()
225 | idxs = adj._indices()
226 | edgeNum = vals.size()
227 | mask = ((t.rand(edgeNum) + 1.0 - args.random_mask_rate).floor()).type(t.bool)
228 | newIdxs = idxs[:, mask]
229 | newVals = t.ones(newIdxs.shape[1]).to(args.devices[0]).float()
230 | return self._normalize_adj(t.sparse.FloatTensor(newIdxs, newVals, adj.shape))
231 |
232 | def _normalize_adj(self, adj):
233 | row_degree = t.pow(t.sparse.sum(adj, dim=1).to_dense(), 0.5)
234 | col_degree = t.pow(t.sparse.sum(adj, dim=0).to_dense(), 0.5)
235 | newRows, newCols = adj._indices()[0, :], adj._indices()[1, :]
236 | rowNorm, colNorm = row_degree[newRows], col_degree[newCols]
237 | newVals = adj._values() / rowNorm / colNorm
238 | return t.sparse.FloatTensor(adj._indices(), newVals, adj.shape)
239 |
240 | class OpenGraph(nn.Module):
241 | def __init__(self):
242 | super(OpenGraph, self).__init__()
243 | self.topoEncoder = TopoEncoder().to(args.devices[0])
244 | self.graphTransformer = GraphTransformer().to(args.devices[1])
245 | self.masker = Masker().to(args.devices[0])
246 |
247 | def forward(self, adj, initial_projector, user_num):
248 | topo_embeds = self.topoEncoder(adj, initial_projector(), user_num).to(args.devices[1])
249 | final_embeds = self.graphTransformer(topo_embeds)
250 | return final_embeds
251 |
252 | def pred_norm(self, pos_preds, neg_preds):
253 | pos_preds_num = pos_preds.shape[0]
254 | neg_preds_shape = neg_preds.shape
255 | preds = t.concat([pos_preds, neg_preds.view(-1)])
256 | preds = preds - preds.max()
257 | pos_preds = preds[:pos_preds_num]
258 | neg_preds = preds[pos_preds_num:].view(neg_preds_shape)
259 | return pos_preds, neg_preds
260 |
261 | def cal_loss(self, batch_data, adj, initial_projector):
262 | ancs, poss, negs = batch_data
263 | with t.no_grad():
264 | masked_adj = self.masker(adj, (ancs.to(args.devices[0]), (poss.to(args.devices[0]))))
265 | initial_embeds = initial_projector()
266 | topo_embeds = self.topoEncoder(masked_adj, initial_embeds).to(args.devices[1])
267 | ancs, poss, negs = ancs.to(args.devices[1]), poss.to(args.devices[1]), negs.to(args.devices[1])
268 | input_seq = t.concat([ancs, poss, negs])
269 | input_seq = topo_embeds[input_seq]
270 | final_embeds = self.graphTransformer(input_seq)
271 | anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds[:ancs.shape[0] * 3], [ancs.shape[0]] * 3)
272 | # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs]
273 | if final_embeds.isinf().any() or final_embeds.isnan().any():
274 | raise Exception('Final embedding fails')
275 |
276 | pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T)
277 | if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any():
278 | raise Exception('Preds fails')
279 | pos_loss = pos_preds
280 | neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log()
281 | pre_loss = -(pos_loss - neg_loss).mean()
282 |
283 | if t.isinf(pre_loss).any() or t.isnan(pre_loss).any():
284 | raise Exception('NaN or Inf')
285 |
286 | reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters())))
287 | loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()}
288 | return pre_loss + reg_loss, loss_dict
289 |
290 | def cal_loss_node(self, batch_data, adj, initial_projector):
291 | ancs, labels = batch_data
292 | poss = labels + adj.shape[0] - args.class_num
293 | negs = t.from_numpy(np.array(list(range(args.class_num)))).to(t.int64).cuda() + adj.shape[0] - args.class_num
294 | with t.no_grad():
295 | masked_adj = self.masker(adj, (ancs.to(args.devices[0]), (poss.to(args.devices[0]))))
296 | initial_embeds = initial_projector()
297 | topo_embeds = self.topoEncoder(masked_adj, initial_embeds).to(args.devices[1])
298 | ancs, poss, negs = ancs.to(args.devices[1]), poss.to(args.devices[1]), negs.to(args.devices[1])
299 | input_seq = t.concat([ancs, poss, negs])
300 | input_seq = topo_embeds[input_seq]
301 | final_embeds = self.graphTransformer(input_seq)
302 | # anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds[:ancs.shape[0] * 3], [ancs.shape[0]] * 3)
303 | anc_embeds = final_embeds[:ancs.shape[0]]
304 | pos_embeds = final_embeds[ancs.shape[0]:ancs.shape[0]+poss.shape[0]]
305 | neg_embeds = final_embeds[-negs.shape[0]:]
306 | # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs]
307 | if final_embeds.isinf().any() or final_embeds.isnan().any():
308 | raise Exception('Final embedding fails')
309 |
310 | pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T)
311 | if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any():
312 | raise Exception('Preds fails')
313 | pos_loss = pos_preds
314 | neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log()
315 | pre_loss = -(pos_loss - neg_loss).mean()
316 |
317 | if t.isinf(pre_loss).any() or t.isnan(pre_loss).any():
318 | raise Exception('NaN or Inf')
319 |
320 | reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters())))
321 | preds = anc_embeds @ neg_embeds.T
322 | loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean(), 'preds': t.argmax(preds, dim=-1)}
323 | return pre_loss + reg_loss, loss_dict
324 |
325 | def pred_for_node_test(self, nodes, adj, initial_projector, rerun_embed=True):
326 | if rerun_embed:
327 | final_embeds = self.graphTransformer(self.topoEncoder(adj, initial_projector()).to(args.devices[1]))
328 | self.final_embeds = final_embeds
329 | final_embeds = self.final_embeds
330 | pck_embeds = final_embeds[nodes]
331 | class_embeds = final_embeds[-args.class_num:]
332 | preds = pck_embeds @ class_embeds.T
333 | return t.argmax(preds, dim=-1)
334 |
335 | class ALRS:
336 | def __init__(self, optimizer, loss_threshold=0.01, loss_ratio_threshold=0.01, decay_rate=0.97):
337 | self.optimizer = optimizer
338 | self.loss_threshold = loss_threshold
339 | self.decay_rate = decay_rate
340 | self.loss_ratio_threshold = loss_ratio_threshold
341 | self.last_loss = 1e9
342 |
343 | def step(self, loss):
344 | delta = self.last_loss - loss
345 | if delta < self.loss_threshold and delta / self.last_loss < self.loss_ratio_threshold:
346 | for group in self.optimizer.param_groups:
347 | group['lr'] *= self.decay_rate
348 | self.last_loss = loss
349 |
--------------------------------------------------------------------------------
/node_classification/params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser(description='Model Parameters')
5 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
6 | parser.add_argument('--batch', default=1024, type=int, help='training batch size')
7 | parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)')
8 | parser.add_argument('--shot', default=-1, type=int, help='number of shots for each node')
9 | parser.add_argument('--epoch', default=50, type=int, help='number of epochs')
10 | parser.add_argument('--tune_step', default=100, type=int, help='number of projection tuning steps')
11 | parser.add_argument('--save_path', default='tem', help='file name to save model and training record')
12 | parser.add_argument('--load_model', default=None, help='model name to load')
13 | parser.add_argument('--tstdata', default='', type=str, help='name of test dataset')
14 | parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training')
15 | parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use')
16 | parser.add_argument('--topk', default=20, type=int, help='topk in evaluation')
17 | parser.add_argument('--cache_adj', default=0, type=int, help='indicates wheter cache bidirectional adjs')
18 | parser.add_argument('--cache_proj', default=1, type=int, help='indicates wheter cache projector and matrices')
19 | parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps')
20 | parser.add_argument('--data_dir', default='../datasets', type=str, help='dataset directory')
21 |
22 | parser.add_argument('--niter', default=2, type=int, help='number of iteration in svd')
23 | parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer')
24 | parser.add_argument('--drop_rate', default=0.1, type=float, help='dropout rate')
25 | parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor')
26 | parser.add_argument('--clamp', default=-1, type=float, help='absolute value for the limit of prediction scores while training')
27 | parser.add_argument('--mask_method', default='trn', type=str, help='which graph masking method to use')
28 | parser.add_argument('--mask_alg', default='linear', type=str, help='which graph masking algorithm in trn mask_method')
29 | parser.add_argument('--random_mask_rate', default=0.5, type=float, help='mask ratio in random mask_method')
30 | parser.add_argument('--act', default='leaky', type=str, help='activation function')
31 | parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation')
32 | parser.add_argument('--latdim', default=1024, type=int, help='latent dimensionality')
33 | parser.add_argument('--head', default=4, type=int, help='number of attention heads')
34 | parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not')
35 | parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn iterations')
36 | parser.add_argument('--gt_layer', default=4, type=int, help='number of graph transformer layers')
37 | parser.add_argument('--proj_method', default='svd', type=str, help='initial projection method')
38 | parser.add_argument('--mask', default='trn', type=str, help='indicating which mask strategy to apply')
39 | parser.add_argument('--loss', default='ce', type=str, help='loss function')
40 | parser.add_argument('--mask_bat', default=512, type=int, help='batch size for masking')
41 | parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer')
42 | parser.add_argument('--pred_iter', default=1, type=int, help='number of prediction iterations')
43 | parser.add_argument('--proj_trn_steps', default=10, type=int, help='number of training steps for one initial projection')
44 | return parser.parse_args()
45 | args = parse_args()
--------------------------------------------------------------------------------