├── .gitattributes ├── .gitignore ├── Hypergraph.jpeg ├── LICENSE ├── README.md ├── SIGIR2023_STHGCN.pdf ├── STHGCN.png ├── conf ├── README.md ├── ablation_conf │ ├── ca_wo_hyper_collab.yml │ ├── ca_wo_hypergraph.yml │ ├── ca_wo_st_info.yml │ ├── nyc_wo_hyper_collab.yml │ ├── nyc_wo_hypergraph.yml │ ├── nyc_wo_st_info.yml │ ├── tky_wo_hyper_collab.yml │ ├── tky_wo_hypergraph.yml │ └── tky_wo_st_info.yml └── best_conf │ ├── ca.yml │ ├── nyc.yml │ └── tky.yml ├── data ├── ca │ ├── raw.zip │ └── us_state_polygon_json.json ├── nyc │ └── raw.zip └── tky │ └── raw.zip ├── dataset ├── __init__.py └── lbsn_dataset.py ├── generate_ca_raw.py ├── layer ├── __init__.py ├── conv.py ├── embedding_layer.py ├── sampler.py └── st_encoder.py ├── metric ├── __init__.py └── rank_metric.py ├── model ├── __init__.py ├── seq_transformer.py └── sthgcn.py ├── multiple_run.py ├── preprocess ├── __init__.py ├── file_reader.py ├── generate_hypergraph.py ├── preprocess_fn.py └── preprocess_main.py ├── requirements.txt ├── run.py └── utils ├── __init__.py ├── conf_util.py ├── math_util.py ├── pipeline_util.py └── sys_util.py /.gitattributes: -------------------------------------------------------------------------------- 1 | data/ca/raw.zip filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Added by user 3 | .DS_Store 4 | data/nyc/raw 5 | data/tky/raw 6 | data/ca/raw 7 | data/nyc/preprocessed 8 | data/tky/preprocessed 9 | data/ca/preprocessed 10 | log/ 11 | checkpoint/ 12 | tensorboard/ 13 | logs/ 14 | test/ 15 | .idea/ 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | -------------------------------------------------------------------------------- /Hypergraph.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/Spatio-Temporal-Hypergraph-Model/27b595846d29019799485985bff49f9ed02c4ade/Hypergraph.jpeg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STHGCN 2 | This repository includes the implementation details of the method STHGCN and helps readers to reproduce the results in the paper 3 | **Spatio-Temporal Hypergraph Learning for Next POI Recommendation**, whose objective is to utilize hypergraph convolution networks to 4 | model diverse user behaviors in the next-POI recommendation task. 5 | 6 | In particular, we introduce a hypergraph to construct the complex 7 | structure of check-ins and trajectories. We develop hypergraph transformer layers to capture high-order heterogeneous inter-user 8 | and intra-user trajectories correlations while incorporating spatio-temporal contexts. 9 | Comprehensive experiments conducted on three real-world datasets have demonstrated the superiority of STHGCN in the task of next-POI recommendation, 10 | outperforming baseline models by a large margin. For more information, please see our paper **Spatio-Temporal Hypergraph Learning for Next POI Recommendation** (Yan *et al. 2023*). 11 | 12 | + Multi-Level Hypergraph 13 | ![Multi-Level Hypergraph](Hypergraph.jpeg) 14 | 15 | + Model Framework 16 | ![STHGCN Overall Framework](STHGCN.png) 17 | 18 | ## Installation 19 | 1. Clone the repository (If showing error of no permission, need to first [add a new SSH key to your GitHub account](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/adding-a-new-ssh-key-to-your-github-account).): 20 | ```shell 21 | git clone https://github.com/ant-research/Spatio-Temporal-Hypergraph-Model.git 22 | ``` 23 | 2. The repository has some important dependencies below, and we don't guarantee that the code still works if using higher versions. 24 | Please refer to the respective page to install: 25 | + [Pytorch](https://pytorch.org/) == 1.7.0 26 | + [pytorch_geometric](https://github.com/pyg-team/pytorch_geometric#installation) == 1.7.2 27 | + torch-scatter == 2.0.7 28 | + torch-sparse == 0.6.9 29 | + torch-cluster == 1.5.9 30 | + torch-spline-conv == 1.2.1 31 | 3. Install other dependencies in `requirements.txt`: 32 | ```shell 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## Hardware 37 | Here are the minimum requirements of the hardware including CPU and GPU. 38 | 39 | + CPU: core 16, memory 30GB 40 | + GPU: Tesla V100 (16GB) 41 | 42 | Important messages you need to know: 43 | 1. For Foursquare-TKY dataset, if the first value of hyperparameter sizes is more than 300 such as "400-240", please use 32GB GPU; 44 | 2. If you use other type of GPU like Tesla P100, it may produce different results even if you use the best configuration file with the same random seed. 45 | 46 | ## Dataset 47 | The Foursquare-NYC dataset is collected in New York city and the Foursquare-TKY dataset is collected in Tokyo over 11 months 48 | (from April 2012 to February 2013) from Foursquare. The Gowalla-CA dataset with a longer time period (from March 2009 to October 2010) and 49 | broader geographical coverage are collected in California and Nevada on the Gowalla platform. 50 | 51 | ### Preprocess 52 | The preprocess step includes creating train/validate/test sample files and generating hypergraph-related PyG data files. 53 | The preprocess step will be excuted the first time we train our model using `run.py`, this step will be skipped if we 54 | already preprocess the same data before. 55 | 56 | To get the raw data: 57 | + Unzip `data/nyc/raw.zip` to `data/nyc`. Then you will get `raw` directory with three .csv files, including trainginng sample file `NYC_train.csv`, 58 | validating sample file `NYC_val.csv` and testing sample file `NYC_test.csv`. 59 | + Unzip `data/tky/raw.zip` to `data/tky`. Then you will get `raw` directory with file `dataset_TSMC2014_TKY.txt` containing 60 | all the checkin information. 61 | + Unzip `data/ca/raw.zip` to `data/ca`. Then you will get `raw` directory with file `loc-gowalla_totalCheckins.txt` containing raw checkin records, 62 | `gowalla_spots_subset1.csv` containing category-involved checkin records, and `us_state_polygon_json.json` containing the POI 63 | polygon of every state in U.S. To get the raw data of CA used for preprocessing including the correct category information, run 64 | ```shell 65 | python generate_ca_raw.py 66 | ``` 67 | 68 | If you want to compare your model with our work in the same preprocess setting, 69 | we strongly suggest to directly use the `sample.csv`, `train_sample.csv`, `validate_sample.csv`, and `test_sample.csv` 70 | sample files in the preprocessed directories. 71 | 72 | ### Statistical Information 73 | After preprocessing (some works only show the statistics before preprocessing), the key statistics of the three dtasets are shown below. 74 | The first 6 columns are calculated based on `sample.csv` which is all the samples before removing the unseen user or poi. Meanwhile, the last 75 | 3 columns are calculated based on `train_sample.csv`, `validate_sample.csv` and `test_sample.csv`. 76 | 77 | | Dataset Name | #user | #poi | #category | #check-in | #trajectory | #training sample| #validation sample| #testing sample| 78 | |----------------|-------|---------|-----------|-----------|-------------|-----------------|-------------------|----------------| 79 | | Foursquare-NYC | 1,048 | 4,981 | 318 | 103,941 | 14,130 | 72,206 | 1,400 | 1,347 | 80 | | Foursquare-TKY | 2,282 | 7,833 | 290 | 405,000 | 65,499 | 274,597 | 6,868 | 7,038 | 81 | | Gowalla-CA | 3,957 | 9,690 | 296 | 238,369 | 45,123 | 154,253 | 3,529 | 2,780 | 82 | 83 | ### Original Link 84 | In case of some readers feel confused from the data provided by our work and by other works, here we introduce where and how our data comes from. 85 | 86 | Actually, we suffered a lot from searching for the valid data from previous works. [STAN](https://github.com/yingtaoluo/Spatial-Temporal-Attention-Network-for-POI-Recommendation) 87 | provides the **Raw** data of NYC, TKY and CA, but these data lack category information. [GETNext](https://github.com/songyangme/GETNext) only 88 | provides the **Preprocessed** data of NYC, while TKY and CA are missing. We use the raw NYC data from STAN's link, and preprocess them based on the 89 | description from GETNext. Unfortunately, there still be a minor gap between our preprocessed NYC data with what is provided by GETNext. 90 | 91 | For fair comparison of different models, for NYC, we download the preprocessed files from GETNext; For TKY and CA, we download them from the link provided 92 | by STAN, we also fetch the category information from [IRenMF](https://dl.acm.org/doi/10.1145/2661829.2662002). We run GETNext model with our preprocessed data and 93 | the performances are only a little worse (~0.01 on Acc@1) than what are reported in thier paper, so we just use the performances reported in their paper. 94 | 95 | Thanks for all the data providers. 96 | 97 | + NYC: 98 | + Preprocessed: https://github.com/songyangme/GETNext/blob/master/dataset/NYC.zip 99 | + TKY: 100 | + Raw: http://www-public.imtbs-tsp.eu/~zhang_da/pub/dataset_tsmc2014.zip 101 | + CA: 102 | + Raw: http://snap.stanford.edu/data/loc-gowalla.html; 103 | + Category information: https://www.yongliu.org/datasets.html 104 | 105 | ## Main Experimental Results 106 | To reproduce the main results in our paper. Please follow the steps below. 107 | 108 | ### Main Performance 109 | To know the meaning of every config in yaml file, please refer to [conf/README.md](https://github.com/ant-research/Spatio-Temporal-Hypergraph-Model/blob/main/conf/README.md). 110 | 111 | We can reproduce the best performance of our model with the script below. Please choose 'nyc', 'tky', 112 | or 'ca' for *{dataset_name}*. 113 | ```shell 114 | python run.py -f best_conf/{dataset_name}.yml 115 | ``` 116 | 117 | | Dataset Name | Acc@1 | Acc@5 | Acc@10 | MRR | #Parameters | Training Speed
(per epoch)| 118 | |--------------|--------|--------|--------|--------|-------------|------------------------------| 119 | | NYC | 0.2734 | 0.5361 | 0.6244 | 0.3915 | 27,820,020 | 3m24s | 120 | | TKY | 0.2950 | 0.5207 | 0.5980 | 0.3986 | 30,167,576 | 59m31s | 121 | | CA | 0.1730 | 0.3529 | 0.4191 | 0.2558 | 31,810,778 | 15m40s | 122 | 123 | The average performances of 10 runs can be achived using the script below. `-n` denotes the total number of experiments. `-g` 124 | denotes the gpu id (default 0). 125 | ```shell 126 | python multiple_run.py -f best_conf/{dataset_name}.yml -n 10 -g 0 127 | ``` 128 | 129 | | Dataset Name | Acc@1 | Acc@5 | Acc@10 | MRR | 130 | |--------------|-----------------|-----------------|-----------------|-----------------| 131 | | NYC | 0.2625 ± 0.0054 | 0.5226 ± 0.0033 | 0.6117 ± 0.0044 | 0.3798 ± 0.0041 | 132 | | TKY | 0.2905 ± 0.0035 | 0.5184 ± 0.0035 | 0.5969 ± 0.0030 | 0.3951 ± 0.0025 | 133 | | CA | 0.1652 ± 0.0036 | 0.3405 ± 0.0041 | 0.4177 ± 0.0038 | 0.2491 ± 0.0026 | 134 | 135 | ### Ablation Study 136 | Based on the configuration for best model in `conf/best_conf/`, we only modify some key configs to do ablation study. 137 | 138 | + For *w/o Hypergraph*, we set `model_name: seq_transformer`. It's worth to mension that we use the same NeighborSampler and 139 | transform the adjacency list into sequential input for Transformer model. So the inputs are the same with vanilla Transformer. Please run 140 | ```shell 141 | python multiple_run.py -f ablation_conf/{dataset_name}_wo_hypergraph.yml -n 10 -g 0 142 | ``` 143 | + For *w/o ST Information*, we set `time_fusion_mode: `. Please run 144 | ```shell 145 | python multiple_run.py -f ablation_conf/{dataset_name}_wo_st_info.yml -n 10 -g 0 146 | ``` 147 | + For *w/o Hyperedge Collaboration*, we set `do_traj2traj: False` 148 | ```shell 149 | python multiple_run.py -f ablation_conf/{dataset_name}_wo_hyper_collab.yml -n 10 -g 0 150 | ``` 151 | 152 | The ablation results are list below, which are consistent though slightly different with Table 4 in our paper: 153 | 154 | + NYC: 155 | 156 | | | Acc@1 | Acc@5 | Acc@10 | MRR | 157 | |-----------------------------|-----------------|-----------------|-----------------|-----------------| 158 | | Full Model | 0.2625 ± 0.0054 | 0.5226 ± 0.0033 | 0.6117 ± 0.0044 | 0.3798 ± 0.0041 | 159 | | w/o Hypergraph | 0.2391 ± 0.0068 | 0.5137 ± 0.0094 | 0.6069 ± 0.0091 | 0.3618 ± 0.0050 | 160 | | w/o ST Information | 0.2332 ± 0.0048 | 0.5113 ± 0.0039 | 0.6091 ± 0.0069 | 0.3591 ± 0.0034 | 161 | | w/o Hyperedge Collaboration | 0.2490 ± 0.0058 | 0.5048 ± 0.0098 | 0.5885 ± 0.0052 | 0.3641 ± 0.0055 | 162 | 163 | + TKY: 164 | 165 | | | Acc@1 | Acc@5 | Acc@10 | MRR | 166 | |-----------------------------|-----------------|-----------------|-----------------|-----------------| 167 | | Full Model | 0.2905 ± 0.0035 | 0.5184 ± 0.0035 | 0.5969 ± 0.0030 | 0.3951 ± 0.0025 | 168 | | w/o Hypergraph | 0.2368 ± 0.0018 | 0.4453 ± 0.0016 | 0.5222 ± 0.0015 | 0.3337 ± 0.0014 | 169 | | w/o ST Information | 0.2629 ± 0.0051 | 0.4941 ± 0.0039 | 0.5770 ± 0.0030 | 0.3689 ± 0.0037 | 170 | | w/o Hyperedge Collaboration | 0.2455 ± 0.0027 | 0.4589 ± 0.0031 | 0.5361 ± 0.0026 | 0.3446 ± 0.0018 | 171 | 172 | + CA: 173 | 174 | | | Acc@1 | Acc@5 | Acc@10 | MRR | 175 | |-----------------------------|-----------------|-----------------|-----------------|-----------------| 176 | | Full Model | 0.1652 ± 0.0036 | 0.3405 ± 0.0041 | 0.4177 ± 0.0038 | 0.2491 ± 0.0026 | 177 | | w/o Hypergraph | 0.1476 ± 0.0031 | 0.3146 ± 0.0024 | 0.3859 ± 0.0051 | 0.2270 ± 0.0028 | 178 | | w/o ST Information | 0.1578 ± 0.0042 | 0.3242 ± 0.0037 | 0.4021 ± 0.0054 | 0.2384 ± 0.0032 | 179 | | w/o Hyperedge Collaboration | 0.1538 ± 0.0028 | 0.3227 ± 0.0046 | 0.3917 ± 0.0045 | 0.2341 ± 0.0028 | 180 | 181 | ## Tensorboard 182 | All the measurements and visualizations can be displayed via tensorboard tool. The tensorboard files are 183 | in `tensorboard` directory. 184 | 185 | ```shell 186 | tensorboard --logdir {tensorboard_directory} 187 | ``` 188 | 189 | If you want to analysis the experimental results such as measuring the mean and 190 | standard deviation of 10 runs of the same hyper-parameter setting, you can just download the **TABLE VIEW** data as csv files 191 | under **HPARAMS** tab. 192 | 193 | ## Citation 194 | If you compare with, build on, or use aspects of the STHGCN, please cite the following: 195 | 196 | ```text 197 | @inproceedings{sigir2023sthgcn, 198 | title={Spatio-Temporal Hypergraph Learning for Next POI Recommendation}, 199 | author={Yan, Xiaodong, and Song, Tengwei and Jiao, Yifeng and He, Jianshan and Wang, Jiaotuan and Li, Ruopeng and Chu, Wei}, 200 | booktitle={Proceedings of the 46th International ACM SIGIR Conference on Research and Development in Information Retrieval}, 201 | year={2023}, 202 | series={SIGIR '23} 203 | } 204 | ``` 205 | -------------------------------------------------------------------------------- /SIGIR2023_STHGCN.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/Spatio-Temporal-Hypergraph-Model/27b595846d29019799485985bff49f9ed02c4ade/SIGIR2023_STHGCN.pdf -------------------------------------------------------------------------------- /STHGCN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/Spatio-Temporal-Hypergraph-Model/27b595846d29019799485985bff49f9ed02c4ade/STHGCN.png -------------------------------------------------------------------------------- /conf/README.md: -------------------------------------------------------------------------------- 1 | | group | argument | definition | 2 | |-------|----------|------------| 3 | |dataset_args|dataset_name|dataset name, choose from 'nyc', 'tky', and 'ca'| 4 | | |min_poi_freq|the least value of one poi's checkin records, if less than or equal to this value, we will remove this poi| 5 | | |min_user_freq|the least value of one user's checkin records, if less than or equal to this value, we will remove this user| 6 | | |session_time_interval| the time interval of consecutive checkin records in every trajectory should be larger than or equal to this value | 7 | | |threshold| the similarity threshold of two trajectories when building hypergraph, if less than this value, we will remove this traj2traj relation | 8 | | |filter_mode| the similarity metric, choose from 'jaccard' and 'min size'| 9 | | |num_spatial_slots|the total number of slots for continuous distance value| 10 | | |spatial_slot_type|construct distance slots automatically based on min, max value of distance, choose from 'linear' and 'exp'| 11 | | |do_label_encode| whether to encode the id via LabelEncoder| 12 | | |only_last_metric| whether to use only the last checkin of every trajectory as sample to evaluate our model| 13 | | |max_d_epsilon| add this value to maximum distance to avoid bugs| 14 | |model_args|model_name|model name, choose from 'sthgcn' (our model) and 'seq_transformer' (for ablation study)| 15 | | |intra_jaccard_threshold|the intra-user similarity threshold for hyperedge2hyperedge collaboration, only keep those collaborations whose similarities are larger than this value| 16 | | |inter_jaccard_threshold|the inter-user similarity threshold for hyperedge2hyperedge collaboration, only keep those collaborations whose similarities are larger than this value 17 | | |sizes|sample size for different hops, the last element is for checkin2trajectory, other elements is for multi-hop trajectory2trajectory. e.g. sizes=[10, 20, 30], [10,20] is for traj2traj 2-hop sampling, [30] is for ci2traj.| 18 | | |dropout_rate|the dropout rate| 19 | | |num_edge_type|the total number of edge type| 20 | | |generate_edge_attr|whether to generate edge attr embedding based on edge type| 21 | | |embed_fusion_type|embedding fusion type, choose from 'concat' and 'add'| 22 | | |embed_size|the embedding size of id embedding and the hidden representation of trajectory| 23 | | |st_embed_size|the embedding size of spatial and temporal embedding| 24 | | |activation|the activation function, choose from 'elu', 'relu', 'leaky_relu' and 'tanh'| 25 | | |phase_factor|phase factor for time encoder| 26 | | |use_linear_trans|whether to use linear transformation before output for time encoder| 27 | | |do_traj2traj|whether to use hyperedge2hyperedge collaboration| 28 | | |distance_encoder_type|encoder type of distance, choose from 'time', 'hstlstm', 'stan' and 'simple'. Specially, 'time' means using the TimeEncoder to handle distance value| 29 | | |quantile|clip the maximum distance value with clip(0, max_d*quantile), should modify the code in dataset/lbsn_dataset to make this work| 30 | |conv_args|num_attention_heads|the total number of attention heads| 31 | | |residual_beta|the residual weight of initial representation for gated residual module| 32 | | |learn_beta|whether to learn residual beta automatically| 33 | | |conv_dropout_rate|the dropout rate for hypergraph transformer| 34 | | |trans_method|the translation method of message assembler, choose from 'corr', 'sub', 'add', 'multi' and 'concat'| 35 | | |edge_fusion_mode|the fusion mode of edge vector, choose from 'concat' and 'add'| 36 | | |head_fusion_mode|the fusion mode of multi-head, choose from 'concat' and 'add'| 37 | | |time_fusion_mode|the fusion mode of time vector, choose from 'concat' and 'add'| 38 | | |residual_fusion_mode|the fusion mode of gated residual module, choose from 'concat' and 'add'| 39 | | |negative_slope|the negative slope for leaky_relu activation function| 40 | |run_args|seed|random seed for generate random number, and reproduce the experiments. Not used for multiple-run setting.| 41 | | |gpu|gpu index, use cpu if set -1| 42 | | |batch_size|training batch size| 43 | | |eval_batch_size|evaluation batch size| 44 | | |learning_rate|the learning rate| 45 | | |do_train|whether to do training| 46 | | |do_validate|whether to do validation| 47 | | |do_test|whether to do testing| 48 | | |warm_up_steps|the warm up steps with constant initial learning rate| 49 | | |cooldown_rate|the cooldown rate for learning rate schedualing, make the learning rate approximate an exponential decay curve with respect to the global steps| 50 | | |max_steps|the max steps for training| 51 | | |epoch|the training epoch| 52 | | |valid_steps|do evaluating every valid_steps| 53 | | |num_workers|the total number of workers for dataloader| 54 | | |init_checkpoint|the checkpoint path| 55 | |seq_transformer_args|only works when `model_args.name==seq_transformer`| | 56 | | |sequence_length |the max length of the sequences| 57 | | |header_num|the head number of multi-head| 58 | | |encoder_layers_num|the total number of encoder layers| 59 | | |hidden_size|the embedding size of hidden representation| 60 | | |dropout|the dropout rate| 61 | | |do_positional_encoding|whether to use positional encoding| 62 | -------------------------------------------------------------------------------- /conf/ablation_conf/ca_wo_hyper_collab.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: ca 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: sthgcn 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.01 17 | sizes: 300-600 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: False 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: add 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 14000 58 | cooldown_rate: 1.4 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 500 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /conf/ablation_conf/ca_wo_hypergraph.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: ca 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: seq_transformer 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.01 17 | sizes: 300-600 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: True 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: add 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 14000 58 | cooldown_rate: 1.4 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 500 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /conf/ablation_conf/ca_wo_st_info.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: ca 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: sthgcn 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.01 17 | sizes: 300-600 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: True 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 14000 58 | cooldown_rate: 1.4 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 500 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /conf/ablation_conf/nyc_wo_hyper_collab.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: nyc 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: sthgcn 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.005 17 | sizes: 300-500 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: False 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: add 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 8000 58 | cooldown_rate: 1.5 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 500 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /conf/ablation_conf/nyc_wo_hypergraph.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: nyc 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: seq_transformer 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.005 17 | sizes: 300-500 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: True 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: add 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 29364979 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 8000 58 | cooldown_rate: 1.5 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 500 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /conf/ablation_conf/nyc_wo_st_info.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: nyc 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: sthgcn 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.005 17 | sizes: 300-500 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: True 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 8000 58 | cooldown_rate: 1.5 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 500 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /conf/ablation_conf/tky_wo_hyper_collab.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: tky 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: sthgcn 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.01 17 | sizes: 400-240 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: False 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: add 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 48000 58 | cooldown_rate: 1.5 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 4000 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /conf/ablation_conf/tky_wo_hypergraph.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: tky 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: seq_transformer 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.01 17 | sizes: 400-240 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: True 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: add 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 48000 58 | cooldown_rate: 1.5 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 4000 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /conf/ablation_conf/tky_wo_st_info.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: tky 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: sthgcn 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.01 17 | sizes: 400-240 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: True 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 48000 58 | cooldown_rate: 1.5 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 4000 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /conf/best_conf/ca.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: ca 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: sthgcn 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.01 17 | sizes: 300-600 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: True 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: add 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 27486607 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 14000 58 | cooldown_rate: 1.4 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 500 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /conf/best_conf/nyc.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: nyc 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: sthgcn 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.005 17 | sizes: 300-500 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: True 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: add 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 80786525 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 8000 58 | cooldown_rate: 1.5 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 500 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /conf/best_conf/tky.yml: -------------------------------------------------------------------------------- 1 | dataset_args: 2 | dataset_name: tky 3 | min_poi_freq: 9 4 | min_user_freq: 9 5 | session_time_interval: 1440 6 | threshold: 0.005 7 | filter_mode: jaccard 8 | num_spatial_slots: 3000 9 | spatial_slot_type: linear 10 | do_label_encode: True 11 | only_last_metric: True 12 | max_d_epsilon: 5 13 | model_args: 14 | model_name: sthgcn 15 | intra_jaccard_threshold: 0.0 16 | inter_jaccard_threshold: 0.01 17 | sizes: 400-240 18 | dropout_rate: 0.1 19 | num_edge_type: 2 20 | generate_edge_attr: True 21 | embed_fusion_type: concat 22 | embed_size: 128 23 | st_embed_size: 128 24 | activation: relu 25 | phase_factor: 5 26 | use_linear_trans: True 27 | do_traj2traj: True 28 | distance_encoder_type: time 29 | quantile: 0.85 30 | seq_transformer_args: 31 | sequence_length: 20 32 | header_num: 2 33 | encoder_layers_num: 2 34 | hidden_size: 512 35 | dropout: 0.3 36 | do_positional_encoding: True 37 | conv_args: 38 | num_attention_heads: 4 39 | residual_beta: 0.5 40 | learn_beta: false 41 | conv_dropout_rate: 0.1 42 | trans_method: add 43 | edge_fusion_mode: add 44 | head_fusion_mode: concat 45 | time_fusion_mode: add 46 | residual_fusion_mode: add 47 | negative_slope: 0.2 48 | run_args: 49 | seed: 54607333 50 | gpu: 0 51 | batch_size: 64 52 | eval_batch_size: 64 53 | learning_rate: 0.0001 54 | do_train: True 55 | do_validate: True 56 | do_test: True 57 | warm_up_steps: 48000 58 | cooldown_rate: 1.5 59 | max_steps: 100000 60 | epoch: 20 61 | valid_steps: 4000 62 | num_workers: 4 63 | init_checkpoint: 64 | -------------------------------------------------------------------------------- /data/ca/raw.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3ab6e3f7a545f4f3826121aa3a3f84593b8b2ca299fec7293d9727dcb1784c5f 3 | size 183298955 4 | -------------------------------------------------------------------------------- /data/nyc/raw.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/Spatio-Temporal-Hypergraph-Model/27b595846d29019799485985bff49f9ed02c4ade/data/nyc/raw.zip -------------------------------------------------------------------------------- /data/tky/raw.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/Spatio-Temporal-Hypergraph-Model/27b595846d29019799485985bff49f9ed02c4ade/data/tky/raw.zip -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from dataset.lbsn_dataset import LBSNDataset 2 | 3 | __all__ = ["LBSNDataset"] 4 | -------------------------------------------------------------------------------- /dataset/lbsn_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import pandas as pd 4 | import os.path as osp 5 | from utils import get_root_dir, construct_slots 6 | 7 | 8 | class LBSNDataset: 9 | def __init__(self, cfg): 10 | self.data_path = osp.join(get_root_dir(), 'data', cfg.dataset_args.dataset_name, 'preprocessed') 11 | self.padding_poi_id, self.padding_poi_category, self.padding_user_id = 0, 0, 0 12 | self.padding_hour_id, self.padding_weekday_id = 0, 0 13 | self.num_user, self.num_poi, self.num_category, self.num_checkin, self.num_traj = 0, 0, 0, 0, 0 14 | 15 | df, df_train, df_valid, df_test, ci2traj, traj2traj = self.read() 16 | self.x = [traj2traj.x, ci2traj.x] 17 | self.edge_index = [traj2traj.edge_index, ci2traj.edge_index] 18 | self.edge_attr = [traj2traj.edge_attr, None] 19 | self.edge_t = [None, ci2traj.edge_t] 20 | self.edge_delta_t = [traj2traj.edge_delta_t, ci2traj.edge_delta_t] 21 | self.edge_delta_s = [traj2traj.edge_delta_s, ci2traj.edge_delta_s] 22 | self.edge_type = [traj2traj.edge_type, None] 23 | 24 | self.checkin_offset = torch.as_tensor([df.check_ins_id.max() + 1], dtype=torch.long) 25 | self.node_idx_train = self.get_node_id(df_train) 26 | self.node_idx_valid = self.get_node_id(df_valid) 27 | self.node_idx_test = self.get_node_id(df_test) 28 | self.max_time_train = self.get_max_time(df_train) 29 | self.max_time_valid = self.get_max_time(df_valid) 30 | self.max_time_test = self.get_max_time(df_test) 31 | self.label_train = self.get_label(df_train) 32 | self.label_valid = self.get_label(df_valid) 33 | self.label_test = self.get_label(df_test) 34 | self.sample_idx_train = self.get_sample_id(df_train) 35 | self.sample_idx_valid = self.get_sample_id(df_valid) 36 | self.sample_idx_test = self.get_sample_id(df_test) 37 | 38 | self.min_d, self.max_d = 1e8, 0. 39 | delta_s = torch.cat([ci2traj.edge_delta_s, traj2traj.edge_delta_s], dim=0) 40 | 41 | self.min_d = min(self.min_d, delta_s.min()) 42 | self.max_d_chj2traj = max(self.max_d, ci2traj.edge_delta_s.max()) 43 | self.max_d_tj2traj = max(self.max_d, traj2traj.edge_delta_s.max()) 44 | self.max_d_tj2traj += cfg.dataset_args.max_d_epsilon 45 | 46 | if cfg.model_args.distance_encoder_type == 'hstlstm': 47 | self.spatial_slots = construct_slots( 48 | self.min_d, 49 | self.max_d, 50 | cfg.dataset_args.num_spatial_slots, 51 | cfg.dataset_args.spatial_slot_type 52 | ) 53 | else: 54 | self.spatial_slots = self.min_d, self.max_d_chj2traj, self.max_d_tj2traj 55 | 56 | logging.info(f'[Initialize Dataset] #user: {self.num_user}') 57 | logging.info(f'[Initialize Dataset] #poi: {self.num_poi}') 58 | logging.info(f'[Initialize Dataset] #category: {self.num_category}') 59 | logging.info(f'[Initialize Dataset] #checkin: {self.num_checkin}') 60 | logging.info(f'[Initialize Dataset] #trajectory: {self.num_traj}') 61 | logging.info(f'[Initialize Dataset] #training_sample: {self.sample_idx_train.shape[0]}') 62 | logging.info(f'[Initialize Dataset] #validation_sample: {self.sample_idx_valid.shape[0]}') 63 | logging.info(f'[Initialize Dataset] #testing_sample: {self.sample_idx_test.shape[0]}') 64 | 65 | def read(self): 66 | df = pd.read_csv(osp.join(self.data_path, 'sample.csv')).reset_index(drop=True) 67 | le_data = pd.read_pickle(osp.join(self.data_path, 'label_encoding.pkl')) 68 | self.padding_poi_id = le_data[5] 69 | self.padding_poi_category = le_data[6] 70 | self.padding_user_id = le_data[7] 71 | self.padding_hour_id = le_data[8] 72 | self.padding_weekday_id = le_data[9] 73 | 74 | self.num_user = df['UserId'].nunique() 75 | self.num_poi = df['PoiId'].nunique() 76 | self.num_category = df['PoiCategoryId'].nunique() 77 | self.num_checkin = df.shape[0] 78 | self.num_traj = df['pseudo_session_trajectory_id'].nunique() 79 | 80 | df_train = pd.read_csv(osp.join(self.data_path, 'train_sample.csv'), sep=',') 81 | df_valid = pd.read_csv(osp.join(self.data_path, 'validate_sample.csv'), sep=',') 82 | df_test = pd.read_csv(osp.join(self.data_path, 'test_sample.csv'), sep=',') 83 | 84 | ci2traj = torch.load(osp.join(self.data_path, 'ci2traj_pyg_data.pt')) 85 | traj2traj = torch.load(osp.join(self.data_path, 'traj2traj_pyg_data.pt')) 86 | 87 | return df, df_train, df_valid, df_test, ci2traj, traj2traj 88 | 89 | def get_node_id(self, df): 90 | query_id = torch.tensor(df.query_pseudo_session_trajectory_id, dtype=torch.long) 91 | node_id = query_id + self.checkin_offset 92 | return node_id 93 | 94 | @staticmethod 95 | def get_max_time(df): 96 | max_time = torch.tensor(df.last_checkin_epoch_time, dtype=torch.long) 97 | return max_time 98 | 99 | @staticmethod 100 | def get_label(df): 101 | poi_id = torch.tensor(df.PoiId, dtype=torch.long) 102 | cate_id = torch.tensor(df.PoiCategoryId, dtype=torch.long) 103 | longitude = torch.tensor(df.Longitude, dtype=torch.float) 104 | latitude = torch.tensor(df.Latitude, dtype=torch.float) 105 | time_hour = torch.tensor(pd.to_datetime(df['UTCTimeOffset']).dt.hour / 24, dtype=torch.float) 106 | y = torch.stack([poi_id, cate_id, longitude, latitude, time_hour], dim=-1) 107 | return y 108 | 109 | @staticmethod 110 | def get_sample_id(df): 111 | sample_id = torch.tensor(df.index, dtype=torch.long) 112 | return sample_id 113 | -------------------------------------------------------------------------------- /generate_ca_raw.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pandas as pd 3 | import json 4 | import shapely 5 | from utils import get_root_dir 6 | 7 | data_path = osp.join(get_root_dir(), 'data', 'ca', 'raw') 8 | raw_checkins = pd.read_csv(osp.join(data_path, 'loc-gowalla_totalCheckins.txt'), sep='\t', header=None) 9 | raw_checkins.columns = ['userid', 'datetime', 'checkins_lat', 'checkins_lng', 'id'] 10 | subset1 = pd.read_csv(osp.join(data_path, 'gowalla_spots_subset1.csv')) 11 | raw_checkins_subset1 = raw_checkins.merge(subset1, on='id') 12 | 13 | with open(osp.join(data_path, 'us_state_polygon_json.json'), 'r') as f: 14 | us_state_polygon = json.load(f) 15 | 16 | for i in us_state_polygon['features']: 17 | if i['properties']['name'].lower() == 'california': 18 | california = shapely.polygons(i['geometry']['coordinates'][0]) 19 | if i['properties']['name'].lower() == 'nevada': 20 | nevada = shapely.polygons(i['geometry']['coordinates'][0]) 21 | 22 | # check if the checkin took place in California or Nevada 23 | raw_checkins_subset1['is_ca'] = raw_checkins_subset1.apply( 24 | lambda x: nevada.intersects( 25 | shapely.geometry.Point(x['checkins_lng'], x['checkins_lat'])) or california.intersects( 26 | shapely.geometry.Point(x['checkins_lng'], x['checkins_lat'])), axis=1 27 | ) 28 | raw_checkins_subset1 = raw_checkins_subset1[raw_checkins_subset1['is_ca']] 29 | 30 | df = raw_checkins_subset1[['userid', 'id', 'spot_categories', 'checkins_lat', 'checkins_lng', 'datetime']] 31 | df.columns = ['UserId', 'PoiId', 'PoiCategoryId', 'Latitude', 'Longitude', 'UTCTime'] 32 | df.to_csv(osp.join(data_path, 'dataset_gowalla_ca_ne.csv'), index=False) 33 | -------------------------------------------------------------------------------- /layer/__init__.py: -------------------------------------------------------------------------------- 1 | from layer.conv import HypergraphTransformer 2 | from layer.sampler import NeighborSampler 3 | from layer.embedding_layer import ( 4 | CheckinEmbedding, 5 | EdgeEmbedding 6 | ) 7 | from layer.st_encoder import ( 8 | PositionEncoder, 9 | TimeEncoder, 10 | DistanceEncoderHSTLSTM, 11 | DistanceEncoderSTAN, 12 | DistanceEncoderSimple 13 | ) 14 | 15 | 16 | __all__ = [ 17 | "HypergraphTransformer", 18 | "NeighborSampler", 19 | "PositionEncoder", 20 | "CheckinEmbedding", 21 | "EdgeEmbedding", 22 | "TimeEncoder", 23 | "DistanceEncoderHSTLSTM", 24 | "DistanceEncoderSTAN", 25 | "DistanceEncoderSimple" 26 | ] 27 | -------------------------------------------------------------------------------- /layer/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union, Tuple, Optional 3 | from torch import Tensor, cat 4 | from torch.nn import init, Parameter, Linear, LayerNorm 5 | import torch.nn.functional as F 6 | from torch_sparse import SparseTensor 7 | from torch_geometric.typing import OptPairTensor, Adj, OptTensor 8 | from torch_geometric.nn.conv import MessagePassing 9 | from torch_geometric.utils import softmax 10 | from utils import ccorr 11 | 12 | 13 | class HypergraphTransformer(MessagePassing): 14 | r"""Hypergraph Conv containing relation transform、edge fusion(including time fusion)、 15 | self attention and gated residual connection(or skip connection). 16 | 17 | .. math:: 18 | \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + 19 | \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j}, 20 | where the attention coefficients :math:`\alpha_{i,j}` are computed via 21 | """ 22 | 23 | def __init__( 24 | self, 25 | in_channels: Union[int, Tuple[int, int]], 26 | out_channels: int, 27 | attn_heads: int = 4, 28 | residual_beta: Optional[float] = None, 29 | learn_beta: bool = False, 30 | dropout: float = 0., 31 | negative_slope: float = 0.2, 32 | bias: bool = True, 33 | trans_method: str = 'add', 34 | edge_fusion_mode: str = 'add', 35 | time_fusion_mode: str = None, 36 | head_fusion_mode: str = 'concat', 37 | residual_fusion_mode: str = None, 38 | edge_dim: int = None, 39 | rel_embed_dim: int = None, 40 | time_embed_dim: int = 0, 41 | dist_embed_dim: int = 0, 42 | normalize: bool = True, 43 | message_mode: str = 'node_edge', 44 | have_query_feature: bool = False, 45 | **kwargs 46 | ): 47 | super(HypergraphTransformer, self).__init__(aggr='add', node_dim=0, **kwargs) 48 | 49 | self.in_channels = in_channels 50 | self.out_channels = out_channels 51 | self.attn_heads = attn_heads 52 | self.learn_beta = learn_beta 53 | self.residual_beta = residual_beta 54 | self.dropout = dropout 55 | self.negative_slope = negative_slope 56 | self.trans_method = trans_method 57 | self.edge_dim = edge_dim 58 | self.rel_embed_dim = rel_embed_dim 59 | self.time_embed_dim = time_embed_dim 60 | self.dist_embed_dim = dist_embed_dim 61 | self.edge_fusion_mode = edge_fusion_mode 62 | self.time_fusion_mode = time_fusion_mode 63 | self.head_fusion_mode = head_fusion_mode 64 | self.residual_fusion_mode = residual_fusion_mode 65 | self.normalize = normalize 66 | self.message_mode = message_mode 67 | self.trans_flag = False 68 | self.have_query_feature = have_query_feature 69 | 70 | if isinstance(in_channels, int): 71 | in_channels = (in_channels, in_channels) 72 | self.in_channels = in_channels 73 | if in_channels[0] != out_channels and self.trans_method == 'add': 74 | self.trans_flag = True 75 | self.lin_trans_x = Linear(in_channels[0], in_channels[1]) 76 | 77 | if not self.have_query_feature: 78 | # V -> hyperedge (E) situation 79 | self.att_r = Parameter(Tensor(1, attn_heads, out_channels)) 80 | 81 | self.attn_in_dim, self.attn_out_dim = self._check_attn_dim(in_channels[1], out_channels) 82 | 83 | self.lin_key = Linear(self.attn_in_dim, attn_heads * out_channels) 84 | self.lin_query = Linear(in_channels[1], attn_heads * out_channels) 85 | if self.message_mode == 'node_edge': 86 | self.lin_value = Linear(self.attn_in_dim, attn_heads * out_channels) 87 | else: 88 | self.lin_value = Linear(in_channels[1], attn_heads * out_channels) 89 | 90 | if self.residual_fusion_mode == 'concat': 91 | self.lin_ffn_0 = Linear(in_channels[1] + self.attn_out_dim, out_channels + 128) 92 | self.lin_ffn_1 = Linear(out_channels + 128, out_channels) 93 | elif residual_fusion_mode == 'add': 94 | if head_fusion_mode == 'concat': 95 | self.lin_ffn_1 = Linear(attn_heads * out_channels, out_channels, bias=bias) 96 | self.lin_skip = Linear(in_channels[0], attn_heads * out_channels, bias=bias) 97 | if learn_beta: 98 | self.lin_beta = Linear(3 * attn_heads * out_channels, 1, bias=False) 99 | else: 100 | self.lin_skip = Linear(in_channels[0], out_channels, bias=bias) 101 | if learn_beta: 102 | self.lin_beta = Linear(3 * out_channels, 1, bias=False) 103 | else: 104 | self.lin_ffn_0 = Linear(self.attn_out_dim, out_channels + 128) 105 | self.lin_ffn_1 = Linear(out_channels + 128, out_channels) 106 | if self.head_fusion_mode == 'add': 107 | self.layer_norm = LayerNorm(out_channels) 108 | else: 109 | self.layer_norm = LayerNorm(out_channels * attn_heads) 110 | 111 | self.reset_parameters() 112 | 113 | def reset_parameters(self): 114 | if self.trans_flag: 115 | self.lin_trans_x.reset_parameters() 116 | self.lin_key.reset_parameters() 117 | self.lin_query.reset_parameters() 118 | self.lin_value.reset_parameters() 119 | if not self.have_query_feature: 120 | init.xavier_uniform_(self.att_r) 121 | if self.residual_fusion_mode == 'add': 122 | self.lin_skip.reset_parameters() 123 | if self.head_fusion_mode == 'concat': 124 | self.lin_ffn_1.reset_parameters() 125 | if self.learn_beta: 126 | self.lin_beta.reset_parameters() 127 | else: 128 | self.lin_ffn_0.reset_parameters() 129 | self.lin_ffn_1.reset_parameters() 130 | if not self.residual_fusion_mode: 131 | self.layer_norm.reset_parameters() 132 | 133 | # the edge_type are stored as edge_index value 134 | def forward( 135 | self, 136 | x: Union[Tensor, OptPairTensor], 137 | edge_index: Adj, 138 | edge_time_embed: Tensor, 139 | edge_dist_embed: Tensor, 140 | edge_type_embed: Tensor, 141 | edge_attr_embed: Tensor, 142 | ): 143 | 144 | if isinstance(x, Tensor): 145 | x: OptPairTensor = (x, x) 146 | 147 | if isinstance(edge_index, SparseTensor): 148 | out = self.propagate( 149 | edge_index, 150 | x=(x[0][edge_index.storage.col()], x[1][edge_index.storage.row()]), 151 | edge_attr_embed=edge_attr_embed, 152 | edge_time_embed=edge_time_embed, 153 | edge_dist_embed=edge_dist_embed, 154 | edge_type_embed=edge_type_embed, 155 | have_query_feature=self.have_query_feature, 156 | size=None 157 | ) 158 | else: 159 | out = self.propagate( 160 | edge_index, 161 | x=(x[0][edge_index[0]], x[1][edge_index[1]]), 162 | edge_attr_embed=edge_attr_embed, 163 | edge_time_embed=edge_time_embed, 164 | edge_dist_embed=edge_dist_embed, 165 | edge_type_embed=edge_type_embed, 166 | have_query_feature=self.have_query_feature, 167 | size=None 168 | ) 169 | 170 | if not self.have_query_feature: 171 | out += self.att_r 172 | 173 | if self.head_fusion_mode == 'concat': 174 | out = out.view(-1, self.attn_heads * self.out_channels) 175 | else: 176 | out = out.mean(dim=1) 177 | 178 | # todo dont use two FC before out 179 | if self.residual_fusion_mode == 'concat': 180 | out = cat([out, x[1]], dim=-1) 181 | out = self.lin_ffn_0(out) 182 | out = F.relu(out) 183 | out = self.lin_ffn_1(out) 184 | elif self.residual_fusion_mode == 'add': 185 | x_skip = self.lin_skip(x[1]) 186 | 187 | if self.learn_beta: 188 | beta = self.lin_beta(cat([out, x_skip, out - x_skip], -1)) 189 | beta = beta.sigmoid() 190 | out = beta * x_skip + (1 - beta) * out 191 | else: 192 | if self.residual_beta is not None: 193 | out = self.residual_beta * x_skip + (1 - self.residual_beta) * out 194 | else: 195 | out += x_skip 196 | if self.head_fusion_mode == 'concat': 197 | out = self.lin_ffn_1(out) 198 | else: 199 | out = self.layer_norm(out) 200 | out = self.lin_ffn_0(out) 201 | out = F.relu(out) 202 | out = self.lin_ffn_1(out) 203 | if self.normalize: 204 | out = F.normalize(out, p=2., dim=-1) 205 | return out 206 | 207 | def message( 208 | self, 209 | x: OptPairTensor, 210 | edge_attr_embed: Tensor, 211 | edge_time_embed: Tensor, 212 | edge_dist_embed: Tensor, 213 | edge_type_embed: Tensor, 214 | index: Tensor, 215 | ptr: OptTensor, 216 | have_query_feature: bool, 217 | size_i: Optional[int] 218 | ) -> Tensor: 219 | x_j, x_i = x 220 | 221 | if self.trans_flag: 222 | if have_query_feature: 223 | x_i = self.lin_trans_x(x_i) 224 | x_j_raw = self.lin_trans_x(x_j) 225 | x_j = self.lin_trans_x(x_j) 226 | else: 227 | x_j_raw = x_j 228 | if edge_type_embed is not None: 229 | x_j = self.rel_transform(x_j, edge_type_embed) 230 | 231 | if self.time_fusion_mode == 'concat': 232 | x_j = cat([x_j, edge_time_embed, edge_dist_embed], dim=-1) 233 | elif self.time_fusion_mode == 'add': 234 | x_j += edge_time_embed + edge_dist_embed 235 | 236 | if edge_attr_embed is not None: 237 | if self.edge_fusion_mode == 'concat': 238 | x_j = cat([x_j, edge_attr_embed], dim=-1) 239 | else: 240 | x_j += edge_attr_embed 241 | 242 | key = self.lin_key(x_j).view(-1, self.attn_heads, self.out_channels) 243 | if not have_query_feature: 244 | query = self.att_r 245 | alpha = (key * query).sum(dim=-1) 246 | alpha = F.leaky_relu(alpha, self.negative_slope) 247 | else: 248 | query = self.lin_query(x_i).view(-1, self.attn_heads, self.out_channels) 249 | alpha = (query * key).sum(dim=-1) / math.sqrt(self.out_channels) 250 | 251 | alpha = softmax(alpha, index, ptr, size_i) 252 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 253 | 254 | if self.message_mode == 'node_edge': 255 | out = self.lin_value(x_j).view(-1, self.attn_heads, self.out_channels) 256 | else: 257 | out = self.lin_value(x_j_raw).view(-1, self.attn_heads, self.out_channels) 258 | 259 | out *= alpha.view(-1, self.attn_heads, 1) 260 | 261 | return out 262 | 263 | def rel_transform(self, ent_embed, edge_type_embed): 264 | if self.trans_method == "corr": 265 | trans_embed = ccorr(ent_embed, edge_type_embed) 266 | elif self.trans_method == "sub": 267 | trans_embed = ent_embed - edge_type_embed 268 | elif self.trans_method == "multi": 269 | trans_embed = ent_embed * edge_type_embed 270 | elif self.trans_method == 'add': 271 | trans_embed = ent_embed + edge_type_embed 272 | elif self.trans_method == 'concat': 273 | trans_embed = cat([ent_embed, edge_type_embed], dim=1) 274 | else: 275 | raise NotImplementedError 276 | return trans_embed 277 | 278 | def _check_attn_dim(self, in_channels, out_channels): 279 | attn_in_dim = in_channels 280 | attn_out_dim = out_channels * self.attn_heads if self.head_fusion_mode == 'concat' else out_channels 281 | if self.trans_method == 'concat': 282 | attn_in_dim += self.rel_embed_dim 283 | else: 284 | assert attn_in_dim == self.rel_embed_dim, \ 285 | "[HypergraphTransformer >> Translation Error] Node embedding dimension {} is not equal with relation " \ 286 | "embedding dimension {} when you are using '{}' translation method" \ 287 | ".".format(attn_in_dim, self.rel_embed_dim, self.trans_method) 288 | 289 | if self.time_fusion_mode: 290 | if self.time_fusion_mode == 'concat': 291 | attn_in_dim += self.time_embed_dim + self.dist_embed_dim 292 | else: 293 | assert attn_in_dim == self.time_embed_dim, \ 294 | "[HypergraphTransformer >> Time Fusion Error] Time embedding dimension {} is " \ 295 | "not equal with edge fusion result embedding dimension {} when you are using '{}' " \ 296 | "time fusion mode.".format(self.time_embed_dim, attn_in_dim, self.time_fusion_mode) 297 | assert attn_in_dim == self.dist_embed_dim, \ 298 | "[HypergraphTransformer >> Time Fusion Error] Time embedding dimension {} is " \ 299 | "not equal with edge fusion result embedding dimension {} when you are using '{}' " \ 300 | "time fusion mode.".format(self.dist_embed_dim, attn_in_dim, self.time_fusion_mode) 301 | 302 | if self.edge_fusion_mode == 'concat' and self.edge_dim is not None: 303 | attn_in_dim += self.edge_dim 304 | else: 305 | if self.edge_dim is not None: 306 | assert attn_in_dim == self.edge_dim, \ 307 | "[HypergraphTransformer >> Edge Fusion Error] Edge embedding dimension {} is " \ 308 | "not equal with translation result embedding dimension {} when you are using '{}' " \ 309 | "edge fusion mode.".format(self.edge_dim, attn_in_dim, self.edge_fusion_mode) 310 | return attn_in_dim, attn_out_dim 311 | 312 | def __repr__(self): 313 | return '{}(in_channels={}, out_channels={}, attn_heads={})'.format( 314 | self.__class__.__name__, self.in_channels, self.out_channels, self.attn_heads) 315 | -------------------------------------------------------------------------------- /layer/embedding_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CheckinEmbedding(nn.Module): 6 | def __init__( 7 | self, 8 | embed_size, 9 | fusion_type, 10 | dataset_args 11 | ): 12 | super(CheckinEmbedding, self).__init__() 13 | self.embed_size = embed_size 14 | self.fusion_type = fusion_type 15 | self.user_embedding = nn.Embedding( 16 | dataset_args.num_user + 1, 17 | self.embed_size, 18 | padding_idx=dataset_args.padding_user_id 19 | ) 20 | self.poi_embedding = nn.Embedding( 21 | dataset_args.num_poi + 1, 22 | self.embed_size, 23 | padding_idx=dataset_args.padding_poi_id 24 | ) 25 | self.category_embedding = nn.Embedding( 26 | dataset_args.num_category + 1, 27 | self.embed_size, 28 | padding_idx=dataset_args.padding_poi_category 29 | ) 30 | self.dayofweek_embedding = nn.Embedding(8, self.embed_size, padding_idx=dataset_args.padding_weekday_id) 31 | self.hourofday_embedding = nn.Embedding(25, self.embed_size, padding_idx=dataset_args.padding_hour_id) 32 | if self.fusion_type == 'concat': 33 | self.output_embed_size = 5 * self.embed_size 34 | elif self.fusion_type == 'add': 35 | self.output_embed_size = embed_size 36 | else: 37 | raise ValueError(f"Get wrong fusion type {self.fusion_type}") 38 | 39 | def forward(self, data): 40 | embedding_list = [ 41 | self.user_embedding(data[..., 0].long()), 42 | self.poi_embedding(data[..., 1].long()), 43 | self.category_embedding(data[..., 2].long()), 44 | self.dayofweek_embedding(data[..., 6].long()), 45 | self.hourofday_embedding(data[..., 7].long()) 46 | ] 47 | if self.fusion_type == 'concat': 48 | self.output_embed_size = len(embedding_list) * self.embed_size 49 | return torch.cat(embedding_list, -1) 50 | elif self.fusion_type == 'add': 51 | return torch.squeeze(sum(embedding_list)) 52 | else: 53 | raise ValueError(f"Get wrong fusion type {self.fusion_type}") 54 | 55 | 56 | class EdgeEmbedding(torch.nn.Module): 57 | def __init__(self, embed_size, fusion_type, num_edge_type): 58 | super(EdgeEmbedding, self).__init__() 59 | self.embed_size = embed_size 60 | self.fusion_type = fusion_type 61 | self.edge_type_embedding = nn.Embedding(num_edge_type, self.embed_size) 62 | self.output_embed_size = self.embed_size 63 | 64 | def forward(self, data): 65 | embedding_list = [self.edge_type_embedding(data.long())] 66 | 67 | if self.fusion_type == 'concat': 68 | self.output_embed_size = len(embedding_list) * self.embed_size 69 | return torch.cat(embedding_list, -1) 70 | elif self.fusion_type == 'add': 71 | return sum(embedding_list) 72 | else: 73 | raise ValueError(f"Get wrong fusion type {self.fusion_type}") 74 | -------------------------------------------------------------------------------- /layer/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List, Optional, NamedTuple 3 | from scipy.sparse import coo_matrix 4 | import torch 5 | from torch import Tensor 6 | from torch_sparse import SparseTensor 7 | from torch_scatter import scatter_mean, scatter_max, scatter_add, scatter_min 8 | from utils import haversine 9 | import logging 10 | 11 | 12 | class NeighborSampler(torch.utils.data.DataLoader): 13 | """ 14 | Hypergraph sampler with two level hypergraph for next-poi task. 15 | 16 | Args: 17 | x: feature matrix. 18 | edge_index: two tensor-composited list, the first one is the edge_index of traj2traj(index 0), the second is the 19 | edge_index of ci2traj(index 1). 20 | edge_attr: traj2traj jaccard similarity, source hyperedge size and target hyperedge size. 21 | edge_t: actual time of each checkin event, so traj2traj(index 0) doesnt contain this edge_t. 22 | edge_delta_t: relative time within trajectory, traj2traj(index 0) doesnt contain this value. 23 | edge_type: intra-user(0) or inter-user(1) indicator, but ci2traj(index 1) doesnt contain this value. 24 | sizes: the last element is for ci2traj, other elements is for multi-hop traj2traj. e.g. sizes=[10, 20, 30], 25 | [10,20] is for traj2traj 2-hop sampling, [30] is for ci2traj. 26 | sample_idx: sample id, torch.long. 27 | node_idx: query trajectory id, torch.long. 28 | label: task label for loss computation, tensor with 4 columns (poi_id, poi_cat_id, poi_lat, poi_lon). 29 | edge_delta_s: relative distance within trajectory, traj2traj(index 0) doesnt contain this value 30 | max_time: target time of every sample, last checkin time before candidate checkin. 31 | num_nodes: max trajectory index. The number of nodes in the graph. 32 | (default: :obj:`None`) 33 | intra_jaccard_threshod: filter out intra-user traj2traj, when the jaccard similarity is less than this value. 34 | inter_jaccard_threshod: filter out inter-user traj2traj, when the jaccard similarity is less than this value. 35 | transform: A function/transform that takes in an a sampled mini-batch and returns a transformed version. 36 | **kwargs: Additional arguments of 37 | :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, 38 | :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | x: List[Tensor], 44 | edge_index: List[Tensor], 45 | edge_attr: List[Tensor], 46 | edge_t: List[Tensor], 47 | edge_delta_t: List[Tensor], 48 | edge_type: List[Tensor], 49 | sizes: List[int], 50 | sample_idx: Tensor, 51 | node_idx: Tensor, 52 | label: Tensor, 53 | edge_delta_s: List[Tensor] = None, 54 | max_time: Optional[Tensor] = None, 55 | num_nodes: Optional[int] = None, 56 | intra_jaccard_threshold: float = 0.0, 57 | inter_jaccard_threshold: float = 0.0, 58 | **kwargs 59 | ): 60 | # raw feature 61 | traj_x = x[0] 62 | ci_x = x[1] 63 | 64 | # traj2traj related 65 | traj2traj_edge_attr = edge_attr[0] 66 | traj2traj_edge_index = edge_index[0] 67 | traj2traj_edge_type = edge_type[0] 68 | traj2traj_edge_delta_t = edge_delta_t[0] 69 | traj2traj_edge_delta_s = edge_delta_s[0] 70 | 71 | # ci2traj related 72 | ci2traj_edge_index = edge_index[1] 73 | ci2traj_edge_t = edge_t[1] 74 | ci2traj_edge_delta_t = edge_delta_t[1] 75 | ci2traj_edge_delta_s = edge_delta_s[1] 76 | 77 | # to cpu 78 | traj_x = traj_x.to('cpu') 79 | ci_x = ci_x.to('cpu') 80 | traj2traj_edge_attr = traj2traj_edge_attr.to('cpu') 81 | traj2traj_edge_index = traj2traj_edge_index.to('cpu') 82 | traj2traj_edge_type = traj2traj_edge_type.to('cpu') 83 | traj2traj_edge_delta_t = traj2traj_edge_delta_t.to('cpu') 84 | traj2traj_edge_delta_s = traj2traj_edge_delta_s.to('cpu') 85 | ci2traj_edge_index = ci2traj_edge_index.to('cpu') 86 | ci2traj_edge_t = ci2traj_edge_t.to('cpu') 87 | ci2traj_edge_delta_t = ci2traj_edge_delta_t.to('cpu') 88 | ci2traj_edge_delta_s = ci2traj_edge_delta_s.to('cpu') 89 | 90 | if "collate_fn" in kwargs: 91 | del kwargs["collate_fn"] 92 | 93 | self.traj_x = traj_x 94 | self.ci_x = ci_x 95 | self.max_traj_size = self.traj_x[:, 0].max().item() 96 | self.x = torch.cat([ci_x, traj_x], dim=0) 97 | self.ci_offset = ci_x.shape[0] 98 | self.traj2traj_edge_attr = traj2traj_edge_attr 99 | self.traj2traj_edge_type = traj2traj_edge_type 100 | self.traj2traj_edge_delta_t = traj2traj_edge_delta_t 101 | self.traj2traj_edge_delta_s = traj2traj_edge_delta_s 102 | self.ci2traj_edge_t = ci2traj_edge_t 103 | self.ci2traj_edge_delta_t = ci2traj_edge_delta_t 104 | self.ci2traj_edge_delta_s = ci2traj_edge_delta_s 105 | 106 | self.y = label # load from data 107 | self.node_idx = node_idx # target trajectory index, used as query 108 | self.max_time = max_time 109 | self.sizes = sizes 110 | self.he2he_jaccard = None 111 | self.intra_jaccard_threshold = intra_jaccard_threshold 112 | self.inter_jaccard_threshold = inter_jaccard_threshold 113 | 114 | # Obtain a *transposed* SparseTensor instance. 115 | if int(node_idx.max()) > traj2traj_edge_index.max(): 116 | raise ValueError('Query node index is not in graph.') 117 | if num_nodes is None: 118 | num_nodes = max(int(traj2traj_edge_index.max()), int(ci2traj_edge_index.max())) + 1 119 | 120 | self.traj2traj_adj_t = SparseTensor( 121 | row=traj2traj_edge_index[0], 122 | col=traj2traj_edge_index[1], 123 | value=torch.arange(traj2traj_edge_index.size(1)), 124 | sparse_sizes=(num_nodes, num_nodes) 125 | ).t() 126 | self.ci2traj_adj_t = SparseTensor( 127 | row=ci2traj_edge_index[0], 128 | col=ci2traj_edge_index[1], 129 | value=torch.arange(ci2traj_edge_index.size(1)), 130 | sparse_sizes=(num_nodes, num_nodes) 131 | ).t() 132 | 133 | self.traj2traj_adj_t.storage.rowptr() 134 | self.ci2traj_adj_t.storage.rowptr() 135 | 136 | super(NeighborSampler, self).__init__(sample_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs) 137 | 138 | def sample(self, batch): 139 | """ 140 | adjs: traj2traj multi-hop + ci2traj one-hop adj_t data 141 | 142 | :param batch: 143 | :return: 144 | """ 145 | if not isinstance(batch, Tensor): 146 | batch = torch.tensor(batch) 147 | 148 | batch_size: int = len(batch) 149 | 150 | adjs = [] 151 | sample_idx = batch 152 | n_id = self.node_idx[sample_idx] 153 | max_time = self.max_time[sample_idx] 154 | 155 | # Sample traj2traj multi-hop dynamic relation 156 | for i, size in enumerate(self.sizes): 157 | # n_id is the original node_id, the row and col idx in adj_t is mapped to 0~(len(n_id)-1) 158 | if i == len(self.sizes) - 1: 159 | # Sample ci2traj one-hop checkin relation 160 | adj_t, n_id = self.ci2traj_adj_t.sample_adj(n_id, size, replace=False) 161 | row, col, e_id = adj_t.coo() 162 | edge_attr = None 163 | edge_t = self.ci2traj_edge_t[e_id] 164 | edge_type = None 165 | edge_delta_t = self.ci2traj_edge_delta_t[e_id] 166 | edge_delta_s = self.ci2traj_edge_delta_s[e_id] 167 | else: 168 | # Sample traj2traj multi-hop relation 169 | adj_t, n_id = self.traj2traj_adj_t.sample_adj(n_id, size, replace=False) 170 | row, col, e_id = adj_t.coo() 171 | edge_attr = self.traj2traj_edge_attr[e_id] 172 | edge_t = None 173 | edge_type = self.traj2traj_edge_type[e_id] 174 | edge_delta_t = self.traj2traj_edge_delta_t[e_id] 175 | edge_delta_s = self.traj2traj_edge_delta_s[e_id] 176 | 177 | size = adj_t.sparse_sizes()[::-1] 178 | 179 | if adj_t.nnz(): 180 | assert size[0] >= col.max() + 1, '[NeighborSampler] adj_t source max index exceed sparse_sizes[1]' 181 | else: 182 | # empty subgraph 183 | adj_t, edge_attr, edge_t, edge_type, edge_delta_t, edge_delta_s = None, None, None, None, None, None 184 | adjs.append((adj_t, edge_attr, edge_t, edge_type, edge_delta_t, edge_delta_s, e_id, size)) 185 | 186 | # Mask ci2traj for target traj: filter only ci2traj edge beyond target time 187 | target_mask = row < batch_size 188 | edge_max_time = max_time[row[target_mask]] 189 | length = torch.sum(target_mask) 190 | time_mask = edge_t[target_mask] <= edge_max_time 191 | target_mask[:length] = time_mask 192 | 193 | if row[target_mask].size(0) == 0: 194 | raise ValueError( 195 | f'[NeighborSampler] All trajectories have no checkin before target time!!' 196 | ) 197 | adj_t = SparseTensor( 198 | row=row[target_mask], 199 | col=col[target_mask], 200 | sparse_sizes=(batch_size, adj_t.sparse_sizes()[1]) 201 | ) 202 | edge_t = edge_t[target_mask] 203 | edge_type = None 204 | edge_delta_t = edge_delta_t[target_mask] 205 | edge_delta_s = edge_delta_s[target_mask] 206 | e_id = e_id[target_mask] 207 | adjs.append((adj_t, edge_attr, edge_t, edge_type, edge_delta_t, edge_delta_s, e_id, size)) 208 | 209 | # Filter traj2traj with leakage 210 | target_mask[length:] = True 211 | he_poi = self.ci_x[col[target_mask]][:, 1] 212 | im = coo_matrix(( 213 | np.ones(row[target_mask].shape[0]), 214 | (he_poi.numpy().astype(np.long), row[target_mask].numpy()) 215 | )).tocsr() 216 | self.he2he_jaccard = im.T * im 217 | self.he2he_jaccard = self.he2he_jaccard.tocoo() 218 | 219 | # Calculate jaccard similarity of traj2traj 220 | filtered_traj_size = self.he2he_jaccard.diagonal() 221 | source_size = filtered_traj_size[self.he2he_jaccard.col] 222 | target_size = filtered_traj_size[self.he2he_jaccard.row] 223 | self.he2he_jaccard.data = self.he2he_jaccard.data / (source_size + target_size - self.he2he_jaccard.data) 224 | 225 | # Only considering the traj2traj data 226 | for i, adj in enumerate(adjs[:-2]): 227 | if not i: 228 | adjs[i] = self.filter_traj2traj_with_leakage(adj, traj_size=filtered_traj_size, mode=1) 229 | else: 230 | adjs[i] = self.filter_traj2traj_with_leakage(adj, traj_size=None, mode=2) 231 | 232 | # Trajectory without checkin neighbors is not allowed!!! 233 | if adj_t.storage.row().unique().shape[0] != batch_size: 234 | diff_node = list(set(range(batch_size)) - set(adj_t.storage.row().unique().tolist())) 235 | raise ValueError( 236 | f'[NeighborSampler] Trajectory without checkin neighbors after filtering by max_time is not allowed!!\n' 237 | f'Those samples are sample_idx:{sample_idx[diff_node]},\n' 238 | f'and the corresponding query trajectories are: {n_id[diff_node]},\n' 239 | f'the original query trajectories are: {n_id[diff_node] - self.ci_offset}.' 240 | ) 241 | 242 | adjs = adjs[0] if len(adjs) == 1 else adjs[::-1] 243 | out = (sample_idx, n_id, adjs) 244 | out = self.convert_batch(*out) 245 | return out 246 | 247 | def filter_traj2traj_with_leakage(self, adj, traj_size, mode=1): 248 | """The original traj2traj topology is in adj_t, we set the value to all ones, and 249 | then we substitute it with traj2traj_jaccard, and keep the data within [0, 1]. 250 | 251 | :param adj: tuple data with traj2traj information 252 | :param traj_size: calculated out of this function, only take into effect when mode=1 253 | :param mode: 1, use self.he2he_jaccard to filter, 2, use edge_attr[:, 2] to filter 254 | :return: 255 | """ 256 | adj_t, edge_attr, edge_t, edge_type, edge_delta_t, edge_delta_s, e_id, size = adj 257 | 258 | if adj_t is None: 259 | return adj 260 | 261 | row, col, value = adj_t.coo() 262 | if mode == 1: 263 | # Add epsilon in case we delete the full-overlap traj2traj 264 | epsilon = 1e-6 265 | he2he = coo_matrix(( 266 | np.ones(adj_t.nnz()) + epsilon, 267 | (row.numpy(), col.numpy()) 268 | )) 269 | size_i = he2he.shape[0] 270 | size_j = he2he.shape[1] 271 | he2he = he2he - self.he2he_jaccard.tocsc()[:size_i, :size_j].tocoo() 272 | he2he = he2he.tocoo() 273 | 274 | # Valid within [0, 1] 275 | valid_mask = he2he.data >= 0 276 | he2he = SparseTensor( 277 | row=torch.tensor(he2he.row[valid_mask], dtype=torch.long), 278 | col=torch.tensor(he2he.col[valid_mask], dtype=torch.long), 279 | value=torch.tensor(he2he.data[valid_mask]) 280 | ) 281 | 282 | if adj_t.nnz() != he2he.nnz(): 283 | raise ValueError(f"[NeighborSampler] he2he filtered size not equal.") 284 | 285 | # Keep intra-user and overlaped inter-user traj2traj 286 | inter_threshold_mask = he2he.storage.value() <= (1 - self.inter_jaccard_threshold + epsilon) 287 | intra_threshold_mask = he2he.storage.value() <= (1 - self.intra_jaccard_threshold + epsilon) 288 | inter_user_mask = (edge_type == 1) & inter_threshold_mask 289 | intra_user_mask = (edge_type == 0) & intra_threshold_mask 290 | mask = intra_user_mask | inter_user_mask 291 | keep_num = torch.sum(mask).item() 292 | if keep_num == 0: 293 | adj = (None, None, None, None, None, None, e_id, size) 294 | return adj 295 | else: 296 | # logging.info( 297 | # f"[NeighborSampler] Remaining {keep_num} traj2traj[0] edges, including " 298 | # f"{intra_user_mask.sum().item()} intra-user traj2traj edges and {inter_user_mask.sum().item()} " 299 | # f"inter-user traj2traj edges." 300 | # ) 301 | # save jaccard metric to value 302 | adj_t = SparseTensor( 303 | row=row[mask], 304 | col=col[mask], 305 | value=he2he.storage.value()[mask], 306 | sparse_sizes=adj_t.sparse_sizes() 307 | ) 308 | edge_t = edge_t[mask] if edge_t is not None else None 309 | 310 | # recover similarity metric, and calculate edge_attr 311 | row, col, value = adj_t.coo() 312 | edge_attr = (1 + epsilon) - value 313 | source_traj_size = torch.tensor(traj_size[row]) / self.max_traj_size 314 | target_traj_size = torch.tensor(traj_size[col]) / self.max_traj_size 315 | edge_attr = torch.stack([source_traj_size, target_traj_size, edge_attr], dim=1) 316 | else: 317 | inter_threshold_mask = edge_attr[:, 2] >= self.inter_jaccard_threshold 318 | intra_threshold_mask = edge_attr[:, 2] >= self.intra_jaccard_threshold 319 | inter_user_mask = (edge_type == 1) & inter_threshold_mask 320 | intra_user_mask = (edge_type == 0) & intra_threshold_mask 321 | mask = intra_user_mask | inter_user_mask 322 | keep_num = torch.sum(mask).item() 323 | if keep_num == 0: 324 | adj = (None, None, None, None, None, None, e_id, size) 325 | return adj 326 | else: 327 | # logging.info( 328 | # f"[NeighborSampler] Remaining {keep_num} traj2traj[>0] edges, including " 329 | # f"{intra_user_mask.sum().item()} intra-user traj2traj edges and {inter_user_mask.sum().item()} " 330 | # f"inter-user traj2traj edges." 331 | # ) 332 | edge_attr = edge_attr[mask] 333 | adj_t = SparseTensor( 334 | row=row[mask], 335 | col=col[mask], 336 | value=value[mask], 337 | sparse_sizes=adj_t.sparse_sizes() 338 | ) 339 | 340 | edge_type = edge_type[mask] 341 | edge_delta_t = edge_delta_t[mask] 342 | edge_delta_s = edge_delta_s[mask] 343 | e_id = e_id[mask] 344 | adj = (adj_t, edge_attr, edge_t, edge_type, edge_delta_t, edge_delta_s, e_id, size) 345 | return adj 346 | 347 | def convert_batch(self, sample_idx, n_id, adjs): 348 | """ 349 | Add target label for batch data, and update target trajectory mean_time, mean_lon, mean_lat, last_lon, last_lat. 350 | 351 | :param sample_idx: sample_idx from label table; 352 | :param n_id: original index of nodes in hypergraph; 353 | :param adjs: Adj data of multi-hop neighbors; 354 | :return: Batch data 355 | """ 356 | adjs_t, edge_attrs, edge_ts, edge_types, edge_delta_ts, edge_delta_ss = [], [], [], [], [], [] 357 | y = self.y[sample_idx] 358 | 359 | x_target = None 360 | 361 | # checkin_feature 'user_id', 'poi_id', 'poi_cat', 'time', 'poi_lon', 'poi_lat' 362 | # trajectory_feature 'size', 'mean_lon', 'mean_lat', 'mean_time', 'start_time', 'end_time' 363 | i = 0 364 | for adj_t, edge_attr, edge_t, edge_type, edge_delta_t, edge_delta_s, _, _ in adjs: 365 | 366 | if adj_t is None: 367 | pass 368 | 369 | else: 370 | col, row, _ = adj_t.coo() 371 | if not i: 372 | # Update filtered_ci2traj edge information and generate x feature for target trajectory (x_target) 373 | source_checkin_lon_lat = self.x[n_id[row]][:, 4:6] # [#edge, 2] 374 | traj_min_time, _ = scatter_min(edge_t, col, dim=-1) # [N, ] 375 | traj_max_time, e_id = scatter_max(edge_t, col, dim=-1) 376 | traj_mean_time = scatter_mean(edge_t, col, dim=-1) # [N, ] 377 | traj_last_lon_lat = source_checkin_lon_lat[e_id] # [N, 2] 378 | traj_mean_lon_lat = scatter_mean(source_checkin_lon_lat, col, dim=0) # [N, 2] 379 | traj_size = scatter_add(torch.ones_like(edge_t), col, dim=-1) # [N, ] 380 | 381 | edge_delta_t = traj_max_time[col] - edge_t 382 | edge_delta_s = torch.cat([traj_last_lon_lat[col], source_checkin_lon_lat], dim=-1) 383 | edge_delta_s = haversine( 384 | edge_delta_s[:, 0], 385 | edge_delta_s[:, 1], 386 | edge_delta_s[:, 2], 387 | edge_delta_s[:, 3] 388 | ) 389 | x_target = torch.cat([ 390 | traj_size.unsqueeze(1), 391 | traj_mean_lon_lat, 392 | traj_mean_time.unsqueeze(1), 393 | traj_min_time.unsqueeze(1), 394 | traj_max_time.unsqueeze(1)], 395 | dim=-1 396 | ) 397 | elif i == len(adjs) - 1: 398 | # Update traj2traj edge information for one-hop neighbor -> target 399 | edge_delta_t = x_target[col][:, 3] - self.x[n_id[row]][:, 3] 400 | edge_delta_s = torch.cat([self.x[n_id[row]][:, 1:3], x_target[col][:, 1:3]], dim=-1) 401 | edge_delta_s = haversine( 402 | edge_delta_s[:, 0], 403 | edge_delta_s[:, 1], 404 | edge_delta_s[:, 2], 405 | edge_delta_s[:, 3] 406 | ) 407 | else: 408 | pass 409 | 410 | adjs_t.append(adj_t) 411 | edge_ts.append(edge_t) 412 | edge_attrs.append(edge_attr) 413 | edge_types.append(edge_type) 414 | edge_delta_ts.append(edge_delta_t) 415 | edge_delta_ss.append(edge_delta_s) 416 | i += 1 417 | 418 | result = Batch( 419 | sample_idx=sample_idx, 420 | x=self.x[n_id], 421 | x_target=x_target, 422 | y=y, 423 | adjs_t=adjs_t, 424 | edge_attrs=edge_attrs, 425 | edge_ts=edge_ts, 426 | edge_types=edge_types, 427 | edge_delta_ts=edge_delta_ts, 428 | edge_delta_ss=edge_delta_ss 429 | ) 430 | return result 431 | 432 | def __repr__(self): 433 | return '{}(sizes={})'.format(self.__class__.__name__, self.sizes) 434 | 435 | 436 | class Batch(NamedTuple): 437 | sample_idx: Tensor 438 | x: Tensor 439 | x_target: Tensor 440 | y: Tensor 441 | adjs_t: List[SparseTensor] 442 | edge_attrs: List[Tensor] 443 | edge_ts: List[Tensor] 444 | edge_types: List[Tensor] 445 | edge_delta_ts: List[Tensor] 446 | edge_delta_ss: List[Tensor] 447 | 448 | def to(self, *args, **kwargs): 449 | return Batch( 450 | sample_idx=self.sample_idx.to(*args, **kwargs), 451 | x=self.x.to(*args, **kwargs), 452 | x_target=self.x_target.to(*args, **kwargs), 453 | y=self.y.to(*args, **kwargs), 454 | adjs_t=[adj_t.to(*args, **kwargs) if adj_t is not None else None for adj_t in self.adjs_t], 455 | edge_attrs=[ 456 | edge_attr.to(*args, **kwargs) 457 | if edge_attr is not None 458 | else None 459 | for edge_attr in self.edge_attrs 460 | ], 461 | edge_ts=[ 462 | edge_t.to(*args, **kwargs) 463 | if edge_t is not None 464 | else None 465 | for edge_t in self.edge_ts 466 | ], 467 | edge_types=[ 468 | edge_type.to(*args, **kwargs) 469 | if edge_type is not None 470 | else None 471 | for edge_type in self.edge_types 472 | ], 473 | edge_delta_ts=[ 474 | edge_delta_t.to(*args, **kwargs) 475 | if edge_delta_t is not None 476 | else None 477 | for edge_delta_t in self.edge_delta_ts 478 | ], 479 | edge_delta_ss=[ 480 | edge_delta_s.to(*args, **kwargs) 481 | if edge_delta_s is not None 482 | else None 483 | for edge_delta_s in self.edge_delta_ss 484 | ] 485 | ) 486 | -------------------------------------------------------------------------------- /layer/st_encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | from utils import cal_slot_distance_batch 6 | 7 | 8 | class PositionEncoder(nn.Module): 9 | def __init__(self, d_model, device, dropout=0.1, max_len=500): 10 | super(PositionEncoder, self).__init__() 11 | self.dropout = torch.nn.Dropout(p=dropout) 12 | 13 | pe = torch.zeros(max_len, d_model, device=device) 14 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 15 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (- math.log(10000.0) / d_model)) 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | self.pe = pe.unsqueeze(0).transpose(0, 1) 19 | 20 | def forward(self, x): 21 | x = x + self.pe[:x.size(1), :].transpose(0, 1) 22 | return self.dropout(x) 23 | 24 | 25 | class TimeEncoder(nn.Module): 26 | r""" 27 | This is a trainable encoder to map continuous time value into a low-dimension time vector. 28 | Ref: https://github.com/StatsDLMathsRecomSys/Inductive-representation-learning-on-temporal-graphs/blob/master/module.py 29 | 30 | The input of ts should be like [E, 1] with all time interval as values. 31 | """ 32 | 33 | def __init__(self, args, embedding_dim): 34 | super(TimeEncoder, self).__init__() 35 | self.time_dim = embedding_dim 36 | self.expand_dim = self.time_dim 37 | self.factor = args.phase_factor 38 | self.use_linear_trans = args.use_linear_trans 39 | 40 | self.basis_freq = nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, self.time_dim))).float()) 41 | self.phase = nn.Parameter(torch.zeros(self.time_dim).float()) 42 | if self.use_linear_trans: 43 | self.dense = nn.Linear(self.time_dim, self.expand_dim, bias=False) 44 | nn.init.xavier_normal_(self.dense.weight) 45 | 46 | def forward(self, ts): 47 | if ts.dim() == 1: 48 | dim = 1 49 | edge_len = ts.size().numel() 50 | else: 51 | edge_len, dim = ts.size() 52 | ts = ts.view(edge_len, dim) 53 | map_ts = ts * self.basis_freq.view(1, -1) 54 | map_ts += self.phase.view(1, -1) 55 | harmonic = torch.cos(map_ts) 56 | if self.use_linear_trans: 57 | harmonic = harmonic.type(self.dense.weight.dtype) 58 | harmonic = self.dense(harmonic) 59 | return harmonic 60 | 61 | 62 | class DistanceEncoderHSTLSTM(nn.Module): 63 | r""" 64 | This is a trainable encoder to map continuous distance value into a low-dimension vector. 65 | Ref: HST-LSTM 66 | 67 | First determine the position of diffrent slot bins, and do linear interpolation within different slots 68 | with the embedding of the slots as a trainable parameters. 69 | """ 70 | 71 | def __init__(self, args, embedding_dim, spatial_slots): 72 | super(DistanceEncoderHSTLSTM, self).__init__() 73 | self.dist_dim = embedding_dim 74 | self.spatial_slots = spatial_slots 75 | self.embed_q = nn.Embedding(len(spatial_slots), self.dist_dim) 76 | self.device = args.gpu 77 | 78 | def place_parameters(self, ld, hd, l, h): 79 | if self.device == 'cpu': 80 | ld = torch.from_numpy(np.array(ld)).type(torch.FloatTensor) 81 | hd = torch.from_numpy(np.array(hd)).type(torch.FloatTensor) 82 | l = torch.from_numpy(np.array(l)).type(torch.LongTensor) 83 | h = torch.from_numpy(np.array(h)).type(torch.LongTensor) 84 | else: 85 | ld = torch.from_numpy(np.array(ld, dtype=np.float16)).type(torch.FloatTensor).to(self.device) 86 | hd = torch.from_numpy(np.array(hd, dtype=np.float16)).type(torch.FloatTensor).to(self.device) 87 | l = torch.from_numpy(np.array(l, dtype=np.float16)).type(torch.LongTensor).to(self.device) 88 | h = torch.from_numpy(np.array(h, dtype=np.float16)).type(torch.LongTensor).to(self.device) 89 | return ld, hd, l, h 90 | 91 | def cal_inter(self, ld, hd, l, h, embed): 92 | """ 93 | Calculate a linear interpolation. 94 | :param ld: Distances to lower bound, shape (batch_size, step) 95 | :param hd: Distances to higher bound, shape (batch_size, step) 96 | :param l: Lower bound indexes, shape (batch_size, step) 97 | :param h: Higher bound indexes, shape (batch_size, step) 98 | """ 99 | # Fetch the embed of higher and lower bound. 100 | # Each result shape (batch_size, step, input_size) 101 | l_embed = embed(l) 102 | h_embed = embed(h) 103 | return torch.stack([hd], -1) * l_embed + torch.stack([ld], -1) * h_embed 104 | 105 | def forward(self, dist): 106 | self.spatial_slots = sorted(self.spatial_slots) 107 | d_ld, d_hd, d_l, d_h = self.place_parameters(*cal_slot_distance_batch(dist, self.spatial_slots)) 108 | batch_q = self.cal_inter(d_ld, d_hd, d_l, d_h, self.embed_q) 109 | return batch_q 110 | 111 | 112 | class DistanceEncoderSTAN(nn.Module): 113 | r""" 114 | This is a trainable encoder to map continuous distance value into a low-dimension vector. 115 | Ref: STAN 116 | 117 | Interpolating between min and max distance value, only need to initial minimum distance embedding and maximum 118 | distance embedding. 119 | """ 120 | 121 | def __init__(self, args, embedding_dim, spatial_slots): 122 | super(DistanceEncoderSTAN, self).__init__() 123 | self.dist_dim = embedding_dim 124 | self.min_d, self.max_d_ch2tj, self.max_d_tj2tj = spatial_slots 125 | self.embed_min = nn.Embedding(1, self.dist_dim) 126 | self.embed_max = nn.Embedding(1, self.dist_dim) 127 | self.embed_max_traj = nn.Embedding(1, self.dist_dim) 128 | self.quantile = args.quantile 129 | 130 | def forward(self, dist, dist_type): 131 | if dist_type == 'ch2tj': 132 | emb_low, emb_high = self.embed_min.weight, self.embed_max.weight 133 | max_d = self.max_d_ch2tj 134 | else: 135 | emb_low, emb_high = self.embed_min.weight, self.embed_max_traj.weight 136 | max_d = self.max_d_tj2tj 137 | 138 | # if you want to clip in case of outlier maxmimum exist, please uncomment the line below 139 | # max_d = torch.quantile(dist, self.quantile) 140 | dist = dist.clip(0, max_d) 141 | vsl = (dist - self.min_d).unsqueeze(-1).expand(-1, self.dist_dim) 142 | vsu = (max_d - dist).unsqueeze(-1).expand(-1, self.dist_dim) 143 | 144 | space_interval = (emb_low * vsu + emb_high * vsl) / (max_d - self.min_d) 145 | return space_interval 146 | 147 | 148 | class DistanceEncoderSimple(nn.Module): 149 | r""" 150 | This is a trainable encoder to map continuous distance value into a low-dimension vector. 151 | 152 | Only need to initial just on embedding, and directly do scalar*vector multiply. 153 | """ 154 | def __init__(self, args, embedding_dim, spatial_slots): 155 | super(DistanceEncoderSimple, self).__init__() 156 | self.args = args 157 | self.dist_dim = embedding_dim 158 | self.min_d, self.max_d, self.max_d_traj = spatial_slots 159 | self.embed_unit = nn.Embedding(1, self.dist_dim) 160 | 161 | def forward(self, dist): 162 | dist = dist.unsqueeze(-1).expand(-1, self.dist_dim) 163 | return dist * self.embed_unit.weight 164 | -------------------------------------------------------------------------------- /metric/__init__.py: -------------------------------------------------------------------------------- 1 | from metric.rank_metric import ( 2 | recall, 3 | ndcg, 4 | map_k, 5 | mrr 6 | ) 7 | 8 | __all__ = [ 9 | "recall", 10 | "ndcg", 11 | "map_k", 12 | "mrr" 13 | ] 14 | -------------------------------------------------------------------------------- /metric/rank_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def recall(lab, prd, k): 5 | return torch.sum(torch.sum(lab == prd[:, :k], dim=1)) / lab.shape[0] 6 | 7 | 8 | def ndcg(lab, prd, k): 9 | exist_pos = torch.nonzero(prd[:, :k] == lab, as_tuple=False)[:, 1] + 1 10 | dcg = 1 / torch.log2(exist_pos.float() + 1) 11 | return torch.sum(dcg) / lab.shape[0] 12 | 13 | 14 | def map_k(lab, prd, k): 15 | exist_pos = torch.nonzero(prd[:, :k] == lab, as_tuple=False)[:, 1] + 1 16 | map_tmp = 1 / exist_pos 17 | return torch.sum(map_tmp) / lab.shape[0] 18 | 19 | 20 | def mrr(lab, prd): 21 | exist_pos = torch.nonzero(prd == lab, as_tuple=False)[:, 1] + 1 22 | mrr_tmp = 1 / exist_pos 23 | return torch.sum(mrr_tmp) / lab.shape[0] 24 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.sthgcn import STHGCN 2 | from model.seq_transformer import SequentialTransformer 3 | 4 | __all__ = [ 5 | "STHGCN", 6 | "SequentialTransformer" 7 | ] 8 | -------------------------------------------------------------------------------- /model/seq_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from layer import CheckinEmbedding, PositionEncoder 4 | 5 | 6 | class SequentialTransformer(nn.Module): 7 | def __init__(self, cfg): 8 | super(SequentialTransformer, self).__init__() 9 | self.dataset_args = cfg.dataset_args 10 | self.device = cfg.run_args.device 11 | self.do_positional_encoding = cfg.seq_transformer_args.do_positional_encoding 12 | self.embed_fusion_type = cfg.model_args.embed_fusion_type 13 | self.checkin_embedding_layer = CheckinEmbedding( 14 | embed_size=cfg.model_args.embed_size, 15 | fusion_type=self.embed_fusion_type, 16 | dataset_args=self.dataset_args 17 | ) 18 | self.checkin_embed_size = self.checkin_embedding_layer.output_embed_size 19 | 20 | encoder_layers = nn.TransformerEncoderLayer( 21 | d_model=self.checkin_embed_size, 22 | nhead=cfg.seq_transformer_args.header_num, 23 | dim_feedforward=cfg.seq_transformer_args.hidden_size, 24 | dropout=cfg.seq_transformer_args.dropout 25 | ) 26 | self.transformer_encoder = nn.TransformerEncoder( 27 | encoder_layers, 28 | num_layers=cfg.seq_transformer_args.encoder_layers_num 29 | ) 30 | self.transformer_positional_encoding = PositionEncoder( 31 | self.checkin_embed_size, 32 | self.device, 33 | cfg.seq_transformer_args.dropout 34 | ) 35 | self.sequence_length = cfg.seq_transformer_args.sequence_length 36 | 37 | self.linear = torch.nn.Linear(self.checkin_embed_size, self.dataset_args.num_poi) 38 | self.loss_func = torch.nn.CrossEntropyLoss() 39 | 40 | def forward(self, data, label=None, mode='train'): 41 | # Generate seq input 42 | split_idx = data['split_index'] 43 | check_in_x = data['x'][split_idx + 1:] 44 | 45 | # mask checkins based on batch: [N, #checkin, d] 46 | edge_index_tmp = data['edge_index'][0][:, (split_idx + 1):].to_dense() 47 | checkin_sequential_feature = torch.unsqueeze(edge_index_tmp, dim=-1) * torch.unsqueeze( 48 | check_in_x, dim=0).repeat(edge_index_tmp.shape[0], 1, 1) 49 | 50 | # remove zero tensor and make sequence: [N, #checkin, d] (embedding layer input) 51 | check_in_sequential_input, check_in_sequential_mask = self.generate_sequential_input( 52 | checkin_sequential_feature, 53 | device=check_in_x.device, 54 | max_length=self.sequence_length 55 | ) 56 | 57 | sequential_feature = self.checkin_embedding_layer(check_in_sequential_input) 58 | if self.do_positional_encoding: 59 | self.transformer_positional_encoding(sequential_feature) 60 | sequential_feature = sequential_feature.transpose(1, 0) 61 | sequential_out = self.transformer_encoder(sequential_feature, src_key_padding_mask=check_in_sequential_mask) 62 | sequential_out = torch.mean(sequential_out, dim=0) 63 | 64 | logits = self.linear(sequential_out) 65 | loss = self.loss_func(logits, label.long()) 66 | return logits, loss 67 | 68 | def generate_sequential_input(self, sequential_feature, device, max_length): 69 | """ 70 | Generate sequential input for sequential model 71 | """ 72 | padding_tensor = torch.tensor([ 73 | self.dataset_args.padding_user_id, 74 | self.dataset_args.padding_poi_id, 75 | self.dataset_args.padding_poi_category, 76 | 0, 77 | 0, 78 | 0, 79 | self.dataset_args.padding_weekday_id, 80 | self.dataset_args.padding_hour_id 81 | ], dtype=torch.float64, device=device) 82 | input_ids = torch.unsqueeze(torch.unsqueeze(padding_tensor, dim=0).repeat(max_length, 1), dim=0).repeat( 83 | sequential_feature.size()[0], 1, 1) 84 | mask = torch.ones(torch.Size([sequential_feature.size()[0], max_length]), dtype=torch.bool, device=device) 85 | nonzero_index = torch.nonzero(sequential_feature) 86 | nonzero_index_tmp = nonzero_index[:, :2].unique(dim=0).cpu().detach().numpy() 87 | sequential_time_feature = [ 88 | (m, n, sequential_feature[m, n, 3].cpu().detach().numpy().tolist()) for m, n in nonzero_index_tmp] 89 | 90 | nonzero_index_tmp = sorted(sequential_time_feature, key=lambda x: (x[0], -x[-1])) 91 | batch_len = 0 92 | batch_idx_tmp = 0 93 | nonzero_index_tmp_copy = [] 94 | for m, n, t in nonzero_index_tmp: 95 | if batch_idx_tmp != m: 96 | batch_idx_tmp += 1 97 | batch_len = 0 98 | if batch_len >= max_length: 99 | continue 100 | nonzero_index_tmp_copy.append((m, n, t)) 101 | batch_len += 1 102 | 103 | batch_len = 0 104 | batch_idx_tmp = 0 105 | nonzero_index_tmp_copy = sorted(nonzero_index_tmp_copy, key=lambda x: (x[0], x[-1])) 106 | for m, n, _ in nonzero_index_tmp_copy: 107 | if batch_idx_tmp != m: 108 | batch_idx_tmp += 1 109 | batch_len = 0 110 | if batch_len >= max_length: 111 | continue 112 | input_ids[batch_idx_tmp, batch_len] = sequential_feature[m, n] 113 | mask[batch_idx_tmp, batch_len] = False 114 | batch_len += 1 115 | return input_ids, mask 116 | -------------------------------------------------------------------------------- /model/sthgcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from layer import ( 4 | CheckinEmbedding, 5 | EdgeEmbedding, 6 | HypergraphTransformer, 7 | TimeEncoder, 8 | DistanceEncoderHSTLSTM, 9 | DistanceEncoderSTAN, 10 | DistanceEncoderSimple 11 | ) 12 | 13 | 14 | class STHGCN(nn.Module): 15 | def __init__(self, cfg): 16 | super(STHGCN, self).__init__() 17 | self.device = cfg.run_args.device 18 | self.batch_size = cfg.run_args.batch_size 19 | self.eval_batch_size = cfg.run_args.eval_batch_size 20 | self.do_traj2traj = cfg.model_args.do_traj2traj 21 | self.distance_encoder_type = cfg.model_args.distance_encoder_type 22 | self.dropout_rate = cfg.model_args.dropout_rate 23 | self.generate_edge_attr = cfg.model_args.generate_edge_attr 24 | self.num_conv_layers = len(cfg.model_args.sizes) 25 | self.num_poi = cfg.dataset_args.num_poi 26 | self.embed_fusion_type = cfg.model_args.embed_fusion_type 27 | self.checkin_embedding_layer = CheckinEmbedding( 28 | embed_size=cfg.model_args.embed_size, 29 | fusion_type=self.embed_fusion_type, 30 | dataset_args=cfg.dataset_args 31 | ) 32 | self.checkin_embed_size = self.checkin_embedding_layer.output_embed_size 33 | self.edge_type_embedding_layer = EdgeEmbedding( 34 | embed_size=self.checkin_embed_size, 35 | fusion_type=self.embed_fusion_type, 36 | num_edge_type=cfg.model_args.num_edge_type 37 | ) 38 | 39 | if cfg.model_args.activation == 'elu': 40 | self.act = nn.ELU() 41 | elif cfg.model_args.activation == 'relu': 42 | self.act = nn.RReLU() 43 | elif cfg.model_args.activation == 'leaky_relu': 44 | self.act = nn.LeakyReLU() 45 | else: 46 | self.act = torch.tanh 47 | 48 | if cfg.conv_args.time_fusion_mode == 'add': 49 | continuous_encoder_dim = self.checkin_embed_size 50 | else: 51 | continuous_encoder_dim = cfg.model_args.st_embed_size 52 | 53 | if self.generate_edge_attr: 54 | # use edge_type to create edge_attr_embed 55 | self.edge_attr_embedding_layer = EdgeEmbedding( 56 | embed_size=self.checkin_embed_size, 57 | fusion_type=self.embed_fusion_type, 58 | num_edge_type=cfg.model_args.num_edge_type 59 | ) 60 | else: 61 | # source_traj_size, target_traj_size, jaccard similarity as the raw features, and do linear transformation 62 | if cfg.conv_args.edge_fusion_mode == 'add': 63 | self.edge_attr_embedding_layer = nn.Linear(3, self.checkin_embed_size) 64 | else: 65 | self.edge_attr_embedding_layer = None 66 | 67 | self.conv_list = nn.ModuleList() 68 | 69 | # conv for ci2traj within which some ci2traj relations have been removed by time to prevent data leakage 70 | self.conv_for_time_filter = HypergraphTransformer( 71 | in_channels=self.checkin_embed_size, 72 | out_channels=self.checkin_embed_size, 73 | attn_heads=cfg.conv_args.num_attention_heads, 74 | residual_beta=cfg.conv_args.residual_beta, 75 | learn_beta=cfg.conv_args.learn_beta, 76 | dropout=cfg.conv_args.conv_dropout_rate, 77 | trans_method=cfg.conv_args.trans_method, 78 | edge_fusion_mode=cfg.conv_args.edge_fusion_mode, 79 | time_fusion_mode=cfg.conv_args.time_fusion_mode, 80 | head_fusion_mode=cfg.conv_args.head_fusion_mode, 81 | residual_fusion_mode=None, 82 | edge_dim=None, 83 | rel_embed_dim=self.checkin_embed_size, 84 | time_embed_dim=continuous_encoder_dim, 85 | dist_embed_dim=continuous_encoder_dim, 86 | negative_slope=cfg.conv_args.negative_slope, 87 | have_query_feature=False 88 | ) 89 | self.norms_for_time_filter = nn.BatchNorm1d(self.checkin_embed_size) 90 | self.dropout_for_time_filter = nn.Dropout(self.dropout_rate) 91 | 92 | if self.do_traj2traj: 93 | for i in range(self.num_conv_layers): 94 | if i == 0: 95 | # ci2traj full 96 | have_query_feature = False 97 | residual_fusion_mode = None 98 | edge_size = None 99 | else: 100 | # traj2traj 101 | have_query_feature = True 102 | residual_fusion_mode = cfg.conv_args.residual_fusion_mode 103 | if self.edge_attr_embedding_layer is None: 104 | edge_size = 3 105 | else: 106 | edge_size = self.checkin_embed_size 107 | 108 | self.conv_list.append( 109 | HypergraphTransformer( 110 | in_channels=self.checkin_embed_size, 111 | out_channels=self.checkin_embed_size, 112 | attn_heads=cfg.conv_args.num_attention_heads, 113 | residual_beta=cfg.conv_args.residual_beta, 114 | learn_beta=cfg.conv_args.learn_beta, 115 | dropout=cfg.conv_args.conv_dropout_rate, 116 | trans_method=cfg.conv_args.trans_method, 117 | edge_fusion_mode=cfg.conv_args.edge_fusion_mode, 118 | time_fusion_mode=cfg.conv_args.time_fusion_mode, 119 | head_fusion_mode=cfg.conv_args.head_fusion_mode, 120 | residual_fusion_mode=residual_fusion_mode, 121 | edge_dim=edge_size, 122 | rel_embed_dim=self.checkin_embed_size, 123 | time_embed_dim=continuous_encoder_dim, 124 | dist_embed_dim=continuous_encoder_dim, 125 | negative_slope=cfg.conv_args.negative_slope, 126 | have_query_feature=have_query_feature 127 | ) 128 | ) 129 | self.norms_list = nn.ModuleList() 130 | for i in range(self.num_conv_layers): 131 | self.norms_list.append(nn.BatchNorm1d(self.checkin_embed_size)) 132 | 133 | self.dropout_list = nn.ModuleList() 134 | for i in range(self.num_conv_layers): 135 | self.dropout_list.append(nn.Dropout(self.dropout_rate)) 136 | 137 | self.continuous_time_encoder = TimeEncoder(cfg.model_args, continuous_encoder_dim) 138 | 139 | if self.distance_encoder_type == 'stan': 140 | self.continuous_distance_encoder = DistanceEncoderSTAN( 141 | cfg.model_args, 142 | continuous_encoder_dim, 143 | cfg.dataset_args.spatial_slots 144 | ) 145 | elif self.distance_encoder_type == 'time': 146 | self.continuous_distance_encoder = TimeEncoder(cfg.model_args, continuous_encoder_dim) 147 | elif self.distance_encoder_type == 'hstlstm': 148 | self.continuous_distance_encoder = DistanceEncoderHSTLSTM( 149 | cfg.model_args, 150 | continuous_encoder_dim, 151 | cfg.dataset_args.spatial_slots 152 | ) 153 | elif self.distance_encoder_type == 'simple': 154 | self.continuous_distance_encoder = DistanceEncoderSimple( 155 | cfg.model_args, 156 | continuous_encoder_dim, 157 | cfg.dataset_args.spatial_slots 158 | ) 159 | else: 160 | raise ValueError(f"Get wrong distance_encoder_type argument: {cfg.model_args.distance_encoder_type}!") 161 | 162 | self.linear = nn.Linear(self.checkin_embed_size, self.num_poi) 163 | self.loss_func = nn.CrossEntropyLoss() 164 | 165 | def forward(self, data, label=None, mode='train'): 166 | input_x = data['x'] # [?, 8] 167 | split_idx = data['split_index'] 168 | 169 | check_in_x = input_x[split_idx+1:] 170 | checkin_feature = self.checkin_embedding_layer(check_in_x) 171 | trajectory_feature = torch.zeros( 172 | split_idx+1, 173 | self.checkin_embed_size, 174 | device=checkin_feature.device 175 | ) 176 | x = torch.cat([trajectory_feature, checkin_feature], dim=0) 177 | 178 | # change this line if you want to modify the time granularity when encode 179 | edge_time_embed = self.continuous_time_encoder(data['delta_ts'][0] / (60 * 60)) 180 | if self.distance_encoder_type == 'stan': 181 | edge_distance_embed = self.continuous_distance_encoder(data['delta_ss'][0], dist_type='ch2tj') 182 | else: 183 | edge_distance_embed = self.continuous_distance_encoder(data['delta_ss'][0]) 184 | 185 | edge_attr_embed, edge_type_embed = None, None 186 | if data['edge_type'][0] is not None: 187 | if self.generate_edge_attr: 188 | edge_attr_embed = self.edge_attr_embedding_layer(data['edge_type'][0]) 189 | edge_type_embed = self.edge_type_embedding_layer(data['edge_type'][0]) 190 | 191 | x_for_time_filter = self.conv_for_time_filter( 192 | x, 193 | edge_index=data['edge_index'][0], 194 | edge_attr_embed=edge_attr_embed, 195 | edge_time_embed=edge_time_embed, 196 | edge_dist_embed=edge_distance_embed, 197 | edge_type_embed=edge_type_embed 198 | ) 199 | x_for_time_filter = self.norms_for_time_filter(x_for_time_filter) 200 | x_for_time_filter = self.act(x_for_time_filter) 201 | x_for_time_filter = self.dropout_for_time_filter(x_for_time_filter) 202 | 203 | if data['edge_index'][-1] is not None and self.do_traj2traj: 204 | # all conv 205 | for idx, (edge_index, edge_attr, delta_ts, delta_dis, edge_type) in enumerate( 206 | zip(data["edge_index"][1:], data["edge_attr"][1:], data["delta_ts"][1:], data["delta_ss"][1:], 207 | data["edge_type"][1:]) 208 | ): 209 | edge_time_embed = self.continuous_time_encoder(delta_ts / (60 * 60)) 210 | if self.distance_encoder_type == 'stan': 211 | edge_distance_embed = self.continuous_distance_encoder(delta_dis, dist_type='tj2tj') 212 | else: 213 | edge_distance_embed = self.continuous_distance_encoder(delta_dis) 214 | 215 | edge_attr_embed, edge_type_embed = None, None 216 | if edge_type is not None: 217 | edge_type_embed = self.edge_type_embedding_layer(edge_type) 218 | if self.generate_edge_attr: 219 | edge_attr_embed = self.edge_attr_embedding_layer(edge_type) 220 | else: 221 | if self.edge_attr_embedding_layer: 222 | edge_attr_embed = self.edge_attr_embedding_layer(edge_attr.to(torch.float32)) 223 | else: 224 | edge_attr_embed = edge_attr.to(torch.float32) 225 | 226 | if idx == len(data['edge_index']) - 2: 227 | if mode in ('test', 'validate'): 228 | batch_size = self.eval_batch_size 229 | else: 230 | batch_size = self.batch_size 231 | x_target = x_for_time_filter[:batch_size] 232 | else: 233 | x_target = x[:edge_index.sparse_sizes()[0]] 234 | 235 | x = self.conv_list[idx]( 236 | (x, x_target), 237 | edge_index=edge_index, 238 | edge_attr_embed=edge_attr_embed, 239 | edge_time_embed=edge_time_embed, 240 | edge_dist_embed=edge_distance_embed, 241 | edge_type_embed=edge_type_embed 242 | ) 243 | x = self.norms_list[idx](x) 244 | x = self.act(x) 245 | x = self.dropout_list[idx](x) 246 | else: 247 | x = x_for_time_filter 248 | 249 | logits = self.linear(x) 250 | loss = self.loss_func(logits, label.long()) 251 | return logits, loss 252 | -------------------------------------------------------------------------------- /multiple_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | # Parse arguments 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('-n', '--num_run', help='The total number of experiments.', required=True) 8 | parser.add_argument('-f', '--yaml_file', help='The configuration file.', required=True) 9 | parser.add_argument('-g', '--gpu_id', help='The gpu index.', default=0, required=False) 10 | args = parser.parse_args() 11 | 12 | for i in range(int(args.num_run)): 13 | print(f"Start carrying out experiment {i+1}/{args.num_run}...") 14 | exec_str = f"CUDA_VISIBLE_DEVICES={args.gpu_id} python run.py -f {args.yaml_file} --multi_run_mode" 15 | os.system(exec_str) 16 | print("\n\n") 17 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from preprocess.generate_hypergraph import ( 2 | generate_hypergraph_from_file, 3 | generate_hyperedge_stat, 4 | generate_traj2traj_data, 5 | generate_ci2traj_pyg_data, 6 | merge_traj2traj_data, 7 | filter_chunk 8 | ) 9 | from preprocess.preprocess_fn import ( 10 | remove_unseen_user_poi, 11 | id_encode, 12 | ignore_first, 13 | only_keep_last 14 | ) 15 | from preprocess.file_reader import ( 16 | FileReaderBase, 17 | FileReader 18 | ) 19 | from preprocess.preprocess_main import ( 20 | preprocess 21 | ) 22 | 23 | __all__ = [ 24 | "FileReaderBase", 25 | "FileReader", 26 | "generate_hypergraph_from_file", 27 | "generate_hyperedge_stat", 28 | "generate_traj2traj_data", 29 | "generate_ci2traj_pyg_data", 30 | "merge_traj2traj_data", 31 | "filter_chunk", 32 | "remove_unseen_user_poi", 33 | "id_encode", 34 | "ignore_first", 35 | "only_keep_last", 36 | "preprocess" 37 | ] 38 | -------------------------------------------------------------------------------- /preprocess/file_reader.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | import pandas as pd 4 | from datetime import datetime 5 | from datetime import timedelta 6 | from tqdm import tqdm 7 | import logging 8 | from utils import get_root_dir 9 | from preprocess import id_encode, ignore_first, only_keep_last 10 | 11 | 12 | class FileReaderBase: 13 | root_path = get_root_dir() 14 | 15 | @classmethod 16 | def read_dataset(cls, file_name, dataset_name): 17 | raise NotImplementedError 18 | 19 | 20 | class FileReader(FileReaderBase): 21 | @classmethod 22 | def read_dataset(cls, file_name, dataset_name): 23 | file_path = osp.join(cls.root_path, 'raw', file_name) 24 | if dataset_name == 'ca': 25 | df = pd.read_csv(file_path, sep=',') 26 | df['UTCTimeOffset'] = df['UTCTime'].apply(lambda x: datetime.strptime(x, "%Y-%m-%dT%H:%M:%SZ")) 27 | df['PoiCategoryName'] = df['PoiCategoryId'] 28 | else: 29 | df = pd.read_csv(file_path, sep='\t', encoding='latin-1', header=None) 30 | df.columns = [ 31 | 'UserId', 'PoiId', 'PoiCategoryId', 'PoiCategoryName', 'Latitude', 'Longitude', 'TimezoneOffset', 32 | 'UTCTime' 33 | ] 34 | df['UTCTime'] = df['UTCTime'].apply(lambda x: datetime.strptime(x, "%a %b %d %H:%M:%S +0000 %Y")) 35 | df['UTCTimeOffset'] = df['UTCTime'] + df['TimezoneOffset'].apply(lambda x: timedelta(hours=x/60)) 36 | df['UTCTimeOffsetEpoch'] = df['UTCTimeOffset'].apply(lambda x: x.strftime('%s')) 37 | df['UTCTimeOffsetWeekday'] = df['UTCTimeOffset'].apply(lambda x: x.weekday()) 38 | df['UTCTimeOffsetHour'] = df['UTCTimeOffset'].apply(lambda x: x.hour) 39 | df['UTCTimeOffsetDay'] = df['UTCTimeOffset'].apply(lambda x: x.strftime('%Y-%m-%d')) 40 | df['UserRank'] = df.groupby('UserId')['UTCTimeOffset'].rank(method='first') 41 | 42 | logging.info( 43 | f'[Preprocess - Load Raw Data] min UTCTimeOffset: {min(df["UTCTimeOffset"])}, ' 44 | f'max UTCTimeOffSet: {max(df["UTCTimeOffset"])}, #User: {df["UserId"].nunique()}, ' 45 | f'#POI: {df["PoiId"].nunique()}, #check-in: {df.shape[0]}' 46 | ) 47 | return df 48 | 49 | @classmethod 50 | def do_filter(cls, df, poi_min_freq, user_min_freq): 51 | poi_count = df.groupby('PoiId')['UserId'].count().reset_index() 52 | df = df[df['PoiId'].isin(poi_count[poi_count['UserId'] > poi_min_freq]['PoiId'])] 53 | user_count = df.groupby('UserId')['PoiId'].count().reset_index() 54 | df = df[df['UserId'].isin(user_count[user_count['PoiId'] > user_min_freq]['UserId'])] 55 | 56 | logging.info( 57 | f"[Preprocess - Filter Low Frequency User] User count: {len(user_count)}, " 58 | f"Low frequency user count: {len(user_count[user_count['PoiId'] <= user_min_freq])}, " 59 | f"ratio: {len(user_count[user_count['PoiId'] <= user_min_freq]) / len(user_count):.5f}" 60 | ) 61 | logging.info( 62 | f"[Preprocess - Filter Low Frequency POI] POI count: {len(poi_count)}, " 63 | f"Low frequency POI count: {len(poi_count[poi_count['UserId'] <= poi_min_freq])}, " 64 | f"ratio: {len(poi_count[poi_count['UserId'] <= poi_min_freq]) / len(poi_count):.5f}" 65 | ) 66 | return df 67 | 68 | @classmethod 69 | def split_train_test(cls, df, is_sorted=False): 70 | if not is_sorted: 71 | df = df.sort_values(by=['UserId', 'UTCTimeOffset'], ascending=True) 72 | 73 | df['UserRank'] = df.groupby('UserId')['UTCTimeOffset'].rank(method='first') 74 | df['SplitTag'] = 'train' 75 | total_len = df.shape[0] 76 | validation_index = int(total_len * 0.8) 77 | test_index = int(total_len * 0.9) 78 | df = df.sort_values(by='UTCTimeOffset', ascending=True) 79 | df.iloc[validation_index:test_index]['SplitTag'] = 'validation' 80 | df.iloc[test_index:]['SplitTag'] = 'test' 81 | df['UserRank'] = df.groupby('UserId')['UTCTimeOffset'].rank(method='first') 82 | 83 | # Filter out check-in records when their gaps with thier previous check-in and later check-in 84 | # are both larger than 24 hours 85 | df = df.sort_values(by=['UserId', 'UTCTimeOffset'], ascending=True) 86 | isolated_index = [] 87 | for idx, diff1, diff2, user, user1, user2 in zip( 88 | df.index, 89 | df['UTCTimeOffset'].diff(1), 90 | df['UTCTimeOffset'].diff(-1), 91 | df['UserId'], 92 | df['UserId'].shift(1), 93 | df['UserId'].shift(-1) 94 | ): 95 | if pd.isna(diff1) and abs(diff2.total_seconds()) > 86400 and user == user2: 96 | isolated_index.append(idx) 97 | elif pd.isna(diff2) and abs(diff1.total_seconds()) > 86400 and user == user1: 98 | isolated_index.append(idx) 99 | if abs(diff1.total_seconds()) > 86400 and abs(diff2.total_seconds()) > 86400 and user == user1 and user == user2: 100 | isolated_index.append(idx) 101 | elif abs(diff2.total_seconds()) > 86400 and user == user2 and user != user1: 102 | isolated_index.append(idx) 103 | elif abs(diff1.total_seconds()) > 86400 and user == user1 and user != user2: 104 | isolated_index.append(idx) 105 | df = df[~df.index.isin(set(isolated_index))] 106 | 107 | logging.info('[Preprocess - Train/Validate/Test Split] Done.') 108 | return df 109 | 110 | @classmethod 111 | def generate_id(cls, df, session_time_interval, do_label_encode=True, only_last_metric=True): 112 | df = df.sort_values(by=['UserId', 'UTCTimeOffset'], ascending=True) 113 | 114 | # generate pseudo session trajectory(temporal) 115 | start_id = 0 116 | pseudo_session_trajectory_id = [start_id] 117 | start_user = df['UserId'].tolist()[0] 118 | time_interval = [] 119 | for user, time_diff in tqdm(zip(df['UserId'], df['UTCTimeOffset'].diff())): 120 | if pd.isna(time_diff): 121 | time_interval.append(None) 122 | continue 123 | elif start_user != user: 124 | # difference user 125 | start_id += 1 126 | start_user = user 127 | elif time_diff.total_seconds() / 60 > session_time_interval: 128 | # same user, beyond interval 129 | start_id += 1 130 | time_interval.append(time_diff.total_seconds() / 60) 131 | pseudo_session_trajectory_id.append(start_id) 132 | 133 | assert len(pseudo_session_trajectory_id) == len(df) 134 | 135 | # do label encoding 136 | if do_label_encode: 137 | df_train = df[df['SplitTag'] == 'train'] 138 | # todo check if result will be influenced by padding id (nyc use len(), but tky and ca use 0) 139 | poi_id_le, padding_poi_ie = id_encode(df_train, df, 'PoiId', padding=0) 140 | poi_category_le, padding_poi_category = id_encode(df_train, df, 'PoiCategoryId', padding=0) 141 | user_id_le, padding_user_id = id_encode(df_train, df, 'UserId', padding=0) 142 | hour_id_le, padding_hour_id = id_encode(df_train, df, 'UTCTimeOffsetHour', padding=0) 143 | weekday_id_le, padding_weekday_id = id_encode(df_train, df, 'UTCTimeOffsetWeekday', padding=0) 144 | 145 | with open(osp.join(cls.root_path, 'preprocessed', 'label_encoding.pkl'), 'wb') as f: 146 | pickle.dump([ 147 | poi_id_le, poi_category_le, user_id_le, hour_id_le, weekday_id_le, 148 | padding_poi_ie, padding_poi_category, padding_user_id, padding_hour_id, padding_weekday_id 149 | ], f) 150 | 151 | df['check_ins_id'] = df['UTCTimeOffset'].rank(ascending=True, method='first') - 1 152 | df['time_interval'] = time_interval 153 | df['pseudo_session_trajectory_id'] = pseudo_session_trajectory_id 154 | 155 | # Ignore the first check-in of every trajectory when creating samples 156 | df = ignore_first(df) 157 | 158 | if only_last_metric: 159 | df = only_keep_last(df) 160 | 161 | ignore_num = len(df[df["SplitTag"] == "ignore"]) 162 | logging.info(f'[Preprocess] ignore sample num: {ignore_num}, ratio: {ignore_num/df.shape[0]}.') 163 | 164 | trajectory_id_count = df.groupby(['pseudo_session_trajectory_id'])['check_ins_id'].count().reset_index() 165 | check_ins_count = trajectory_id_count[trajectory_id_count['check_ins_id'] == 1] 166 | 167 | logging.info( 168 | f"[Preprocess] pseudo session trajectory of single check-ins count: {len(check_ins_count)}, " 169 | f"ratio: {len(check_ins_count) / len(trajectory_id_count)}." 170 | ) 171 | return df 172 | -------------------------------------------------------------------------------- /preprocess/generate_hypergraph.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import pandas as pd 3 | import numpy as np 4 | from scipy.sparse import coo_matrix 5 | import torch 6 | from torch_sparse import SparseTensor 7 | from torch_geometric.data import Data 8 | from utils import haversine 9 | import os 10 | import os.path as osp 11 | import logging 12 | 13 | 14 | def generate_hypergraph_from_file(input_file, output_path, args): 15 | """ 16 | Construct incidence matrix of [Checkin -> Trajectory] and adjcency list of [Trajectory -> Trajectory] 17 | from the raw record, the edge_index will be like 18 | [[ -CheckIn- ] 19 | [ -Trajectory(hyperedge)] 20 | and 21 | [[ -Trajectory(hyperedge)- ] 22 | [ -Trajectory(hyperedge)] 23 | separately. 24 | 25 | Use columns in txt file for next-poi task: 26 | UserId, check_ins_id, PoiId, Latitude, Longitude, PoiCategoryId, UTCTimeOffsetEpoch, 27 | pseudo_session_trajectory_id, UTCTimeOffsetWeekday, UTCTimeOffsetHour. 28 | 29 | The two part will save as two .pt files. 30 | 31 | :param input_file: the hypergraph raw path 32 | :param output_path: pyg_data.pt output directory 33 | :param args: parameters parsed for input 34 | :return: None 35 | """ 36 | usecols = [ 37 | 'UserId', 'PoiId', 'PoiCategoryId', 'Latitude', 'Longitude', 'UTCTimeOffsetEpoch', 'UTCTimeOffsetWeekday', 38 | 'UTCTimeOffsetHour', 'check_ins_id', 'pseudo_session_trajectory_id' 39 | ] 40 | threshold = args.threshold 41 | filter_mode = args.filter_mode 42 | data = pd.read_csv(input_file, usecols=usecols) 43 | 44 | traj_column = 'pseudo_session_trajectory_id' 45 | 46 | # If True, Shift traj_id with offset #check_ins_id before saving to pyg data, which means idx of checkin are 47 | # in range [0, #checkin_id - 1], and idx of trajectory are in range [#checkin, #trajectory+#checkin-1] 48 | traj_offset = True 49 | if traj_offset: 50 | checkin_offset = torch.as_tensor([data.check_ins_id.max() + 1], dtype=torch.long) 51 | else: 52 | checkin_offset = torch.as_tensor([0], dtype=torch.long) 53 | 54 | traj_stat = generate_hyperedge_stat(data, traj_column) 55 | ci2traj_pyg_data = generate_ci2traj_pyg_data(data, traj_stat, traj_column, checkin_offset) 56 | 57 | traj2traj_intra_u_data = generate_traj2traj_data( 58 | data, 59 | traj_stat, 60 | traj_column, 61 | threshold, 62 | filter_mode=filter_mode, 63 | relation_type='intra' 64 | ) 65 | traj2traj_inter_u_data = generate_traj2traj_data( 66 | data, 67 | traj_stat, 68 | traj_column, 69 | threshold, 70 | filter_mode=filter_mode, 71 | relation_type='inter' 72 | ) 73 | traj2traj_pyg_data = merge_traj2traj_data(traj_stat, traj2traj_intra_u_data, traj2traj_inter_u_data, checkin_offset) 74 | 75 | # save pyg data 76 | if not osp.isdir(output_path): 77 | os.makedirs(output_path) 78 | ci2traj_out_file = osp.join(output_path, 'ci2traj_pyg_data.pt') 79 | traj2traj_out_file = osp.join(output_path, 'traj2traj_pyg_data.pt') 80 | torch.save(ci2traj_pyg_data, ci2traj_out_file) 81 | torch.save(traj2traj_pyg_data, traj2traj_out_file) 82 | 83 | logging.info( 84 | f'[Preprocess - Generate Hypergraph] Done saving checkin2trajectory pyg data to {ci2traj_out_file}' 85 | f' and trajectory2trajectory pyg data to {traj2traj_out_file}.' 86 | ) 87 | return 88 | 89 | 90 | def generate_hyperedge_stat(data, traj_column): 91 | """ 92 | Generate trajectory hyperedge statistics data (pd.DataFrame) 93 | 94 | :param data: raw pseudo-session trajectory data 95 | :param traj_column: trajectory column name 96 | :return: 97 | """ 98 | traj_stat = pd.DataFrame() 99 | traj_stat['size'] = data.groupby(traj_column)['UTCTimeOffsetEpoch'].apply(len) 100 | traj_stat['mean_lon'] = data.groupby(traj_column)['Longitude'].apply(sum) / traj_stat['size'] 101 | traj_stat['mean_lat'] = data.groupby(traj_column)['Latitude'].apply(sum) / traj_stat['size'] 102 | traj_stat[['last_lon', 'last_lat']] = \ 103 | data.sort_values([traj_column, 'UTCTimeOffsetEpoch']).groupby(traj_column).last()[['Longitude', 'Latitude']] 104 | 105 | traj_stat['start_time'] = data.groupby(traj_column)['UTCTimeOffsetEpoch'].apply(min) 106 | traj_stat['end_time'] = data.groupby(traj_column)['UTCTimeOffsetEpoch'].apply(max) 107 | traj_stat['mean_time'] = data.groupby(traj_column)['UTCTimeOffsetEpoch'].apply(sum) / traj_stat['size'] 108 | traj_stat['time_window_in_hour'] = (traj_stat.end_time - traj_stat.start_time) / (60*60) 109 | logging.info(f'[Preprocess - Generate Hypergraph] Number of hyperedges(trajectory): {traj_stat.shape[0]}.') 110 | logging.info( 111 | f'[Preprocess - Generate Hypergraph] The min, mean, max size of hyperedges are: ' 112 | f'{traj_stat["size"].min()}, {traj_stat["size"].mean()}, {traj_stat["size"].max()}.' 113 | ) 114 | logging.info( 115 | f'[Preprocess - Generate Hypergraph] The min, mean, max time window of hyperedges are:' 116 | f'{traj_stat.time_window_in_hour.min()}, {traj_stat.time_window_in_hour.mean()}, ' 117 | f'{traj_stat.time_window_in_hour.max()}.' 118 | ) 119 | return traj_stat 120 | 121 | 122 | def generate_ci2traj_pyg_data(data, traj_stat, traj_column, checkin_offset): 123 | """ 124 | Generate checkin2trajectory incidence matrix, checkin (here ci is short for checkin) feature matrix, and 125 | edge_delta_t and edge_delta_s. Then store them into pyg data. 126 | edge_delta_t is calculated by (traj(max_time) - current_time) 127 | edge_delta_s is calculated by (geodis(traj(last_lbs), current_lbs)) 128 | 129 | :param data: raw trajectory data; 130 | :param traj_stat: hyperedge(trajectory) statistics; 131 | :param traj_column: trajectory column name; 132 | :param checkin_offset: max checkin index plus 1; 133 | :return: pyg_data including incidence matrix and checkin feature matrix and other edge information. 134 | """ 135 | checkin_feature_columns = [ 136 | 'UserId', 137 | 'PoiId', 138 | 'PoiCategoryId', 139 | 'UTCTimeOffsetEpoch', 140 | 'Longitude', 141 | 'Latitude', 142 | 'UTCTimeOffsetWeekday', 143 | 'UTCTimeOffsetHour' 144 | ] 145 | checkin_feature = data.sort_values('check_ins_id')[checkin_feature_columns].to_numpy() 146 | assert data.check_ins_id.unique().shape[0] == data.check_ins_id.max() + 1, \ 147 | 'check_ins_id is not chronological order in raw data' 148 | 149 | # Calculate distance between trajectory's last poi location and curren poi location 150 | delta_s_in_traj = data.join(traj_stat, on=traj_column, how='left')[ 151 | ['Longitude', 'Latitude', 'last_lon', 'last_lat'] 152 | ] 153 | delta_s_in_traj['distance_km'] = haversine( 154 | delta_s_in_traj.Longitude, 155 | delta_s_in_traj.Latitude, 156 | delta_s_in_traj.last_lon, 157 | delta_s_in_traj.last_lat 158 | ) 159 | 160 | # Create incidence matrix for check-in -> trajectory 161 | ci2traj_adj_t = SparseTensor( 162 | row=torch.as_tensor(data[traj_column].tolist(), dtype=torch.long), 163 | col=torch.as_tensor(data.check_ins_id.tolist(), dtype=torch.long), 164 | value=torch.as_tensor(range(0, data.shape[0]), dtype=torch.long) 165 | ) 166 | perm = ci2traj_adj_t.storage.value() 167 | ci2traj_edge_t = torch.tensor(data.UTCTimeOffsetEpoch.tolist())[perm] 168 | ci2traj_edge_delta_t = torch.tensor( 169 | traj_stat.end_time[data[traj_column].tolist()].values - data.UTCTimeOffsetEpoch.values 170 | )[perm] 171 | ci2traj_edge_delta_s = torch.tensor(delta_s_in_traj.distance_km.tolist())[perm] 172 | 173 | ci2traj_edge_index = torch.stack([ci2traj_adj_t.storage.col(), ci2traj_adj_t.storage.row() + checkin_offset]) 174 | 175 | ci2traj_pyg_data = Data( 176 | edge_index=ci2traj_edge_index, 177 | x=torch.tensor(checkin_feature), 178 | edge_t=ci2traj_edge_t, 179 | edge_delta_t=ci2traj_edge_delta_t, 180 | edge_delta_s=ci2traj_edge_delta_s 181 | ) 182 | ci2traj_pyg_data.num_hyperedges = traj_stat.shape[0] 183 | return ci2traj_pyg_data 184 | 185 | 186 | def generate_traj2traj_data( 187 | data, 188 | traj_stat, 189 | traj_column, 190 | threshold=0.02, 191 | filter_mode='min size', 192 | chunk_num=10, 193 | relation_type='intra' 194 | ): 195 | """ 196 | Generate hyperedge2hyperedge (traj2traj) dynamic relation. 197 | 198 | :param data: raw trajectory data; 199 | :param traj_stat: hyperedge(trajectory) statistics; 200 | :param traj_column: trajectory column name; 201 | :param threshold: threshold for filtering noise relation; 202 | :param filter_mode: filter mode for filtering noise relation; 203 | :param chunk_num: number of chunk for fast filtering. 204 | :param relation_type: intra or inter, switch for different type of hyperedge2hyperedge relation. 205 | :return: hyperedge2hyperedge tuple data(edge_index(coo), edge_type, edge_delta_t and edge_delta_s. 206 | """ 207 | traj2traj_original_metric = None 208 | # First create sparse matrix for trajectory -> poi, then generate inter-user adjacency list 209 | # one trajectory may have multiple identical poi_id, we drop the duplicate ones first 210 | traj_user_map = data[['UserId', traj_column]].drop_duplicates().set_index(traj_column) 211 | traj_size_adjust = None 212 | if relation_type == 'inter': 213 | traj_poi_map = data[['PoiId', traj_column]].drop_duplicates() 214 | traj2node = coo_matrix(( 215 | np.ones(traj_poi_map.shape[0]), 216 | (np.array(traj_poi_map['PoiId'], dtype=np.int64), np.array(traj_poi_map[traj_column], dtype=np.int64)) 217 | )).tocsr() 218 | 219 | # adjust the traj_id size based on new traj_poi_map 220 | traj_size_adjust = traj_poi_map.groupby(traj_column).apply(len).tolist() 221 | else: 222 | traj2node = coo_matrix(( 223 | np.ones(traj_user_map.shape[0]), 224 | (np.array(traj_user_map['UserId'], dtype=np.int64), np.array(traj_user_map.index, dtype=np.int64)) 225 | )).tocsr() 226 | 227 | node2traj = traj2node.T 228 | traj2traj = node2traj * traj2node 229 | traj2traj = traj2traj.tocoo() 230 | 231 | # for inter-user type, save the original similarity metric 232 | if relation_type == 'inter': 233 | row_filtered, col_filtered, data_filtered = filter_chunk( 234 | row=traj2traj.row, 235 | col=traj2traj.col, 236 | data=traj2traj.data, 237 | chunk_num=chunk_num, 238 | he_size=traj_size_adjust, 239 | threshold=0, 240 | filter_mode=filter_mode 241 | ) 242 | traj2traj_original_metric = coo_matrix((data_filtered, (row_filtered, col_filtered)), shape=traj2traj.shape) 243 | 244 | # Filter 1: filter based on pre-define conditions 245 | # 1. different trajectory 2. source_endtime <= target_starttime 246 | mask_1 = traj2traj.row != traj2traj.col 247 | mask_2 = traj_stat.end_time[traj2traj.col].values <= traj_stat.start_time[traj2traj.row].values 248 | mask = mask_1 & mask_2 249 | if relation_type == 'inter': 250 | # 3. diffrent user 251 | mask_3 = traj_user_map['UserId'][traj2traj.row].values != traj_user_map['UserId'][traj2traj.col].values 252 | mask = mask & mask_3 253 | 254 | traj2traj.row = traj2traj.row[mask] 255 | traj2traj.col = traj2traj.col[mask] 256 | traj2traj.data = traj2traj.data[mask] 257 | 258 | if relation_type == 'inter': 259 | # Filter 2: filter based on pre-define metric threshold 260 | row_filtered, col_filtered, data_filtered = filter_chunk( 261 | row=traj2traj.row, 262 | col=traj2traj.col, 263 | data=traj2traj.data, 264 | chunk_num=chunk_num, 265 | he_size=traj_size_adjust, 266 | threshold=threshold, 267 | filter_mode=filter_mode 268 | ) 269 | traj2traj.row = row_filtered 270 | traj2traj.col = col_filtered 271 | traj2traj.data = data_filtered 272 | edge_type = np.ones_like(traj2traj.row) 273 | else: 274 | edge_type = np.zeros_like(traj2traj.row) 275 | 276 | # Calculate edge_delta_t and edge_delta_s 277 | edge_delta_t = traj_stat.mean_time[traj2traj.row].values - traj_stat.mean_time[traj2traj.col].values 278 | edge_delta_s = np.stack([ 279 | traj_stat.mean_lon[traj2traj.row].values, 280 | traj_stat.mean_lat[traj2traj.row].values, 281 | traj_stat.mean_lon[traj2traj.col].values, 282 | traj_stat.mean_lat[traj2traj.col].values], 283 | axis=1 284 | ) 285 | 286 | edge_delta_s = torch.tensor(edge_delta_s) 287 | edge_delta_s = haversine(edge_delta_s[:, 0], edge_delta_s[:, 1], edge_delta_s[:, 2], edge_delta_s[:, 3]) 288 | 289 | logging.info( 290 | f"[Preprocess - Generate Hypergraph] Number of {relation_type}-user hyperedge2hyperedge(traj2traj) " 291 | f"relation has been generated: {traj2traj.row.shape[0]}, while threshold={threshold} and mode={filter_mode}." 292 | ) 293 | 294 | return traj2traj, traj2traj_original_metric, edge_type, edge_delta_t, edge_delta_s.numpy() 295 | 296 | 297 | def merge_traj2traj_data(traj_stat, intra_u_data, inter_u_data, checkin_offset): 298 | """ 299 | Merge intra-user and inter-user hyperedge2hyperedge(traj2traj) dynamic relation. 300 | Merge intra-user and inter-user hyperedge2hyperedge(traj2traj) dynamic relation. 301 | 302 | :param traj_stat: hyperedge(trajectory) statistics; 303 | :param intra_u_data: hyperedge2hyperedge(traj2traj) relation between the same user, composited of tuple with 304 | edge_index(coo), edge_attr(np.array), edge_type(np.array), edge_delta_t(np.array), edge_delta_s(np.array); 305 | :param inter_u_data: hyperedge2hyperedge(traj2traj) relation between different users, composited of tuple like 306 | intra_u_data. 307 | :param checkin_offset: max checkin index plus 1; 308 | :return: pyg data of traj2traj 309 | """ 310 | traj_feature = traj_stat[['size', 'mean_lon', 'mean_lat', 'mean_time', 'start_time', 'end_time']].to_numpy() 311 | 312 | # add two extra feature column to make sure traj feature has the same dimension size with ci feature 313 | padding_feature = np.zeros([traj_feature.shape[0], 2]) 314 | traj_feature = np.concatenate([traj_feature, padding_feature], axis=1) 315 | 316 | intra_edge_index, _, intra_edge_type, intra_edge_delta_t, intra_edge_delta_s = intra_u_data 317 | inter_edge_index, traj2traj_orginal_metric, inter_edge_type, inter_edge_delta_t, inter_edge_delta_s = inter_u_data 318 | row = np.concatenate([intra_edge_index.row, inter_edge_index.row]) 319 | col = np.concatenate([intra_edge_index.col, inter_edge_index.col]) 320 | 321 | # replace data with metric value 322 | metric_data = coo_matrix((np.ones(row.shape[0]), (row, col)), shape=traj2traj_orginal_metric.shape) 323 | epsilon = coo_matrix((np.zeros(row.shape[0]) + 1e-6, (row, col)), shape=traj2traj_orginal_metric.shape) 324 | metric_data = metric_data.multiply(traj2traj_orginal_metric) 325 | metric_data += epsilon 326 | 327 | adj_t = SparseTensor( 328 | row=torch.as_tensor(row, dtype=torch.long), 329 | col=torch.as_tensor(col, dtype=torch.long), 330 | value=torch.as_tensor(range(0, row.shape[0]), dtype=torch.long) 331 | ) 332 | perm = adj_t.storage.value() 333 | 334 | x = torch.tensor(traj_feature) 335 | edge_type = torch.tensor(np.concatenate([intra_edge_type, inter_edge_type]))[perm] 336 | edge_delta_t = torch.tensor(np.concatenate([intra_edge_delta_t, inter_edge_delta_t]))[perm] 337 | edge_delta_s = torch.tensor(np.concatenate([intra_edge_delta_s, inter_edge_delta_s]))[perm] 338 | 339 | edge_index = torch.stack([ 340 | adj_t.storage.col() + checkin_offset, 341 | adj_t.storage.row() + checkin_offset 342 | ]) 343 | 344 | # edge_attr: source_size, target_size, jaccard_similarity 345 | source_size = x[edge_index[0] - checkin_offset][:, 0] / x[:, 0].max() 346 | target_size = x[edge_index[1] - checkin_offset][:, 0] / x[:, 0].max() 347 | edge_attr = torch.stack([source_size, target_size, torch.tensor(metric_data.data)], dim=1) 348 | 349 | traj2traj_pyg_data = Data( 350 | edge_index=edge_index, 351 | x=x, 352 | edge_attr=edge_attr, 353 | edge_type=edge_type, 354 | edge_delta_t=edge_delta_t, 355 | edge_delta_s=edge_delta_s 356 | ) 357 | return traj2traj_pyg_data 358 | 359 | 360 | def filter_chunk(row, col, data, he_size, chunk_num=10, threshold=0.02, filter_mode='min size'): 361 | """ 362 | Filter noise hyperedge2hyperedge connection based on metric threshold 363 | 364 | :param row: row, hyperedge2hyperedge scipy.sparse coo matrix 365 | :param col: col, hyperedge2hyperedge scipy.sparse coo matrix 366 | :param data: data, hyperedge2hyperedge scipy.sparse coo matrix 367 | :param he_size: hyperedge size list (drop duplicates) 368 | :param chunk_num: number of chunk to prevent from oom issue 369 | :param threshold: metric threshold, relation will be kept only if metric value is greater than threshold 370 | :param filter_mode: min_size - propotional to minmum size, 'jaccard' - jaccard similarity 371 | min_size, E2E_{ij} keeps when E2E_{ij} \ge \theta\min(|\mathcal{E}_i|,|\mathcal{E}_j|) 372 | jaccard, E2E_{ij} keeps when \frac{E2E_{ij}}{|\mathcal{E}_i|+|\mathcal{E}_j| - E2E_{ij}} \ge \theta 373 | :return: 374 | """ 375 | # Split the data to multiple chunks for large data 376 | chunk_bin = np.linspace(0, row.shape[0], chunk_num, dtype=np.int64) 377 | rows, cols, datas = [], [], [] 378 | for i in tqdm(range(len(chunk_bin) - 1)): 379 | row_chunk = row[chunk_bin[i]:chunk_bin[i + 1]] 380 | col_chunk = col[chunk_bin[i]:chunk_bin[i + 1]] 381 | data_chunk = data[chunk_bin[i]:chunk_bin[i + 1]] 382 | source_size = np.array(list(map(he_size.__getitem__, row_chunk.tolist()))) 383 | target_size = np.array(list(map(he_size.__getitem__, col_chunk.tolist()))) 384 | if filter_mode == 'min size': 385 | # propotional to minimum size 386 | metric = data_chunk / np.minimum(source_size, target_size) 387 | else: 388 | # jaccard similarity 389 | metric = data_chunk / (source_size + target_size - data_chunk) 390 | filter_mask = metric >= threshold 391 | rows.append(row_chunk[filter_mask]) 392 | cols.append(col_chunk[filter_mask]) 393 | datas.append(metric[filter_mask]) 394 | 395 | return np.concatenate(rows), np.concatenate(cols), np.concatenate(datas) 396 | -------------------------------------------------------------------------------- /preprocess/preprocess_fn.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from typing import Dict, Tuple 3 | from sklearn.preprocessing import LabelEncoder 4 | import logging 5 | 6 | 7 | def id_encode( 8 | fit_df: pd.DataFrame, 9 | encode_df: pd.DataFrame, 10 | column: str, 11 | padding: int = -1 12 | ) -> Tuple[LabelEncoder, int]: 13 | """ 14 | 15 | :param fit_df: only consider the data in encode df for constructing LabelEncoder instance 16 | :param encode_df: the dataframe which use the constructed LabelEncoder instance to encode their values 17 | :param column: the column to be encoded 18 | :param padding: 19 | :return: 20 | """ 21 | id_le = LabelEncoder() 22 | id_le = id_le.fit(fit_df[column].values.tolist()) 23 | if padding == 0: 24 | padding_id = padding 25 | encode_df[column] = [ 26 | id_le.transform([i])[0] + 1 if i in id_le.classes_ else padding_id 27 | for i in encode_df[column].values.tolist() 28 | ] 29 | else: 30 | padding_id = len(id_le.classes_) 31 | encode_df[column] = [ 32 | id_le.transform([i])[0] if i in id_le.classes_ else padding_id 33 | for i in encode_df[column].values.tolist() 34 | ] 35 | return id_le, padding_id 36 | 37 | 38 | def ignore_first(df: pd.DataFrame) -> pd.DataFrame: 39 | """ 40 | Ignore the first check-in sample of every trajectory because of no historical check-in. 41 | 42 | """ 43 | df['pseudo_session_trajectory_rank'] = df.groupby( 44 | 'pseudo_session_trajectory_id')['UTCTimeOffset'].rank(method='first') 45 | df['query_pseudo_session_trajectory_id'] = df['pseudo_session_trajectory_id'].shift() 46 | df.loc[df['pseudo_session_trajectory_rank'] == 1, 'query_pseudo_session_trajectory_id'] = None 47 | df['last_checkin_epoch_time'] = df['UTCTimeOffsetEpoch'].shift() 48 | df.loc[df['pseudo_session_trajectory_rank'] == 1, 'last_checkin_epoch_time'] = None 49 | df.loc[df['UserRank'] == 1, 'SplitTag'] = 'ignore' 50 | df.loc[df['pseudo_session_trajectory_rank'] == 1, 'SplitTag'] = 'ignore' 51 | return df 52 | 53 | 54 | def only_keep_last(df: pd.DataFrame) -> pd.DataFrame: 55 | """ 56 | Only keep the last check-in samples in validation and testing for measuring model performance. 57 | 58 | """ 59 | df['pseudo_session_trajectory_count'] = df.groupby( 60 | 'pseudo_session_trajectory_id')['UTCTimeOffset'].transform('count') 61 | df.loc[(df['SplitTag'] == 'validation') & ( 62 | df['pseudo_session_trajectory_count'] != df['pseudo_session_trajectory_rank'] 63 | ), 'SplitTag'] = 'ignore' 64 | df.loc[(df['SplitTag'] == 'test') & ( 65 | df['pseudo_session_trajectory_count'] != df['pseudo_session_trajectory_rank'] 66 | ), 'SplitTag'] = 'ignore' 67 | return df 68 | 69 | 70 | def remove_unseen_user_poi(df: pd.DataFrame) -> Dict: 71 | """ 72 | Remove the samples of Validate and Test if those POIs or Users didnt show in training samples 73 | 74 | """ 75 | preprocess_result = dict() 76 | df_train = df[df['SplitTag'] == 'train'] 77 | df_validate = df[df['SplitTag'] == 'validation'] 78 | df_test = df[df['SplitTag'] == 'test'] 79 | 80 | train_user_set = set(df_train['UserId']) 81 | train_poi_set = set(df_train['PoiId']) 82 | df_validate = df_validate[ 83 | (df_validate['UserId'].isin(train_user_set)) & (df_validate['PoiId'].isin(train_poi_set))].reset_index() 84 | df_test = df_test[(df_test['UserId'].isin(train_user_set)) & (df_test['PoiId'].isin(train_poi_set))].reset_index() 85 | 86 | preprocess_result['sample'] = df 87 | preprocess_result['train_sample'] = df_train 88 | preprocess_result['validate_sample'] = df_validate 89 | preprocess_result['test_sample'] = df_test 90 | 91 | logging.info( 92 | f"[Preprocess] train shape: {df_train.shape}, validation shape: {df_validate.shape}, " 93 | f"test shape: {df_test.shape}" 94 | ) 95 | return preprocess_result 96 | -------------------------------------------------------------------------------- /preprocess/preprocess_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pandas as pd 4 | from datetime import datetime 5 | import os.path as osp 6 | from utils import Cfg, get_root_dir 7 | from preprocess import ( 8 | ignore_first, 9 | only_keep_last, 10 | id_encode, 11 | remove_unseen_user_poi, 12 | FileReader, 13 | generate_hypergraph_from_file 14 | ) 15 | import logging 16 | 17 | 18 | def preprocess_nyc(path: bytes, preprocessed_path: bytes) -> pd.DataFrame: 19 | raw_path = osp.join(path, 'raw') 20 | 21 | df_train = pd.read_csv(osp.join(raw_path, 'NYC_train.csv')) 22 | df_val = pd.read_csv(osp.join(raw_path, 'NYC_val.csv')) 23 | df_test = pd.read_csv(osp.join(raw_path, 'NYC_test.csv')) 24 | df_train['SplitTag'] = 'train' 25 | df_val['SplitTag'] = 'validation' 26 | df_test['SplitTag'] = 'test' 27 | df = pd.concat([df_train, df_val, df_test]) 28 | df.columns = [ 29 | 'UserId', 'PoiId', 'PoiCategoryId', 'PoiCategoryCode', 'PoiCategoryName', 'Latitude', 'Longitude', 30 | 'TimezoneOffset', 'UTCTime', 'UTCTimeOffset', 'UTCTimeOffsetWeekday', 'UTCTimeOffsetNormInDayTime', 31 | 'pseudo_session_trajectory_id', 'UTCTimeOffsetNormDayShift', 'UTCTimeOffsetNormRelativeTime', 'SplitTag' 32 | ] 33 | 34 | # data transformation 35 | df['trajectory_id'] = df['pseudo_session_trajectory_id'] 36 | df['UTCTimeOffset'] = df['UTCTimeOffset'].apply(lambda x: datetime.strptime(x[:19], "%Y-%m-%d %H:%M:%S")) 37 | df['UTCTimeOffsetEpoch'] = df['UTCTimeOffset'].apply(lambda x: x.strftime('%s')) 38 | df['UTCTimeOffsetWeekday'] = df['UTCTimeOffset'].apply(lambda x: x.weekday()) 39 | df['UTCTimeOffsetHour'] = df['UTCTimeOffset'].apply(lambda x: x.hour) 40 | df['UTCTimeOffsetDay'] = df['UTCTimeOffset'].apply(lambda x: x.strftime('%Y-%m-%d')) 41 | df['UserRank'] = df.groupby('UserId')['UTCTimeOffset'].rank(method='first') 42 | df = df.sort_values(by=['UserId', 'UTCTimeOffset'], ascending=True) 43 | 44 | # id encoding 45 | df['check_ins_id'] = df['UTCTimeOffset'].rank(ascending=True, method='first') - 1 46 | traj_id_le, padding_traj_id = id_encode(df, df, 'pseudo_session_trajectory_id') 47 | 48 | df_train = df[df['SplitTag'] == 'train'] 49 | poi_id_le, padding_poi_id = id_encode(df_train, df, 'PoiId') 50 | poi_category_le, padding_poi_category = id_encode(df_train, df, 'PoiCategoryId') 51 | user_id_le, padding_user_id = id_encode(df_train, df, 'UserId') 52 | hour_id_le, padding_hour_id = id_encode(df_train, df, 'UTCTimeOffsetHour') 53 | weekday_id_le, padding_weekday_id = id_encode(df_train, df, 'UTCTimeOffsetWeekday') 54 | 55 | # save mapping logic 56 | with open(osp.join(preprocessed_path, 'label_encoding.pkl'), 'wb') as f: 57 | pickle.dump([ 58 | poi_id_le, poi_category_le, user_id_le, hour_id_le, weekday_id_le, 59 | padding_poi_id, padding_poi_category, padding_user_id, padding_hour_id, padding_weekday_id 60 | ], f) 61 | 62 | # ignore the first for train/validate/test and keep the last for validata/test 63 | df = ignore_first(df) 64 | df = only_keep_last(df) 65 | return df 66 | 67 | 68 | def preprocess_tky_ca(cfg: Cfg, path: bytes) -> pd.DataFrame: 69 | if cfg.dataset_args.dataset_name == 'tky': 70 | raw_file = 'dataset_TSMC2014_TKY.txt' 71 | else: 72 | raw_file = 'dataset_gowalla_ca_ne.csv' 73 | 74 | FileReader.root_path = path 75 | data = FileReader.read_dataset(raw_file, cfg.dataset_args.dataset_name) 76 | data = FileReader.do_filter(data, cfg.dataset_args.min_poi_freq, cfg.dataset_args.min_user_freq) 77 | data = FileReader.split_train_test(data) 78 | 79 | # for ca dataset, after one round of filter, there still be many low frequency pois and users, so do twice 80 | if cfg.dataset_args.dataset_name == 'ca': 81 | data = FileReader.do_filter(data, cfg.dataset_args.min_poi_freq, cfg.dataset_args.min_user_freq) 82 | data = FileReader.split_train_test(data) 83 | 84 | data = FileReader.generate_id( 85 | data, 86 | cfg.dataset_args.session_time_interval, 87 | cfg.dataset_args.do_label_encode, 88 | cfg.dataset_args.only_last_metric 89 | ) 90 | return data 91 | 92 | 93 | def preprocess(cfg: Cfg): 94 | root_path = get_root_dir() 95 | dataset_name = cfg.dataset_args.dataset_name 96 | data_path = osp.join(root_path, 'data', dataset_name) 97 | preprocessed_path = osp.join(data_path, 'preprocessed') 98 | sample_file = osp.join(preprocessed_path, 'sample.csv') 99 | train_file = osp.join(preprocessed_path, 'train_sample.csv') 100 | validate_file = osp.join(preprocessed_path, 'validate_sample.csv') 101 | test_file = osp.join(preprocessed_path, 'test_sample.csv') 102 | 103 | keep_cols = [ 104 | 'check_ins_id', 'UTCTimeOffset', 'UTCTimeOffsetEpoch', 'pseudo_session_trajectory_id', 105 | 'query_pseudo_session_trajectory_id', 'UserId', 'Latitude', 'Longitude', 'PoiId', 'PoiCategoryId', 106 | 'PoiCategoryName', 'last_checkin_epoch_time' 107 | ] 108 | 109 | if not osp.exists(preprocessed_path): 110 | os.makedirs(preprocessed_path) 111 | 112 | # Step 1. preprocess raw files and create sample files including 113 | # 1. data transformation; 2. id encoding; 3.train/validate/test splitting; 4. remove unseen user or poi 114 | if not osp.exists(sample_file): 115 | if 'nyc' == dataset_name: 116 | keep_cols += ['trajectory_id'] 117 | preprocessed_data = preprocess_nyc(data_path, preprocessed_path) 118 | elif 'tky' == dataset_name or 'ca' == dataset_name: 119 | preprocessed_data = preprocess_tky_ca(cfg, data_path) 120 | else: 121 | raise ValueError(f'Wrong dataset name: {dataset_name} ') 122 | 123 | preprocessed_result = remove_unseen_user_poi(preprocessed_data) 124 | preprocessed_result['sample'].to_csv(sample_file, index=False) 125 | preprocessed_result['train_sample'][keep_cols].to_csv(train_file, index=False) 126 | preprocessed_result['validate_sample'][keep_cols].to_csv(validate_file, index=False) 127 | preprocessed_result['test_sample'][keep_cols].to_csv(test_file, index=False) 128 | 129 | # Step 2. generate hypergraph related data 130 | if not osp.exists(osp.join(preprocessed_path, 'ci2traj_pyg_data.pt')): 131 | generate_hypergraph_from_file(sample_file, preprocessed_path, cfg.dataset_args) 132 | 133 | logging.info('[Preprocess] Done preprocessing.') 134 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.54.1 2 | pyyaml==6.0 3 | pandas==0.24.2 4 | numba==0.47.0 5 | scipy==1.4.1 6 | tensorboard==1.15.0 7 | scikit-learn==0.23.2 8 | shapely==2.0.1 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import os.path as osp 4 | import datetime 5 | import torch 6 | import random 7 | from tqdm import tqdm 8 | import argparse 9 | from torch.utils.tensorboard import SummaryWriter 10 | from preprocess import preprocess 11 | from utils import seed_torch, set_logger, Cfg, count_parameters, test_step, save_model 12 | from layer import NeighborSampler 13 | from dataset import LBSNDataset 14 | from model import STHGCN, SequentialTransformer 15 | 16 | 17 | if __name__ == '__main__': 18 | # Parse arguments 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-f', '--yaml_file', help='The configuration file.', required=True) 21 | parser.add_argument('--multi_run_mode', help='Run multiple experiments with the same config.', action='store_true') 22 | args = parser.parse_args() 23 | conf_file = args.yaml_file 24 | 25 | cfg = Cfg(conf_file) 26 | 27 | sizes = [int(i) for i in cfg.model_args.sizes.split('-')] 28 | cfg.model_args.sizes = sizes 29 | 30 | # cuda setting 31 | if int(cfg.run_args.gpu) >= 0: 32 | device = 'cuda:' + str(cfg.run_args.gpu) 33 | else: 34 | device = 'cpu' 35 | cfg.run_args.device = device 36 | 37 | # for multiple runs, seed is replaced with random value 38 | if args.multi_run_mode: 39 | cfg.run_args.seed = None 40 | if cfg.run_args.seed is None: 41 | seed = random.randint(0, 100000000) 42 | else: 43 | seed = int(cfg.run_args.seed) 44 | 45 | seed_torch(seed) 46 | 47 | current_time = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') 48 | cfg.run_args.save_path = f'tensorboard/{current_time}/{cfg.dataset_args.dataset_name}' 49 | cfg.run_args.log_path = f'log/{current_time}/{cfg.dataset_args.dataset_name}' 50 | 51 | if not osp.isdir(cfg.run_args.save_path): 52 | os.makedirs(cfg.run_args.save_path) 53 | if not osp.isdir(cfg.run_args.log_path): 54 | os.makedirs(cfg.run_args.log_path) 55 | 56 | set_logger(cfg.run_args) 57 | summary_writer = SummaryWriter(log_dir=cfg.run_args.save_path) 58 | 59 | hparam_dict = {} 60 | for group, hparam in cfg.__dict__.items(): 61 | hparam_dict.update(hparam.__dict__) 62 | hparam_dict['seed'] = seed 63 | hparam_dict['sizes'] = '-'.join([str(item) for item in cfg.model_args.sizes]) 64 | 65 | # Preprocess data 66 | preprocess(cfg) 67 | 68 | # Initialize dataset 69 | lbsn_dataset = LBSNDataset(cfg) 70 | cfg.dataset_args.spatial_slots = lbsn_dataset.spatial_slots 71 | cfg.dataset_args.num_user = lbsn_dataset.num_user 72 | cfg.dataset_args.num_poi = lbsn_dataset.num_poi 73 | cfg.dataset_args.num_category = lbsn_dataset.num_category 74 | cfg.dataset_args.padding_poi_id = lbsn_dataset.padding_poi_id 75 | cfg.dataset_args.padding_user_id = lbsn_dataset.padding_user_id 76 | cfg.dataset_args.padding_poi_category = lbsn_dataset.padding_poi_category 77 | cfg.dataset_args.padding_hour_id = lbsn_dataset.padding_hour_id 78 | cfg.dataset_args.padding_weekday_id = lbsn_dataset.padding_weekday_id 79 | 80 | # Initialize neighbor sampler(dataloader) 81 | sampler_train, sampler_validate, sampler_test = None, None, None 82 | 83 | if cfg.run_args.do_train: 84 | sampler_train = NeighborSampler( 85 | lbsn_dataset.x, 86 | lbsn_dataset.edge_index, 87 | lbsn_dataset.edge_attr, 88 | intra_jaccard_threshold=cfg.model_args.intra_jaccard_threshold, 89 | inter_jaccard_threshold=cfg.model_args.inter_jaccard_threshold, 90 | edge_t=lbsn_dataset.edge_t, 91 | edge_delta_t=lbsn_dataset.edge_delta_t, 92 | edge_type=lbsn_dataset.edge_type, 93 | sizes=sizes, 94 | sample_idx=lbsn_dataset.sample_idx_train, 95 | node_idx=lbsn_dataset.node_idx_train, 96 | edge_delta_s=lbsn_dataset.edge_delta_s, 97 | max_time=lbsn_dataset.max_time_train, 98 | label=lbsn_dataset.label_train, 99 | batch_size=cfg.run_args.batch_size, 100 | num_workers=0 if device == 'cpu' else cfg.run_args.num_workers, 101 | shuffle=True, 102 | pin_memory=True 103 | ) 104 | 105 | if cfg.run_args.do_validate: 106 | sampler_validate = NeighborSampler( 107 | lbsn_dataset.x, 108 | lbsn_dataset.edge_index, 109 | lbsn_dataset.edge_attr, 110 | intra_jaccard_threshold=cfg.model_args.intra_jaccard_threshold, 111 | inter_jaccard_threshold=cfg.model_args.inter_jaccard_threshold, 112 | edge_t=lbsn_dataset.edge_t, 113 | edge_delta_t=lbsn_dataset.edge_delta_t, 114 | edge_type=lbsn_dataset.edge_type, 115 | sizes=sizes, 116 | sample_idx=lbsn_dataset.sample_idx_valid, 117 | node_idx=lbsn_dataset.node_idx_valid, 118 | edge_delta_s=lbsn_dataset.edge_delta_s, 119 | max_time=lbsn_dataset.max_time_valid, 120 | label=lbsn_dataset.label_valid, 121 | batch_size=cfg.run_args.eval_batch_size, 122 | num_workers=0 if device == 'cpu' else cfg.run_args.num_workers, 123 | shuffle=False, 124 | pin_memory=True 125 | ) 126 | 127 | if cfg.run_args.do_test: 128 | sampler_test = NeighborSampler( 129 | lbsn_dataset.x, 130 | lbsn_dataset.edge_index, 131 | lbsn_dataset.edge_attr, 132 | intra_jaccard_threshold=cfg.model_args.intra_jaccard_threshold, 133 | inter_jaccard_threshold=cfg.model_args.inter_jaccard_threshold, 134 | edge_t=lbsn_dataset.edge_t, 135 | edge_delta_t=lbsn_dataset.edge_delta_t, 136 | edge_type=lbsn_dataset.edge_type, 137 | sizes=sizes, 138 | sample_idx=lbsn_dataset.sample_idx_test, 139 | node_idx=lbsn_dataset.node_idx_test, 140 | edge_delta_s=lbsn_dataset.edge_delta_s, 141 | max_time=lbsn_dataset.max_time_test, 142 | label=lbsn_dataset.label_test, 143 | batch_size=cfg.run_args.eval_batch_size, 144 | num_workers=0 if device == 'cpu' else cfg.run_args.num_workers, 145 | shuffle=False, 146 | pin_memory=True 147 | ) 148 | 149 | if cfg.model_args.model_name == 'sthgcn': 150 | model = STHGCN(cfg) 151 | elif cfg.model_args.model_name == 'seq_transformer': 152 | model = SequentialTransformer(cfg) 153 | else: 154 | raise NotImplementedError( 155 | f'[Training] Model {cfg.model_args.name}, please choose from ["sthgcn", "seq_transformer"]' 156 | ) 157 | 158 | model = model.to(device) 159 | logging.info(f'[Training] Seed: {seed}') 160 | logging.info('[Training] Model Parameter Configuration:') 161 | for name, param in model.named_parameters(): 162 | logging.info(f'[Training] Parameter {name}: {param.size()}, require_grad = {param.requires_grad}') 163 | logging.info(f'[Training] #Parameters: {count_parameters(model)}') 164 | 165 | if cfg.run_args.do_train: 166 | current_learning_rate = cfg.run_args.learning_rate 167 | optimizer = torch.optim.Adam( 168 | filter(lambda p: p.requires_grad, model.parameters()), 169 | lr=current_learning_rate 170 | ) 171 | if cfg.run_args.warm_up_steps: 172 | warm_up_steps = cfg.run_args.warm_up_steps 173 | else: 174 | warm_up_steps = cfg.run_args.max_steps // 2 175 | 176 | init_step = 0 177 | if cfg.run_args.init_checkpoint: 178 | # Restore model from checkpoint directory 179 | # manually set in yml 180 | logging.info(f'[Training] Loading checkpoint %s...' % cfg.run_args.init_checkpoint) 181 | checkpoint = torch.load(osp.join(cfg.run_args.init_checkpoint, 'checkpoint.pt')) 182 | init_step = checkpoint['step'] 183 | model.load_state_dict(checkpoint['model_state_dict']) 184 | current_learning_rate = checkpoint['current_learning_rate'] 185 | warm_up_steps = checkpoint['warm_up_steps'] 186 | cooldown_rate = checkpoint['cooldown_rate'] 187 | sizes = checkpoint['sizes'] 188 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 189 | else: 190 | logging.info(f'[Training] Randomly Initializing Model...') 191 | init_step = 0 192 | step = init_step 193 | 194 | # Set valid dataloader as it would be evaluated during training 195 | logging.info(f'[Training] Initial learning rate: {current_learning_rate}') 196 | 197 | # Training Loop 198 | best_metrics = 0.0 199 | global_step = 0 200 | for eph in range(cfg.run_args.epoch): 201 | training_logs = [] 202 | if global_step >= cfg.run_args.max_steps: 203 | break 204 | for data in tqdm(sampler_train): 205 | model.train() 206 | split_index = torch.max(data.adjs_t[1].storage.row()).tolist() 207 | data = data.to(device) 208 | input_data = { 209 | 'x': data.x, 210 | 'edge_index': data.adjs_t, 211 | 'edge_attr': data.edge_attrs, 212 | 'split_index': split_index, 213 | 'delta_ts': data.edge_delta_ts, 214 | 'delta_ss': data.edge_delta_ss, 215 | 'edge_type': data.edge_types 216 | } 217 | 218 | out, loss = model(input_data, label=data.y[:, 0]) 219 | training_logs.append(loss) 220 | optimizer.zero_grad() 221 | loss.backward() 222 | optimizer.step() 223 | summary_writer.add_scalar(f'train/loss_step', loss, global_step) 224 | 225 | if cfg.run_args.do_validate and global_step % cfg.run_args.valid_steps == 0: 226 | logging.info(f'[Evaluating] Evaluating on Valid Dataset...') 227 | 228 | logging.info(f'[Evaluating] Epoch {eph}, step {global_step}:') 229 | recall_res, ndcg_res, map_res, mrr_res, eval_loss = test_step(model, data=sampler_validate) 230 | summary_writer.add_scalar(f'validate/Recall@1', 100*recall_res[1], global_step) 231 | summary_writer.add_scalar(f'validate/Recall@5', 100*recall_res[5], global_step) 232 | summary_writer.add_scalar(f'validate/Recall@10', 100*recall_res[10], global_step) 233 | summary_writer.add_scalar(f'validate/Recall@20', 100*recall_res[20], global_step) 234 | summary_writer.add_scalar(f'validate/MRR', mrr_res, global_step) 235 | summary_writer.add_scalar(f'validate/eval_loss', eval_loss, global_step) 236 | summary_writer.add_scalar('train/learning_rate', current_learning_rate, global_step) 237 | 238 | metrics = 4 * recall_res[1] + recall_res[20] 239 | 240 | # save model based on compositional recall metrics 241 | if metrics > best_metrics: 242 | save_variable_list = { 243 | 'step': global_step, 244 | 'current_learning_rate': current_learning_rate, 245 | 'warm_up_steps': warm_up_steps, 246 | 'cooldown_rate': cfg.run_args.cooldown_rate, 247 | 'sizes': sizes 248 | } 249 | logging.info(f'[Training] Save model at step {global_step} epoch {eph}') 250 | save_model(model, optimizer, save_variable_list, cfg.run_args, hparam_dict) 251 | best_metrics = metrics 252 | 253 | # learning rate schedule 254 | if global_step >= warm_up_steps: 255 | current_learning_rate = current_learning_rate / 10 256 | logging.info(f'[Training] Change learning_rate to {current_learning_rate} at step {global_step}') 257 | optimizer = torch.optim.Adam( 258 | filter(lambda p: p.requires_grad, model.parameters()), 259 | lr=current_learning_rate 260 | ) 261 | warm_up_steps = warm_up_steps * cfg.run_args.cooldown_rate 262 | 263 | if global_step >= cfg.run_args.max_steps: 264 | break 265 | global_step += 1 266 | 267 | epoch_loss = sum([loss for loss in training_logs]) / len(training_logs) 268 | logging.info(f'[Training] Average train loss at step {global_step} is {epoch_loss}:') 269 | summary_writer.add_scalar('train/loss_epoch', epoch_loss, eph) 270 | 271 | if cfg.run_args.do_test: 272 | logging.info('[Evaluating] Start evaluating on test set...') 273 | 274 | checkpoint = torch.load(osp.join(cfg.run_args.save_path, 'checkpoint.pt')) 275 | model.load_state_dict(checkpoint['model_state_dict']) 276 | recall_res, ndcg_res, map_res, mrr_res, loss = test_step(model, sampler_test) 277 | num_params = count_parameters(model) 278 | metric_dict = { 279 | 'hparam/num_params': num_params, 280 | 'hparam/Recall@1': recall_res[1], 281 | 'hparam/Recall@5': recall_res[5], 282 | 'hparam/Recall@10': recall_res[10], 283 | 'hparam/Recall@20': recall_res[20], 284 | 'hparam/NDCG@1': ndcg_res[1], 285 | 'hparam/NDCG@5': ndcg_res[5], 286 | 'hparam/NDCG@10': ndcg_res[10], 287 | 'hparam/NDCG@20': ndcg_res[20], 288 | 'hparam/MAP@1': map_res[1], 289 | 'hparam/MAP@5': map_res[5], 290 | 'hparam/MAP@10': map_res[10], 291 | 'hparam/MAP@20': map_res[20], 292 | 'hparam/MRR': mrr_res, 293 | } 294 | logging.info(f'[Evaluating] Test evaluation result : {metric_dict}') 295 | summary_writer.add_hparams(hparam_dict, metric_dict) 296 | summary_writer.close() 297 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.math_util import ( 2 | cal_slot_distance, 3 | cal_slot_distance_batch, 4 | construct_slots, 5 | delta_t_calculate, 6 | ccorr, 7 | haversine 8 | ) 9 | from utils.sys_util import ( 10 | get_root_dir, 11 | set_logger, 12 | seed_torch 13 | ) 14 | from utils.pipeline_util import ( 15 | save_model, 16 | count_parameters, 17 | test_step 18 | ) 19 | from utils.conf_util import DictToObject, Cfg 20 | 21 | __all__ = [ 22 | "DictToObject", 23 | "Cfg", 24 | "cal_slot_distance", 25 | "cal_slot_distance_batch", 26 | "construct_slots", 27 | "delta_t_calculate", 28 | "ccorr", 29 | "haversine", 30 | "get_root_dir", 31 | "set_logger", 32 | "seed_torch", 33 | "save_model", 34 | "count_parameters", 35 | "test_step" 36 | ] 37 | -------------------------------------------------------------------------------- /utils/conf_util.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os.path as osp 3 | from utils import get_root_dir 4 | 5 | 6 | class DictToObject(object): 7 | def __init__(self, d): 8 | self.__dict__.update(d) 9 | 10 | def __repr__(self): 11 | return str(self.__dict__) 12 | 13 | 14 | class Cfg: 15 | def __init__(self, file_name): 16 | file_path = osp.join(get_root_dir(), 'conf', file_name) 17 | with open(file_path, "r") as f: 18 | conf = yaml.safe_load(f) 19 | self.model_args = DictToObject(conf.get('model_args', {})) 20 | self.conv_args = DictToObject(conf.get('conv_args', {})) 21 | self.seq_transformer_args = DictToObject(conf.get('seq_transformer_args', {})) 22 | self.run_args = DictToObject(conf.get('run_args', {})) 23 | self.dataset_args = DictToObject(conf.get('dataset_args', {})) 24 | -------------------------------------------------------------------------------- /utils/math_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch_sparse.tensor import SparseTensor 4 | from math import radians, cos, sin, asin, sqrt, exp 5 | import pandas as pd 6 | from bisect import bisect 7 | import time 8 | 9 | 10 | def com_mult(a, b): 11 | r1, i1 = a[..., 0], a[..., 1] 12 | r2, i2 = b[..., 0], b[..., 1] 13 | return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1) 14 | 15 | 16 | def conj(a): 17 | a[..., 1] = -a[..., 1] 18 | return a 19 | 20 | 21 | def ccorr(a, b): 22 | return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],)) 23 | 24 | 25 | def cal_slot_distance(value, slots): 26 | """ 27 | Calculate a value's distance with nearest lower bound and higher bound in slots. 28 | :param value: The value to be calculated. 29 | :param slots: values of slots, needed to be sorted. 30 | :return: normalized distance with lower bound and higher bound, 31 | and index of lower bound and higher bound. 32 | """ 33 | time1 = time.time() 34 | higher_bound = bisect(slots, value) 35 | time2 = time.time() 36 | lower_bound = higher_bound - 1 37 | if higher_bound == len(slots): 38 | return 1., 0., lower_bound, lower_bound, time2 - time1 39 | else: 40 | lower_value = slots[lower_bound] 41 | higher_value = slots[higher_bound] 42 | total_distance = higher_value - lower_value 43 | return (value - lower_value) / total_distance, ( 44 | higher_value - value) / total_distance, lower_bound, higher_bound, time2 - time1 45 | 46 | 47 | def cal_slot_distance_batch(batch_value, slots): 48 | """ 49 | Proceed `cal_slot_distance` on a batch of data. 50 | :param batch_value: a batch of value, size (batch_size, step) 51 | :param slots: values of slots, needed to be sorted. 52 | :return: batch of distances and indexes. All with shape (batch_size, step). 53 | """ 54 | # Lower bound distance, higher bound distance, lower bound, higher bound. 55 | 56 | ld, hd, l, h = [], [], [], [] 57 | time_cost_list = [] 58 | for step in batch_value: 59 | ld_one, hd_one, l_one, h_one, time_cost = cal_slot_distance(step, slots) 60 | ld.append(ld_one) 61 | hd.append(hd_one) 62 | l.append(l_one) 63 | h.append(h_one) 64 | time_cost_list.append(time_cost) 65 | print(f"total bisect time: {sum(time_cost_list)}") 66 | 67 | return torch.tensor(ld), torch.tensor(hd), torch.tensor(l), torch.tensor(h) 68 | 69 | 70 | def construct_slots(min_value, max_value, num_slots, type): 71 | """ 72 | Construct values of slots given min value and max value. 73 | :param min_value: minimum value. 74 | :param max_value: maximum value. 75 | :param num_slots: number of slots to construct. 76 | :param type: type of slots to construct, 'linear' or 'exp'. 77 | :return: values of slots. 78 | """ 79 | if type == 'exp': 80 | n = (max_value - min_value) / (exp(num_slots - 1) - 1) 81 | slots = [n * (exp(x) - 1) + min_value for x in range(num_slots)] 82 | slots.append(n * (num_slots - 1) + n * 100 + min_value) 83 | return slots 84 | elif type == 'linear': 85 | n = (max_value - min_value) / (num_slots - 1) 86 | slots = [n * x + min_value for x in range(num_slots-1)] 87 | slots.append(n*(num_slots-1)*100 + min_value) 88 | return slots 89 | 90 | 91 | def delta_t_calculate(x_year: Tensor, adj_t: SparseTensor): 92 | src_years = x_year[adj_t.storage.col()] 93 | tar_years = x_year[adj_t.storage.row()] 94 | delta_ts_pre = tar_years - src_years 95 | src_tar_mult = src_years * tar_years 96 | delta_ts = torch.where(src_tar_mult == 0, src_tar_mult, delta_ts_pre) 97 | return delta_ts 98 | 99 | 100 | def haversine(lon1, lat1, lon2, lat2): 101 | """ 102 | Calculate the great circle distance between two points 103 | on the earth (specified in decimal degrees) 104 | """ 105 | 106 | def row_wise(lon1, lat1, lon2, lat2): 107 | lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2]) 108 | dlon = lon2 - lon1 109 | dlat = lat2 - lat1 110 | a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2 111 | c = 2 * asin(sqrt(a)) 112 | return c 113 | 114 | if isinstance(lon1, torch.Tensor): 115 | if not lon1.numel(): 116 | return None 117 | lon1 = torch.deg2rad(lon1) 118 | lat1 = torch.deg2rad(lat1) 119 | lon2 = torch.deg2rad(lon2) 120 | lat2 = torch.deg2rad(lat2) 121 | dlon = lon2 - lon1 122 | dlat = lat2 - lat1 123 | a = torch.sin(dlat / 2) ** 2 + torch.cos(lat1) * torch.cos(lat2) * torch.sin(dlon / 2) ** 2 124 | c = 2 * torch.asin(torch.sqrt(a)) 125 | elif isinstance(lon1, pd.Series): 126 | if not lon1.shape[0]: 127 | return None 128 | lon_lat = pd.concat([lon1, lat1, lon2, lat2], axis=1) 129 | c = lon_lat.apply(lambda x: row_wise(x[0], x[1], x[2], x[3]), axis=1) 130 | else: 131 | if pd.isna(lon1) or pd.isna(lat1) or pd.isna(lon2) or pd.isna(lat2): 132 | return None 133 | c = row_wise(lon1, lat1, lon2, lat2) 134 | 135 | r = 6371 136 | return c * r 137 | -------------------------------------------------------------------------------- /utils/pipeline_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import logging 4 | import numpy as np 5 | from tqdm import tqdm 6 | import os.path as osp 7 | from metric import ( 8 | recall, 9 | ndcg, 10 | map_k, 11 | mrr 12 | ) 13 | 14 | 15 | def save_model(model, optimizer, save_variable_list, run_args, argparse_dict): 16 | """ 17 | Save the parameters of the model and the optimizer, 18 | as well as some other variables such as step and learning_rate 19 | """ 20 | with open(osp.join(run_args.log_path, 'config.json'), 'w') as fjson: 21 | for key, value in argparse_dict.items(): 22 | if isinstance(value, torch.Tensor): 23 | argparse_dict[key] = value.numpy().tolist() 24 | json.dump(argparse_dict, fjson) 25 | 26 | torch.save({ 27 | **save_variable_list, 28 | 'model_state_dict': model.state_dict(), 29 | 'optimizer_state_dict': optimizer.state_dict()}, 30 | osp.join(run_args.save_path, 'checkpoint.pt') 31 | ) 32 | 33 | 34 | def count_parameters(model): 35 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 36 | 37 | 38 | def test_step(model, data, ks=(1, 5, 10, 20)): 39 | model.eval() 40 | loss_list = [] 41 | pred_list = [] 42 | label_list = [] 43 | with torch.no_grad(): 44 | for row in tqdm(data): 45 | split_index = torch.max(row.adjs_t[1].storage.row()).tolist() 46 | row = row.to(model.device) 47 | 48 | input_data = { 49 | 'x': row.x, 50 | 'edge_index': row.adjs_t, 51 | 'edge_attr': row.edge_attrs, 52 | 'split_index': split_index, 53 | 'delta_ts': row.edge_delta_ts, 54 | 'delta_ss': row.edge_delta_ss, 55 | 'edge_type': row.edge_types 56 | } 57 | 58 | out, loss = model(input_data, label=row.y[:, 0], mode='test') 59 | loss_list.append(loss.cpu().detach().numpy().tolist()) 60 | ranking = torch.sort(out, descending=True)[1] 61 | pred_list.append(ranking.cpu().detach()) 62 | label_list.append(row.y[:, :1].cpu()) 63 | pred_ = torch.cat(pred_list, dim=0) 64 | label_ = torch.cat(label_list, dim=0) 65 | recalls, NDCGs, MAPs = {}, {}, {} 66 | logging.info(f"[Evaluating] Average loss: {np.mean(loss_list)}") 67 | for k_ in ks: 68 | recalls[k_] = recall(label_, pred_, k_).cpu().detach().numpy().tolist() 69 | NDCGs[k_] = ndcg(label_, pred_, k_).cpu().detach().numpy().tolist() 70 | MAPs[k_] = map_k(label_, pred_, k_).cpu().detach().numpy().tolist() 71 | logging.info(f"[Evaluating] Recall@{k_} : {recalls[k_]},\tNDCG@{k_} : {NDCGs[k_]},\tMAP@{k_} : {MAPs[k_]}") 72 | mrr_res = mrr(label_, pred_).cpu().detach().numpy().tolist() 73 | logging.info(f"[Evaluating] MRR : {mrr_res}") 74 | return recalls, NDCGs, MAPs, mrr_res, np.mean(loss_list) 75 | -------------------------------------------------------------------------------- /utils/sys_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | import torch 5 | import numpy as np 6 | import os.path as osp 7 | 8 | 9 | def get_root_dir(): 10 | dirname = os.getcwd() 11 | dirname_split = dirname.split("/") 12 | index = dirname_split.index("Spatio-Temporal-Hypergraph-Model") 13 | dirname = "/".join(dirname_split[:index + 1]) 14 | return dirname 15 | 16 | 17 | def set_logger(args): 18 | """ 19 | Write logs to checkpoint and console 20 | """ 21 | if args.do_train: 22 | log_file = osp.join(args.log_path or args.init_checkpoint, 'train.log') 23 | else: 24 | log_file = osp.join(args.log_path or args.init_checkpoint, 'test.log') 25 | 26 | # Remove all handlers associated with the root logger object 27 | for handler in logging.root.handlers: 28 | logging.root.removeHandler(handler) 29 | 30 | logging.basicConfig( 31 | format='%(asctime)s %(levelname)-8s %(message)s', 32 | level=logging.INFO, 33 | datefmt='%Y-%m-%d %H:%M:%S', 34 | filename=log_file, 35 | filemode='w+' 36 | ) 37 | 38 | 39 | def seed_torch(seed=42): 40 | random.seed(seed) 41 | os.environ['PYTHONHASHSEED'] = str(seed) 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | torch.backends.cudnn.deterministic = True 47 | torch.backends.cudnn.benchmark = False 48 | torch.backends.cudnn.enabled = True 49 | --------------------------------------------------------------------------------