├── .coveragerc ├── .flake8 ├── .github └── workflows │ └── publish.yml ├── .gitignore ├── .style.yapf ├── LICENSE ├── MANIFEST.in ├── README.md ├── azure-pipelines.yml ├── requirements.txt ├── setup.py ├── srdatasets ├── __init__.py ├── __main__.py ├── dataloader.py ├── dataloader_pytorch.py ├── datasets │ ├── __init__.py │ ├── amazon.py │ ├── citeulike.py │ ├── dataset.py │ ├── foursquare.py │ ├── gowalla.py │ ├── lastfm1k.py │ ├── movielens20m.py │ ├── retailrocket.py │ ├── tafeng.py │ ├── taobao.py │ ├── tmall.py │ ├── utils.py │ └── yelp.py ├── download.py ├── process.py └── utils.py └── tests ├── datasets ├── test_amazon.py ├── test_citeulike.py ├── test_foursquare.py ├── test_gowalla.py ├── test_lastfm1k.py ├── test_movielens20m.py └── test_tafeng.py ├── test_dataloader.py └── test_process.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = srdatasets 3 | omit = srdatasets/__main__.py 4 | 5 | [report] 6 | exclude_lines = 7 | pragma: no cover 8 | raise 9 | except 10 | if __name__ == .__main__.: -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501, E203, W503, E722, W293, E125, E126 3 | max-line-length = 120 -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: publish 2 | 3 | on: 4 | push: 5 | tags: 6 | - "*" 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@master 13 | - name: Setup Python 14 | uses: actions/setup-python@v1 15 | with: 16 | python-version: 3.6 17 | - name: Build 18 | run: | 19 | pip install setuptools wheel 20 | python setup.py sdist bdist_wheel 21 | - name: Publish 22 | uses: pypa/gh-action-pypi-publish@master 23 | with: 24 | user: __token__ 25 | password: ${{ secrets.pypi_password }} 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .vscode 106 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | column_limit = 120 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 Cheng Guo 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE requirements.txt version.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://dev.azure.com/guocheng672/sequential-recommendation-datasets/_apis/build/status/guocheng2018.sequential-recommendation-datasets?branchName=master)](https://dev.azure.com/guocheng672/sequential-recommendation-datasets/_build/latest?definitionId=1&branchName=master) 2 | [![codebeat badge](https://codebeat.co/badges/a6b38c4a-dffd-4318-9e69-969f57526b77)](https://codebeat.co/projects/github-com-guocheng2018-sequential-recommendation-datasets-master) 3 | 4 | # Sequential Recommendation Datasets 5 | 6 | This repository collects some commonly used sequential recommendation datasets in recent research papers and provides a tool for downloading, preprocessing and batch-loading those datasets. The preprocessing method can be customized based on the task, for example: short-term recommendation (including session-based recommendation) and long-short term recommendation. Loading has faster version which intergrates the DataLoader of PyTorch. 7 | 8 | ## Datasets 9 | 10 | - [Amazon-Books](http://jmcauley.ucsd.edu/data/amazon/) 11 | - [Amazon-Electronics](http://jmcauley.ucsd.edu/data/amazon/) 12 | - [Amazon-Movies](http://jmcauley.ucsd.edu/data/amazon/) 13 | - [Amazon-CDs](http://jmcauley.ucsd.edu/data/amazon/) 14 | - [Amazon-Clothing](http://jmcauley.ucsd.edu/data/amazon/) 15 | - [Amazon-Home](http://jmcauley.ucsd.edu/data/amazon/) 16 | - [Amazon-Kindle](http://jmcauley.ucsd.edu/data/amazon/) 17 | - [Amazon-Sports](http://jmcauley.ucsd.edu/data/amazon/) 18 | - [Amazon-Phones](http://jmcauley.ucsd.edu/data/amazon/) 19 | - [Amazon-Health](http://jmcauley.ucsd.edu/data/amazon/) 20 | - [Amazon-Toys](http://jmcauley.ucsd.edu/data/amazon/) 21 | - [Amazon-VideoGames](http://jmcauley.ucsd.edu/data/amazon/) 22 | - [Amazon-Tools](http://jmcauley.ucsd.edu/data/amazon/) 23 | - [Amazon-Beauty](http://jmcauley.ucsd.edu/data/amazon/) 24 | - [Amazon-Apps](http://jmcauley.ucsd.edu/data/amazon/) 25 | - [Amazon-Office](http://jmcauley.ucsd.edu/data/amazon/) 26 | - [Amazon-Pet](http://jmcauley.ucsd.edu/data/amazon/) 27 | - [Amazon-Automotive](http://jmcauley.ucsd.edu/data/amazon/) 28 | - [Amazon-Grocery](http://jmcauley.ucsd.edu/data/amazon/) 29 | - [Amazon-Patio](http://jmcauley.ucsd.edu/data/amazon/) 30 | - [Amazon-Baby](http://jmcauley.ucsd.edu/data/amazon/) 31 | - [Amazon-Music](http://jmcauley.ucsd.edu/data/amazon/) 32 | - [Amazon-MusicalInstruments](http://jmcauley.ucsd.edu/data/amazon/) 33 | - [Amazon-InstantVideo](http://jmcauley.ucsd.edu/data/amazon/) 34 | - [CiteULike](http://konect.cc/networks/citeulike-ut/) 35 | - [FourSquare-NYC](https://sites.google.com/site/yangdingqi/home/foursquare-dataset) 36 | - [FourSquare-Tokyo](https://sites.google.com/site/yangdingqi/home/foursquare-dataset) 37 | - [Gowalla](https://snap.stanford.edu/data/loc-Gowalla.html) 38 | - [Lastfm1K](http://ocelma.net/MusicRecommendationDataset/lastfm-1K.html) 39 | - [MovieLens20M](https://grouplens.org/datasets/movielens/) 40 | - [Retailrocket](https://www.kaggle.com/retailrocket/ecommerce-dataset) 41 | - [TaFeng](https://stackoverflow.com/a/25460645/8810037) 42 | - [Taobao](https://tianchi.aliyun.com/dataset/dataDetail?dataId=649) 43 | - [Tmall](https://tianchi.aliyun.com/dataset/dataDetail?dataId=47) 44 | - [Yelp](https://www.yelp.com/dataset) 45 | 46 | ## Install this tool 47 | 48 | Stable version 49 | ```bash 50 | pip install -U srdatasets —-user 51 | ``` 52 | 53 | Latest version 54 | ```bash 55 | pip install git+https://github.com/guocheng2018/sequential-recommendation-datasets.git --user 56 | ``` 57 | 58 | ## Download datasets 59 | 60 | Run the command below to download datasets. As some datasets are not directly accessible, you'll be warned to download them manually and place them somewhere it tells you. 61 | 62 | ```bash 63 | srdatasets download --dataset=[dataset_name] 64 | ``` 65 | 66 | To get a view of downloaded and processed status of all datasets, run 67 | 68 | ```bash 69 | srdatasets info 70 | ``` 71 | 72 | ## Process datasets 73 | 74 | The generic processing command is 75 | 76 | ```bash 77 | srdatasets process --dataset=[dataset_name] [--options] 78 | ``` 79 | 80 | ### Splitting options 81 | 82 | Two dataset splitting methods are provided: **user-based** and **time-based**. User-based means that splitting is executed on every user behavior sequence given the ratio of validation set and test set, while time-based means that splitting is based on the date of user behaviors. After splitting some dataset, two processed datasets are generated, one for development, which uses the validation set as the test set, the other for test, which contains the full training set. 83 | 84 | ```code 85 | --split-by User or time (default: user) 86 | --test-split Proportion of test set to full dataset (default: 0.2) 87 | --dev-split Proportion of validation set to full training set (default: 0.1) 88 | ``` 89 | 90 | **NOTE**: time-based splitting need you to manually input days at console by tipping you total days of that dataset, since you may not know. 91 | 92 | ### Task related options 93 | 94 | For **short term** recommnedation task, you use previous `input-len` items to predict next `target-len` items. To make user interests more focused, user behavior sequences can also be cut into sessions if `session-interval` is given. If the number of previous items is smaller than `input-len`, 0 is padded to the left. 95 | 96 | For **long and short term** recommendation task, you use `pre-sessions` previous sessions and current session to predict `target-len` items. The target items are picked randomly or lastly from current session. So the length of current session is `max-session-len` - `target-len` while the length of any previous session is `max-session-len`. If any previous session or current session is shorter than the preset length, 0 is padded to the left. 97 | 98 | ```code 99 | --task Short or long-short (default: short) 100 | --input-len Number of previous items (default: 5) 101 | --target-len Number of target items (default: 1) 102 | --pre-sessions Number of previous sessions (default: 10) 103 | --pick-targets Randomly or lastly pick items from current session (default: random) 104 | --session-interval Session splitting interval (minutes) (default: 0) 105 | --min-session-len Sessions less than this in length will be dropped (default: 2) 106 | --max-session-len Sessions greater than this in length will be cut (default: 20) 107 | ``` 108 | 109 | ### Common options 110 | 111 | ```code 112 | --min-freq-item Items less than this in frequency will be dropped (default: 5) 113 | --min-freq-user Users less than this in frequency will be dropped (default: 5) 114 | --no-augment Do not use data augmentation (default: False) 115 | --remove-duplicates Remove duplicated items in user sequence or user session (if splitted) (default: False) 116 | ``` 117 | 118 | ### Dataset related options 119 | 120 | ```code 121 | --rating-threshold Interactions with rating less than this will be dropped (Amazon, Movielens, Yelp) (default: 4) 122 | --item-type Recommend artists or songs (Lastfm) (default: song) 123 | ``` 124 | 125 | ### Version 126 | 127 | By using different options, a dataset will have many processed versions. You can run the command below to get configurations and statistics of all processed versions of some dataset. The `config id` shown in output is a required argument of `DataLoader`. 128 | 129 | ```bash 130 | srdatasets info --dataset=[dataset_name] 131 | ``` 132 | 133 | ## DataLoader 134 | 135 | DataLoader is a built-in class that makes loading processed datasets easy. Practically, once initialized a dataloder by passing the dataset name, processed version (config id), batch_size and a flag to load training data or test data, you can then loop it to get batch data. Considering that some models use rank-based learning, negative sampling is intergrated into DataLoader. The negatives are sampled from all items except items in current data according to popularity. By default it (`negatives_per_target`) is turned off. Also, the time of user behaviors is sometimes an important feature, you can include it into batch data by setting `include_timestmap` to True. 136 | 137 | ### Arguments 138 | 139 | - `dataset_name`: dataset name (case insensitive) 140 | - `config_id`: configuration id 141 | - `batch_size`: batch size (default: 1) 142 | - `train`: load training dataset (default: True) 143 | - `development`: load the dataset aiming for development (default: False) 144 | - `negatives_per_target`: number of negative samples per target (default: 0) 145 | - `include_timestamp`: add timestamps to batch data (default: False) 146 | - `drop_last`: drop last incomplete batch (default: False) 147 | 148 | ### Attributes 149 | 150 | - `num_users`: total users in training dataset 151 | - `num_items`: total items in training dataset (not including the padding item 0) 152 | 153 | ### Initialization example 154 | 155 | ```python 156 | from srdatasets.dataloader import DataLoader 157 | 158 | trainloader = DataLoader("amazon-books", "c1574673118829", batch_size=32, train=True, negatives_per_target=5, include_timestamp=True) 159 | testloader = DataLoader("amazon-books", "c1574673118829", batch_size=32, train=False, include_timestamp=True) 160 | ``` 161 | 162 | For pytorch users, there is a wrapper implementation of `torch.utils.data.DataLoader`, you can then set keyword arguments like `num_workers` and `pin_memory` to speed up loading data 163 | 164 | ```python 165 | from srdatasets.dataloader_pytorch import DataLoader 166 | 167 | trainloader = DataLoader("amazon-books", "c1574673118829", batch_size=32, train=True, negatives_per_target=5, include_timestamp=True, num_workers=8, pin_memory=True) 168 | testloader = DataLoader("amazon-books", "c1574673118829", batch_size=32, train=False, include_timestamp=True, num_workers=8, pin_memory=True) 169 | ``` 170 | 171 | ### Iteration template 172 | 173 | For short term recommendation task 174 | 175 | ```python 176 | for epoch in range(10): 177 | # Train 178 | for users, input_items, target_items, input_item_timestamps, target_item_timestamps, negative_samples in trainloader: 179 | # Shape 180 | # users: (batch_size,) 181 | # input_items: (batch_size, input_len) 182 | # target_items: (batch_size, target_len) 183 | # input_item_timestamps: (batch_size, input_len) 184 | # target_item_timestamps: (batch_size, target_len) 185 | # negative_samples: (batch_size, target_len, negatives_per_target) 186 | # 187 | # DataType 188 | # numpy.ndarray or torch.LongTensor 189 | pass 190 | 191 | # Test 192 | for users, input_items, target_items, input_item_timestamps, target_item_timestamps in testloader: 193 | pass 194 | ``` 195 | 196 | For long and short term recommendation task 197 | 198 | ```python 199 | for epoch in range(10): 200 | # Train 201 | for users, pre_sessions_items, cur_session_items, target_items, pre_sessions_item_timestamps, cur_session_item_timestamps, target_item_timestamps, negative_samples in trainloader: 202 | # Shape 203 | # users: (batch_size,) 204 | # pre_sessions_items: (batch_size, pre_sessions * max_session_len) 205 | # cur_session_items: (batch_size, max_session_len - target_len) 206 | # target_items: (batch_size, target_len) 207 | # pre_sessions_item_timestamps: (batch_size, pre_sessions * max_session_len) 208 | # cur_session_item_timestamps: (batch_size, max_session_len - target_len) 209 | # target_item_timestamps: (batch_size, target_len) 210 | # negative_samples: (batch_size, target_len, negatives_per_target) 211 | # 212 | # DataType 213 | # numpy.ndarray or torch.LongTensor 214 | pass 215 | 216 | # Test 217 | for users, pre_sessions_items, cur_session_items, target_items, pre_sessions_item_timestamps, cur_session_item_timestamps, target_item_timestamps in testloader: 218 | pass 219 | ``` 220 | 221 | ## Disclaimers 222 | 223 | This repo does not host or distribute any of the datasets, it is your responsibility to determine whether you have permission to use the dataset under the dataset's license. 224 | -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | trigger: 2 | - master 3 | 4 | pool: 5 | vmImage: "ubuntu-latest" 6 | strategy: 7 | matrix: 8 | Python36: 9 | python.version: "3.6" 10 | 11 | steps: 12 | - task: UsePythonVersion@0 13 | inputs: 14 | versionSpec: "$(python.version)" 15 | displayName: "Use Python $(python.version)" 16 | 17 | - script: | 18 | python -m pip install --upgrade pip 19 | pip install -r requirements.txt 20 | pip install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html 21 | pip install -e . 22 | displayName: "Install dependencies" 23 | 24 | - script: | 25 | pip install flake8 26 | flake8 srdatasets tests 27 | displayName: "Run lint tests" 28 | 29 | - script: | 30 | pip install pytest pytest-azurepipelines pytest-cov 31 | pytest --ignore=tests/datasets/test_citeulike.py --cov=srdatasets --cov-report=html tests 32 | displayName: "pytest" 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas>=0.25.0 2 | tqdm>=4.33.0 3 | tabulate>=0.8.5 4 | numpy>=1.16.4 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | with open("requirements.txt", "r") as f: 7 | install_requires = f.read().splitlines() 8 | 9 | setup(name="srdatasets", 10 | version="0.1.4", 11 | author="Cheng Guo", 12 | author_email="guocheng672@gmail.com", 13 | description="A collection of research ready datasets for sequential recommendation", 14 | long_description=long_description, 15 | long_description_content_type="text/markdown", 16 | url="https://github.com/guocheng2018/sequential-recommendation-datasets", 17 | packages=find_packages(), 18 | python_requires=">=3.6", 19 | install_requires=install_requires, 20 | classifiers=[ 21 | "Development Status :: 4 - Beta", 22 | "Intended Audience :: Developers", 23 | "Intended Audience :: Science/Research", 24 | "License :: OSI Approved :: Apache Software License", 25 | ], 26 | entry_points={"console_scripts": ["srdatasets=srdatasets.__main__:main"]}) 27 | -------------------------------------------------------------------------------- /srdatasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guocheng2025/Sequential-Recommendation-Datasets/198c43962e5f15a9d99d0b17c156a9603d91a1c2/srdatasets/__init__.py -------------------------------------------------------------------------------- /srdatasets/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import sys 4 | 5 | from pandas.io.json import json_normalize 6 | from tabulate import tabulate 7 | 8 | from srdatasets.datasets import __datasets__ 9 | from srdatasets.download import _download 10 | from srdatasets.process import _process 11 | from srdatasets.utils import __warehouse__, get_datasetname, get_downloaded_datasets, get_processed_datasets, read_json 12 | 13 | logging.basicConfig(level=logging.INFO, 14 | format="%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s", 15 | datefmt="%m/%d/%Y %H:%M:%S") 16 | 17 | 18 | def read_arguments(): 19 | parser = argparse.ArgumentParser("srdatasets | python -m srdatasets") 20 | subparsers = parser.add_subparsers(help="commands", dest="command") 21 | # info 22 | parser_i = subparsers.add_parser("info", help="print local datasets info") 23 | parser_i.add_argument("--dataset", type=str, default=None, help="dataset name") 24 | 25 | # download 26 | parser_d = subparsers.add_parser("download", help="download datasets") 27 | parser_d.add_argument("--dataset", type=str, required=True, help="dataset name") 28 | 29 | # process 30 | parser_g = subparsers.add_parser("process", 31 | help="process datasets", 32 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 33 | parser_g.add_argument("--dataset", type=str, required=True, help="dataset name") 34 | parser_g.add_argument("--min-freq-item", type=int, default=5, help="minimum occurrence times of item") 35 | parser_g.add_argument("--min-freq-user", type=int, default=5, help="minimum occurrence times of user") 36 | parser_g.add_argument("--task", 37 | type=str, 38 | choices=["short", "long-short"], 39 | default="short", 40 | help="short-term task or long-short-term task") 41 | parser_g.add_argument("--split-by", 42 | type=str, 43 | choices=["user", "time"], 44 | default="user", 45 | help="user-based or time-based dataset splitting") 46 | parser_g.add_argument("--dev-split", 47 | type=float, 48 | default=0.1, 49 | help="[user-split] the fraction of developemnt dataset") 50 | parser_g.add_argument("--test-split", type=float, default=0.2, help="[user-split] the fraction of test dataset") 51 | parser_g.add_argument("--input-len", type=int, default=5, help="[short] input sequence length") 52 | parser_g.add_argument("--target-len", type=int, default=1, help="target sequence length") 53 | parser_g.add_argument("--no-augment", action="store_true", help="do not use data augmentation") 54 | parser_g.add_argument("--remove-duplicates", action="store_true", help="remove duplicate items in user sequence") 55 | parser_g.add_argument("--session-interval", 56 | type=int, 57 | default=0, 58 | help="[short-optional, long-short-required] split user sequences into sessions (minutes)") 59 | parser_g.add_argument("--max-session-len", type=int, default=20, help="max session length") 60 | parser_g.add_argument("--min-session-len", type=int, default=2, help="min session length") 61 | parser_g.add_argument("--pre-sessions", type=int, default=10, help="[long-short] number of previous sessions") 62 | parser_g.add_argument("--pick-targets", 63 | type=str, 64 | choices=["last", "random"], 65 | default="random", 66 | help="[long-short] pick T random or last items from current session as targets") 67 | parser_g.add_argument("--rating-threshold", 68 | type=int, 69 | default=4, 70 | help="[Amazon-X, Movielens20M, Yelp] ratings great or equal than this are treated as valid") 71 | parser_g.add_argument("--item-type", 72 | type=str, 73 | choices=["song", "artist"], 74 | default="song", 75 | help="[Lastfm1K] set item to song or artist") 76 | args = parser.parse_args() 77 | return args, parser 78 | 79 | 80 | def handle_dowload(args, downloaded_datasets): 81 | if args.dataset not in __datasets__: 82 | raise ValueError("Supported datasets: {}".format(", ".join(__datasets__))) 83 | if args.dataset in downloaded_datasets: 84 | raise ValueError("{} has been downloaded".format(args.dataset)) 85 | _download(args.dataset) 86 | 87 | 88 | def handle_process(args, downloaded_datasets, processed_datasets): 89 | if args.dataset not in __datasets__: 90 | raise ValueError("Supported datasets: {}".format(", ".join(__datasets__))) 91 | if args.dataset not in downloaded_datasets: 92 | raise ValueError("{} has not been downloaded".format(args.dataset)) 93 | 94 | if args.split_by == "user": 95 | if args.dev_split <= 0 or args.dev_split >= 1: 96 | raise ValueError("dev split ratio should be in (0, 1)") 97 | if args.test_split <= 0 or args.test_split >= 1: 98 | raise ValueError("test split ratio should be in (0, 1)") 99 | 100 | if args.task == "short": 101 | if args.input_len <= 0: 102 | raise ValueError("input length must > 0") 103 | if args.session_interval < 0: 104 | raise ValueError("session interval must >= 0 minutes") 105 | else: 106 | if args.session_interval <= 0: 107 | raise ValueError("session interval must > 0 minutes") 108 | if args.pre_sessions < 1: 109 | raise ValueError("number of previous sessions must > 0") 110 | 111 | if args.target_len <= 0: 112 | raise ValueError("target length must > 0") 113 | 114 | if args.session_interval > 0: 115 | if args.min_session_len <= args.target_len: 116 | raise ValueError("min session length must > target length") 117 | if args.max_session_len < args.min_session_len: 118 | raise ValueError("max session length must >= min session length") 119 | 120 | if args.dataset in processed_datasets: 121 | # TODO Improve processed check when some arguments are not used 122 | time_splits = {} 123 | for c in processed_datasets[args.dataset]: 124 | config = read_json(__warehouse__.joinpath(args.dataset, "processed", c, "config.json")) 125 | if args.split_by == "user" and all([args.__dict__[k] == v for k, v in config.items()]): 126 | print("You have run this config, the config id is: {}".format(c)) 127 | sys.exit(1) 128 | if args.split_by == "time" and all( 129 | [args.__dict__[k] == v for k, v in config.items() if k not in ["dev_split", "test_split"]]): 130 | time_splits[(config["dev_split"], config["test_split"])] = c 131 | args.time_splits = time_splits 132 | _process(args) 133 | 134 | 135 | def handle_info(args, downloaded_datasets, processed_datasets): 136 | if args.dataset is None: 137 | table = [[ 138 | d, "Y" if d in downloaded_datasets else "", 139 | len(processed_datasets[d]) if d in processed_datasets else "" 140 | ] for d in __datasets__] 141 | print(tabulate(table, headers=["name", "downloaded", "processed configs"], tablefmt="psql")) 142 | else: 143 | if args.dataset not in __datasets__: 144 | raise ValueError("Supported datasets: {}".format(", ".join(__datasets__))) 145 | if args.dataset not in downloaded_datasets: 146 | print("{} has not been downloaded".format(args.dataset)) 147 | else: 148 | if args.dataset not in processed_datasets: 149 | print("{} has not been processed".format(args.dataset)) 150 | else: 151 | configs = json_normalize([ 152 | read_json(__warehouse__.joinpath(args.dataset, "processed", c, "config.json")) 153 | for c in processed_datasets[args.dataset] 154 | ]) 155 | print("Configs") 156 | configs_part1 = configs.iloc[:, :8] 157 | configs_part1.insert(0, "config id", processed_datasets[args.dataset]) 158 | print(tabulate(configs_part1, headers="keys", showindex=False, tablefmt="psql")) 159 | print() 160 | configs_part2 = configs.iloc[:, 8:] 161 | configs_part2.insert(0, "config id", processed_datasets[args.dataset]) 162 | print(tabulate(configs_part2, headers="keys", showindex=False, tablefmt="psql")) 163 | print("\nStats") 164 | stats = json_normalize([ 165 | read_json(__warehouse__.joinpath(args.dataset, "processed", c, m, "stats.json")) 166 | for c in processed_datasets[args.dataset] for m in ["dev", "test"] 167 | ]) 168 | modes = ["development", "test"] * len(processed_datasets[args.dataset]) 169 | stats.insert(0, "mode", modes) 170 | ids = [] 171 | for c in processed_datasets[args.dataset]: 172 | ids.extend([c, ""]) 173 | stats.insert(0, "config id", ids) 174 | print(tabulate(stats, headers="keys", showindex=False, tablefmt="psql")) 175 | 176 | 177 | def main(): 178 | args, parser = read_arguments() 179 | if "dataset" in args and args.dataset is not None: 180 | args.dataset = get_datasetname(args.dataset) 181 | 182 | if args.command is None: 183 | parser.print_help() 184 | else: 185 | downloaded_datasets = get_downloaded_datasets() 186 | processed_datasets = get_processed_datasets() 187 | 188 | if args.command == "download": 189 | handle_dowload(args, downloaded_datasets) 190 | elif args.command == "process": 191 | handle_process(args, downloaded_datasets, processed_datasets) 192 | else: 193 | handle_info(args, downloaded_datasets, processed_datasets) 194 | 195 | 196 | if __name__ == "__main__": 197 | main() 198 | -------------------------------------------------------------------------------- /srdatasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import logging 4 | import math 5 | import pickle 6 | import random 7 | from collections import Counter 8 | 9 | import numpy as np 10 | 11 | from srdatasets.datasets import __datasets__ 12 | from srdatasets.utils import __warehouse__, get_datasetname, get_processed_datasets 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class DataLoader: 18 | def __init__(self, 19 | dataset_name: str, 20 | config_id: str, 21 | batch_size: int = 1, 22 | train: bool = True, 23 | development: bool = False, 24 | negatives_per_target: int = 0, 25 | include_timestamp: bool = False, 26 | drop_last: bool = False): 27 | """Loader of sequential recommendation datasets 28 | 29 | Args: 30 | dataset_name (str): dataset name. 31 | config_id (str): dataset config id 32 | batch_size (int): batch_size 33 | train (bool, optional): load training data 34 | development (bool, optional): use the dataset for hyperparameter tuning 35 | negatives_per_target (int, optional): number of negative samples per target 36 | include_timestamp (bool, optional): add timestamps to batch data 37 | drop_last (bool, optional): drop last incomplete batch 38 | 39 | Note: training data is shuffled automatically. 40 | """ 41 | dataset_name = get_datasetname(dataset_name) 42 | 43 | if dataset_name not in __datasets__: 44 | raise ValueError("Unrecognized dataset, currently supported datasets: {}".format(", ".join(__datasets__))) 45 | 46 | _processed_datasets = get_processed_datasets() 47 | if dataset_name not in _processed_datasets: 48 | raise ValueError("{} has not been processed, currently processed datasets: {}".format( 49 | dataset_name, ", ".join(_processed_datasets) if _processed_datasets else "none")) 50 | 51 | if config_id not in _processed_datasets[dataset_name]: 52 | raise ValueError("Unrecognized config id, existing config ids: {}".format(", ".join( 53 | _processed_datasets[dataset_name]))) 54 | 55 | if negatives_per_target < 0: 56 | negatives_per_target = 0 57 | logger.warning("Number of negative samples per target should >= 0, reset to 0") 58 | 59 | if not train and negatives_per_target > 0: 60 | logger.warning( 61 | "Negative samples are used for training, set negatives_per_target has no effect when testing") 62 | 63 | dataset_dir = __warehouse__.joinpath(dataset_name, "processed", config_id, "dev" if development else "test") 64 | with open(dataset_dir.joinpath("stats.json"), "r") as f: 65 | self.stats = json.load(f) 66 | 67 | dataset_path = dataset_dir.joinpath("train.pkl" if train else "test.pkl") 68 | with open(dataset_path, "rb") as f: 69 | self.dataset = pickle.load(f) # list 70 | 71 | if train: 72 | counter = Counter() 73 | for data in self.dataset: 74 | if len(data) > 5: 75 | counter.update(data[1] + data[2] + data[3]) 76 | else: 77 | counter.update(data[1] + data[2]) 78 | self.item_counts = np.array([counter[i] for i in range(max(counter.keys()) + 1)]) 79 | 80 | if batch_size <= 0: 81 | raise ValueError("batch_size should >= 1") 82 | if batch_size > len(self.dataset): 83 | raise ValueError("batch_size exceeds the dataset size") 84 | 85 | self.batch_size = batch_size 86 | self.train = train 87 | self.include_timestamp = include_timestamp 88 | self.negatives_per_target = negatives_per_target 89 | self.drop_last = drop_last 90 | self._batch_idx = 0 91 | 92 | @property 93 | def num_users(self): 94 | return self.stats["users"] 95 | 96 | @property 97 | def num_items(self): 98 | return self.stats["items"] 99 | 100 | def __iter__(self): 101 | return self 102 | 103 | def __len__(self): 104 | """Number of batches 105 | """ 106 | if self.drop_last: 107 | return math.floor(len(self.dataset) / self.batch_size) 108 | else: 109 | return math.ceil(len(self.dataset) / self.batch_size) 110 | 111 | def sample_negatives(self, batch_items_list): 112 | negatives = [] 113 | for b in np.concatenate(batch_items_list, 1): 114 | item_counts = copy.deepcopy(self.item_counts) 115 | item_counts[b] = 0 116 | item_counts[0] = 0 117 | probs = item_counts / item_counts.sum() 118 | _negatives = np.random.choice(len(item_counts), 119 | size=self.negatives_per_target * batch_items_list[-1].shape[1], 120 | replace=False, 121 | p=probs) 122 | _negatives = _negatives.reshape((-1, self.negatives_per_target)) 123 | negatives.append(_negatives) 124 | return np.stack(negatives) 125 | 126 | def __next__(self): 127 | """ 128 | Returns: 129 | user_ids: (batch_size,) 130 | input sequences: (batch_size, input_len) 131 | target sequences: (batch_size, target_len) 132 | """ 133 | if self._batch_idx == len(self): 134 | self._batch_idx = 0 135 | raise StopIteration 136 | else: 137 | if self._batch_idx == 0 and self.train: 138 | random.shuffle(self.dataset) 139 | batch = self.dataset[self._batch_idx * self.batch_size:(self._batch_idx + 1) * self.batch_size] 140 | self._batch_idx += 1 141 | batch_data = [np.array(b) for b in zip(*batch)] 142 | # Diff task 143 | target_idx = 3 if len(batch_data) > 5 else 2 144 | if not self.include_timestamp: 145 | batch_data = batch_data[:target_idx + 1] 146 | # Sampling negatives 147 | if self.train and self.negatives_per_target > 0: 148 | negatives = self.sample_negatives(batch_data[1:target_idx + 1]) 149 | batch_data.append(negatives) 150 | return batch_data 151 | -------------------------------------------------------------------------------- /srdatasets/dataloader_pytorch.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import pickle 4 | from collections import Counter 5 | 6 | import torch 7 | import torch.utils.data 8 | 9 | from srdatasets.datasets import __datasets__ 10 | from srdatasets.utils import __warehouse__, get_datasetname, get_processed_datasets 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class Dataset(torch.utils.data.Dataset): 16 | def __init__(self, name: str, config_id: str, train: bool, development: bool): 17 | super(Dataset, self).__init__() 18 | datadir = __warehouse__.joinpath(name, "processed", config_id, "dev" if development else "test") 19 | datapath = datadir.joinpath("train.pkl" if train else "test.pkl") 20 | if datapath.exists(): 21 | with open(datapath, "rb") as f: 22 | self.dataset = pickle.load(f) 23 | else: 24 | raise ValueError("{} does not exist!".format(datapath)) 25 | with open(datadir.joinpath("stats.json"), "r") as f: 26 | self.stats = json.load(f) 27 | 28 | if train: 29 | self.item_counts = Counter() 30 | for data in self.dataset: 31 | if len(data) > 5: 32 | self.item_counts.update(data[1] + data[2] + data[3]) 33 | else: 34 | self.item_counts.update(data[1] + data[2]) 35 | 36 | def __len__(self): 37 | return len(self.dataset) 38 | 39 | def __getitem__(self, idx): 40 | return tuple(map(torch.tensor, self.dataset[idx])) 41 | 42 | 43 | class DataLoader(torch.utils.data.DataLoader): 44 | def collate_fn(self, batch): 45 | """ Negative sampling and Timestamps removal or adding 46 | """ 47 | target_pos = 3 if len(batch[0]) > 5 else 2 48 | 49 | batch_data = list(zip(*batch)) 50 | if self.include_timestamp: 51 | batch_data = list(map(torch.stack, batch_data)) 52 | else: 53 | batch_data = list(map(torch.stack, batch_data[:target_pos + 1])) 54 | 55 | if self.train and self.negatives_per_target > 0: 56 | batch_item_counts = self.item_counts.repeat(len(batch), 57 | 1).scatter(1, torch.cat(batch_data[1:target_pos + 1], 1), 0) 58 | # Prevent padding item 0 from being negative samples 59 | batch_item_counts[:, 0] = 0 60 | negatives = torch.multinomial( 61 | batch_item_counts, 62 | self.negatives_per_target * batch_data[target_pos].size(1), 63 | ) 64 | negatives = negatives.view(len(batch), -1, self.negatives_per_target) 65 | batch_data.append(negatives) 66 | return batch_data 67 | 68 | @property 69 | def num_users(self): 70 | return self.dataset.stats["users"] 71 | 72 | @property 73 | def num_items(self): 74 | return self.dataset.stats["items"] 75 | 76 | def __init__(self, 77 | dataset_name: str, 78 | config_id: str, 79 | batch_size: int = 1, 80 | train: bool = True, 81 | development: bool = False, 82 | negatives_per_target: int = 0, 83 | include_timestamp: bool = False, 84 | **kwargs): 85 | """Loader of sequential recommendation datasets 86 | 87 | Args: 88 | dataset_name (str): dataset name 89 | config_id (str): dataset config id 90 | batch_size (int): batch_size 91 | train (bool, optional): load training dataset 92 | development (bool, optional): use the dataset for hyperparameter tuning 93 | negatives_per_target (int, optional): number of negative samples per target 94 | include_timestamp (bool, optional): add timestamps to batch data 95 | 96 | Note: training data is shuffled automatically. 97 | """ 98 | dataset_name = get_datasetname(dataset_name) 99 | 100 | if dataset_name not in __datasets__: 101 | raise ValueError("Unrecognized dataset, currently supported datasets: {}".format(", ".join(__datasets__))) 102 | 103 | _processed_datasets = get_processed_datasets() 104 | if dataset_name not in _processed_datasets: 105 | raise ValueError("{} has not been processed, currently processed datasets: {}".format( 106 | dataset_name, ", ".join(_processed_datasets) if _processed_datasets else "none")) 107 | 108 | if config_id not in _processed_datasets[dataset_name]: 109 | raise ValueError("Unrecognized config id, existing config ids: {}".format(", ".join( 110 | _processed_datasets[dataset_name]))) 111 | 112 | if negatives_per_target < 0: 113 | negatives_per_target = 0 114 | logger.warning("Number of negative samples per target should >= 0, reset to 0") 115 | 116 | if not train and negatives_per_target > 0: 117 | logger.warning( 118 | "Negative samples are used for training, set negatives_per_target has no effect when testing") 119 | 120 | self.train = train 121 | self.include_timestamp = include_timestamp 122 | self.negatives_per_target = negatives_per_target 123 | 124 | self.dataset = Dataset(dataset_name, config_id, train, development) 125 | if train: 126 | self.item_counts = torch.tensor( 127 | [self.dataset.item_counts[i] for i in range(max(self.dataset.item_counts.keys()) + 1)], 128 | dtype=torch.float) 129 | 130 | super().__init__(self.dataset, batch_size=batch_size, shuffle=train, collate_fn=self.collate_fn, **kwargs) 131 | -------------------------------------------------------------------------------- /srdatasets/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from srdatasets.datasets.amazon import Amazon 2 | from srdatasets.datasets.citeulike import CiteULike 3 | from srdatasets.datasets.foursquare import FourSquare 4 | from srdatasets.datasets.gowalla import Gowalla 5 | from srdatasets.datasets.lastfm1k import Lastfm1K 6 | from srdatasets.datasets.movielens20m import MovieLens20M 7 | from srdatasets.datasets.retailrocket import Retailrocket 8 | from srdatasets.datasets.tafeng import TaFeng 9 | from srdatasets.datasets.taobao import Taobao 10 | from srdatasets.datasets.tmall import Tmall 11 | from srdatasets.datasets.yelp import Yelp 12 | 13 | dataset_classes = { 14 | "Amazon": Amazon, 15 | "CiteULike": CiteULike, 16 | "FourSquare": FourSquare, 17 | "Gowalla": Gowalla, 18 | "Lastfm1K": Lastfm1K, 19 | "MovieLens20M": MovieLens20M, 20 | "Retailrocket": Retailrocket, 21 | "TaFeng": TaFeng, 22 | "Taobao": Taobao, 23 | "Tmall": Tmall, 24 | "Yelp": Yelp 25 | } 26 | 27 | amazon_datasets = ["Amazon-" + c for c in Amazon.__corefile__.keys()] 28 | foursquare_datasets = ["FourSquare-" + c for c in FourSquare.__corefile__.keys()] 29 | 30 | __datasets__ = (amazon_datasets + ["CiteULike"] + foursquare_datasets + [ 31 | "Gowalla", 32 | "Lastfm1K", 33 | "MovieLens20M", 34 | "TaFeng", 35 | "Retailrocket", 36 | "Taobao", 37 | "Tmall", 38 | "Yelp", 39 | ]) 40 | -------------------------------------------------------------------------------- /srdatasets/datasets/amazon.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from srdatasets.datasets.dataset import Dataset 4 | from srdatasets.datasets.utils import download_url 5 | 6 | 7 | class Amazon(Dataset): 8 | 9 | __corefile__ = { 10 | "Books": "ratings_Books.csv", 11 | "Electronics": "ratings_Electronics.csv", 12 | "Movies": "ratings_Movies_and_TV.csv", 13 | "CDs": "ratings_CDs_and_Vinyl.csv", 14 | "Clothing": "ratings_Clothing_Shoes_and_Jewelry.csv", 15 | "Home": "ratings_Home_and_Kitchen.csv", 16 | "Kindle": "ratings_Kindle_Store.csv", 17 | "Sports": "ratings_Sports_and_Outdoors.csv", 18 | "Phones": "ratings_Cell_Phones_and_Accessories.csv", 19 | "Health": "ratings_Health_and_Personal_Care.csv", 20 | "Toys": "ratings_Toys_and_Games.csv", 21 | "VideoGames": "ratings_Video_Games.csv", 22 | "Tools": "ratings_Tools_and_Home_Improvement.csv", 23 | "Beauty": "ratings_Beauty.csv", 24 | "Apps": "ratings_Apps_for_Android.csv", 25 | "Office": "ratings_Office_Products.csv", 26 | "Pet": "ratings_Pet_Supplies.csv", 27 | "Automotive": "ratings_Automotive.csv", 28 | "Grocery": "ratings_Grocery_and_Gourmet_Food.csv", 29 | "Patio": "ratings_Patio_Lawn_and_Garden.csv", 30 | "Baby": "ratings_Baby.csv", 31 | "Music": "ratings_Digital_Music.csv", 32 | "MusicalInstruments": "ratings_Musical_Instruments.csv", 33 | "InstantVideo": "ratings_Amazon_Instant_Video.csv" 34 | } 35 | 36 | url_prefix = "http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/" 37 | 38 | def download(self, category): 39 | download_url(self.url_prefix + self.__corefile__[category], self.rootdir.joinpath(self.__corefile__[category])) 40 | 41 | def transform(self, category, rating_threshold): 42 | """ Records with rating less than `rating_threshold` are dropped 43 | """ 44 | df = pd.read_csv(self.rootdir.joinpath(self.__corefile__[category]), 45 | header=None, 46 | names=["user_id", "item_id", "rating", "timestamp"]) 47 | df = df[df["rating"] >= rating_threshold].drop("rating", axis=1) 48 | return df 49 | -------------------------------------------------------------------------------- /srdatasets/datasets/citeulike.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | 5 | from srdatasets.datasets.dataset import Dataset 6 | from srdatasets.datasets.utils import download_url, extract 7 | 8 | 9 | class CiteULike(Dataset): 10 | 11 | __url__ = "http://konect.cc/files/download.tsv.citeulike-ut.tar.bz2" 12 | __corefile__ = os.path.join("citeulike-ut", "out.citeulike-ut") 13 | 14 | def download(self): 15 | download_url(self.__url__, self.rawpath) 16 | extract(self.rawpath, self.rootdir) 17 | 18 | def transform(self): 19 | df = pd.read_csv(self.rootdir.joinpath(self.__corefile__), 20 | sep=" ", 21 | header=None, 22 | index_col=False, 23 | names=["user_id", "tag_id", "positive", "timestamp"], 24 | usecols=[0, 1, 3], 25 | comment="%", 26 | converters={"timestamp": lambda x: int(float(x))}) 27 | df = df.rename(columns={"tag_id": "item_id"}) 28 | return df 29 | -------------------------------------------------------------------------------- /srdatasets/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from urllib.parse import urlparse 4 | 5 | 6 | class Dataset(ABC): 7 | """ Base dataset of SR datasets 8 | """ 9 | def __init__(self, rootdir): 10 | """ `rootdir` is the directory of the raw dataset """ 11 | self.rootdir = rootdir 12 | 13 | @property 14 | def rawpath(self): 15 | if hasattr(self, "__url__"): 16 | return self.rootdir.joinpath(os.path.basename(urlparse(self.__url__).path)) 17 | else: 18 | return "" 19 | 20 | @abstractmethod 21 | def download(self): 22 | """ Download and extract the raw dataset """ 23 | pass 24 | 25 | @abstractmethod 26 | def transform(self): 27 | """ Transform to the general data format, which is 28 | a pd.DataFrame instance that contains three columns: [user_id, item_id, timestamp] 29 | """ 30 | pass 31 | -------------------------------------------------------------------------------- /srdatasets/datasets/foursquare.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import pandas as pd 5 | 6 | from srdatasets.datasets.dataset import Dataset 7 | from srdatasets.datasets.utils import download_url, extract 8 | 9 | 10 | class FourSquare(Dataset): 11 | 12 | __url__ = "http://www-public.imtbs-tsp.eu/~zhang_da/pub/dataset_tsmc2014.zip" 13 | 14 | __corefile__ = { 15 | "NYC": os.path.join("dataset_tsmc2014", "dataset_TSMC2014_NYC.txt"), 16 | "Tokyo": os.path.join("dataset_tsmc2014", "dataset_TSMC2014_TKY.txt") 17 | } 18 | 19 | def download(self): 20 | download_url(self.__url__, self.rawpath) 21 | extract(self.rawpath, self.rootdir) 22 | 23 | def transform(self, city): 24 | """ city: `NYC` or `Tokyo` 25 | """ 26 | df = pd.read_csv( 27 | self.rootdir.joinpath(self.__corefile__[city]), 28 | sep="\t", 29 | header=None, 30 | names=[ 31 | "user_id", "venue_id", "venue_category_id", "venue_category_name", "latitude", "longtitude", 32 | "timezone_offset", "utc_time" 33 | ], 34 | usecols=[0, 1, 7], 35 | converters={"utc_time": lambda x: int(datetime.strptime(x, "%a %b %d %H:%M:%S %z %Y").timestamp())}) 36 | df = df.rename(columns={"venue_id": "item_id", "utc_time": "timestamp"}) 37 | return df 38 | -------------------------------------------------------------------------------- /srdatasets/datasets/gowalla.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import pandas as pd 4 | 5 | from srdatasets.datasets.dataset import Dataset 6 | from srdatasets.datasets.utils import download_url, extract 7 | 8 | 9 | class Gowalla(Dataset): 10 | 11 | __url__ = "https://snap.stanford.edu/data/loc-gowalla_totalCheckins.txt.gz" 12 | __corefile__ = "loc-gowalla_totalCheckins.txt" 13 | 14 | def download(self): 15 | download_url(self.__url__, self.rawpath) 16 | extract(self.rawpath, self.rootdir.joinpath("loc-gowalla_totalCheckins.txt")) 17 | 18 | def transform(self): 19 | """ Time: yyyy-mm-ddThh:mm:ssZ -> timestamp """ 20 | df = pd.read_csv( 21 | self.rootdir.joinpath(self.__corefile__), 22 | sep="\t", 23 | names=["user_id", "check_in_time", "latitude", "longtitude", "location_id"], 24 | usecols=[0, 1, 4], 25 | converters={"check_in_time": lambda x: int(datetime.strptime(x, "%Y-%m-%dT%H:%M:%SZ").timestamp())}) 26 | df = df.rename(columns={"location_id": "item_id", "check_in_time": "timestamp"}) 27 | return df 28 | -------------------------------------------------------------------------------- /srdatasets/datasets/lastfm1k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import pandas as pd 5 | 6 | from srdatasets.datasets.dataset import Dataset 7 | from srdatasets.datasets.utils import download_url, extract 8 | 9 | 10 | class Lastfm1K(Dataset): 11 | 12 | __url__ = "http://mtg.upf.edu/static/datasets/last.fm/lastfm-dataset-1K.tar.gz" 13 | __corefile__ = os.path.join("lastfm-dataset-1K", "userid-timestamp-artid-artname-traid-traname.tsv") 14 | 15 | def download(self): 16 | download_url(self.__url__, self.rawpath) 17 | extract(self.rawpath, self.rootdir) 18 | 19 | def transform(self, item_type): 20 | """ item_type can be `artist` or `song` 21 | """ 22 | df = pd.read_csv( 23 | self.rootdir.joinpath(self.__corefile__), 24 | sep="\t", 25 | names=["user_id", "timestamp", "artist_id", "artist_name", "song_id", "song_name"], 26 | usecols=[0, 1, 2, 4], 27 | converters={"timestamp": lambda x: int(datetime.strptime(x, "%Y-%m-%dT%H:%M:%SZ").timestamp())}) 28 | if item_type == "song": 29 | df = df.drop("artist_id", axis=1).rename(columns={"song_id": "item_id"}) 30 | else: 31 | df = df.drop("song_id", axis=1).rename(columns={"artist_id": "item_id"}) 32 | return df 33 | -------------------------------------------------------------------------------- /srdatasets/datasets/movielens20m.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | 5 | from srdatasets.datasets.dataset import Dataset 6 | from srdatasets.datasets.utils import download_url, extract 7 | 8 | 9 | class MovieLens20M(Dataset): 10 | 11 | __url__ = "http://files.grouplens.org/datasets/movielens/ml-20m.zip" 12 | __corefile__ = os.path.join("ml-20m", "ratings.csv") 13 | 14 | def download(self): 15 | download_url(self.__url__, self.rawpath) 16 | extract(self.rawpath, self.rootdir) 17 | 18 | def transform(self, rating_threshold): 19 | """ Records with rating less than `rating_threshold` are dropped 20 | """ 21 | df = pd.read_csv(self.rootdir.joinpath(self.__corefile__), 22 | header=0, 23 | names=["user_id", "movie_id", "rating", "timestamp"]) 24 | df = df.rename(columns={"movie_id": "item_id"}) 25 | df = df[df.rating >= rating_threshold].drop("rating", axis=1) 26 | return df 27 | -------------------------------------------------------------------------------- /srdatasets/datasets/retailrocket.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pandas as pd 4 | 5 | from srdatasets.datasets.dataset import Dataset 6 | from srdatasets.datasets.utils import extract 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class Retailrocket(Dataset): 12 | 13 | __corefile__ = "events.csv" 14 | 15 | def download(self): 16 | if not self.rootdir.joinpath("ecommerce-dataset.zip").exists(): 17 | logger.warning( 18 | "Since RetailRocket dataset is not directly accessible, please visit \ 19 | https://www.kaggle.com/retailrocket/ecommerce-dataset and download \ 20 | it manually, after downloaded, place file 'ecommerce-dataset.zip' \ 21 | under %s and run this command again", self.rootdir) 22 | else: 23 | extract(self.rootdir.joinpath("ecommerce-dataset.zip"), self.rootdir) 24 | 25 | def transform(self): 26 | df = pd.read_csv(self.rootdir.joinpath(self.__corefile__), 27 | header=0, 28 | index_col=False, 29 | usecols=[0, 1, 3], 30 | converters={"timestamp": lambda x: int(int(x) / 1000)}) 31 | df = df.rename(columns={"visitorid": "user_id", "itemid": "item_id"}) 32 | return df 33 | -------------------------------------------------------------------------------- /srdatasets/datasets/tafeng.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import pandas as pd 4 | 5 | from srdatasets.datasets.dataset import Dataset 6 | from srdatasets.datasets.utils import download_url, extract 7 | 8 | 9 | class TaFeng(Dataset): 10 | 11 | __url__ = "https://sites.google.com/site/dataminingcourse2009/spring2016/annoucement2016/assignment3/D11-02.ZIP" 12 | __corefile__ = ["D11", "D12", "D01", "D02"] 13 | 14 | def download(self): 15 | download_url(self.__url__, self.rawpath) 16 | extract(self.rawpath, self.rootdir) 17 | 18 | def transform(self): 19 | dfs = [] 20 | for cf in self.__corefile__: 21 | df = pd.read_csv( 22 | self.rootdir.joinpath(cf), 23 | sep=";", 24 | header=0, 25 | index_col=False, 26 | names=["timestamp", "user_id", "age", "area", "pcate", "item_id", "number", "cost", "price"], 27 | usecols=[0, 1, 5], 28 | encoding="big5", 29 | converters={"timestamp": lambda x: int(datetime.strptime(x, "%Y-%m-%d %H:%M:%S").timestamp())}) 30 | dfs.append(df) 31 | return pd.concat(dfs, ignore_index=True) 32 | -------------------------------------------------------------------------------- /srdatasets/datasets/taobao.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pandas as pd 4 | 5 | from srdatasets.datasets.dataset import Dataset 6 | from srdatasets.datasets.utils import extract 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class Taobao(Dataset): 12 | 13 | __corefile__ = "UserBehavior.csv" 14 | 15 | def download(self): 16 | if not self.rootdir.joinpath("UserBehavior.csv.zip").exists(): 17 | logger.warning( 18 | "Since Taobao dataset is not directly accessible, please visit \ 19 | https://tianchi.aliyun.com/dataset/dataDetail?dataId=649 and \ 20 | download it manually, after downloaded, place file \ 21 | 'UserBehavior.csv.zip' under %s and run this command again", self.rootdir) 22 | else: 23 | extract(self.rootdir.joinpath("UserBehavior.csv.zip"), self.rootdir) 24 | 25 | def transform(self): 26 | df = pd.read_csv(self.rootdir.joinpath(self.__corefile__), 27 | header=None, 28 | index_col=False, 29 | names=["user_id", "item_id", "category_id", "behavior_type", "timestamp"], 30 | usecols=[0, 1, 4]) 31 | return df 32 | -------------------------------------------------------------------------------- /srdatasets/datasets/tmall.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from datetime import datetime 4 | 5 | import pandas as pd 6 | 7 | from srdatasets.datasets.dataset import Dataset 8 | from srdatasets.datasets.utils import extract 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class Tmall(Dataset): 14 | 15 | __corefile__ = os.path.join("data_format1", "user_log_format1.csv") 16 | 17 | def download(self): 18 | if not self.rootdir.joinpath("data_format1.zip").exists(): 19 | logger.warning( 20 | "Since Tmall dataset is not directly accessible, please visit \ 21 | https://tianchi.aliyun.com/dataset/dataDetail?dataId=47 and \ 22 | download it manually, after downloaded, place file \ 23 | 'data_format1.zip' under %s and run this command again", self.rootdir) 24 | else: 25 | extract(self.rootdir.joinpath("data_format1.zip"), self.rootdir) 26 | 27 | def transform(self): 28 | df = pd.read_csv(self.rootdir.joinpath(self.__corefile__), 29 | header=0, 30 | index_col=False, 31 | usecols=[0, 1, 5], 32 | converters={"time_stamp": lambda x: int(datetime.strptime("2015" + x, "%Y%m%d").timestamp())}) 33 | df = df.rename(columns={"time_stamp": "timestamp"}) 34 | return df 35 | -------------------------------------------------------------------------------- /srdatasets/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import gzip 3 | import logging 4 | import os 5 | import shutil 6 | import tarfile 7 | import urllib.request 8 | from zipfile import ZipFile 9 | 10 | from tqdm import tqdm 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class DownloadProgressBar(tqdm): 16 | """ From https://stackoverflow.com/a/53877507/8810037 17 | """ 18 | def update_to(self, b=1, bsize=1, tsize=None): 19 | if tsize is not None: 20 | self.total = tsize 21 | self.update(b * bsize - self.n) 22 | 23 | 24 | def download_url(url, output_path): 25 | try: 26 | with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]) as t: 27 | urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) 28 | logger.info("Download successful") 29 | except: 30 | logger.exception("Download failed, please try again") 31 | if output_path.exists(): 32 | os.remove(output_path) 33 | 34 | 35 | def extract(filepath, out): 36 | """ out: a file or a directory 37 | """ 38 | logger.info("Unzipping...") 39 | filename = filepath.as_posix() 40 | 41 | if filename.endswith(".zip") or filename.endswith(".ZIP"): 42 | with ZipFile(filepath) as zipObj: 43 | zipObj.extractall(out) 44 | 45 | elif filename.endswith(".tar.gz"): 46 | with tarfile.open(filepath, "r:gz") as tar: 47 | tar.extractall(out) 48 | 49 | elif filename.endswith(".tar.bz2"): 50 | with tarfile.open(filepath, "r:bz2") as tar: 51 | tar.extractall(out) 52 | 53 | elif filename.endswith(".gz"): 54 | with gzip.open(filepath, "rb") as fin: 55 | with open(out, "wb") as fout: 56 | shutil.copyfileobj(fin, fout) 57 | 58 | elif filename.endswith(".bz2"): 59 | with bz2.open(filepath, "rb") as fin: 60 | with open(out, "wb") as fout: 61 | shutil.copyfileobj(fin, fout) 62 | else: 63 | logger.error("Unrecognized compressing format of %s", filepath) 64 | return 65 | 66 | logger.info("OK") 67 | -------------------------------------------------------------------------------- /srdatasets/datasets/yelp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pandas as pd 4 | 5 | from srdatasets.datasets.dataset import Dataset 6 | from srdatasets.datasets.utils import extract 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class Yelp(Dataset): 12 | 13 | __corefile__ = "review.json" 14 | 15 | def download(self): 16 | if not self.rootdir.joinpath("yelp_dataset.tar.gz").exists(): 17 | logger.warning( 18 | "Since Yelp dataset is not directly accessible, please visit \ 19 | https://www.yelp.com/dataset/download and download it manually, \ 20 | after downloaded, place file 'yelp_dataset.tar.gz' \ 21 | under %s and run this command again", self.rootdir) 22 | else: 23 | extract(self.rootdir.joinpath("yelp_dataset.tar.gz"), self.rootdir) 24 | 25 | def transform(self, stars_threshold): 26 | df = pd.read_json(self.rootdir.joinpath(self.__corefile__), orient="records", lines=True) 27 | df = df[["user_id", "business_id", "stars", "date"]] 28 | df["date"] = df["date"].apply(lambda x: int(x.timestamp())) 29 | df = df[df["stars"] >= stars_threshold].drop("stars", axis=1) 30 | df = df.rename(columns={"business_id": "item_id", "date": "timestamp"}) 31 | return df 32 | -------------------------------------------------------------------------------- /srdatasets/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from srdatasets.utils import __warehouse__ 4 | from srdatasets.datasets import dataset_classes 5 | 6 | 7 | def _download(dataset_name): 8 | _rawdir = __warehouse__.joinpath(dataset_name, "raw") 9 | os.makedirs(_rawdir, exist_ok=True) 10 | 11 | if dataset_name.startswith("Amazon"): 12 | dataset_classes["Amazon"](_rawdir).download(dataset_name.split("-")[1]) 13 | elif dataset_name.startswith("FourSquare"): 14 | dataset_classes["FourSquare"](_rawdir).download() 15 | else: 16 | dataset_classes[dataset_name](_rawdir).download() 17 | -------------------------------------------------------------------------------- /srdatasets/process.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import os 5 | import pickle 6 | import random 7 | import sys 8 | import time 9 | from collections import defaultdict 10 | from datetime import datetime 11 | 12 | from tqdm import tqdm 13 | 14 | from srdatasets.datasets import dataset_classes 15 | from srdatasets.utils import __warehouse__ 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def _process(args): 21 | if "-" in args.dataset: 22 | classname, sub = args.dataset.split("-") 23 | else: 24 | classname = args.dataset 25 | d = dataset_classes[classname](__warehouse__.joinpath(args.dataset, "raw")) 26 | 27 | config = { 28 | "min_freq_user": args.min_freq_user, 29 | "min_freq_item": args.min_freq_item, 30 | "input_len": args.input_len, 31 | "target_len": args.target_len, 32 | "no_augment": args.no_augment, 33 | "remove_duplicates": args.remove_duplicates, 34 | "session_interval": args.session_interval, 35 | "min_session_len": args.min_session_len, 36 | "max_session_len": args.max_session_len, 37 | "split_by": args.split_by, 38 | "dev_split": args.dev_split, 39 | "test_split": args.test_split, 40 | "task": args.task, 41 | "pre_sessions": args.pre_sessions, 42 | "pick_targets": args.pick_targets 43 | } 44 | if classname in ["Amazon", "MovieLens20M", "Yelp"]: 45 | config["rating_threshold"] = args.rating_threshold 46 | elif classname == "Lastfm1K": 47 | config["item_type"] = args.item_type 48 | 49 | logger.info("Transforming...") 50 | if classname == "Amazon": 51 | df = d.transform(sub, args.rating_threshold) 52 | elif classname in ["MovieLens20M", "Yelp"]: 53 | df = d.transform(args.rating_threshold) 54 | elif classname == "FourSquare": 55 | df = d.transform(sub) 56 | elif classname == "Lastfm1K": 57 | df = d.transform(args.item_type) 58 | else: 59 | df = d.transform() 60 | 61 | if args.split_by == "time": 62 | config["dev_split"], config["test_split"] = access_split_days(df) 63 | # Processed check 64 | if ("time_splits" in args and (config["dev_split"], config["test_split"]) in args.time_splits): 65 | logger.warning("You have run this config, the config id is {}".format( 66 | args.time_splits[(config["dev_split"], config["test_split"])])) 67 | sys.exit(1) 68 | config["max_timestamp"] = df["timestamp"].max() 69 | 70 | preprocess_and_save(df, args.dataset, config) 71 | 72 | 73 | def access_split_days(df): 74 | min_timestamp = df["timestamp"].min() 75 | max_timestamp = df["timestamp"].max() 76 | first_day = datetime.fromtimestamp(min_timestamp).strftime("%Y-%m-%d") 77 | last_day = datetime.fromtimestamp(max_timestamp).strftime("%Y-%m-%d") 78 | total_days = math.ceil((max_timestamp - min_timestamp) / 86400) 79 | print("Date range: {} ~ {}, total days: {}".format(first_day, last_day, total_days)) 80 | while True: 81 | try: 82 | test_last_days = int(input("Last N days for test: ")) 83 | dev_last_days = int(input("Last N days for dev: ")) 84 | if test_last_days <= 0 or dev_last_days <= 0: 85 | raise ValueError 86 | elif test_last_days + dev_last_days >= total_days: 87 | raise AssertionError 88 | else: 89 | break 90 | except ValueError: 91 | print("Please input a positive integer!") 92 | except AssertionError: 93 | print("test_last_days + dev_last_days < total_days") 94 | return dev_last_days, test_last_days 95 | 96 | 97 | def preprocess_and_save(df, dname, config): 98 | """General preprocessing method 99 | 100 | Args: 101 | df (DataFrame): columns: `user_id`, `item_id`, `timestamp`. 102 | args (Namespace): arguments. 103 | """ 104 | # Generate sequences 105 | logger.info("Generating user sequences...") 106 | seqs = generate_sequences(df, config) 107 | 108 | # Split sequences in different ways 109 | if config["session_interval"] > 0: 110 | split = split_sequences_session 111 | else: 112 | split = split_sequences 113 | 114 | logger.info("Splitting user sequences into train/test...") 115 | train_seqs, test_seqs = split(seqs, config, 0) 116 | 117 | logger.info("Splitting train into dev-train/dev-test...") 118 | dev_train_seqs, dev_test_seqs = split(train_seqs, config, 1) 119 | 120 | # Remove duplicates (optional) 121 | if config["remove_duplicates"]: 122 | logger.info("Removing duplicates...") 123 | train_seqs, test_seqs, dev_train_seqs, dev_test_seqs = [ 124 | remove_duplicates(seqs, config) for seqs in [train_seqs, test_seqs, dev_train_seqs, dev_test_seqs] 125 | ] 126 | 127 | # Do not use data augmentation (optional) 128 | if config["no_augment"]: 129 | logger.info("Enabling no data augmentation...") 130 | train_seqs, test_seqs, dev_train_seqs, dev_test_seqs = [ 131 | enable_no_augment(seqs, config) for seqs in [train_seqs, test_seqs, dev_train_seqs, dev_test_seqs] 132 | ] 133 | 134 | # Remove unknowns 135 | logger.info("Removing unknowns in test...") 136 | test_seqs = remove_unknowns(train_seqs, test_seqs, config) 137 | 138 | logger.info("Removing unknowns in dev-test...") 139 | dev_test_seqs = remove_unknowns(dev_train_seqs, dev_test_seqs, config) 140 | 141 | # Reassign user and item ids 142 | logger.info("Reassigning ids (train/test)...") 143 | train_seqs, test_seqs = reassign_ids(train_seqs, test_seqs) 144 | 145 | logger.info("Reassigning ids (dev-train/dev-test)...") 146 | dev_train_seqs, dev_test_seqs = reassign_ids(dev_train_seqs, dev_test_seqs) 147 | 148 | # Make datasets based on task 149 | if config["task"] == "short": 150 | make_dataset = make_dataset_short 151 | else: 152 | make_dataset = make_dataset_long_short 153 | 154 | logger.info("Making datasets...") 155 | train_data, test_data, dev_train_data, dev_test_data = [ 156 | make_dataset(seqs, config) for seqs in [train_seqs, test_seqs, dev_train_seqs, dev_test_seqs] 157 | ] 158 | 159 | # Dump to disk 160 | logger.info("Dumping...") 161 | processed_path = __warehouse__.joinpath(dname, "processed", "c" + str(int(time.time() * 1000))) 162 | dump(processed_path, train_data, test_data, 0) 163 | dump(processed_path, dev_train_data, dev_test_data, 1) 164 | 165 | # Save config 166 | save_config(processed_path, config) 167 | logger.info("OK, the config id is: %s", processed_path.stem) 168 | 169 | 170 | def enable_no_augment(seqs, config): 171 | """ 172 | For short-term task: keep most recent (input_len + target_len) items, 173 | For long-short-term task: keep most recent (pre_sessions + 1) sessions 174 | """ 175 | seqs_ = [] 176 | if config["task"] == "short": 177 | for user_id, seq in tqdm(seqs): 178 | seqs_.append((user_id, seq[-config["input_len"] - config["target_len"]:])) 179 | else: 180 | user_sessions = defaultdict(list) 181 | for user_id, seq in seqs: 182 | user_sessions[user_id].append(seq) 183 | for user_id, sessions in tqdm(user_sessions.items()): 184 | seqs_.extend((user_id, s) for s in sessions[-config["pre_sessions"] - 1]) 185 | return seqs_ 186 | 187 | 188 | def reassign_ids(train_seqs, test_seqs): 189 | user_to_idx = {} 190 | item_to_idx = {} # starts from 1, 0 for padding 191 | train_seqs_ = [] 192 | test_seqs_ = [] 193 | for user_id, seq in tqdm(train_seqs): 194 | # Build dicts 195 | if user_id not in user_to_idx: 196 | user_to_idx[user_id] = len(user_to_idx) 197 | for item, timestamp in seq: 198 | if item not in item_to_idx: 199 | item_to_idx[item] = len(item_to_idx) + 1 200 | # Reassign 201 | train_seqs_.append((user_to_idx[user_id], [(item_to_idx[i], t) for i, t in seq])) 202 | for user_id, seq in tqdm(test_seqs): 203 | test_seqs_.append((user_to_idx[user_id], [(item_to_idx[i], t) for i, t in seq])) 204 | return train_seqs_, test_seqs_ 205 | 206 | 207 | def generate_sequences(df, config): 208 | logger.info("Dropping items (freq < %s)...", config["min_freq_item"]) 209 | df = drop_items(df, config["min_freq_item"]) 210 | 211 | logger.info("Dropping users (freq < %s)...", config["min_freq_user"]) 212 | df = drop_users(df, config["min_freq_user"]) 213 | 214 | logger.info("Grouping items by user...") 215 | df = df.sort_values("timestamp", ascending=True) 216 | df["item_and_time"] = list(zip(df["item_id"], df["timestamp"])) 217 | seqs = df.groupby("user_id")["item_and_time"].apply(list) 218 | seqs = list(zip(seqs.index, seqs)) 219 | 220 | logger.info("Dropping too short user sequences...") 221 | seqs = [s for s in tqdm(seqs) if len(s[1]) > config["target_len"]] 222 | 223 | if config["session_interval"] > 0: 224 | logger.info("Splitting sessions...") 225 | _seqs = [] 226 | for user_id, seq in tqdm(seqs): 227 | seq_buffer = [] 228 | for i, (item_id, timestamp) in enumerate(seq): 229 | if i == 0: 230 | seq_buffer.append((item_id, timestamp)) 231 | else: 232 | if timestamp - seq[i - 1][1] > config["session_interval"] * 60: 233 | if len(seq_buffer) >= config["min_session_len"]: 234 | _seqs.append((user_id, seq_buffer[-config["max_session_len"]:])) 235 | seq_buffer = [(item_id, timestamp)] 236 | else: 237 | seq_buffer.append((item_id, timestamp)) 238 | if len(seq_buffer) >= config["min_session_len"]: 239 | _seqs.append((user_id, seq_buffer[-config["max_session_len"]:])) 240 | seqs = _seqs 241 | return seqs 242 | 243 | 244 | def split_sequences(user_seq, config, mode): 245 | """ Without sessions 246 | """ 247 | if config["split_by"] == "user": 248 | test_ratio = config["dev_split"] if mode else config["test_split"] 249 | else: 250 | last_days = (config["dev_split"] + config["test_split"] if mode else config["test_split"]) 251 | split_timestamp = config["max_timestamp"] - last_days * 86400 252 | train_seqs = [] 253 | test_seqs = [] 254 | for user_id, seq in tqdm(user_seq): 255 | train_num = 0 256 | if config["split_by"] == "user": 257 | train_num = math.ceil(len(seq) * (1 - test_ratio)) 258 | else: 259 | for item, timestamp in seq: 260 | if timestamp < split_timestamp: 261 | train_num += 1 262 | if train_num > config["target_len"]: 263 | train_seqs.append((user_id, seq[:train_num])) 264 | if len(seq) - train_num > config["target_len"]: 265 | test_seqs.append((user_id, seq[train_num:])) 266 | return train_seqs, test_seqs 267 | 268 | 269 | def split_sequences_session(user_seq, config, mode): 270 | """ With sessions, when number of sessions is small, len of test_seqs can be 0 271 | """ 272 | if config["split_by"] == "user": 273 | test_ratio = config["dev_split"] if mode else config["test_split"] 274 | else: 275 | last_days = (config["dev_split"] + config["test_split"] if mode else config["test_split"]) 276 | split_timestamp = config["max_timestamp"] - last_days * 86400 277 | user_sessions = defaultdict(list) 278 | for user_id, seq in user_seq: 279 | user_sessions[user_id].append(seq) 280 | train_seqs = [] 281 | test_seqs = [] 282 | for user_id, sessions in tqdm(user_sessions.items()): 283 | if config["split_by"] == "user": 284 | train_num = math.ceil((1 - test_ratio) * len(sessions)) 285 | else: 286 | train_num = 0 287 | for s in sessions: 288 | if s[0][1] < split_timestamp: 289 | train_num += 1 290 | if train_num > 0: 291 | train_seqs.extend((user_id, s) for s in sessions[:train_num]) 292 | test_seqs.extend((user_id, s) for s in sessions[train_num:]) 293 | return train_seqs, test_seqs 294 | 295 | 296 | def remove_duplicates(user_seq, config): 297 | """ By default, we keep the first 298 | """ 299 | user_seq_ = [] 300 | for user_id, seq in tqdm(user_seq): 301 | seq_ = [] 302 | shown_items = set() 303 | for item, timestamp in seq: 304 | if item not in shown_items: 305 | shown_items.add(item) 306 | seq_.append((item, timestamp)) 307 | if config["session_interval"] > 0: 308 | if len(seq_) >= config["min_session_len"]: 309 | user_seq_.append((user_id, seq_)) 310 | else: 311 | if len(seq_) > config["target_len"]: 312 | user_seq_.append((user_id, seq_)) 313 | return user_seq_ 314 | 315 | 316 | def remove_unknowns(train_seqs, test_seqs, config): 317 | """ Remove users and items in test_seqs that are not shown in train_seqs 318 | """ 319 | users = set() 320 | items = set() 321 | for user_id, seq in train_seqs: 322 | users.add(user_id) 323 | items.update([i for i, t in seq]) 324 | test_seqs_ = [] 325 | for user_id, seq in tqdm(test_seqs): 326 | if user_id in users: 327 | seq_ = [(i, t) for i, t in seq if i in items] 328 | if config["session_interval"] > 0: 329 | if len(seq_) >= config["min_session_len"]: 330 | test_seqs_.append((user_id, seq_)) 331 | else: 332 | if len(seq_) > config["target_len"]: 333 | test_seqs_.append((user_id, seq_)) 334 | return test_seqs_ 335 | 336 | 337 | def make_targets(seq, config): 338 | """ For long-short-term task 339 | """ 340 | if config["pick_targets"] == "random": 341 | indices = list(range(len(seq))) 342 | random.shuffle(indices) 343 | cur_session_indices = sorted(indices[config["target_len"]:]) 344 | target_indices = sorted(indices[:config["target_len"]]) 345 | cur_session = [seq[i] for i in cur_session_indices] 346 | targets = [seq[i] for i in target_indices] 347 | else: 348 | cur_session = seq[:-config["target_len"]] 349 | targets = seq[-config["target_len"]:] 350 | # Padding 351 | cur_session = [(0, -1)] * (config["max_session_len"] - config["target_len"] - len(cur_session)) + cur_session 352 | return cur_session, targets 353 | 354 | 355 | def make_dataset_long_short(user_seq, config): 356 | """ 357 | len of pre_sessions: max_session_len * pre_sessions 358 | len of cur_session: max_session_len - target_len 359 | """ 360 | max_session_len = config["max_session_len"] 361 | n_pre_sessions = config["pre_sessions"] 362 | dataset = [] 363 | user_sessions = defaultdict(list) 364 | for user_id, seq in user_seq: 365 | user_sessions[user_id].append(seq) 366 | for user_id, sessions in tqdm(user_sessions.items()): 367 | d = len(sessions) - 1 - n_pre_sessions 368 | if d <= 0: 369 | pre_sessions = [(0, -1)] * max_session_len * (-d) 370 | for s in sessions[:-1]: 371 | pre_sessions += [(0, -1)] * (max_session_len - len(s)) + s 372 | cur_session, targets = make_targets(sessions[-1], config) 373 | dataset.append((user_id, pre_sessions, cur_session, targets)) 374 | else: 375 | for i in range(d): 376 | pre_sessions = [] 377 | for s in sessions[i:i + n_pre_sessions]: 378 | pre_sessions += [(0, -1)] * (max_session_len - len(s)) + s 379 | cur_session, targets = make_targets(sessions[i + n_pre_sessions], config) 380 | dataset.append((user_id, pre_sessions, cur_session, targets)) 381 | dataset_ = [] 382 | for data in dataset: 383 | pre_items, pre_times = list(zip(*data[1])) 384 | cur_items, cur_times = list(zip(*data[2])) 385 | target_items, target_times = list(zip(*data[3])) 386 | dataset_.append((data[0], pre_items, cur_items, target_items, pre_times, cur_times, target_times)) 387 | return dataset_ 388 | 389 | 390 | def make_dataset_short(user_seq, config): 391 | """ Build dataset for short-term task 392 | """ 393 | input_len = config["input_len"] 394 | target_len = config["target_len"] 395 | dataset = [] 396 | for user_id, seq in tqdm(user_seq): 397 | if len(seq) <= input_len + target_len: 398 | padding_num = input_len + target_len - len(seq) 399 | dataset.append((user_id, [(0, -1)] * padding_num + seq[:-target_len], seq[-target_len:])) 400 | else: 401 | augmented_seqs = [(user_id, seq[i:i + input_len], seq[i + input_len:i + input_len + target_len]) 402 | for i in range(len(seq) - input_len - target_len + 1)] 403 | dataset.extend(augmented_seqs) 404 | dataset_ = [] 405 | for data in dataset: 406 | input_items, input_times = list(zip(*data[1])) 407 | target_items, target_times = list(zip(*data[2])) 408 | dataset_.append((data[0], input_items, target_items, input_times, target_times)) 409 | return dataset_ 410 | 411 | 412 | def cal_stats(train_data, test_data): 413 | users = set() 414 | items = set() 415 | interactions = 0 416 | for data in train_data: 417 | users.add(data[0]) 418 | if len(data) > 5: 419 | items_ = data[1] + data[2] + data[3] 420 | else: 421 | items_ = data[1] + data[2] 422 | for item in items_: 423 | if item > 0: # reassigned 424 | items.add(item) 425 | interactions += 1 426 | stats = { 427 | "users": len(users), 428 | "items": len(items), 429 | "interactions": interactions, 430 | "density": interactions / len(users) / len(items), 431 | "train size": len(train_data), 432 | "test size": len(test_data) 433 | } 434 | return stats 435 | 436 | 437 | def drop_users(df, min_freq): 438 | counts = df["user_id"].value_counts() 439 | df = df[df["user_id"].isin(counts[counts >= min_freq].index)] 440 | return df 441 | 442 | 443 | def drop_items(df, min_freq): 444 | counts = df["item_id"].value_counts() 445 | df = df[df["item_id"].isin(counts[counts >= min_freq].index)] 446 | return df 447 | 448 | 449 | def save_config(path, config): 450 | if "max_timestamp" in config: 451 | del config["max_timestamp"] 452 | with open(path.joinpath("config.json"), "w") as f: 453 | json.dump(config, f) 454 | 455 | 456 | def dump(path, train_data, test_data, mode): 457 | """ Save preprocessed datasets """ 458 | dirname = "dev" if mode else "test" 459 | os.makedirs(path.joinpath(dirname)) 460 | with open(path.joinpath(dirname, "train.pkl"), "wb") as f: 461 | pickle.dump(train_data, f) 462 | with open(path.joinpath(dirname, "test.pkl"), "wb") as f: 463 | pickle.dump(test_data, f) 464 | stats = cal_stats(train_data, test_data) 465 | with open(path.joinpath(dirname, "stats.json"), "w") as f: 466 | json.dump(stats, f) 467 | 468 | 469 | # ====== TODO API for custom dataset ====== # 470 | -------------------------------------------------------------------------------- /srdatasets/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from pathlib import Path 5 | 6 | from srdatasets.datasets import __datasets__, dataset_classes 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | __warehouse__ = Path(os.path.expanduser("~")).joinpath(".srdatasets") 11 | 12 | 13 | def get_processed_datasets(): 14 | P = [ 15 | "dev/train.pkl", 16 | "dev/test.pkl", 17 | "dev/stats.json", 18 | "test/train.pkl", 19 | "test/test.pkl", 20 | "test/stats.json", 21 | "config.json", 22 | ] 23 | D = {} 24 | if __warehouse__.exists(): 25 | for d in __warehouse__.iterdir(): # loop datasets 26 | if d.joinpath("processed").exists(): 27 | configs = [] 28 | for c in d.joinpath("processed").iterdir(): # loop configs 29 | if all(c.joinpath(p).exists() for p in P): 30 | configs.append(c.stem) 31 | if configs: 32 | D[d.stem] = configs 33 | return D 34 | 35 | 36 | def get_downloaded_datasets(): 37 | """ Simple check based on the existences of corefiles 38 | """ 39 | D = [] 40 | for d in __datasets__: 41 | if "-" in d: 42 | corefile = dataset_classes[d.split("-")[0]].__corefile__[d.split("-")[1]] 43 | else: 44 | corefile = dataset_classes[d].__corefile__ 45 | if isinstance(corefile, list): 46 | if all(__warehouse__.joinpath(d, "raw", cf).exists() for cf in corefile): 47 | D.append(d) 48 | else: 49 | if __warehouse__.joinpath(d, "raw", corefile).exists(): 50 | D.append(d) 51 | return D 52 | 53 | 54 | def read_json(path): 55 | content = {} 56 | if path.exists(): 57 | with open(path, "r") as f: 58 | try: 59 | content = json.load(f) 60 | except: 61 | logger.exception("Read json file failed") 62 | return content 63 | 64 | 65 | def get_datasetname(name): 66 | return {d.lower(): d for d in __datasets__}.get(name.lower(), name) 67 | -------------------------------------------------------------------------------- /tests/datasets/test_amazon.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from srdatasets.datasets import Amazon 4 | from srdatasets.utils import __warehouse__ 5 | 6 | 7 | def test_download_and_trandform(): 8 | rawdir = __warehouse__.joinpath("Amazon", "raw") 9 | os.makedirs(rawdir, exist_ok=True) 10 | amazon = Amazon(rawdir) 11 | category = "Pet" 12 | amazon.download(category) 13 | assert rawdir.joinpath(amazon.__corefile__[category]).exists() 14 | df = amazon.transform(category, 4) 15 | assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"]) 16 | assert len(df) > 0 17 | -------------------------------------------------------------------------------- /tests/datasets/test_citeulike.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from srdatasets.datasets import CiteULike 4 | from srdatasets.utils import __warehouse__ 5 | 6 | 7 | def test_download_and_trandform(): 8 | rawdir = __warehouse__.joinpath("CiteULike", "raw") 9 | os.makedirs(rawdir, exist_ok=True) 10 | citeulike = CiteULike(rawdir) 11 | citeulike.download() 12 | assert rawdir.joinpath(citeulike.__corefile__).exists() 13 | df = citeulike.transform() 14 | assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"]) 15 | assert len(df) > 0 16 | -------------------------------------------------------------------------------- /tests/datasets/test_foursquare.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from srdatasets.datasets import FourSquare 4 | from srdatasets.utils import __warehouse__ 5 | 6 | 7 | def test_download_and_trandform(): 8 | rawdir = __warehouse__.joinpath("FourSquare", "raw") 9 | os.makedirs(rawdir, exist_ok=True) 10 | foursquare = FourSquare(rawdir) 11 | cities = ["NYC", "Tokyo"] 12 | foursquare.download() 13 | for c in cities: 14 | assert rawdir.joinpath(foursquare.__corefile__[c]).exists() 15 | for c in cities: 16 | df = foursquare.transform(c) 17 | assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"]) 18 | assert len(df) > 0 19 | -------------------------------------------------------------------------------- /tests/datasets/test_gowalla.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from srdatasets.datasets import Gowalla 4 | from srdatasets.utils import __warehouse__ 5 | 6 | 7 | def test_download_and_trandform(): 8 | rawdir = __warehouse__.joinpath("Gowalla", "raw") 9 | os.makedirs(rawdir, exist_ok=True) 10 | gowalla = Gowalla(rawdir) 11 | gowalla.download() 12 | assert rawdir.joinpath(gowalla.__corefile__).exists() 13 | df = gowalla.transform() 14 | assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"]) 15 | assert len(df) > 0 16 | -------------------------------------------------------------------------------- /tests/datasets/test_lastfm1k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from srdatasets.datasets import Lastfm1K 4 | from srdatasets.utils import __warehouse__ 5 | 6 | 7 | def test_download_and_trandform(): 8 | rawdir = __warehouse__.joinpath("Lastfm1K", "raw") 9 | os.makedirs(rawdir, exist_ok=True) 10 | lastfm1k = Lastfm1K(rawdir) 11 | lastfm1k.download() 12 | assert rawdir.joinpath(lastfm1k.__corefile__).exists() 13 | df = lastfm1k.transform("song") 14 | assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"]) 15 | assert len(df) > 0 16 | -------------------------------------------------------------------------------- /tests/datasets/test_movielens20m.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from srdatasets.datasets import MovieLens20M 4 | from srdatasets.utils import __warehouse__ 5 | 6 | 7 | def test_download_and_trandform(): 8 | rawdir = __warehouse__.joinpath("MovieLens20M", "raw") 9 | os.makedirs(rawdir, exist_ok=True) 10 | movielens20m = MovieLens20M(rawdir) 11 | movielens20m.download() 12 | assert rawdir.joinpath(movielens20m.__corefile__).exists() 13 | df = movielens20m.transform(4) 14 | assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"]) 15 | assert len(df) > 0 16 | -------------------------------------------------------------------------------- /tests/datasets/test_tafeng.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from srdatasets.datasets import TaFeng 4 | from srdatasets.utils import __warehouse__ 5 | 6 | 7 | def test_download_and_trandform(): 8 | rawdir = __warehouse__.joinpath("TaFeng", "raw") 9 | os.makedirs(rawdir, exist_ok=True) 10 | tafeng = TaFeng(rawdir) 11 | tafeng.download() 12 | assert all(rawdir.joinpath(cf).exists() for cf in tafeng.__corefile__) 13 | df = tafeng.transform() 14 | assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"]) 15 | assert len(df) > 0 16 | -------------------------------------------------------------------------------- /tests/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import shutil 3 | from argparse import Namespace 4 | 5 | import srdatasets.dataloader 6 | import srdatasets.dataloader_pytorch 7 | from srdatasets.download import _download 8 | from srdatasets.process import _process 9 | from srdatasets.utils import __warehouse__, get_downloaded_datasets, get_processed_datasets 10 | 11 | args = Namespace(dataset="FourSquare-NYC", 12 | min_freq_item=10, 13 | min_freq_user=10, 14 | task="short", 15 | split_by="user", 16 | dev_split=0.1, 17 | test_split=0.2, 18 | input_len=9, 19 | target_len=1, 20 | session_interval=0, 21 | max_session_len=10, 22 | min_session_len=2, 23 | pre_sessions=10, 24 | pick_targets="random", 25 | no_augment=False, 26 | remove_duplicates=False) 27 | 28 | if args.dataset not in get_downloaded_datasets(): 29 | _download(args.dataset) 30 | 31 | if args.dataset in get_processed_datasets(): 32 | shutil.rmtree(__warehouse__.joinpath(args.dataset, "processed")) 33 | 34 | # For short term task 35 | short_args = copy.deepcopy(args) 36 | _process(short_args) 37 | 38 | # For long-short term task 39 | long_short_args = copy.deepcopy(args) 40 | long_short_args.task = "long-short" 41 | long_short_args.session_interval = 60 42 | _process(long_short_args) 43 | 44 | 45 | def test_dataloader(): 46 | config_ids = get_processed_datasets()[args.dataset] 47 | for cid in config_ids: 48 | for DataLoader in [srdatasets.dataloader.DataLoader, srdatasets.dataloader_pytorch.DataLoader]: 49 | dataloader = DataLoader(args.dataset, 50 | cid, 51 | batch_size=32, 52 | negatives_per_target=5, 53 | include_timestamp=True, 54 | drop_last=True) 55 | if len(dataloader.dataset[0]) > 5: 56 | for users, pre_sess_items, cur_sess_items, target_items, pre_sess_timestamps, cur_sess_timestamps, \ 57 | target_timestamps, negatives in dataloader: 58 | assert users.shape == (32, ) 59 | assert pre_sess_items.shape == (32, args.pre_sessions * args.max_session_len) 60 | assert cur_sess_items.shape == (32, args.max_session_len - args.target_len) 61 | assert target_items.shape == (32, args.target_len) 62 | assert pre_sess_timestamps.shape == (32, args.pre_sessions * args.max_session_len) 63 | assert cur_sess_timestamps.shape == (32, args.max_session_len - args.target_len) 64 | assert target_timestamps.shape == (32, args.target_len) 65 | assert negatives.shape == (32, args.target_len, 5) 66 | else: 67 | for users, in_items, out_items, in_timestamps, out_timestamps, negatives in dataloader: 68 | assert users.shape == (32, ) 69 | assert in_items.shape == (32, args.input_len) 70 | assert out_items.shape == (32, args.target_len) 71 | assert in_timestamps.shape == (32, args.input_len) 72 | assert out_timestamps.shape == (32, args.target_len) 73 | assert negatives.shape == (32, args.target_len, 5) 74 | 75 | 76 | # TODO Test Pytorch version DataLoader 77 | -------------------------------------------------------------------------------- /tests/test_process.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from argparse import Namespace 3 | 4 | from srdatasets.download import _download 5 | from srdatasets.process import _process 6 | from srdatasets.utils import get_downloaded_datasets 7 | 8 | # ===== Integration testing ===== 9 | 10 | args = Namespace(dataset="FourSquare-NYC", 11 | min_freq_item=10, 12 | min_freq_user=10, 13 | task="short", 14 | split_by="user", 15 | dev_split=0.1, 16 | test_split=0.2, 17 | input_len=9, 18 | target_len=1, 19 | session_interval=0, 20 | max_session_len=10, 21 | min_session_len=2, 22 | pre_sessions=10, 23 | pick_targets="random", 24 | no_augment=False, 25 | remove_duplicates=False) 26 | 27 | if args.dataset not in get_downloaded_datasets(): 28 | _download(args.dataset) 29 | 30 | 31 | def test_process_short_user(): 32 | local_args = copy.deepcopy(args) 33 | _process(local_args) 34 | 35 | 36 | def test_process_short_user_session(): 37 | local_args = copy.deepcopy(args) 38 | local_args.session_interval = 60 39 | _process(local_args) 40 | 41 | 42 | def test_process_short_time(monkeypatch): 43 | monkeypatch.setattr("builtins.input", lambda prompt="": 10) 44 | local_args = copy.deepcopy(args) 45 | local_args.split_by = "time" 46 | _process(local_args) 47 | 48 | 49 | def test_process_short_time_session(monkeypatch): 50 | monkeypatch.setattr("builtins.input", lambda prompt="": 10) 51 | local_args = copy.deepcopy(args) 52 | local_args.split_by = "time" 53 | local_args.session_interval = 60 54 | _process(local_args) 55 | 56 | 57 | def test_process_long_short_user(): 58 | local_args = copy.deepcopy(args) 59 | local_args.session_interval = 60 60 | local_args.task = "long-short" 61 | _process(local_args) 62 | 63 | 64 | def test_process_long_short_time(monkeypatch): 65 | monkeypatch.setattr("builtins.input", lambda prompt="": 10) 66 | local_args = copy.deepcopy(args) 67 | local_args.split_by = "time" 68 | local_args.session_interval = 60 69 | local_args.task = "long-short" 70 | _process(local_args) 71 | 72 | 73 | def test_no_augment_and_remove_duplicates(): 74 | local_args = copy.deepcopy(args) 75 | local_args.no_augment = True 76 | local_args.remove_duplicates = True 77 | _process(local_args) 78 | 79 | 80 | # ===== TODO Unit testing ===== 81 | --------------------------------------------------------------------------------