├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── bug_report_CN.md │ ├── feature_request.md │ └── feature_request_CN.md └── workflows │ └── python-package.yml ├── .gitignore ├── LICENSE ├── README.md ├── asset ├── arch.png ├── diginetica.png ├── ml-1m.png └── recbole-gnn-logo.png ├── recbole_gnn ├── config.py ├── data │ ├── __init__.py │ ├── dataloader.py │ ├── dataset.py │ └── transform.py ├── model │ ├── abstract_recommender.py │ ├── general_recommender │ │ ├── __init__.py │ │ ├── directau.py │ │ ├── hmlet.py │ │ ├── lightgcl.py │ │ ├── lightgcn.py │ │ ├── ncl.py │ │ ├── ngcf.py │ │ ├── sgl.py │ │ ├── simgcl.py │ │ ├── ssl4rec.py │ │ └── xsimgcl.py │ ├── layers.py │ ├── sequential_recommender │ │ ├── __init__.py │ │ ├── gcegnn.py │ │ ├── gcsan.py │ │ ├── lessr.py │ │ ├── niser.py │ │ ├── sgnnhn.py │ │ ├── srgnn.py │ │ └── tagnn.py │ └── social_recommender │ │ ├── __init__.py │ │ ├── diffnet.py │ │ ├── mhcn.py │ │ └── sept.py ├── properties │ ├── model │ │ ├── DiffNet.yaml │ │ ├── DirectAU.yaml │ │ ├── GCEGNN.yaml │ │ ├── GCSAN.yaml │ │ ├── HMLET.yaml │ │ ├── LESSR.yaml │ │ ├── LightGCL.yaml │ │ ├── LightGCN.yaml │ │ ├── MHCN.yaml │ │ ├── NCL.yaml │ │ ├── NGCF.yaml │ │ ├── NISER.yaml │ │ ├── SEPT.yaml │ │ ├── SGL.yaml │ │ ├── SGNNHN.yaml │ │ ├── SRGNN.yaml │ │ ├── SSL4REC.yaml │ │ ├── SimGCL.yaml │ │ ├── TAGNN.yaml │ │ └── XSimGCL.yaml │ └── quick_start_config │ │ ├── sequential_base.yaml │ │ └── social_base.yaml ├── quick_start.py ├── trainer.py └── utils.py ├── results ├── README.md ├── general │ └── ml-1m.md ├── sequential │ └── diginetica.md └── social │ └── lastfm.md ├── run_hyper.py ├── run_recbole_gnn.py ├── run_test.sh └── tests ├── test_data └── test │ ├── test.inter │ └── test.net ├── test_model.py └── test_model.yaml /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[\U0001F41BBUG] Describe your problem in one sentence." 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. extra yaml file 16 | 2. your code 17 | 3. script for running 18 | 19 | **Expected behavior** 20 | A clear and concise description of what you expected to happen. 21 | 22 | **Screenshots** 23 | If applicable, add screenshots to help explain your problem. 24 | 25 | **Colab Links** 26 | If applicable, add links to Colab or other Jupyter laboratory platforms that can reproduce the bug. 27 | 28 | **Desktop (please complete the following information):** 29 | - OS: [e.g. Linux, macOS or Windows] 30 | - RecBole Version [e.g. 0.1.0] 31 | - Python Version [e.g. 3.79] 32 | - PyTorch Version [e.g. 1.60] 33 | - cudatoolkit Version [e.g. 9.2, none] 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report_CN.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug 报告 3 | about: 提交一份 bug 报告,帮助 RecBole-GNN 变得更好 4 | title: "[\U0001F41BBUG] 用一句话描述您的问题。" 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **描述这个 bug** 11 | 对 bug 作一个清晰简明的描述。 12 | 13 | **如何复现** 14 | 复现这个 bug 的步骤: 15 | 1. 您引入的额外 yaml 文件 16 | 2. 您的代码 17 | 3. 您的运行脚本 18 | 19 | **预期** 20 | 对您的预期作清晰简明的描述。 21 | 22 | **屏幕截图** 23 | 添加屏幕截图以帮助解释您的问题。(可选) 24 | 25 | **链接** 26 | 添加能够复现 bug 的代码链接,如 Colab 或者其他在线 Jupyter 平台。(可选) 27 | 28 | **实验环境(请补全下列信息):** 29 | - 操作系统: [如 Linux, macOS 或 Windows] 30 | - RecBole 版本 [如 0.1.0] 31 | - Python 版本 [如 3.79] 32 | - PyTorch 版本 [如 1.60] 33 | - cudatoolkit 版本 [如 9.2, none] 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[\U0001F4A1SUG] Description of what you want to happen in one sentence" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request_CN.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 请求添加新功能 3 | about: 提出一个关于本项目新功能/新特性的建议 4 | title: "[\U0001F4A1SUG] 一句话描述您希望新增的功能或特性" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **您希望添加的功能是否与某个问题相关?** 11 | 关于这个问题的简洁清晰的描述,例如,当 [...] 时,我总是很沮丧。 12 | 13 | **描述您希望的解决方案** 14 | 关于解决方案的简洁清晰的描述。 15 | 16 | **描述您考虑的替代方案** 17 | 关于您考虑的,能实现这个功能的其他替代方案的简洁清晰的描述。 18 | 19 | **其他** 20 | 您可以添加其他任何的资料、链接或者屏幕截图,以帮助我们理解这个新功能。 21 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: RecBole-GNN tests 2 | 3 | # Controls when the action will run. 4 | on: 5 | # Triggers the workflow on push or pull request events but only for the master branch 6 | push: 7 | pull_request: 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.9] 19 | torch-version: [2.0.0] 20 | defaults: 21 | run: 22 | shell: bash -l {0} 23 | 24 | steps: 25 | - uses: actions/checkout@v2 26 | - name: Setup Miniconda 27 | uses: conda-incubator/setup-miniconda@v2 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | channels: conda-forge 31 | channel-priority: true 32 | auto-activate-base: true 33 | # install setuptools as a interim solution for bugs in PyTorch 1.10.2 (#69904) 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install pytest 38 | pip install dgl 39 | pip install torch==${{ matrix.torch-version}}+cpu -f https://download.pytorch.org/whl/torch_stable.html 40 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html 41 | pip install recbole==1.1.1 42 | conda install -c conda-forge faiss-cpu 43 | # Use "python -m pytest" instead of "pytest" to fix imports 44 | - name: Test model 45 | run: | 46 | python -m pytest -v tests/test_model.py 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # RecBole 132 | log_tensorboard/ 133 | saved/ 134 | dataset/ 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 RUCAIBox 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RecBole-GNN 2 | 3 | ![](asset/recbole-gnn-logo.png) 4 | 5 | ----- 6 | 7 | *Updates*: 8 | 9 | * [Oct 29, 2023] Add [SSL4Rec](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/ssl4rec.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/76, by [@downeykking](https://github.com/downeykking)) 10 | * [Oct 23, 2023] Add sparse tensor support, accelerating LightGCN & NGCF by ~5x, with 1/6 GPU memories. (https://github.com/RUCAIBox/RecBole-GNN/pull/75, by [@downeykking](https://github.com/downeykking)) 11 | * [Oct 20, 2023] Add [DirectAU](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/directau.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/74, by [@downeykking](https://github.com/downeykking)) 12 | * [Oct 16, 2023] Add [XSimGCL](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/xsimgcl.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/72, by [@downeykking](https://github.com/downeykking)) 13 | * [Apr 12, 2023] Add [LightGCL](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/lightgcl.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/63, by [@wending0417](https://github.com/wending0417)) 14 | * [Oct 29, 2022] Adaptation to RecBole 1.1.1. (https://github.com/RUCAIBox/RecBole-GNN/pull/53) 15 | * [Jun 15, 2022] Add [MultiBehaviorDataset](https://github.com/RUCAIBox/RecBole-GNN/blob/8c61463451b294dce9af2d1939a5e054f7955e0f/recbole_gnn/data/dataset.py#L145). (https://github.com/RUCAIBox/RecBole-GNN/pull/43, by [@Tokkiu](https://github.com/Tokkiu)) 16 | 17 | ----- 18 | 19 | **RecBole-GNN** is a library built upon [PyTorch](https://pytorch.org) and [RecBole](https://github.com/RUCAIBox/RecBole) for reproducing and developing recommendation algorithms based on graph neural networks (GNNs). Our library includes algorithms covering three major categories: 20 | * **General Recommendation** with user-item interaction graphs; 21 | * **Sequential Recommendation** with session/sequence graphs; 22 | * **Social Recommendation** with social networks. 23 | 24 | ![](asset/arch.png) 25 | 26 | ## Highlights 27 | 28 | * **Easy-to-use and unified API**: 29 | Our library shares unified API and input (atomic files) as RecBole. 30 | * **Efficient and reusable graph processing**: 31 | We provide highly efficient and reusable basic datasets, dataloaders and layers for graph processing and learning. 32 | * **Extensive graph library**: 33 | Graph neural networks from widely-used library like [PyG](https://github.com/pyg-team/pytorch_geometric) are incorporated. Recently proposed graph algorithms can be easily equipped and compared with existing methods. 34 | 35 | ## Requirements 36 | 37 | ``` 38 | recbole==1.1.1 39 | pyg>=2.0.4 40 | pytorch>=1.7.0 41 | python>=3.7.0 42 | ``` 43 | 44 | > If you are using `recbole==1.0.1`, please refer to our `recbole1.0.1` branch [[link]](https://github.com/hyp1231/RecBole-GNN/tree/recbole1.0.1). 45 | 46 | ## Quick-Start 47 | 48 | With the source code, you can use the provided script for initial usage of our library: 49 | 50 | ```bash 51 | python run_recbole_gnn.py 52 | ``` 53 | 54 | If you want to change the models or datasets, just run the script by setting additional command parameters: 55 | 56 | ```bash 57 | python run_recbole_gnn.py -m [model] -d [dataset] 58 | ``` 59 | 60 | ## Implemented Models 61 | 62 | We list currently supported models according to category: 63 | 64 | **General Recommendation**: 65 | 66 | * **[NGCF](recbole_gnn/model/general_recommender/ngcf.py)** from Wang *et al.*: [Neural Graph Collaborative Filtering](https://arxiv.org/abs/1905.08108) (SIGIR 2019). 67 | * **[LightGCN](recbole_gnn/model/general_recommender/lightgcn.py)** from He *et al.*: [LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation](https://arxiv.org/abs/2002.02126) (SIGIR 2020). 68 | * **[SSL4Rec](recbole_gnn/model/general_recommender/ssl4rec.py)** from Yao *et al.*: [Self-supervised Learning for Large-scale Item Recommendations](https://arxiv.org/abs/2007.12865) (CIKM 2021). 69 | * **[SGL](recbole_gnn/model/general_recommender/sgl.py)** from Wu *et al.*: [Self-supervised Graph Learning for Recommendation](https://arxiv.org/abs/2010.10783) (SIGIR 2021). 70 | * **[HMLET](recbole_gnn/model/general_recommender/hmlet.py)** from Kong *et al.*: [Linear, or Non-Linear, That is the Question!](https://arxiv.org/abs/2111.07265) (WSDM 2022). 71 | * **[NCL](recbole_gnn/model/general_recommender/ncl.py)** from Lin *et al.*: [Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning](https://arxiv.org/abs/2202.06200) (TheWebConf 2022). 72 | * **[DirectAU](recbole_gnn/model/general_recommender/directau.py)** from Wang *et al.*: [Towards Representation Alignment and Uniformity in Collaborative Filtering](https://arxiv.org/abs/2206.12811) (KDD 2022). 73 | * **[SimGCL](recbole_gnn/model/general_recommender/simgcl.py)** from Yu *et al.*: [Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for Recommendation](https://arxiv.org/abs/2112.08679) (SIGIR 2022). 74 | * **[XSimGCL](recbole_gnn/model/general_recommender/xsimgcl.py)** from Yu *et al.*: [XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation](https://arxiv.org/abs/2209.02544) (TKDE 2023). 75 | * **[LightGCL](recbole_gnn/model/general_recommender/lightgcl.py)** from Cai *et al.*: [LightGCL: Simple Yet Effective Graph Contrastive Learning for Recommendation 76 | ](https://arxiv.org/abs/2302.08191) (ICLR 2023). 77 | 78 | **Sequential Recommendation**: 79 | 80 | * **[SR-GNN](recbole_gnn/model/sequential_recommender/srgnn.py)** from Wu *et al.*: [Session-based Recommendation with Graph Neural Networks](https://arxiv.org/abs/1811.00855) (AAAI 2019). 81 | * **[GC-SAN](recbole_gnn/model/sequential_recommender/gcsan.py)** from Xu *et al.*: [Graph Contextualized Self-Attention Network for Session-based Recommendation](https://www.ijcai.org/proceedings/2019/547) (IJCAI 2019). 82 | * **[NISER+](recbole_gnn/model/sequential_recommender/niser.py)** from Gupta *et al.*: [NISER: Normalized Item and Session Representations to Handle Popularity Bias](https://arxiv.org/abs/1909.04276) (GRLA, CIKM 2019 workshop). 83 | * **[LESSR](recbole_gnn/model/sequential_recommender/lessr.py)** from Chen *et al.*: [Handling Information Loss of Graph Neural Networks for Session-based Recommendation](https://dl.acm.org/doi/10.1145/3394486.3403170) (KDD 2020). 84 | * **[TAGNN](recbole_gnn/model/sequential_recommender/tagnn.py)** from Yu *et al.*: [TAGNN: Target Attentive Graph Neural Networks for Session-based Recommendation](https://arxiv.org/abs/2005.02844) (SIGIR 2020 short). 85 | * **[GCE-GNN](recbole_gnn/model/sequential_recommender/gcegnn.py)** from Wang *et al.*: [Global Context Enhanced Graph Neural Networks for Session-based Recommendation](https://arxiv.org/abs/2106.05081) (SIGIR 2020). 86 | * **[SGNN-HN](recbole_gnn/model/sequential_recommender/sgnnhn.py)** from Pan *et al.*: [Star Graph Neural Networks for Session-based Recommendation](https://dl.acm.org/doi/10.1145/3340531.3412014) (CIKM 2020). 87 | 88 | **Social Recommendation**: 89 | 90 | > Note that datasets for social recommendation methods can be downloaded from [Social-Datasets](https://github.com/Sherry-XLL/Social-Datasets). 91 | 92 | * **[DiffNet](recbole_gnn/model/social_recommender/diffnet.py)** from Wu *et al.*: [A Neural Influence Diffusion Model for Social Recommendation](https://arxiv.org/abs/1904.10322) (SIGIR 2019). 93 | * **[MHCN](recbole_gnn/model/social_recommender/mhcn.py)** from Yu *et al.*: [Self-Supervised Multi-Channel Hypergraph Convolutional Network for Social Recommendation](https://doi.org/10.1145/3442381.3449844) (WWW 2021). 94 | * **[SEPT](recbole_gnn/model/social_recommender/sept.py)** from Yu *et al.*: [Socially-Aware Self-Supervised Tri-Training for Recommendation](https://doi.org/10.1145/3447548.3467340) (KDD 2021). 95 | 96 | ## Result 97 | 98 | ### Leaderboard 99 | 100 | We carefully tune the hyper-parameters of the implemented models of each research field and release the corresponding leaderboards for reference: 101 | 102 | - **General** recommendation on `MovieLens-1M` dataset [[link]](results/general/ml-1m.md); 103 | - **Sequential** recommendation on `Diginetica` dataset [[link]](results/sequential/diginetica.md); 104 | - **Social** recommendation on `LastFM` dataset [[link]](results/social/lastfm.md); 105 | 106 | ### Efficiency 107 | 108 | With the sequential/session graphs preprocessing technique, as well as efficient GNN layers, we speed up the training process of our sequential recommenders a lot. 109 | 110 | 111 | 112 | ## The Team 113 | 114 | RecBole-GNN is initially developed and maintained by members from [RUCAIBox](http://aibox.ruc.edu.cn/), the main developers are Yupeng Hou ([@hyp1231](https://github.com/hyp1231)), Lanling Xu ([@Sherry-XLL](https://github.com/Sherry-XLL)) and Changxin Tian ([@ChangxinTian](https://github.com/ChangxinTian)). We also thank Xinzhou ([@downeykking](https://github.com/downeykking)), Wanli ([@wending0417](https://github.com/wending0417)), and Jingqi ([@Tokkiu](https://github.com/Tokkiu)) for their great contribution! ❤️ 115 | 116 | ## Acknowledgement 117 | 118 | The implementation is based on the open-source recommendation library [RecBole](https://github.com/RUCAIBox/RecBole). RecBole-GNN is part of [RecBole 2.0](https://github.com/RUCAIBox/RecBole2.0) now! 119 | 120 | Please cite the following paper as the reference if you use our code or processed datasets. 121 | 122 | ```bibtex 123 | @inproceedings{zhao2022recbole2, 124 | author={Wayne Xin Zhao and Yupeng Hou and Xingyu Pan and Chen Yang and Zeyu Zhang and Zihan Lin and Jingsen Zhang and Shuqing Bian and Jiakai Tang and Wenqi Sun and Yushuo Chen and Lanling Xu and Gaowei Zhang and Zhen Tian and Changxin Tian and Shanlei Mu and Xinyan Fan and Xu Chen and Ji-Rong Wen}, 125 | title={RecBole 2.0: Towards a More Up-to-Date Recommendation Library}, 126 | booktitle = {{CIKM}}, 127 | year={2022} 128 | } 129 | 130 | @inproceedings{zhao2021recbole, 131 | author = {Wayne Xin Zhao and Shanlei Mu and Yupeng Hou and Zihan Lin and Yushuo Chen and Xingyu Pan and Kaiyuan Li and Yujie Lu and Hui Wang and Changxin Tian and Yingqian Min and Zhichao Feng and Xinyan Fan and Xu Chen and Pengfei Wang and Wendi Ji and Yaliang Li and Xiaoling Wang and Ji{-}Rong Wen}, 132 | title = {RecBole: Towards a Unified, Comprehensive and Efficient Framework for Recommendation Algorithms}, 133 | booktitle = {{CIKM}}, 134 | pages = {4653--4664}, 135 | publisher = {{ACM}}, 136 | year = {2021} 137 | } 138 | ``` 139 | -------------------------------------------------------------------------------- /asset/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RUCAIBox/RecBole-GNN/632ef888589944c190ad8f449b49ca559618d4df/asset/arch.png -------------------------------------------------------------------------------- /asset/diginetica.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RUCAIBox/RecBole-GNN/632ef888589944c190ad8f449b49ca559618d4df/asset/diginetica.png -------------------------------------------------------------------------------- /asset/ml-1m.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RUCAIBox/RecBole-GNN/632ef888589944c190ad8f449b49ca559618d4df/asset/ml-1m.png -------------------------------------------------------------------------------- /asset/recbole-gnn-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RUCAIBox/RecBole-GNN/632ef888589944c190ad8f449b49ca559618d4df/asset/recbole-gnn-logo.png -------------------------------------------------------------------------------- /recbole_gnn/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import recbole 3 | from recbole.config.configurator import Config as RecBole_Config 4 | from recbole.utils import ModelType as RecBoleModelType 5 | 6 | from recbole_gnn.utils import get_model, ModelType 7 | 8 | 9 | class Config(RecBole_Config): 10 | def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=None): 11 | """ 12 | Args: 13 | model (str/AbstractRecommender): the model name or the model class, default is None, if it is None, config 14 | will search the parameter 'model' from the external input as the model name or model class. 15 | dataset (str): the dataset name, default is None, if it is None, config will search the parameter 'dataset' 16 | from the external input as the dataset name. 17 | config_file_list (list of str): the external config file, it allows multiple config files, default is None. 18 | config_dict (dict): the external parameter dictionaries, default is None. 19 | """ 20 | if recbole.__version__ == "1.1.1": 21 | self.compatibility_settings() 22 | super(Config, self).__init__(model, dataset, config_file_list, config_dict) 23 | 24 | def compatibility_settings(self): 25 | import numpy as np 26 | np.bool = np.bool_ 27 | np.int = np.int_ 28 | np.float = np.float_ 29 | np.complex = np.complex_ 30 | np.object = np.object_ 31 | np.str = np.str_ 32 | np.long = np.int_ 33 | np.unicode = np.unicode_ 34 | 35 | def _get_model_and_dataset(self, model, dataset): 36 | 37 | if model is None: 38 | try: 39 | model = self.external_config_dict['model'] 40 | except KeyError: 41 | raise KeyError( 42 | 'model need to be specified in at least one of the these ways: ' 43 | '[model variable, config file, config dict, command line] ' 44 | ) 45 | if not isinstance(model, str): 46 | final_model_class = model 47 | final_model = model.__name__ 48 | else: 49 | final_model = model 50 | final_model_class = get_model(final_model) 51 | 52 | if dataset is None: 53 | try: 54 | final_dataset = self.external_config_dict['dataset'] 55 | except KeyError: 56 | raise KeyError( 57 | 'dataset need to be specified in at least one of the these ways: ' 58 | '[dataset variable, config file, config dict, command line] ' 59 | ) 60 | else: 61 | final_dataset = dataset 62 | 63 | return final_model, final_model_class, final_dataset 64 | 65 | def _load_internal_config_dict(self, model, model_class, dataset): 66 | super()._load_internal_config_dict(model, model_class, dataset) 67 | current_path = os.path.dirname(os.path.realpath(__file__)) 68 | model_init_file = os.path.join(current_path, './properties/model/' + model + '.yaml') 69 | quick_start_config_path = os.path.join(current_path, './properties/quick_start_config/') 70 | sequential_base_init = os.path.join(quick_start_config_path, 'sequential_base.yaml') 71 | social_base_init = os.path.join(quick_start_config_path, 'social_base.yaml') 72 | 73 | if os.path.isfile(model_init_file): 74 | config_dict = self._update_internal_config_dict(model_init_file) 75 | 76 | self.internal_config_dict['MODEL_TYPE'] = model_class.type 77 | if self.internal_config_dict['MODEL_TYPE'] == RecBoleModelType.SEQUENTIAL: 78 | self._update_internal_config_dict(sequential_base_init) 79 | if self.internal_config_dict['MODEL_TYPE'] == ModelType.SOCIAL: 80 | self._update_internal_config_dict(social_base_init) 81 | -------------------------------------------------------------------------------- /recbole_gnn/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RUCAIBox/RecBole-GNN/632ef888589944c190ad8f449b49ca559618d4df/recbole_gnn/data/__init__.py -------------------------------------------------------------------------------- /recbole_gnn/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from recbole.data.interaction import cat_interactions 4 | from recbole.data.dataloader.general_dataloader import TrainDataLoader, NegSampleEvalDataLoader, FullSortEvalDataLoader 5 | 6 | from recbole_gnn.data.transform import gnn_construct_transform 7 | 8 | 9 | class CustomizedTrainDataLoader(TrainDataLoader): 10 | def __init__(self, config, dataset, sampler, shuffle=False): 11 | super().__init__(config, dataset, sampler, shuffle=shuffle) 12 | if config['gnn_transform'] is not None: 13 | self.transform = gnn_construct_transform(config) 14 | 15 | 16 | class CustomizedNegSampleEvalDataLoader(NegSampleEvalDataLoader): 17 | def __init__(self, config, dataset, sampler, shuffle=False): 18 | super().__init__(config, dataset, sampler, shuffle=shuffle) 19 | if config['gnn_transform'] is not None: 20 | self.transform = gnn_construct_transform(config) 21 | 22 | def collate_fn(self, index): 23 | index = np.array(index) 24 | if ( 25 | self.neg_sample_args["distribution"] != "none" 26 | and self.neg_sample_args["sample_num"] != "none" 27 | ): 28 | uid_list = self.uid_list[index] 29 | data_list = [] 30 | idx_list = [] 31 | positive_u = [] 32 | positive_i = torch.tensor([], dtype=torch.int64) 33 | 34 | for idx, uid in enumerate(uid_list): 35 | index = self.uid2index[uid] 36 | data_list.append(self._neg_sampling(self._dataset[index])) 37 | idx_list += [idx for i in range(self.uid2items_num[uid] * self.times)] 38 | positive_u += [idx for i in range(self.uid2items_num[uid])] 39 | positive_i = torch.cat( 40 | (positive_i, self._dataset[index][self.iid_field]), 0 41 | ) 42 | 43 | cur_data = cat_interactions(data_list) 44 | idx_list = torch.from_numpy(np.array(idx_list)).long() 45 | positive_u = torch.from_numpy(np.array(positive_u)).long() 46 | 47 | return self.transform(self._dataset, cur_data), idx_list, positive_u, positive_i 48 | else: 49 | data = self._dataset[index] 50 | transformed_data = self.transform(self._dataset, data) 51 | cur_data = self._neg_sampling(transformed_data) 52 | return cur_data, None, None, None 53 | 54 | 55 | class CustomizedFullSortEvalDataLoader(FullSortEvalDataLoader): 56 | def __init__(self, config, dataset, sampler, shuffle=False): 57 | super().__init__(config, dataset, sampler, shuffle=shuffle) 58 | if config['gnn_transform'] is not None: 59 | self.transform = gnn_construct_transform(config) 60 | -------------------------------------------------------------------------------- /recbole_gnn/data/transform.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | import torch 3 | from torch.nn.utils.rnn import pad_sequence 4 | from recbole.data.interaction import Interaction 5 | 6 | 7 | def gnn_construct_transform(config): 8 | if config['gnn_transform'] is None: 9 | raise ValueError('config["gnn_transform"] is None but trying to construct transform.') 10 | str2transform = { 11 | 'sess_graph': SessionGraph, 12 | } 13 | return str2transform[config['gnn_transform']](config) 14 | 15 | 16 | class SessionGraph: 17 | def __init__(self, config): 18 | self.logger = getLogger() 19 | self.logger.info('SessionGraph Transform in DataLoader.') 20 | 21 | def __call__(self, dataset, interaction): 22 | graph_objs = dataset.graph_objs 23 | index = interaction['graph_idx'] 24 | graph_batch = { 25 | k: [graph_objs[k][_.item()] for _ in index] 26 | for k in graph_objs 27 | } 28 | graph_batch['batch'] = [] 29 | 30 | tot_node_num = torch.ones([1], dtype=torch.long) 31 | for i in range(index.shape[0]): 32 | for k in graph_batch: 33 | if 'edge_index' in k: 34 | graph_batch[k][i] = graph_batch[k][i] + tot_node_num 35 | if 'alias_inputs' in graph_batch: 36 | graph_batch['alias_inputs'][i] = graph_batch['alias_inputs'][i] + tot_node_num 37 | graph_batch['batch'].append(torch.full_like(graph_batch['x'][i], i)) 38 | tot_node_num += graph_batch['x'][i].shape[0] 39 | 40 | if hasattr(dataset, 'node_attr'): 41 | node_attr = ['batch'] + dataset.node_attr 42 | else: 43 | node_attr = ['x', 'batch'] 44 | for k in node_attr: 45 | graph_batch[k] = [torch.zeros([1], dtype=graph_batch[k][-1].dtype)] + graph_batch[k] 46 | 47 | for k in graph_batch: 48 | if k == 'alias_inputs': 49 | graph_batch[k] = pad_sequence(graph_batch[k], batch_first=True) 50 | else: 51 | graph_batch[k] = torch.cat(graph_batch[k], dim=-1) 52 | 53 | interaction.update(Interaction(graph_batch)) 54 | return interaction 55 | -------------------------------------------------------------------------------- /recbole_gnn/model/abstract_recommender.py: -------------------------------------------------------------------------------- 1 | from recbole.model.abstract_recommender import GeneralRecommender 2 | from recbole.utils import ModelType as RecBoleModelType 3 | 4 | from recbole_gnn.utils import ModelType 5 | 6 | 7 | class GeneralGraphRecommender(GeneralRecommender): 8 | """This is an abstract general graph recommender. All the general graph models should implement in this class. 9 | The base general graph recommender class provide the basic U-I graph dataset and parameters information. 10 | """ 11 | type = RecBoleModelType.GENERAL 12 | 13 | def __init__(self, config, dataset): 14 | super(GeneralGraphRecommender, self).__init__(config, dataset) 15 | self.edge_index, self.edge_weight = dataset.get_norm_adj_mat(enable_sparse=config["enable_sparse"]) 16 | self.use_sparse = config["enable_sparse"] and dataset.is_sparse 17 | if self.use_sparse: 18 | self.edge_index, self.edge_weight = self.edge_index.to(self.device), None 19 | else: 20 | self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device) 21 | 22 | 23 | class SocialRecommender(GeneralRecommender): 24 | """This is an abstract social recommender. All the social graph model should implement this class. 25 | The base social recommender class provide the basic social graph dataset and parameters information. 26 | """ 27 | type = ModelType.SOCIAL 28 | 29 | def __init__(self, config, dataset): 30 | super(SocialRecommender, self).__init__(config, dataset) 31 | -------------------------------------------------------------------------------- /recbole_gnn/model/general_recommender/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole_gnn.model.general_recommender.lightgcn import LightGCN 2 | from recbole_gnn.model.general_recommender.hmlet import HMLET 3 | from recbole_gnn.model.general_recommender.ncl import NCL 4 | from recbole_gnn.model.general_recommender.ngcf import NGCF 5 | from recbole_gnn.model.general_recommender.sgl import SGL 6 | from recbole_gnn.model.general_recommender.lightgcl import LightGCL 7 | from recbole_gnn.model.general_recommender.simgcl import SimGCL 8 | from recbole_gnn.model.general_recommender.xsimgcl import XSimGCL 9 | from recbole_gnn.model.general_recommender.directau import DirectAU 10 | from recbole_gnn.model.general_recommender.ssl4rec import SSL4REC 11 | -------------------------------------------------------------------------------- /recbole_gnn/model/general_recommender/directau.py: -------------------------------------------------------------------------------- 1 | # r""" 2 | # DiretAU 3 | # ################################################ 4 | # Reference: 5 | # Chenyang Wang et al. "Towards Representation Alignment and Uniformity in Collaborative Filtering." in KDD 2022. 6 | 7 | # Reference code: 8 | # https://github.com/THUwangcy/DirectAU 9 | # """ 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from recbole.model.init import xavier_normal_initialization 17 | from recbole.utils import InputType 18 | from recbole.model.general_recommender import BPR 19 | from recbole_gnn.model.general_recommender import LightGCN 20 | 21 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender 22 | 23 | 24 | class DirectAU(GeneralGraphRecommender): 25 | input_type = InputType.PAIRWISE 26 | 27 | def __init__(self, config, dataset): 28 | super(DirectAU, self).__init__(config, dataset) 29 | 30 | # load parameters info 31 | self.embedding_size = config['embedding_size'] 32 | self.gamma = config['gamma'] 33 | self.encoder_name = config['encoder'] 34 | 35 | # define encoder 36 | if self.encoder_name == 'MF': 37 | self.encoder = MFEncoder(config, dataset) 38 | elif self.encoder_name == 'LightGCN': 39 | self.encoder = LGCNEncoder(config, dataset) 40 | else: 41 | raise ValueError('Non-implemented Encoder.') 42 | 43 | # storage variables for full sort evaluation acceleration 44 | self.restore_user_e = None 45 | self.restore_item_e = None 46 | 47 | # parameters initialization 48 | self.apply(xavier_normal_initialization) 49 | 50 | def forward(self, user, item): 51 | user_e, item_e = self.encoder(user, item) 52 | return F.normalize(user_e, dim=-1), F.normalize(item_e, dim=-1) 53 | 54 | @staticmethod 55 | def alignment(x, y, alpha=2): 56 | return (x - y).norm(p=2, dim=1).pow(alpha).mean() 57 | 58 | @staticmethod 59 | def uniformity(x, t=2): 60 | return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log() 61 | 62 | def calculate_loss(self, interaction): 63 | if self.restore_user_e is not None or self.restore_item_e is not None: 64 | self.restore_user_e, self.restore_item_e = None, None 65 | 66 | user = interaction[self.USER_ID] 67 | item = interaction[self.ITEM_ID] 68 | 69 | user_e, item_e = self.forward(user, item) 70 | align = self.alignment(user_e, item_e) 71 | uniform = self.gamma * (self.uniformity(user_e) + self.uniformity(item_e)) / 2 72 | 73 | return align, uniform 74 | 75 | def predict(self, interaction): 76 | user = interaction[self.USER_ID] 77 | item = interaction[self.ITEM_ID] 78 | user_e = self.user_embedding(user) 79 | item_e = self.item_embedding(item) 80 | return torch.mul(user_e, item_e).sum(dim=1) 81 | 82 | def full_sort_predict(self, interaction): 83 | user = interaction[self.USER_ID] 84 | if self.encoder_name == 'LightGCN': 85 | if self.restore_user_e is None or self.restore_item_e is None: 86 | self.restore_user_e, self.restore_item_e = self.encoder.get_all_embeddings() 87 | user_e = self.restore_user_e[user] 88 | all_item_e = self.restore_item_e 89 | else: 90 | user_e = self.encoder.user_embedding(user) 91 | all_item_e = self.encoder.item_embedding.weight 92 | score = torch.matmul(user_e, all_item_e.transpose(0, 1)) 93 | return score.view(-1) 94 | 95 | 96 | class MFEncoder(BPR): 97 | def __init__(self, config, dataset): 98 | super(MFEncoder, self).__init__(config, dataset) 99 | 100 | def forward(self, user_id, item_id): 101 | return super().forward(user_id, item_id) 102 | 103 | def get_all_embeddings(self): 104 | user_embeddings = self.user_embedding.weight 105 | item_embeddings = self.item_embedding.weight 106 | return user_embeddings, item_embeddings 107 | 108 | 109 | class LGCNEncoder(LightGCN): 110 | def __init__(self, config, dataset): 111 | super(LGCNEncoder, self).__init__(config, dataset) 112 | 113 | def forward(self, user_id, item_id): 114 | user_all_embeddings, item_all_embeddings = self.get_all_embeddings() 115 | u_embed = user_all_embeddings[user_id] 116 | i_embed = item_all_embeddings[item_id] 117 | return u_embed, i_embed 118 | 119 | def get_all_embeddings(self): 120 | return super().forward() 121 | -------------------------------------------------------------------------------- /recbole_gnn/model/general_recommender/hmlet.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/3/21 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | r""" 6 | HMLET 7 | ################################################ 8 | Reference: 9 | Taeyong Kong et al. "Linear, or Non-Linear, That is the Question!." in WSDM 2022. 10 | 11 | Reference code: 12 | https://github.com/qbxlvnf11/HMLET 13 | """ 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from recbole.model.init import xavier_uniform_initialization 19 | from recbole.model.loss import BPRLoss, EmbLoss 20 | from recbole.model.layers import activation_layer 21 | from recbole.utils import InputType 22 | 23 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender 24 | from recbole_gnn.model.layers import LightGCNConv 25 | 26 | 27 | class Gating_Net(nn.Module): 28 | def __init__(self, embedding_dim, mlp_dims, dropout_p): 29 | super(Gating_Net, self).__init__() 30 | self.embedding_dim = embedding_dim 31 | 32 | fc_layers = [] 33 | for i in range(len(mlp_dims)): 34 | if i == 0: 35 | fc = nn.Linear(embedding_dim*2, mlp_dims[i]) 36 | fc_layers.append(fc) 37 | else: 38 | fc = nn.Linear(mlp_dims[i-1], mlp_dims[i]) 39 | fc_layers.append(fc) 40 | if i != len(mlp_dims) - 1: 41 | fc_layers.append(nn.BatchNorm1d(mlp_dims[i])) 42 | fc_layers.append(nn.Dropout(p=dropout_p)) 43 | fc_layers.append(nn.ReLU(inplace=True)) 44 | self.mlp = nn.Sequential(*fc_layers) 45 | 46 | def gumbel_softmax(self, logits, temperature, hard): 47 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 48 | Args: 49 | logits: [batch_size, n_class] unnormalized log-probs 50 | temperature: non-negative scalar 51 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 52 | Returns: 53 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 54 | If hard=True, then the returned sample will be one-hot, otherwise it will 55 | be a probabilitiy distribution that sums to 1 across classes 56 | """ 57 | y = self.gumbel_softmax_sample(logits, temperature) ## (0.6, 0.2, 0.1,..., 0.11) 58 | if hard: 59 | k = logits.size(1) # k is numb of classes 60 | # y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype) ## (1, 0, 0, ..., 0) 61 | y_hard = torch.eq(y, torch.max(y, dim=1, keepdim=True)[0]).type_as(y) 62 | y = (y_hard - y).detach() + y 63 | return y 64 | 65 | def gumbel_softmax_sample(self, logits, temperature): 66 | """ Draw a sample from the Gumbel-Softmax distribution""" 67 | noise = self.sample_gumbel(logits) 68 | y = (logits + noise) / temperature 69 | return F.softmax(y, dim=1) 70 | 71 | def sample_gumbel(self, logits): 72 | """Sample from Gumbel(0, 1)""" 73 | noise = torch.rand(logits.size()) 74 | eps = 1e-20 75 | noise.add_(eps).log_().neg_() 76 | noise.add_(eps).log_().neg_() 77 | return torch.Tensor(noise.float()).to(logits.device) 78 | 79 | def forward(self, feature, temperature, hard): 80 | x = self.mlp(feature) 81 | out = self.gumbel_softmax(x, temperature, hard) 82 | out_value = out.unsqueeze(2) 83 | gating_out = out_value.repeat(1, 1, self.embedding_dim) 84 | return gating_out 85 | 86 | 87 | class HMLET(GeneralGraphRecommender): 88 | r"""HMLET combines both linear and non-linear propagation layers for general recommendation and yields better performance. 89 | """ 90 | input_type = InputType.PAIRWISE 91 | 92 | def __init__(self, config, dataset): 93 | super(HMLET, self).__init__(config, dataset) 94 | 95 | # load parameters info 96 | self.latent_dim = config['embedding_size'] # int type:the embedding size of lightGCN 97 | self.n_layers = config['n_layers'] # int type:the layer num of lightGCN 98 | self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization 99 | self.require_pow = config['require_pow'] # bool type: whether to require pow when regularization 100 | self.gate_layer_ids = config['gate_layer_ids'] # list type: layer ids for non-linear gating 101 | self.gating_mlp_dims = config['gating_mlp_dims'] # list type: list of mlp dimensions in gating module 102 | self.dropout_ratio = config['dropout_ratio'] # dropout ratio for mlp in gating module 103 | self.gum_temp = config['ori_temp'] 104 | self.logger.info(f'Model initialization, gumbel softmax temperature: {self.gum_temp}') 105 | 106 | # define layers and loss 107 | self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim) 108 | self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim) 109 | self.gcn_conv = LightGCNConv(dim=self.latent_dim) 110 | self.activation = nn.ELU() if config['activation_function'] == 'elu' else activation_layer(config['activation_function']) 111 | self.gating_nets = nn.ModuleList([ 112 | Gating_Net(self.latent_dim, self.gating_mlp_dims, self.dropout_ratio) for _ in range(len(self.gate_layer_ids)) 113 | ]) 114 | 115 | self.mf_loss = BPRLoss() 116 | self.reg_loss = EmbLoss() 117 | 118 | # storage variables for full sort evaluation acceleration 119 | self.restore_user_e = None 120 | self.restore_item_e = None 121 | 122 | # parameters initialization 123 | self.apply(xavier_uniform_initialization) 124 | self.other_parameter_name = ['restore_user_e', 'restore_item_e', 'gum_temp'] 125 | 126 | for gating in self.gating_nets: 127 | self._gating_freeze(gating, False) 128 | 129 | def _gating_freeze(self, model, freeze_flag): 130 | for name, child in model.named_children(): 131 | for param in child.parameters(): 132 | param.requires_grad = freeze_flag 133 | 134 | def __choosing_one(self, features, gumbel_out): 135 | feature = torch.sum(torch.mul(features, gumbel_out), dim=1) # batch x embedding_dim (or batch x embedding_dim x layer_num) 136 | return feature 137 | 138 | def __where(self, idx, lst): 139 | for i in range(len(lst)): 140 | if lst[i] == idx: 141 | return i 142 | raise ValueError(f'{idx} not in {lst}.') 143 | 144 | def get_ego_embeddings(self): 145 | r"""Get the embedding of users and items and combine to an embedding matrix. 146 | Returns: 147 | Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim] 148 | """ 149 | user_embeddings = self.user_embedding.weight 150 | item_embeddings = self.item_embedding.weight 151 | ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0) 152 | return ego_embeddings 153 | 154 | def forward(self): 155 | all_embeddings = self.get_ego_embeddings() 156 | embeddings_list = [all_embeddings] 157 | non_lin_emb_list = [all_embeddings] 158 | 159 | for layer_idx in range(self.n_layers): 160 | linear_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight) 161 | if layer_idx not in self.gate_layer_ids: 162 | all_embeddings = linear_embeddings 163 | else: 164 | non_lin_id = self.__where(layer_idx, self.gate_layer_ids) 165 | last_non_lin_emb = non_lin_emb_list[non_lin_id] 166 | non_lin_embeddings = self.activation(self.gcn_conv(last_non_lin_emb, self.edge_index, self.edge_weight)) 167 | stack_embeddings = torch.stack([linear_embeddings, non_lin_embeddings], dim=1) 168 | concat_embeddings = torch.cat((linear_embeddings, non_lin_embeddings), dim=-1) 169 | gumbel_out = self.gating_nets[non_lin_id](concat_embeddings, self.gum_temp, not self.training) 170 | all_embeddings = self.__choosing_one(stack_embeddings, gumbel_out) 171 | non_lin_emb_list.append(all_embeddings) 172 | embeddings_list.append(all_embeddings) 173 | hmlet_all_embeddings = torch.stack(embeddings_list, dim=1) 174 | hmlet_all_embeddings = torch.mean(hmlet_all_embeddings, dim=1) 175 | 176 | user_all_embeddings, item_all_embeddings = torch.split(hmlet_all_embeddings, [self.n_users, self.n_items]) 177 | return user_all_embeddings, item_all_embeddings 178 | 179 | def calculate_loss(self, interaction): 180 | # clear the storage variable when training 181 | if self.restore_user_e is not None or self.restore_item_e is not None: 182 | self.restore_user_e, self.restore_item_e = None, None 183 | 184 | user = interaction[self.USER_ID] 185 | pos_item = interaction[self.ITEM_ID] 186 | neg_item = interaction[self.NEG_ITEM_ID] 187 | 188 | user_all_embeddings, item_all_embeddings = self.forward() 189 | u_embeddings = user_all_embeddings[user] 190 | pos_embeddings = item_all_embeddings[pos_item] 191 | neg_embeddings = item_all_embeddings[neg_item] 192 | 193 | # calculate BPR Loss 194 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) 195 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) 196 | mf_loss = self.mf_loss(pos_scores, neg_scores) 197 | 198 | # calculate regularization Loss 199 | u_ego_embeddings = self.user_embedding(user) 200 | pos_ego_embeddings = self.item_embedding(pos_item) 201 | neg_ego_embeddings = self.item_embedding(neg_item) 202 | 203 | reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow) 204 | loss = mf_loss + self.reg_weight * reg_loss 205 | 206 | return loss 207 | 208 | def predict(self, interaction): 209 | user = interaction[self.USER_ID] 210 | item = interaction[self.ITEM_ID] 211 | 212 | user_all_embeddings, item_all_embeddings = self.forward() 213 | 214 | u_embeddings = user_all_embeddings[user] 215 | i_embeddings = item_all_embeddings[item] 216 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) 217 | return scores 218 | 219 | def full_sort_predict(self, interaction): 220 | user = interaction[self.USER_ID] 221 | if self.restore_user_e is None or self.restore_item_e is None: 222 | self.restore_user_e, self.restore_item_e = self.forward() 223 | # get user embedding from storage variable 224 | u_embeddings = self.restore_user_e[user] 225 | 226 | # dot with all item embedding to accelerate 227 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) 228 | 229 | return scores.view(-1) -------------------------------------------------------------------------------- /recbole_gnn/model/general_recommender/lightgcl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/04/12 3 | # @Author : Wanli Yang 4 | # @Email : 2013774@mail.nankai.edu.cn 5 | 6 | r""" 7 | LightGCL 8 | ################################################ 9 | Reference: 10 | Xuheng Cai et al. "LightGCL: Simple Yet Effective Graph Contrastive Learning for Recommendation" in ICLR 2023. 11 | 12 | Reference code: 13 | https://github.com/HKUDS/LightGCL 14 | """ 15 | 16 | import numpy as np 17 | import scipy.sparse as sp 18 | import torch 19 | import torch.nn as nn 20 | from recbole.model.abstract_recommender import GeneralRecommender 21 | from recbole.model.init import xavier_uniform_initialization 22 | from recbole.model.loss import EmbLoss 23 | from recbole.utils import InputType 24 | import torch.nn.functional as F 25 | 26 | 27 | class LightGCL(GeneralRecommender): 28 | r"""LightGCL is a GCN-based recommender model. 29 | 30 | LightGCL guides graph augmentation by singular value decomposition (SVD) to not only 31 | distill the useful information of user-item interactions but also inject the global 32 | collaborative context into the representation alignment of contrastive learning. 33 | 34 | We implement the model following the original author with a pairwise training mode. 35 | """ 36 | input_type = InputType.PAIRWISE 37 | 38 | def __init__(self, config, dataset): 39 | super(LightGCL, self).__init__(config, dataset) 40 | self._user = dataset.inter_feat[dataset.uid_field] 41 | self._item = dataset.inter_feat[dataset.iid_field] 42 | 43 | # load parameters info 44 | self.embed_dim = config["embedding_size"] 45 | self.n_layers = config["n_layers"] 46 | self.dropout = config["dropout"] 47 | self.temp = config["temp"] 48 | self.lambda_1 = config["lambda1"] 49 | self.lambda_2 = config["lambda2"] 50 | self.q = config["q"] 51 | self.act = nn.LeakyReLU(0.5) 52 | self.reg_loss = EmbLoss() 53 | 54 | # get the normalized adjust matrix 55 | self.adj_norm = self.coo2tensor(self.create_adjust_matrix()) 56 | 57 | # perform svd reconstruction 58 | svd_u, s, svd_v = torch.svd_lowrank(self.adj_norm, q=self.q) 59 | self.u_mul_s = svd_u @ (torch.diag(s)) 60 | self.v_mul_s = svd_v @ (torch.diag(s)) 61 | del s 62 | self.ut = svd_u.T 63 | self.vt = svd_v.T 64 | 65 | self.E_u_0 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.n_users, self.embed_dim))) 66 | self.E_i_0 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.n_items, self.embed_dim))) 67 | self.E_u_list = [None] * (self.n_layers + 1) 68 | self.E_i_list = [None] * (self.n_layers + 1) 69 | self.E_u_list[0] = self.E_u_0 70 | self.E_i_list[0] = self.E_i_0 71 | self.Z_u_list = [None] * (self.n_layers + 1) 72 | self.Z_i_list = [None] * (self.n_layers + 1) 73 | self.G_u_list = [None] * (self.n_layers + 1) 74 | self.G_i_list = [None] * (self.n_layers + 1) 75 | self.G_u_list[0] = self.E_u_0 76 | self.G_i_list[0] = self.E_i_0 77 | 78 | self.E_u = None 79 | self.E_i = None 80 | self.restore_user_e = None 81 | self.restore_item_e = None 82 | 83 | self.apply(xavier_uniform_initialization) 84 | self.other_parameter_name = ['restore_user_e', 'restore_item_e'] 85 | 86 | def create_adjust_matrix(self): 87 | r"""Get the normalized interaction matrix of users and items. 88 | 89 | Returns: 90 | coo_matrix of the normalized interaction matrix. 91 | """ 92 | ratings = np.ones_like(self._user, dtype=np.float32) 93 | matrix = sp.csr_matrix( 94 | (ratings, (self._user, self._item)), 95 | shape=(self.n_users, self.n_items), 96 | ).tocoo() 97 | rowD = np.squeeze(np.array(matrix.sum(1)), axis=1) 98 | colD = np.squeeze(np.array(matrix.sum(0)), axis=0) 99 | for i in range(len(matrix.data)): 100 | matrix.data[i] = matrix.data[i] / pow(rowD[matrix.row[i]] * colD[matrix.col[i]], 0.5) 101 | return matrix 102 | 103 | def coo2tensor(self, matrix: sp.coo_matrix): 104 | r"""Convert coo_matrix to tensor. 105 | 106 | Args: 107 | matrix (scipy.coo_matrix): Sparse matrix to be converted. 108 | 109 | Returns: 110 | torch.sparse.FloatTensor: Transformed sparse matrix. 111 | """ 112 | indices = torch.from_numpy( 113 | np.vstack((matrix.row, matrix.col)).astype(np.int64)) 114 | values = torch.from_numpy(matrix.data) 115 | shape = torch.Size(matrix.shape) 116 | x = torch.sparse.FloatTensor(indices, values, shape).coalesce().to(self.device) 117 | return x 118 | 119 | def sparse_dropout(self, matrix, dropout): 120 | if dropout == 0.0: 121 | return matrix 122 | indices = matrix.indices() 123 | values = F.dropout(matrix.values(), p=dropout) 124 | size = matrix.size() 125 | return torch.sparse.FloatTensor(indices, values, size) 126 | 127 | def forward(self): 128 | for layer in range(1, self.n_layers + 1): 129 | # GNN propagation 130 | self.Z_u_list[layer] = torch.spmm(self.sparse_dropout(self.adj_norm, self.dropout), 131 | self.E_i_list[layer - 1]) 132 | self.Z_i_list[layer] = torch.spmm(self.sparse_dropout(self.adj_norm, self.dropout).transpose(0, 1), 133 | self.E_u_list[layer - 1]) 134 | # aggregate 135 | self.E_u_list[layer] = self.Z_u_list[layer] 136 | self.E_i_list[layer] = self.Z_i_list[layer] 137 | 138 | # aggregate across layer 139 | self.E_u = sum(self.E_u_list) 140 | self.E_i = sum(self.E_i_list) 141 | 142 | return self.E_u, self.E_i 143 | 144 | def calculate_loss(self, interaction): 145 | if self.restore_user_e is not None or self.restore_item_e is not None: 146 | self.restore_user_e, self.restore_item_e = None, None 147 | 148 | user_list = interaction[self.USER_ID] 149 | pos_item_list = interaction[self.ITEM_ID] 150 | neg_item_list = interaction[self.NEG_ITEM_ID] 151 | E_u_norm, E_i_norm = self.forward() 152 | bpr_loss = self.calc_bpr_loss(E_u_norm, E_i_norm, user_list, pos_item_list, neg_item_list) 153 | ssl_loss = self.calc_ssl_loss(E_u_norm, E_i_norm, user_list, pos_item_list) 154 | total_loss = bpr_loss + ssl_loss 155 | return total_loss 156 | 157 | def calc_bpr_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list, neg_item_list): 158 | r"""Calculate the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss. 159 | 160 | Args: 161 | E_u_norm (torch.Tensor): Ego embedding of all users after forwarding. 162 | E_i_norm (torch.Tensor): Ego embedding of all items after forwarding. 163 | user_list (torch.Tensor): List of the user. 164 | pos_item_list (torch.Tensor): List of positive examples. 165 | neg_item_list (torch.Tensor): List of negative examples. 166 | 167 | Returns: 168 | torch.Tensor: Loss of BPR tasks and parameter regularization. 169 | """ 170 | u_e = E_u_norm[user_list] 171 | pi_e = E_i_norm[pos_item_list] 172 | ni_e = E_i_norm[neg_item_list] 173 | pos_scores = torch.mul(u_e, pi_e).sum(dim=1) 174 | neg_scores = torch.mul(u_e, ni_e).sum(dim=1) 175 | loss1 = -(pos_scores - neg_scores).sigmoid().log().mean() 176 | 177 | # reg loss 178 | loss_reg = 0 179 | for param in self.parameters(): 180 | loss_reg += param.norm(2).square() 181 | loss_reg *= self.lambda_2 182 | return loss1 + loss_reg 183 | 184 | def calc_ssl_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list): 185 | r"""Calculate the loss of self-supervised tasks. 186 | 187 | Args: 188 | E_u_norm (torch.Tensor): Ego embedding of all users in the original graph after forwarding. 189 | E_i_norm (torch.Tensor): Ego embedding of all items in the original graph after forwarding. 190 | user_list (torch.Tensor): List of the user. 191 | pos_item_list (torch.Tensor): List of positive examples. 192 | 193 | Returns: 194 | torch.Tensor: Loss of self-supervised tasks. 195 | """ 196 | # calculate G_u_norm&G_i_norm 197 | for layer in range(1, self.n_layers + 1): 198 | # svd_adj propagation 199 | vt_ei = self.vt @ self.E_i_list[layer - 1] 200 | self.G_u_list[layer] = self.u_mul_s @ vt_ei 201 | ut_eu = self.ut @ self.E_u_list[layer - 1] 202 | self.G_i_list[layer] = self.v_mul_s @ ut_eu 203 | 204 | # aggregate across layer 205 | G_u_norm = sum(self.G_u_list) 206 | G_i_norm = sum(self.G_i_list) 207 | 208 | neg_score = torch.log(torch.exp(G_u_norm[user_list] @ E_u_norm.T / self.temp).sum(1) + 1e-8).mean() 209 | neg_score += torch.log(torch.exp(G_i_norm[pos_item_list] @ E_i_norm.T / self.temp).sum(1) + 1e-8).mean() 210 | pos_score = (torch.clamp((G_u_norm[user_list] * E_u_norm[user_list]).sum(1) / self.temp, -5.0, 5.0)).mean() + ( 211 | torch.clamp((G_i_norm[pos_item_list] * E_i_norm[pos_item_list]).sum(1) / self.temp, -5.0, 5.0)).mean() 212 | ssl_loss = -pos_score + neg_score 213 | return self.lambda_1 * ssl_loss 214 | 215 | def predict(self, interaction): 216 | if self.restore_user_e is None or self.restore_item_e is None: 217 | self.restore_user_e, self.restore_item_e = self.forward() 218 | user = self.restore_user_e[interaction[self.USER_ID]] 219 | item = self.restore_item_e[interaction[self.ITEM_ID]] 220 | return torch.sum(user * item, dim=1) 221 | 222 | def full_sort_predict(self, interaction): 223 | if self.restore_user_e is None or self.restore_item_e is None: 224 | self.restore_user_e, self.restore_item_e = self.forward() 225 | user = self.restore_user_e[interaction[self.USER_ID]] 226 | return user.matmul(self.restore_item_e.T) 227 | -------------------------------------------------------------------------------- /recbole_gnn/model/general_recommender/lightgcn.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/3/8 2 | # @Author : Lanling Xu 3 | # @Email : xulanling_sherry@163.com 4 | 5 | r""" 6 | LightGCN 7 | ################################################ 8 | Reference: 9 | Xiangnan He et al. "LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation." in SIGIR 2020. 10 | 11 | Reference code: 12 | https://github.com/kuandeng/LightGCN 13 | """ 14 | 15 | import numpy as np 16 | import torch 17 | 18 | from recbole.model.init import xavier_uniform_initialization 19 | from recbole.model.loss import BPRLoss, EmbLoss 20 | from recbole.utils import InputType 21 | 22 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender 23 | from recbole_gnn.model.layers import LightGCNConv 24 | 25 | 26 | class LightGCN(GeneralGraphRecommender): 27 | r"""LightGCN is a GCN-based recommender model, implemented via PyG. 28 | LightGCN includes only the most essential component in GCN — neighborhood aggregation — for 29 | collaborative filtering. Specifically, LightGCN learns user and item embeddings by linearly 30 | propagating them on the user-item interaction graph, and uses the weighted sum of the embeddings 31 | learned at all layers as the final embedding. 32 | We implement the model following the original author with a pairwise training mode. 33 | """ 34 | input_type = InputType.PAIRWISE 35 | 36 | def __init__(self, config, dataset): 37 | super(LightGCN, self).__init__(config, dataset) 38 | 39 | # load parameters info 40 | self.latent_dim = config['embedding_size'] # int type:the embedding size of lightGCN 41 | self.n_layers = config['n_layers'] # int type:the layer num of lightGCN 42 | self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization 43 | self.require_pow = config['require_pow'] # bool type: whether to require pow when regularization 44 | 45 | # define layers and loss 46 | self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim) 47 | self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim) 48 | self.gcn_conv = LightGCNConv(dim=self.latent_dim) 49 | self.mf_loss = BPRLoss() 50 | self.reg_loss = EmbLoss() 51 | 52 | # storage variables for full sort evaluation acceleration 53 | self.restore_user_e = None 54 | self.restore_item_e = None 55 | 56 | # parameters initialization 57 | self.apply(xavier_uniform_initialization) 58 | self.other_parameter_name = ['restore_user_e', 'restore_item_e'] 59 | 60 | def get_ego_embeddings(self): 61 | r"""Get the embedding of users and items and combine to an embedding matrix. 62 | Returns: 63 | Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim] 64 | """ 65 | user_embeddings = self.user_embedding.weight 66 | item_embeddings = self.item_embedding.weight 67 | ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0) 68 | return ego_embeddings 69 | 70 | def forward(self): 71 | all_embeddings = self.get_ego_embeddings() 72 | embeddings_list = [all_embeddings] 73 | 74 | for layer_idx in range(self.n_layers): 75 | all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight) 76 | embeddings_list.append(all_embeddings) 77 | lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1) 78 | lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1) 79 | 80 | user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items]) 81 | return user_all_embeddings, item_all_embeddings 82 | 83 | def calculate_loss(self, interaction): 84 | # clear the storage variable when training 85 | if self.restore_user_e is not None or self.restore_item_e is not None: 86 | self.restore_user_e, self.restore_item_e = None, None 87 | 88 | user = interaction[self.USER_ID] 89 | pos_item = interaction[self.ITEM_ID] 90 | neg_item = interaction[self.NEG_ITEM_ID] 91 | 92 | user_all_embeddings, item_all_embeddings = self.forward() 93 | u_embeddings = user_all_embeddings[user] 94 | pos_embeddings = item_all_embeddings[pos_item] 95 | neg_embeddings = item_all_embeddings[neg_item] 96 | 97 | # calculate BPR Loss 98 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) 99 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) 100 | mf_loss = self.mf_loss(pos_scores, neg_scores) 101 | 102 | # calculate regularization Loss 103 | u_ego_embeddings = self.user_embedding(user) 104 | pos_ego_embeddings = self.item_embedding(pos_item) 105 | neg_ego_embeddings = self.item_embedding(neg_item) 106 | 107 | reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow) 108 | loss = mf_loss + self.reg_weight * reg_loss 109 | 110 | return loss 111 | 112 | def predict(self, interaction): 113 | user = interaction[self.USER_ID] 114 | item = interaction[self.ITEM_ID] 115 | 116 | user_all_embeddings, item_all_embeddings = self.forward() 117 | 118 | u_embeddings = user_all_embeddings[user] 119 | i_embeddings = item_all_embeddings[item] 120 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) 121 | return scores 122 | 123 | def full_sort_predict(self, interaction): 124 | user = interaction[self.USER_ID] 125 | if self.restore_user_e is None or self.restore_item_e is None: 126 | self.restore_user_e, self.restore_item_e = self.forward() 127 | # get user embedding from storage variable 128 | u_embeddings = self.restore_user_e[user] 129 | 130 | # dot with all item embedding to accelerate 131 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) 132 | 133 | return scores.view(-1) 134 | -------------------------------------------------------------------------------- /recbole_gnn/model/general_recommender/ncl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | r""" 3 | NCL 4 | ################################################ 5 | Reference: 6 | Zihan Lin*, Changxin Tian*, Yupeng Hou*, Wayne Xin Zhao. "Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning." in WWW 2022. 7 | """ 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from recbole.model.init import xavier_uniform_initialization 13 | from recbole.model.loss import BPRLoss, EmbLoss 14 | from recbole.utils import InputType 15 | 16 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender 17 | from recbole_gnn.model.layers import LightGCNConv 18 | 19 | 20 | class NCL(GeneralGraphRecommender): 21 | input_type = InputType.PAIRWISE 22 | 23 | def __init__(self, config, dataset): 24 | super(NCL, self).__init__(config, dataset) 25 | 26 | # load parameters info 27 | self.latent_dim = config['embedding_size'] # int type: the embedding size of the base model 28 | self.n_layers = config['n_layers'] # int type: the layer num of the base model 29 | self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization 30 | 31 | self.ssl_temp = config['ssl_temp'] 32 | self.ssl_reg = config['ssl_reg'] 33 | self.hyper_layers = config['hyper_layers'] 34 | 35 | self.alpha = config['alpha'] 36 | 37 | self.proto_reg = config['proto_reg'] 38 | self.k = config['num_clusters'] 39 | 40 | # define layers and loss 41 | self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim) 42 | self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim) 43 | self.gcn_conv = LightGCNConv(dim=self.latent_dim) 44 | self.mf_loss = BPRLoss() 45 | self.reg_loss = EmbLoss() 46 | 47 | # storage variables for full sort evaluation acceleration 48 | self.restore_user_e = None 49 | self.restore_item_e = None 50 | 51 | # parameters initialization 52 | self.apply(xavier_uniform_initialization) 53 | self.other_parameter_name = ['restore_user_e', 'restore_item_e'] 54 | 55 | self.user_centroids = None 56 | self.user_2cluster = None 57 | self.item_centroids = None 58 | self.item_2cluster = None 59 | 60 | def e_step(self): 61 | user_embeddings = self.user_embedding.weight.detach().cpu().numpy() 62 | item_embeddings = self.item_embedding.weight.detach().cpu().numpy() 63 | self.user_centroids, self.user_2cluster = self.run_kmeans(user_embeddings) 64 | self.item_centroids, self.item_2cluster = self.run_kmeans(item_embeddings) 65 | 66 | def run_kmeans(self, x): 67 | """Run K-means algorithm to get k clusters of the input tensor x 68 | """ 69 | import faiss 70 | kmeans = faiss.Kmeans(d=self.latent_dim, k=self.k, gpu=True) 71 | kmeans.train(x) 72 | cluster_cents = kmeans.centroids 73 | 74 | _, I = kmeans.index.search(x, 1) 75 | 76 | # convert to cuda Tensors for broadcast 77 | centroids = torch.Tensor(cluster_cents).to(self.device) 78 | centroids = F.normalize(centroids, p=2, dim=1) 79 | 80 | node2cluster = torch.LongTensor(I).squeeze().to(self.device) 81 | return centroids, node2cluster 82 | 83 | def get_ego_embeddings(self): 84 | r"""Get the embedding of users and items and combine to an embedding matrix. 85 | Returns: 86 | Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim] 87 | """ 88 | user_embeddings = self.user_embedding.weight 89 | item_embeddings = self.item_embedding.weight 90 | ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0) 91 | return ego_embeddings 92 | 93 | def forward(self): 94 | all_embeddings = self.get_ego_embeddings() 95 | embeddings_list = [all_embeddings] 96 | for layer_idx in range(max(self.n_layers, self.hyper_layers * 2)): 97 | all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight) 98 | embeddings_list.append(all_embeddings) 99 | 100 | lightgcn_all_embeddings = torch.stack(embeddings_list[:self.n_layers + 1], dim=1) 101 | lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1) 102 | 103 | user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items]) 104 | return user_all_embeddings, item_all_embeddings, embeddings_list 105 | 106 | def ProtoNCE_loss(self, node_embedding, user, item): 107 | user_embeddings_all, item_embeddings_all = torch.split(node_embedding, [self.n_users, self.n_items]) 108 | 109 | user_embeddings = user_embeddings_all[user] # [B, e] 110 | norm_user_embeddings = F.normalize(user_embeddings) 111 | 112 | user2cluster = self.user_2cluster[user] # [B,] 113 | user2centroids = self.user_centroids[user2cluster] # [B, e] 114 | pos_score_user = torch.mul(norm_user_embeddings, user2centroids).sum(dim=1) 115 | pos_score_user = torch.exp(pos_score_user / self.ssl_temp) 116 | ttl_score_user = torch.matmul(norm_user_embeddings, self.user_centroids.transpose(0, 1)) 117 | ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1) 118 | 119 | proto_nce_loss_user = -torch.log(pos_score_user / ttl_score_user).sum() 120 | 121 | item_embeddings = item_embeddings_all[item] 122 | norm_item_embeddings = F.normalize(item_embeddings) 123 | 124 | item2cluster = self.item_2cluster[item] # [B, ] 125 | item2centroids = self.item_centroids[item2cluster] # [B, e] 126 | pos_score_item = torch.mul(norm_item_embeddings, item2centroids).sum(dim=1) 127 | pos_score_item = torch.exp(pos_score_item / self.ssl_temp) 128 | ttl_score_item = torch.matmul(norm_item_embeddings, self.item_centroids.transpose(0, 1)) 129 | ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1) 130 | proto_nce_loss_item = -torch.log(pos_score_item / ttl_score_item).sum() 131 | 132 | proto_nce_loss = self.proto_reg * (proto_nce_loss_user + proto_nce_loss_item) 133 | return proto_nce_loss 134 | 135 | def ssl_layer_loss(self, current_embedding, previous_embedding, user, item): 136 | current_user_embeddings, current_item_embeddings = torch.split(current_embedding, [self.n_users, self.n_items]) 137 | previous_user_embeddings_all, previous_item_embeddings_all = torch.split(previous_embedding, [self.n_users, self.n_items]) 138 | 139 | current_user_embeddings = current_user_embeddings[user] 140 | previous_user_embeddings = previous_user_embeddings_all[user] 141 | norm_user_emb1 = F.normalize(current_user_embeddings) 142 | norm_user_emb2 = F.normalize(previous_user_embeddings) 143 | norm_all_user_emb = F.normalize(previous_user_embeddings_all) 144 | pos_score_user = torch.mul(norm_user_emb1, norm_user_emb2).sum(dim=1) 145 | ttl_score_user = torch.matmul(norm_user_emb1, norm_all_user_emb.transpose(0, 1)) 146 | pos_score_user = torch.exp(pos_score_user / self.ssl_temp) 147 | ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1) 148 | 149 | ssl_loss_user = -torch.log(pos_score_user / ttl_score_user).sum() 150 | 151 | current_item_embeddings = current_item_embeddings[item] 152 | previous_item_embeddings = previous_item_embeddings_all[item] 153 | norm_item_emb1 = F.normalize(current_item_embeddings) 154 | norm_item_emb2 = F.normalize(previous_item_embeddings) 155 | norm_all_item_emb = F.normalize(previous_item_embeddings_all) 156 | pos_score_item = torch.mul(norm_item_emb1, norm_item_emb2).sum(dim=1) 157 | ttl_score_item = torch.matmul(norm_item_emb1, norm_all_item_emb.transpose(0, 1)) 158 | pos_score_item = torch.exp(pos_score_item / self.ssl_temp) 159 | ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1) 160 | 161 | ssl_loss_item = -torch.log(pos_score_item / ttl_score_item).sum() 162 | 163 | ssl_loss = self.ssl_reg * (ssl_loss_user + self.alpha * ssl_loss_item) 164 | return ssl_loss 165 | 166 | def calculate_loss(self, interaction): 167 | # clear the storage variable when training 168 | if self.restore_user_e is not None or self.restore_item_e is not None: 169 | self.restore_user_e, self.restore_item_e = None, None 170 | 171 | user = interaction[self.USER_ID] 172 | pos_item = interaction[self.ITEM_ID] 173 | neg_item = interaction[self.NEG_ITEM_ID] 174 | 175 | user_all_embeddings, item_all_embeddings, embeddings_list = self.forward() 176 | 177 | center_embedding = embeddings_list[0] 178 | context_embedding = embeddings_list[self.hyper_layers * 2] 179 | 180 | ssl_loss = self.ssl_layer_loss(context_embedding, center_embedding, user, pos_item) 181 | proto_loss = self.ProtoNCE_loss(center_embedding, user, pos_item) 182 | 183 | u_embeddings = user_all_embeddings[user] 184 | pos_embeddings = item_all_embeddings[pos_item] 185 | neg_embeddings = item_all_embeddings[neg_item] 186 | 187 | # calculate BPR Loss 188 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) 189 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) 190 | 191 | mf_loss = self.mf_loss(pos_scores, neg_scores) 192 | 193 | u_ego_embeddings = self.user_embedding(user) 194 | pos_ego_embeddings = self.item_embedding(pos_item) 195 | neg_ego_embeddings = self.item_embedding(neg_item) 196 | 197 | reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings) 198 | 199 | return mf_loss + self.reg_weight * reg_loss, ssl_loss, proto_loss 200 | 201 | def predict(self, interaction): 202 | user = interaction[self.USER_ID] 203 | item = interaction[self.ITEM_ID] 204 | 205 | user_all_embeddings, item_all_embeddings, embeddings_list = self.forward() 206 | 207 | u_embeddings = user_all_embeddings[user] 208 | i_embeddings = item_all_embeddings[item] 209 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) 210 | return scores 211 | 212 | def full_sort_predict(self, interaction): 213 | user = interaction[self.USER_ID] 214 | if self.restore_user_e is None or self.restore_item_e is None: 215 | self.restore_user_e, self.restore_item_e, embedding_list = self.forward() 216 | # get user embedding from storage variable 217 | u_embeddings = self.restore_user_e[user] 218 | 219 | # dot with all item embedding to accelerate 220 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) 221 | 222 | return scores.view(-1) 223 | -------------------------------------------------------------------------------- /recbole_gnn/model/general_recommender/ngcf.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/3/8 2 | # @Author : Changxin Tian 3 | # @Email : cx.tian@outlook.com 4 | r""" 5 | NGCF 6 | ################################################ 7 | Reference: 8 | Xiang Wang et al. "Neural Graph Collaborative Filtering." in SIGIR 2019. 9 | 10 | Reference code: 11 | https://github.com/xiangwang1223/neural_graph_collaborative_filtering 12 | 13 | """ 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from torch_geometric.utils import dropout_adj 19 | 20 | from recbole.model.init import xavier_normal_initialization 21 | from recbole.model.loss import BPRLoss, EmbLoss 22 | from recbole.utils import InputType 23 | 24 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender 25 | from recbole_gnn.model.layers import BiGNNConv 26 | 27 | 28 | class NGCF(GeneralGraphRecommender): 29 | r"""NGCF is a model that incorporate GNN for recommendation. 30 | We implement the model following the original author with a pairwise training mode. 31 | """ 32 | input_type = InputType.PAIRWISE 33 | 34 | def __init__(self, config, dataset): 35 | super(NGCF, self).__init__(config, dataset) 36 | 37 | # load parameters info 38 | self.embedding_size = config['embedding_size'] 39 | self.hidden_size_list = config['hidden_size_list'] 40 | self.hidden_size_list = [self.embedding_size] + self.hidden_size_list 41 | self.node_dropout = config['node_dropout'] 42 | self.message_dropout = config['message_dropout'] 43 | self.reg_weight = config['reg_weight'] 44 | 45 | # define layers and loss 46 | self.user_embedding = nn.Embedding(self.n_users, self.embedding_size) 47 | self.item_embedding = nn.Embedding(self.n_items, self.embedding_size) 48 | self.GNNlayers = torch.nn.ModuleList() 49 | for input_size, output_size in zip(self.hidden_size_list[:-1], self.hidden_size_list[1:]): 50 | self.GNNlayers.append(BiGNNConv(input_size, output_size)) 51 | self.mf_loss = BPRLoss() 52 | self.reg_loss = EmbLoss() 53 | 54 | # storage variables for full sort evaluation acceleration 55 | self.restore_user_e = None 56 | self.restore_item_e = None 57 | 58 | # parameters initialization 59 | self.apply(xavier_normal_initialization) 60 | self.other_parameter_name = ['restore_user_e', 'restore_item_e'] 61 | 62 | def get_ego_embeddings(self): 63 | r"""Get the embedding of users and items and combine to an embedding matrix. 64 | 65 | Returns: 66 | Tensor of the embedding matrix. Shape of (n_items+n_users, embedding_dim) 67 | """ 68 | user_embeddings = self.user_embedding.weight 69 | item_embeddings = self.item_embedding.weight 70 | ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0) 71 | return ego_embeddings 72 | 73 | def forward(self): 74 | if self.node_dropout == 0: 75 | edge_index, edge_weight = self.edge_index, self.edge_weight 76 | else: 77 | edge_index, edge_weight = self.edge_index, self.edge_weight 78 | if self.use_sparse: 79 | row, col, edge_weight = edge_index.t().coo() 80 | edge_index = torch.stack([row, col], 0) 81 | edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight, 82 | p=self.node_dropout, training=self.training) 83 | from torch_sparse import SparseTensor 84 | edge_index = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight, 85 | sparse_sizes=(self.n_users + self.n_items, self.n_users + self.n_items)) 86 | edge_index = edge_index.t() 87 | edge_weight = None 88 | else: 89 | edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight, 90 | p=self.node_dropout, training=self.training) 91 | 92 | all_embeddings = self.get_ego_embeddings() 93 | embeddings_list = [all_embeddings] 94 | for gnn in self.GNNlayers: 95 | all_embeddings = gnn(all_embeddings, edge_index, edge_weight) 96 | all_embeddings = nn.LeakyReLU(negative_slope=0.2)(all_embeddings) 97 | all_embeddings = nn.Dropout(self.message_dropout)(all_embeddings) 98 | all_embeddings = F.normalize(all_embeddings, p=2, dim=1) 99 | embeddings_list += [all_embeddings] # storage output embedding of each layer 100 | ngcf_all_embeddings = torch.cat(embeddings_list, dim=1) 101 | 102 | user_all_embeddings, item_all_embeddings = torch.split(ngcf_all_embeddings, [self.n_users, self.n_items]) 103 | 104 | return user_all_embeddings, item_all_embeddings 105 | 106 | def calculate_loss(self, interaction): 107 | # clear the storage variable when training 108 | if self.restore_user_e is not None or self.restore_item_e is not None: 109 | self.restore_user_e, self.restore_item_e = None, None 110 | 111 | user = interaction[self.USER_ID] 112 | pos_item = interaction[self.ITEM_ID] 113 | neg_item = interaction[self.NEG_ITEM_ID] 114 | 115 | user_all_embeddings, item_all_embeddings = self.forward() 116 | u_embeddings = user_all_embeddings[user] 117 | pos_embeddings = item_all_embeddings[pos_item] 118 | neg_embeddings = item_all_embeddings[neg_item] 119 | 120 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) 121 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) 122 | mf_loss = self.mf_loss(pos_scores, neg_scores) # calculate BPR Loss 123 | 124 | reg_loss = self.reg_loss(u_embeddings, pos_embeddings, neg_embeddings) # L2 regularization of embeddings 125 | 126 | return mf_loss + self.reg_weight * reg_loss 127 | 128 | def predict(self, interaction): 129 | user = interaction[self.USER_ID] 130 | item = interaction[self.ITEM_ID] 131 | 132 | user_all_embeddings, item_all_embeddings = self.forward() 133 | 134 | u_embeddings = user_all_embeddings[user] 135 | i_embeddings = item_all_embeddings[item] 136 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) 137 | return scores 138 | 139 | def full_sort_predict(self, interaction): 140 | user = interaction[self.USER_ID] 141 | if self.restore_user_e is None or self.restore_item_e is None: 142 | self.restore_user_e, self.restore_item_e = self.forward() 143 | # get user embedding from storage variable 144 | u_embeddings = self.restore_user_e[user] 145 | 146 | # dot with all item embedding to accelerate 147 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) 148 | 149 | return scores.view(-1) 150 | -------------------------------------------------------------------------------- /recbole_gnn/model/general_recommender/sgl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/3/8 3 | # @Author : Changxin Tian 4 | # @Email : cx.tian@outlook.com 5 | r""" 6 | SGL 7 | ################################################ 8 | Reference: 9 | Jiancan Wu et al. "SGL: Self-supervised Graph Learning for Recommendation" in SIGIR 2021. 10 | 11 | Reference code: 12 | https://github.com/wujcan/SGL 13 | """ 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | from torch_geometric.utils import degree 19 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 20 | 21 | from recbole.model.init import xavier_uniform_initialization 22 | from recbole.model.loss import EmbLoss 23 | from recbole.utils import InputType 24 | 25 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender 26 | from recbole_gnn.model.layers import LightGCNConv 27 | 28 | 29 | class SGL(GeneralGraphRecommender): 30 | r"""SGL is a GCN-based recommender model. 31 | 32 | SGL supplements the classical supervised task of recommendation with an auxiliary 33 | self supervised task, which reinforces node representation learning via self- 34 | discrimination.Specifically,SGL generates multiple views of a node, maximizing the 35 | agreement between different views of the same node compared to that of other nodes. 36 | SGL devises three operators to generate the views — node dropout, edge dropout, and 37 | random walk — that change the graph structure in different manners. 38 | 39 | We implement the model following the original author with a pairwise training mode. 40 | """ 41 | input_type = InputType.PAIRWISE 42 | 43 | def __init__(self, config, dataset): 44 | super(SGL, self).__init__(config, dataset) 45 | 46 | # load parameters info 47 | self.latent_dim = config["embedding_size"] 48 | self.n_layers = int(config["n_layers"]) 49 | self.aug_type = config["type"] 50 | self.drop_ratio = config["drop_ratio"] 51 | self.ssl_tau = config["ssl_tau"] 52 | self.reg_weight = config["reg_weight"] 53 | self.ssl_weight = config["ssl_weight"] 54 | 55 | self._user = dataset.inter_feat[dataset.uid_field] 56 | self._item = dataset.inter_feat[dataset.iid_field] 57 | self.dataset = dataset 58 | 59 | # define layers and loss 60 | self.user_embedding = torch.nn.Embedding(self.n_users, self.latent_dim) 61 | self.item_embedding = torch.nn.Embedding(self.n_items, self.latent_dim) 62 | self.gcn_conv = LightGCNConv(dim=self.latent_dim) 63 | self.reg_loss = EmbLoss() 64 | 65 | # storage variables for full sort evaluation acceleration 66 | self.restore_user_e = None 67 | self.restore_item_e = None 68 | 69 | # parameters initialization 70 | self.apply(xavier_uniform_initialization) 71 | self.other_parameter_name = ['restore_user_e', 'restore_item_e'] 72 | 73 | def train(self, mode: bool = True): 74 | r"""Override train method of base class. The subgraph is reconstructed each time it is called. 75 | 76 | """ 77 | T = super().train(mode=mode) 78 | if mode: 79 | self.graph_construction() 80 | return T 81 | 82 | def graph_construction(self): 83 | r"""Devise three operators to generate the views — node dropout, edge dropout, and random walk of a node. 84 | 85 | """ 86 | if self.aug_type == "ND" or self.aug_type == "ED": 87 | self.sub_graph1 = [self.random_graph_augment()] * self.n_layers 88 | self.sub_graph2 = [self.random_graph_augment()] * self.n_layers 89 | elif self.aug_type == "RW": 90 | self.sub_graph1 = [self.random_graph_augment() for _ in range(self.n_layers)] 91 | self.sub_graph2 = [self.random_graph_augment() for _ in range(self.n_layers)] 92 | 93 | def random_graph_augment(self): 94 | def rand_sample(high, size=None, replace=True): 95 | return np.random.choice(np.arange(high), size=size, replace=replace) 96 | 97 | if self.aug_type == "ND": 98 | drop_user = rand_sample(self.n_users, size=int(self.n_users * self.drop_ratio), replace=False) 99 | drop_item = rand_sample(self.n_items, size=int(self.n_items * self.drop_ratio), replace=False) 100 | 101 | mask = np.isin(self._user.numpy(), drop_user) 102 | mask |= np.isin(self._item.numpy(), drop_item) 103 | keep = np.where(~mask) 104 | 105 | row = self._user[keep] 106 | col = self._item[keep] + self.n_users 107 | 108 | elif self.aug_type == "ED" or self.aug_type == "RW": 109 | keep = rand_sample(len(self._user), size=int(len(self._user) * (1 - self.drop_ratio)), replace=False) 110 | row = self._user[keep] 111 | col = self._item[keep] + self.n_users 112 | 113 | edge_index1 = torch.stack([row, col]) 114 | edge_index2 = torch.stack([col, row]) 115 | edge_index = torch.cat([edge_index1, edge_index2], dim=1) 116 | edge_weight = torch.ones(edge_index.size(1)) 117 | num_nodes = self.n_users + self.n_items 118 | 119 | if self.use_sparse: 120 | adj_t = self.dataset.edge_index_to_adj_t(edge_index, edge_weight, num_nodes, num_nodes) 121 | adj_t = gcn_norm(adj_t, None, num_nodes, add_self_loops=False) 122 | return adj_t.to(self.device), None 123 | 124 | edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=False) 125 | 126 | return edge_index.to(self.device), edge_weight.to(self.device) 127 | 128 | def forward(self, graph=None): 129 | all_embeddings = torch.cat([self.user_embedding.weight, self.item_embedding.weight]) 130 | embeddings_list = [all_embeddings] 131 | 132 | if graph is None: # for the original graph 133 | for _ in range(self.n_layers): 134 | all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight) 135 | embeddings_list.append(all_embeddings) 136 | else: # for the augmented graph 137 | for graph_edge_index, graph_edge_weight in graph: 138 | all_embeddings = self.gcn_conv(all_embeddings, graph_edge_index, graph_edge_weight) 139 | embeddings_list.append(all_embeddings) 140 | 141 | embeddings_list = torch.stack(embeddings_list, dim=1) 142 | embeddings_list = torch.mean(embeddings_list, dim=1, keepdim=False) 143 | user_all_embeddings, item_all_embeddings = torch.split(embeddings_list, [self.n_users, self.n_items], dim=0) 144 | 145 | return user_all_embeddings, item_all_embeddings 146 | 147 | def calc_bpr_loss(self, user_emd, item_emd, user_list, pos_item_list, neg_item_list): 148 | r"""Calculate the the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss. 149 | 150 | Args: 151 | user_emd (torch.Tensor): Ego embedding of all users after forwarding. 152 | item_emd (torch.Tensor): Ego embedding of all items after forwarding. 153 | user_list (torch.Tensor): List of the user. 154 | pos_item_list (torch.Tensor): List of positive examples. 155 | neg_item_list (torch.Tensor): List of negative examples. 156 | 157 | Returns: 158 | torch.Tensor: Loss of BPR tasks and parameter regularization. 159 | """ 160 | u_e = user_emd[user_list] 161 | pi_e = item_emd[pos_item_list] 162 | ni_e = item_emd[neg_item_list] 163 | p_scores = torch.mul(u_e, pi_e).sum(dim=1) 164 | n_scores = torch.mul(u_e, ni_e).sum(dim=1) 165 | 166 | l1 = torch.sum(-F.logsigmoid(p_scores - n_scores)) 167 | 168 | u_e_p = self.user_embedding(user_list) 169 | pi_e_p = self.item_embedding(pos_item_list) 170 | ni_e_p = self.item_embedding(neg_item_list) 171 | 172 | l2 = self.reg_loss(u_e_p, pi_e_p, ni_e_p) 173 | 174 | return l1 + l2 * self.reg_weight 175 | 176 | def calc_ssl_loss(self, user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2): 177 | r"""Calculate the loss of self-supervised tasks. 178 | 179 | Args: 180 | user_list (torch.Tensor): List of the user. 181 | pos_item_list (torch.Tensor): List of positive examples. 182 | user_sub1 (torch.Tensor): Ego embedding of all users in the first subgraph after forwarding. 183 | user_sub2 (torch.Tensor): Ego embedding of all users in the second subgraph after forwarding. 184 | item_sub1 (torch.Tensor): Ego embedding of all items in the first subgraph after forwarding. 185 | item_sub2 (torch.Tensor): Ego embedding of all items in the second subgraph after forwarding. 186 | 187 | Returns: 188 | torch.Tensor: Loss of self-supervised tasks. 189 | """ 190 | 191 | u_emd1 = F.normalize(user_sub1[user_list], dim=1) 192 | u_emd2 = F.normalize(user_sub2[user_list], dim=1) 193 | all_user2 = F.normalize(user_sub2, dim=1) 194 | v1 = torch.sum(u_emd1 * u_emd2, dim=1) 195 | v2 = u_emd1.matmul(all_user2.T) 196 | v1 = torch.exp(v1 / self.ssl_tau) 197 | v2 = torch.sum(torch.exp(v2 / self.ssl_tau), dim=1) 198 | ssl_user = -torch.sum(torch.log(v1 / v2)) 199 | 200 | i_emd1 = F.normalize(item_sub1[pos_item_list], dim=1) 201 | i_emd2 = F.normalize(item_sub2[pos_item_list], dim=1) 202 | all_item2 = F.normalize(item_sub2, dim=1) 203 | v3 = torch.sum(i_emd1 * i_emd2, dim=1) 204 | v4 = i_emd1.matmul(all_item2.T) 205 | v3 = torch.exp(v3 / self.ssl_tau) 206 | v4 = torch.sum(torch.exp(v4 / self.ssl_tau), dim=1) 207 | ssl_item = -torch.sum(torch.log(v3 / v4)) 208 | 209 | return (ssl_item + ssl_user) * self.ssl_weight 210 | 211 | def calculate_loss(self, interaction): 212 | if self.restore_user_e is not None or self.restore_item_e is not None: 213 | self.restore_user_e, self.restore_item_e = None, None 214 | 215 | user_list = interaction[self.USER_ID] 216 | pos_item_list = interaction[self.ITEM_ID] 217 | neg_item_list = interaction[self.NEG_ITEM_ID] 218 | 219 | user_emd, item_emd = self.forward() 220 | user_sub1, item_sub1 = self.forward(self.sub_graph1) 221 | user_sub2, item_sub2 = self.forward(self.sub_graph2) 222 | 223 | total_loss = self.calc_bpr_loss(user_emd, item_emd, user_list, pos_item_list, neg_item_list) + \ 224 | self.calc_ssl_loss(user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2) 225 | return total_loss 226 | 227 | def predict(self, interaction): 228 | if self.restore_user_e is None or self.restore_item_e is None: 229 | self.restore_user_e, self.restore_item_e = self.forward() 230 | 231 | user = self.restore_user_e[interaction[self.USER_ID]] 232 | item = self.restore_item_e[interaction[self.ITEM_ID]] 233 | return torch.sum(user * item, dim=1) 234 | 235 | def full_sort_predict(self, interaction): 236 | if self.restore_user_e is None or self.restore_item_e is None: 237 | self.restore_user_e, self.restore_item_e = self.forward() 238 | 239 | user = self.restore_user_e[interaction[self.USER_ID]] 240 | return user.matmul(self.restore_item_e.T) 241 | -------------------------------------------------------------------------------- /recbole_gnn/model/general_recommender/simgcl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | r""" 3 | SimGCL 4 | ################################################ 5 | Reference: 6 | Junliang Yu, Hongzhi Yin, Xin Xia, Tong Chen, Lizhen Cui, Quoc Viet Hung Nguyen. "Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for Recommendation." in SIGIR 2022. 7 | """ 8 | 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from recbole_gnn.model.general_recommender import LightGCN 14 | 15 | 16 | class SimGCL(LightGCN): 17 | def __init__(self, config, dataset): 18 | super(SimGCL, self).__init__(config, dataset) 19 | 20 | self.cl_rate = config['lambda'] 21 | self.eps = config['eps'] 22 | self.temperature = config['temperature'] 23 | 24 | def forward(self, perturbed=False): 25 | all_embs = self.get_ego_embeddings() 26 | embeddings_list = [] 27 | 28 | for layer_idx in range(self.n_layers): 29 | all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight) 30 | if perturbed: 31 | random_noise = torch.rand_like(all_embs, device=all_embs.device) 32 | all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps 33 | embeddings_list.append(all_embs) 34 | lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1) 35 | lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1) 36 | 37 | user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items]) 38 | return user_all_embeddings, item_all_embeddings 39 | 40 | def calculate_cl_loss(self, x1, x2): 41 | x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1) 42 | pos_score = (x1 * x2).sum(dim=-1) 43 | pos_score = torch.exp(pos_score / self.temperature) 44 | ttl_score = torch.matmul(x1, x2.transpose(0, 1)) 45 | ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1) 46 | return -torch.log(pos_score / ttl_score).sum() 47 | 48 | def calculate_loss(self, interaction): 49 | loss = super().calculate_loss(interaction) 50 | 51 | user = torch.unique(interaction[self.USER_ID]) 52 | pos_item = torch.unique(interaction[self.ITEM_ID]) 53 | 54 | perturbed_user_embs_1, perturbed_item_embs_1 = self.forward(perturbed=True) 55 | perturbed_user_embs_2, perturbed_item_embs_2 = self.forward(perturbed=True) 56 | 57 | user_cl_loss = self.calculate_cl_loss(perturbed_user_embs_1[user], perturbed_user_embs_2[user]) 58 | item_cl_loss = self.calculate_cl_loss(perturbed_item_embs_1[pos_item], perturbed_item_embs_2[pos_item]) 59 | 60 | return loss + self.cl_rate * (user_cl_loss + item_cl_loss) 61 | -------------------------------------------------------------------------------- /recbole_gnn/model/general_recommender/ssl4rec.py: -------------------------------------------------------------------------------- 1 | r""" 2 | SSL4REC 3 | ################################################ 4 | Reference: 5 | Tiansheng Yao et al. "Self-supervised Learning for Large-scale Item Recommendations." in CIKM 2021. 6 | 7 | Reference code: 8 | https://github.com/Coder-Yu/SELFRec/model/graph/SSL4Rec.py 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from recbole.model.loss import EmbLoss 16 | from recbole.utils import InputType 17 | 18 | from recbole.model.init import xavier_uniform_initialization 19 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender 20 | 21 | 22 | class SSL4REC(GeneralGraphRecommender): 23 | input_type = InputType.PAIRWISE 24 | 25 | def __init__(self, config, dataset): 26 | super(SSL4REC, self).__init__(config, dataset) 27 | 28 | # load parameters info 29 | self.tau = config["tau"] 30 | self.reg_weight = config["reg_weight"] 31 | self.cl_rate = config["ssl_weight"] 32 | self.require_pow = config["require_pow"] 33 | 34 | self.reg_loss = EmbLoss() 35 | 36 | self.encoder = DNN_Encoder(config, dataset) 37 | 38 | # storage variables for full sort evaluation acceleration 39 | self.restore_user_e = None 40 | self.restore_item_e = None 41 | 42 | # parameters initialization 43 | self.apply(xavier_uniform_initialization) 44 | self.other_parameter_name = ['restore_user_e', 'restore_item_e'] 45 | 46 | def forward(self, user, item): 47 | user_e, item_e = self.encoder(user, item) 48 | return user_e, item_e 49 | 50 | def calculate_batch_softmax_loss(self, user_emb, item_emb, temperature): 51 | user_emb, item_emb = F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1) 52 | pos_score = (user_emb * item_emb).sum(dim=-1) 53 | pos_score = torch.exp(pos_score / temperature) 54 | ttl_score = torch.matmul(user_emb, item_emb.transpose(0, 1)) 55 | ttl_score = torch.exp(ttl_score / temperature).sum(dim=1) 56 | loss = -torch.log(pos_score / ttl_score + 10e-6) 57 | return torch.mean(loss) 58 | 59 | def calculate_loss(self, interaction): 60 | # clear the storage variable when training 61 | if self.restore_user_e is not None or self.restore_item_e is not None: 62 | self.restore_user_e, self.restore_item_e = None, None 63 | 64 | user = interaction[self.USER_ID] 65 | pos_item = interaction[self.ITEM_ID] 66 | 67 | user_embeddings, item_embeddings = self.forward(user, pos_item) 68 | 69 | rec_loss = self.calculate_batch_softmax_loss(user_embeddings, item_embeddings, self.tau) 70 | cl_loss = self.encoder.calculate_cl_loss(pos_item) 71 | reg_loss = self.reg_loss(user_embeddings, item_embeddings, require_pow=self.require_pow) 72 | 73 | loss = rec_loss + self.cl_rate * cl_loss + self.reg_weight * reg_loss 74 | 75 | return loss 76 | 77 | def predict(self, interaction): 78 | user = interaction[self.USER_ID] 79 | item = interaction[self.ITEM_ID] 80 | 81 | user_embeddings, item_embeddings = self.forward(user, item) 82 | 83 | u_embeddings = user_embeddings[user] 84 | i_embeddings = item_embeddings[item] 85 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) 86 | return scores 87 | 88 | def full_sort_predict(self, interaction): 89 | user = interaction[self.USER_ID] 90 | if self.restore_user_e is None or self.restore_item_e is None: 91 | self.restore_user_e, self.restore_item_e = self.forward(torch.arange( 92 | self.n_users, device=self.device), torch.arange(self.n_items, device=self.device)) 93 | # get user embedding from storage variable 94 | u_embeddings = self.restore_user_e[user] 95 | 96 | # dot with all item embedding to accelerate 97 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) 98 | 99 | return scores.view(-1) 100 | 101 | 102 | class DNN_Encoder(nn.Module): 103 | def __init__(self, config, dataset): 104 | super(DNN_Encoder, self).__init__() 105 | 106 | self.emb_size = config["embedding_size"] 107 | self.drop_ratio = config["drop_ratio"] 108 | self.tau = config["tau"] 109 | 110 | self.USER_ID = config["USER_ID_FIELD"] 111 | self.ITEM_ID = config["ITEM_ID_FIELD"] 112 | self.n_users = dataset.num(self.USER_ID) 113 | self.n_items = dataset.num(self.ITEM_ID) 114 | 115 | self.user_tower = nn.Sequential( 116 | nn.Linear(self.emb_size, 1024), 117 | nn.ReLU(True), 118 | nn.Linear(1024, 128), 119 | nn.Tanh() 120 | ) 121 | self.item_tower = nn.Sequential( 122 | nn.Linear(self.emb_size, 1024), 123 | nn.ReLU(True), 124 | nn.Linear(1024, 128), 125 | nn.Tanh() 126 | ) 127 | self.dropout = nn.Dropout(self.drop_ratio) 128 | 129 | self.initial_user_emb = nn.Embedding(self.n_users, self.emb_size) 130 | self.initial_item_emb = nn.Embedding(self.n_items, self.emb_size) 131 | self.reset_parameters() 132 | 133 | def reset_parameters(self): 134 | nn.init.xavier_uniform_(self.initial_user_emb.weight) 135 | nn.init.xavier_uniform_(self.initial_item_emb.weight) 136 | 137 | def forward(self, q, x): 138 | q_emb = self.initial_user_emb(q) 139 | i_emb = self.initial_item_emb(x) 140 | 141 | q_emb = self.user_tower(q_emb) 142 | i_emb = self.item_tower(i_emb) 143 | 144 | return q_emb, i_emb 145 | 146 | def item_encoding(self, x): 147 | i_emb = self.initial_item_emb(x) 148 | i1_emb = self.dropout(i_emb) 149 | i2_emb = self.dropout(i_emb) 150 | 151 | i1_emb = self.item_tower(i1_emb) 152 | i2_emb = self.item_tower(i2_emb) 153 | 154 | return i1_emb, i2_emb 155 | 156 | def calculate_cl_loss(self, idx): 157 | x1, x2 = self.item_encoding(idx) 158 | x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1) 159 | pos_score = (x1 * x2).sum(dim=-1) 160 | pos_score = torch.exp(pos_score / self.tau) 161 | ttl_score = torch.matmul(x1, x2.transpose(0, 1)) 162 | ttl_score = torch.exp(ttl_score / self.tau).sum(dim=1) 163 | return -torch.log(pos_score / ttl_score).mean() 164 | -------------------------------------------------------------------------------- /recbole_gnn/model/general_recommender/xsimgcl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | r""" 3 | XSimGCL 4 | ################################################ 5 | Reference: 6 | Junliang Yu, Xin Xia, Tong Chen, Lizhen Cui, Nguyen Quoc Viet Hung, Hongzhi Yin. "XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation" in TKDE 2023. 7 | 8 | Reference code: 9 | https://github.com/Coder-Yu/SELFRec/blob/main/model/graph/XSimGCL.py 10 | """ 11 | 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from recbole_gnn.model.general_recommender import LightGCN 17 | 18 | 19 | class XSimGCL(LightGCN): 20 | def __init__(self, config, dataset): 21 | super(XSimGCL, self).__init__(config, dataset) 22 | 23 | self.cl_rate = config['lambda'] 24 | self.eps = config['eps'] 25 | self.temperature = config['temperature'] 26 | self.layer_cl = config['layer_cl'] 27 | 28 | def forward(self, perturbed=False): 29 | all_embs = self.get_ego_embeddings() 30 | all_embs_cl = all_embs 31 | embeddings_list = [] 32 | 33 | for layer_idx in range(self.n_layers): 34 | all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight) 35 | if perturbed: 36 | random_noise = torch.rand_like(all_embs, device=all_embs.device) 37 | all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps 38 | embeddings_list.append(all_embs) 39 | if layer_idx == self.layer_cl - 1: 40 | all_embs_cl = all_embs 41 | lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1) 42 | lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1) 43 | 44 | user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items]) 45 | user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embs_cl, [self.n_users, self.n_items]) 46 | if perturbed: 47 | return user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl 48 | return user_all_embeddings, item_all_embeddings 49 | 50 | def calculate_cl_loss(self, x1, x2): 51 | x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1) 52 | pos_score = (x1 * x2).sum(dim=-1) 53 | pos_score = torch.exp(pos_score / self.temperature) 54 | ttl_score = torch.matmul(x1, x2.transpose(0, 1)) 55 | ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1) 56 | return -torch.log(pos_score / ttl_score).mean() 57 | 58 | def calculate_loss(self, interaction): 59 | # clear the storage variable when training 60 | if self.restore_user_e is not None or self.restore_item_e is not None: 61 | self.restore_user_e, self.restore_item_e = None, None 62 | 63 | user = interaction[self.USER_ID] 64 | pos_item = interaction[self.ITEM_ID] 65 | neg_item = interaction[self.NEG_ITEM_ID] 66 | 67 | user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl = self.forward(perturbed=True) 68 | u_embeddings = user_all_embeddings[user] 69 | pos_embeddings = item_all_embeddings[pos_item] 70 | neg_embeddings = item_all_embeddings[neg_item] 71 | 72 | # calculate BPR Loss 73 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) 74 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) 75 | mf_loss = self.mf_loss(pos_scores, neg_scores) 76 | 77 | # calculate regularization Loss 78 | u_ego_embeddings = self.user_embedding(user) 79 | pos_ego_embeddings = self.item_embedding(pos_item) 80 | neg_ego_embeddings = self.item_embedding(neg_item) 81 | reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow) 82 | 83 | user = torch.unique(interaction[self.USER_ID]) 84 | pos_item = torch.unique(interaction[self.ITEM_ID]) 85 | 86 | # calculate CL Loss 87 | user_cl_loss = self.calculate_cl_loss(user_all_embeddings[user], user_all_embeddings_cl[user]) 88 | item_cl_loss = self.calculate_cl_loss(item_all_embeddings[pos_item], item_all_embeddings_cl[pos_item]) 89 | 90 | return mf_loss, self.reg_weight * reg_loss, self.cl_rate * (user_cl_loss + item_cl_loss) 91 | -------------------------------------------------------------------------------- /recbole_gnn/model/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch_geometric.nn import MessagePassing 5 | from torch_sparse import matmul 6 | 7 | 8 | class LightGCNConv(MessagePassing): 9 | def __init__(self, dim): 10 | super(LightGCNConv, self).__init__(aggr='add') 11 | self.dim = dim 12 | 13 | def forward(self, x, edge_index, edge_weight): 14 | return self.propagate(edge_index, x=x, edge_weight=edge_weight) 15 | 16 | def message(self, x_j, edge_weight): 17 | return edge_weight.view(-1, 1) * x_j 18 | 19 | def message_and_aggregate(self, adj_t, x): 20 | return matmul(adj_t, x, reduce=self.aggr) 21 | 22 | def __repr__(self): 23 | return '{}({})'.format(self.__class__.__name__, self.dim) 24 | 25 | 26 | class BipartiteGCNConv(MessagePassing): 27 | def __init__(self, dim): 28 | super(BipartiteGCNConv, self).__init__(aggr='add') 29 | self.dim = dim 30 | 31 | def forward(self, x, edge_index, edge_weight, size): 32 | return self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size) 33 | 34 | def message(self, x_j, edge_weight): 35 | return edge_weight.view(-1, 1) * x_j 36 | 37 | def __repr__(self): 38 | return '{}({})'.format(self.__class__.__name__, self.dim) 39 | 40 | 41 | class BiGNNConv(MessagePassing): 42 | r"""Propagate a layer of Bi-interaction GNN 43 | 44 | .. math:: 45 | output = (L+I)EW_1 + LE \otimes EW_2 46 | """ 47 | 48 | def __init__(self, in_channels, out_channels): 49 | super().__init__(aggr='add') 50 | self.in_channels, self.out_channels = in_channels, out_channels 51 | self.lin1 = torch.nn.Linear(in_features=in_channels, out_features=out_channels) 52 | self.lin2 = torch.nn.Linear(in_features=in_channels, out_features=out_channels) 53 | 54 | def forward(self, x, edge_index, edge_weight): 55 | x_prop = self.propagate(edge_index, x=x, edge_weight=edge_weight) 56 | x_trans = self.lin1(x_prop + x) 57 | x_inter = self.lin2(torch.mul(x_prop, x)) 58 | return x_trans + x_inter 59 | 60 | def message(self, x_j, edge_weight): 61 | return edge_weight.view(-1, 1) * x_j 62 | 63 | def message_and_aggregate(self, adj_t, x): 64 | return matmul(adj_t, x, reduce=self.aggr) 65 | 66 | def __repr__(self): 67 | return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels) 68 | 69 | 70 | class SRGNNConv(MessagePassing): 71 | def __init__(self, dim): 72 | # mean aggregation to incorporate weight naturally 73 | super(SRGNNConv, self).__init__(aggr='mean') 74 | 75 | self.lin = torch.nn.Linear(dim, dim) 76 | 77 | def forward(self, x, edge_index): 78 | x = self.lin(x) 79 | return self.propagate(edge_index, x=x) 80 | 81 | 82 | class SRGNNCell(nn.Module): 83 | def __init__(self, dim): 84 | super(SRGNNCell, self).__init__() 85 | 86 | self.dim = dim 87 | self.incomming_conv = SRGNNConv(dim) 88 | self.outcomming_conv = SRGNNConv(dim) 89 | 90 | self.lin_ih = nn.Linear(2 * dim, 3 * dim) 91 | self.lin_hh = nn.Linear(dim, 3 * dim) 92 | 93 | self._reset_parameters() 94 | 95 | def forward(self, hidden, edge_index): 96 | input_in = self.incomming_conv(hidden, edge_index) 97 | reversed_edge_index = torch.flip(edge_index, dims=[0]) 98 | input_out = self.outcomming_conv(hidden, reversed_edge_index) 99 | inputs = torch.cat([input_in, input_out], dim=-1) 100 | 101 | gi = self.lin_ih(inputs) 102 | gh = self.lin_hh(hidden) 103 | i_r, i_i, i_n = gi.chunk(3, -1) 104 | h_r, h_i, h_n = gh.chunk(3, -1) 105 | reset_gate = torch.sigmoid(i_r + h_r) 106 | input_gate = torch.sigmoid(i_i + h_i) 107 | new_gate = torch.tanh(i_n + reset_gate * h_n) 108 | hy = (1 - input_gate) * hidden + input_gate * new_gate 109 | return hy 110 | 111 | def _reset_parameters(self): 112 | stdv = 1.0 / np.sqrt(self.dim) 113 | for weight in self.parameters(): 114 | weight.data.uniform_(-stdv, stdv) 115 | -------------------------------------------------------------------------------- /recbole_gnn/model/sequential_recommender/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole_gnn.model.sequential_recommender.gcegnn import GCEGNN 2 | from recbole_gnn.model.sequential_recommender.gcsan import GCSAN 3 | from recbole_gnn.model.sequential_recommender.lessr import LESSR 4 | from recbole_gnn.model.sequential_recommender.niser import NISER 5 | from recbole_gnn.model.sequential_recommender.sgnnhn import SGNNHN 6 | from recbole_gnn.model.sequential_recommender.srgnn import SRGNN 7 | from recbole_gnn.model.sequential_recommender.tagnn import TAGNN 8 | -------------------------------------------------------------------------------- /recbole_gnn/model/sequential_recommender/gcsan.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/3/7 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | r""" 6 | GCSAN 7 | ################################################ 8 | 9 | Reference: 10 | Chengfeng Xu et al. "Graph Contextualized Self-Attention Network for Session-based Recommendation." in IJCAI 2019. 11 | 12 | """ 13 | 14 | import torch 15 | from torch import nn 16 | from recbole.model.layers import TransformerEncoder 17 | from recbole.model.loss import EmbLoss, BPRLoss 18 | from recbole.model.abstract_recommender import SequentialRecommender 19 | 20 | from recbole_gnn.model.layers import SRGNNCell 21 | 22 | 23 | class GCSAN(SequentialRecommender): 24 | r"""GCSAN captures rich local dependencies via graph neural network, 25 | and learns long-range dependencies by applying the self-attention mechanism. 26 | 27 | Note: 28 | 29 | In the original paper, the attention mechanism in the self-attention layer is a single head, 30 | for the reusability of the project code, we use a unified transformer component. 31 | According to the experimental results, we only applied regularization to embedding. 32 | """ 33 | 34 | def __init__(self, config, dataset): 35 | super(GCSAN, self).__init__(config, dataset) 36 | 37 | # load parameters info 38 | self.n_layers = config['n_layers'] 39 | self.n_heads = config['n_heads'] 40 | self.hidden_size = config['hidden_size'] # same as embedding_size 41 | self.inner_size = config['inner_size'] # the dimensionality in feed-forward layer 42 | self.hidden_dropout_prob = config['hidden_dropout_prob'] 43 | self.attn_dropout_prob = config['attn_dropout_prob'] 44 | self.hidden_act = config['hidden_act'] 45 | self.layer_norm_eps = config['layer_norm_eps'] 46 | 47 | self.step = config['step'] 48 | self.device = config['device'] 49 | self.weight = config['weight'] 50 | self.reg_weight = config['reg_weight'] 51 | self.loss_type = config['loss_type'] 52 | self.initializer_range = config['initializer_range'] 53 | 54 | # item embedding 55 | self.item_embedding = nn.Embedding(self.n_items, self.hidden_size, padding_idx=0) 56 | 57 | # define layers and loss 58 | self.gnncell = SRGNNCell(self.hidden_size) 59 | self.self_attention = TransformerEncoder( 60 | n_layers=self.n_layers, 61 | n_heads=self.n_heads, 62 | hidden_size=self.hidden_size, 63 | inner_size=self.inner_size, 64 | hidden_dropout_prob=self.hidden_dropout_prob, 65 | attn_dropout_prob=self.attn_dropout_prob, 66 | hidden_act=self.hidden_act, 67 | layer_norm_eps=self.layer_norm_eps 68 | ) 69 | self.reg_loss = EmbLoss() 70 | if self.loss_type == 'BPR': 71 | self.loss_fct = BPRLoss() 72 | elif self.loss_type == 'CE': 73 | self.loss_fct = nn.CrossEntropyLoss() 74 | else: 75 | raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") 76 | 77 | # parameters initialization 78 | self.apply(self._init_weights) 79 | 80 | def _init_weights(self, module): 81 | """ Initialize the weights """ 82 | if isinstance(module, (nn.Linear, nn.Embedding)): 83 | # Slightly different from the TF version which uses truncated_normal for initialization 84 | # cf https://github.com/pytorch/pytorch/pull/5617 85 | module.weight.data.normal_(mean=0.0, std=self.initializer_range) 86 | elif isinstance(module, nn.LayerNorm): 87 | module.bias.data.zero_() 88 | module.weight.data.fill_(1.0) 89 | if isinstance(module, nn.Linear) and module.bias is not None: 90 | module.bias.data.zero_() 91 | 92 | def get_attention_mask(self, item_seq): 93 | """Generate left-to-right uni-directional attention mask for multi-head attention.""" 94 | attention_mask = (item_seq > 0).long() 95 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64 96 | # mask for left-to-right unidirectional 97 | max_len = attention_mask.size(-1) 98 | attn_shape = (1, max_len, max_len) 99 | subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8 100 | subsequent_mask = (subsequent_mask == 0).unsqueeze(1) 101 | subsequent_mask = subsequent_mask.long().to(item_seq.device) 102 | 103 | extended_attention_mask = extended_attention_mask * subsequent_mask 104 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 105 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 106 | return extended_attention_mask 107 | 108 | def forward(self, x, edge_index, alias_inputs, item_seq_len): 109 | hidden = self.item_embedding(x) 110 | for i in range(self.step): 111 | hidden = self.gnncell(hidden, edge_index) 112 | 113 | seq_hidden = hidden[alias_inputs] 114 | # fetch the last hidden state of last timestamp 115 | ht = self.gather_indexes(seq_hidden, item_seq_len - 1) 116 | 117 | attention_mask = self.get_attention_mask(alias_inputs) 118 | outputs = self.self_attention(seq_hidden, attention_mask, output_all_encoded_layers=True) 119 | output = outputs[-1] 120 | at = self.gather_indexes(output, item_seq_len - 1) 121 | seq_output = self.weight * at + (1 - self.weight) * ht 122 | return seq_output 123 | 124 | def calculate_loss(self, interaction): 125 | x = interaction['x'] 126 | edge_index = interaction['edge_index'] 127 | alias_inputs = interaction['alias_inputs'] 128 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 129 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) 130 | pos_items = interaction[self.POS_ITEM_ID] 131 | if self.loss_type == 'BPR': 132 | neg_items = interaction[self.NEG_ITEM_ID] 133 | pos_items_emb = self.item_embedding(pos_items) 134 | neg_items_emb = self.item_embedding(neg_items) 135 | pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B] 136 | neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B] 137 | loss = self.loss_fct(pos_score, neg_score) 138 | else: # self.loss_type = 'CE' 139 | test_item_emb = self.item_embedding.weight 140 | logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) 141 | loss = self.loss_fct(logits, pos_items) 142 | reg_loss = self.reg_loss(self.item_embedding.weight) 143 | total_loss = loss + self.reg_weight * reg_loss 144 | return total_loss 145 | 146 | def predict(self, interaction): 147 | test_item = interaction[self.ITEM_ID] 148 | x = interaction['x'] 149 | edge_index = interaction['edge_index'] 150 | alias_inputs = interaction['alias_inputs'] 151 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 152 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) 153 | test_item_emb = self.item_embedding(test_item) 154 | scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] 155 | return scores 156 | 157 | def full_sort_predict(self, interaction): 158 | x = interaction['x'] 159 | edge_index = interaction['edge_index'] 160 | alias_inputs = interaction['alias_inputs'] 161 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 162 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) 163 | test_items_emb = self.item_embedding.weight 164 | scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items] 165 | return scores 166 | -------------------------------------------------------------------------------- /recbole_gnn/model/sequential_recommender/lessr.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/3/11 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | r""" 6 | LESSR 7 | ################################################ 8 | 9 | Reference: 10 | Tianwen Chen and Raymond Chi-Wing Wong. "Handling Information Loss of Graph Neural Networks for Session-based Recommendation." in KDD 2020. 11 | 12 | Reference code: 13 | https://github.com/twchen/lessr 14 | 15 | """ 16 | 17 | import torch 18 | from torch import nn 19 | from torch_geometric.utils import softmax 20 | from torch_geometric.nn import global_add_pool 21 | from recbole.model.abstract_recommender import SequentialRecommender 22 | 23 | 24 | class EOPA(nn.Module): 25 | def __init__( 26 | self, input_dim, output_dim, batch_norm=True, feat_drop=0.0, activation=None 27 | ): 28 | super().__init__() 29 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 30 | self.feat_drop = nn.Dropout(feat_drop) 31 | self.gru = nn.GRU(input_dim, input_dim, batch_first=True) 32 | self.fc_self = nn.Linear(input_dim, output_dim, bias=False) 33 | self.fc_neigh = nn.Linear(input_dim, output_dim, bias=False) 34 | self.activation = activation 35 | 36 | def reducer(self, nodes): 37 | m = nodes.mailbox['m'] # (num_nodes, deg, d) 38 | # m[i]: the messages passed to the i-th node with in-degree equal to 'deg' 39 | # the order of messages follows the order of incoming edges 40 | # since the edges are sorted by occurrence time when the EOP multigraph is built 41 | # the messages are in the order required by EOPA 42 | _, hn = self.gru(m) # hn: (1, num_nodes, d) 43 | return {'neigh': hn.squeeze(0)} 44 | 45 | def forward(self, mg, feat): 46 | import dgl.function as fn 47 | 48 | with mg.local_scope(): 49 | if self.batch_norm is not None: 50 | feat = self.batch_norm(feat) 51 | mg.ndata['ft'] = self.feat_drop(feat) 52 | if mg.number_of_edges() > 0: 53 | mg.update_all(fn.copy_u('ft', 'm'), self.reducer) 54 | neigh = mg.ndata['neigh'] 55 | rst = self.fc_self(feat) + self.fc_neigh(neigh) 56 | else: 57 | rst = self.fc_self(feat) 58 | if self.activation is not None: 59 | rst = self.activation(rst) 60 | return rst 61 | 62 | 63 | class SGAT(nn.Module): 64 | def __init__( 65 | self, 66 | input_dim, 67 | hidden_dim, 68 | output_dim, 69 | batch_norm=True, 70 | feat_drop=0.0, 71 | activation=None, 72 | ): 73 | super().__init__() 74 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 75 | self.feat_drop = nn.Dropout(feat_drop) 76 | self.fc_q = nn.Linear(input_dim, hidden_dim, bias=True) 77 | self.fc_k = nn.Linear(input_dim, hidden_dim, bias=False) 78 | self.fc_v = nn.Linear(input_dim, output_dim, bias=False) 79 | self.fc_e = nn.Linear(hidden_dim, 1, bias=False) 80 | self.activation = activation 81 | 82 | def forward(self, sg, feat): 83 | import dgl.ops as F 84 | 85 | if self.batch_norm is not None: 86 | feat = self.batch_norm(feat) 87 | feat = self.feat_drop(feat) 88 | q = self.fc_q(feat) 89 | k = self.fc_k(feat) 90 | v = self.fc_v(feat) 91 | e = F.u_add_v(sg, q, k) 92 | e = self.fc_e(torch.sigmoid(e)) 93 | a = F.edge_softmax(sg, e) 94 | rst = F.u_mul_e_sum(sg, v, a) 95 | if self.activation is not None: 96 | rst = self.activation(rst) 97 | return rst 98 | 99 | 100 | class AttnReadout(nn.Module): 101 | def __init__( 102 | self, 103 | input_dim, 104 | hidden_dim, 105 | output_dim, 106 | batch_norm=True, 107 | feat_drop=0.0, 108 | activation=None, 109 | ): 110 | super().__init__() 111 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 112 | self.feat_drop = nn.Dropout(feat_drop) 113 | self.fc_u = nn.Linear(input_dim, hidden_dim, bias=False) 114 | self.fc_v = nn.Linear(input_dim, hidden_dim, bias=True) 115 | self.fc_e = nn.Linear(hidden_dim, 1, bias=False) 116 | self.fc_out = ( 117 | nn.Linear(input_dim, output_dim, bias=False) 118 | if output_dim != input_dim else None 119 | ) 120 | self.activation = activation 121 | 122 | def forward(self, g, feat, last_nodes, batch): 123 | if self.batch_norm is not None: 124 | feat = self.batch_norm(feat) 125 | feat = self.feat_drop(feat) 126 | feat_u = self.fc_u(feat) 127 | feat_v = self.fc_v(feat[last_nodes]) 128 | feat_v = torch.index_select(feat_v, dim=0, index=batch) 129 | e = self.fc_e(torch.sigmoid(feat_u + feat_v)) 130 | alpha = softmax(e, batch) 131 | feat_norm = feat * alpha 132 | rst = global_add_pool(feat_norm, batch) 133 | if self.fc_out is not None: 134 | rst = self.fc_out(rst) 135 | if self.activation is not None: 136 | rst = self.activation(rst) 137 | return rst 138 | 139 | 140 | class LESSR(SequentialRecommender): 141 | r"""LESSR analyzes the information losses when constructing session graphs, 142 | and emphasises lossy session encoding problem and the ineffective long-range dependency capturing problem. 143 | To solve the first problem, authors propose a lossless encoding scheme and an edge-order preserving aggregation layer. 144 | To solve the second problem, authors propose a shortcut graph attention layer that effectively captures long-range dependencies. 145 | 146 | Note: 147 | We follow the original implementation, which requires DGL package. 148 | We find it difficult to implement these functions via PyG, so we remain them. 149 | If you would like to test this model, please install DGL. 150 | """ 151 | 152 | def __init__(self, config, dataset): 153 | super().__init__(config, dataset) 154 | 155 | embedding_dim = config['embedding_size'] 156 | self.num_layers = config['n_layers'] 157 | batch_norm = config['batch_norm'] 158 | feat_drop = config['feat_drop'] 159 | self.loss_type = config['loss_type'] 160 | 161 | self.item_embedding = nn.Embedding(self.n_items, embedding_dim, max_norm=1) 162 | self.layers = nn.ModuleList() 163 | input_dim = embedding_dim 164 | for i in range(self.num_layers): 165 | if i % 2 == 0: 166 | layer = EOPA( 167 | input_dim, 168 | embedding_dim, 169 | batch_norm=batch_norm, 170 | feat_drop=feat_drop, 171 | activation=nn.PReLU(embedding_dim), 172 | ) 173 | else: 174 | layer = SGAT( 175 | input_dim, 176 | embedding_dim, 177 | embedding_dim, 178 | batch_norm=batch_norm, 179 | feat_drop=feat_drop, 180 | activation=nn.PReLU(embedding_dim), 181 | ) 182 | input_dim += embedding_dim 183 | self.layers.append(layer) 184 | self.readout = AttnReadout( 185 | input_dim, 186 | embedding_dim, 187 | embedding_dim, 188 | batch_norm=batch_norm, 189 | feat_drop=feat_drop, 190 | activation=nn.PReLU(embedding_dim), 191 | ) 192 | input_dim += embedding_dim 193 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 194 | self.feat_drop = nn.Dropout(feat_drop) 195 | self.fc_sr = nn.Linear(input_dim, embedding_dim, bias=False) 196 | 197 | if self.loss_type == 'CE': 198 | self.loss_fct = nn.CrossEntropyLoss() 199 | else: 200 | raise NotImplementedError("Make sure 'loss_type' in ['CE']!") 201 | 202 | def forward(self, x, edge_index_EOP, edge_index_shortcut, batch, is_last): 203 | import dgl 204 | 205 | mg = dgl.graph((edge_index_EOP[0], edge_index_EOP[1]), num_nodes=batch.shape[0]) 206 | sg = dgl.graph((edge_index_shortcut[0], edge_index_shortcut[1]), num_nodes=batch.shape[0]) 207 | 208 | feat = self.item_embedding(x) 209 | for i, layer in enumerate(self.layers): 210 | if i % 2 == 0: 211 | out = layer(mg, feat) 212 | else: 213 | out = layer(sg, feat) 214 | feat = torch.cat([out, feat], dim=1) 215 | sr_g = self.readout(mg, feat, is_last, batch) 216 | sr_l = feat[is_last] 217 | sr = torch.cat([sr_l, sr_g], dim=1) 218 | if self.batch_norm is not None: 219 | sr = self.batch_norm(sr) 220 | sr = self.fc_sr(self.feat_drop(sr)) 221 | return sr 222 | 223 | def calculate_loss(self, interaction): 224 | x = interaction['x'] 225 | edge_index_EOP = interaction['edge_index_EOP'] 226 | edge_index_shortcut = interaction['edge_index_shortcut'] 227 | batch = interaction['batch'] 228 | is_last = interaction['is_last'] 229 | seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last) 230 | pos_items = interaction[self.POS_ITEM_ID] 231 | test_item_emb = self.item_embedding.weight 232 | logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) 233 | loss = self.loss_fct(logits, pos_items) 234 | return loss 235 | 236 | def predict(self, interaction): 237 | test_item = interaction[self.ITEM_ID] 238 | x = interaction['x'] 239 | edge_index_EOP = interaction['edge_index_EOP'] 240 | edge_index_shortcut = interaction['edge_index_shortcut'] 241 | batch = interaction['batch'] 242 | is_last = interaction['is_last'] 243 | seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last) 244 | test_item_emb = self.item_embedding(test_item) 245 | scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] 246 | return scores 247 | 248 | def full_sort_predict(self, interaction): 249 | x = interaction['x'] 250 | edge_index_EOP = interaction['edge_index_EOP'] 251 | edge_index_shortcut = interaction['edge_index_shortcut'] 252 | batch = interaction['batch'] 253 | is_last = interaction['is_last'] 254 | seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last) 255 | test_items_emb = self.item_embedding.weight 256 | scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items] 257 | return scores 258 | -------------------------------------------------------------------------------- /recbole_gnn/model/sequential_recommender/niser.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/3/7 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | r""" 6 | NISER 7 | ################################################ 8 | 9 | Reference: 10 | Priyanka Gupta et al. "NISER: Normalized Item and Session Representations to Handle Popularity Bias." in CIKM 2019 GRLA workshop. 11 | 12 | """ 13 | import numpy as np 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | from recbole.model.loss import BPRLoss 18 | from recbole.model.abstract_recommender import SequentialRecommender 19 | 20 | from recbole_gnn.model.layers import SRGNNCell 21 | 22 | 23 | class NISER(SequentialRecommender): 24 | r"""NISER+ is a GNN-based model that normalizes session and item embeddings to handle popularity bias. 25 | """ 26 | 27 | def __init__(self, config, dataset): 28 | super(NISER, self).__init__(config, dataset) 29 | 30 | # load parameters info 31 | self.embedding_size = config['embedding_size'] 32 | self.step = config['step'] 33 | self.device = config['device'] 34 | self.loss_type = config['loss_type'] 35 | self.sigma = config['sigma'] 36 | self.max_seq_length = dataset.field2seqlen[self.ITEM_SEQ] 37 | 38 | # item embedding 39 | self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0) 40 | self.pos_embedding = nn.Embedding(self.max_seq_length, self.embedding_size) 41 | self.item_dropout = nn.Dropout(config['item_dropout']) 42 | 43 | # define layers and loss 44 | self.gnncell = SRGNNCell(self.embedding_size) 45 | self.linear_one = nn.Linear(self.embedding_size, self.embedding_size) 46 | self.linear_two = nn.Linear(self.embedding_size, self.embedding_size) 47 | self.linear_three = nn.Linear(self.embedding_size, 1, bias=False) 48 | self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size) 49 | if self.loss_type == 'BPR': 50 | self.loss_fct = BPRLoss() 51 | elif self.loss_type == 'CE': 52 | self.loss_fct = nn.CrossEntropyLoss() 53 | else: 54 | raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") 55 | 56 | # parameters initialization 57 | self._reset_parameters() 58 | 59 | def _reset_parameters(self): 60 | stdv = 1.0 / np.sqrt(self.embedding_size) 61 | for weight in self.parameters(): 62 | weight.data.uniform_(-stdv, stdv) 63 | 64 | def forward(self, x, edge_index, alias_inputs, item_seq_len): 65 | mask = alias_inputs.gt(0) 66 | hidden = self.item_embedding(x) 67 | # Dropout in NISER+ 68 | hidden = self.item_dropout(hidden) 69 | # Normalize item embeddings 70 | hidden = F.normalize(hidden, dim=-1) 71 | for i in range(self.step): 72 | hidden = self.gnncell(hidden, edge_index) 73 | 74 | seq_hidden = hidden[alias_inputs] 75 | batch_size = seq_hidden.shape[0] 76 | pos_emb = self.pos_embedding.weight[:seq_hidden.shape[1]] 77 | pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) 78 | seq_hidden = seq_hidden + pos_emb 79 | # fetch the last hidden state of last timestamp 80 | ht = self.gather_indexes(seq_hidden, item_seq_len - 1) 81 | q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1)) 82 | q2 = self.linear_two(seq_hidden) 83 | 84 | alpha = self.linear_three(torch.sigmoid(q1 + q2)) 85 | a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1) 86 | seq_output = self.linear_transform(torch.cat([a, ht], dim=1)) 87 | # Normalize session embeddings 88 | seq_output = F.normalize(seq_output, dim=-1) 89 | return seq_output 90 | 91 | def calculate_loss(self, interaction): 92 | x = interaction['x'] 93 | edge_index = interaction['edge_index'] 94 | alias_inputs = interaction['alias_inputs'] 95 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 96 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) 97 | pos_items = interaction[self.POS_ITEM_ID] 98 | if self.loss_type == 'BPR': 99 | neg_items = interaction[self.NEG_ITEM_ID] 100 | pos_items_emb = F.normalize(self.item_embedding(pos_items), dim=-1) 101 | neg_items_emb = F.normalize(self.item_embedding(neg_items), dim=-1) 102 | pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B] 103 | neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B] 104 | loss = self.loss_fct(self.sigma * pos_score, self.sigma * neg_score) 105 | return loss 106 | else: # self.loss_type = 'CE' 107 | test_item_emb = F.normalize(self.item_embedding.weight, dim=-1) 108 | logits = self.sigma * torch.matmul(seq_output, test_item_emb.transpose(0, 1)) 109 | loss = self.loss_fct(logits, pos_items) 110 | return loss 111 | 112 | def predict(self, interaction): 113 | test_item = interaction[self.ITEM_ID] 114 | x = interaction['x'] 115 | edge_index = interaction['edge_index'] 116 | alias_inputs = interaction['alias_inputs'] 117 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 118 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) 119 | test_item_emb = F.normalize(self.item_embedding(test_item), dim=-1) 120 | scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] 121 | return scores 122 | 123 | def full_sort_predict(self, interaction): 124 | x = interaction['x'] 125 | edge_index = interaction['edge_index'] 126 | alias_inputs = interaction['alias_inputs'] 127 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 128 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) 129 | test_items_emb = F.normalize(self.item_embedding.weight, dim=-1) 130 | scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items] 131 | return scores 132 | -------------------------------------------------------------------------------- /recbole_gnn/model/sequential_recommender/sgnnhn.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/3/28 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | r""" 6 | SRGNN 7 | ################################################ 8 | 9 | Reference: 10 | Zhiqiang Pan et al. "Star Graph Neural Networks for Session-based Recommendation." in CIKM 2020. 11 | 12 | Reference code: 13 | https://bitbucket.org/nudtpanzq/sgnn-hn 14 | 15 | """ 16 | 17 | import math 18 | import numpy as np 19 | import torch 20 | from torch import nn 21 | from torch_geometric.nn import global_mean_pool, global_add_pool 22 | from torch_geometric.utils import softmax 23 | from recbole.model.abstract_recommender import SequentialRecommender 24 | from recbole.model.loss import BPRLoss 25 | 26 | from recbole_gnn.model.layers import SRGNNCell 27 | 28 | 29 | def layer_norm(x): 30 | ave_x = torch.mean(x, -1).unsqueeze(-1) 31 | x = x - ave_x 32 | norm_x = torch.sqrt(torch.sum(x**2, -1)).unsqueeze(-1) 33 | y = x / norm_x 34 | return y 35 | 36 | 37 | class SGNNHN(SequentialRecommender): 38 | r"""SGNN-HN applies a star graph neural network to model the complex transition relationship between items in an ongoing session. 39 | To avoid overfitting, it applies highway networks to adaptively select embeddings from item representations. 40 | """ 41 | 42 | def __init__(self, config, dataset): 43 | super(SGNNHN, self).__init__(config, dataset) 44 | 45 | # load parameters info 46 | self.embedding_size = config['embedding_size'] 47 | self.step = config['step'] 48 | self.device = config['device'] 49 | self.loss_type = config['loss_type'] 50 | self.scale = config['scale'] 51 | 52 | # item embedding 53 | self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0) 54 | self.max_seq_length = dataset.field2seqlen[self.ITEM_SEQ] 55 | self.pos_embedding = nn.Embedding(self.max_seq_length, self.embedding_size) 56 | 57 | # define layers and loss 58 | self.gnncell = SRGNNCell(self.embedding_size) 59 | self.linear_one = nn.Linear(self.embedding_size, self.embedding_size) 60 | self.linear_two = nn.Linear(self.embedding_size, self.embedding_size) 61 | self.linear_three = nn.Linear(self.embedding_size, self.embedding_size) 62 | self.linear_four = nn.Linear(self.embedding_size, 1, bias=False) 63 | self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size) 64 | if self.loss_type == 'BPR': 65 | self.loss_fct = BPRLoss() 66 | elif self.loss_type == 'CE': 67 | self.loss_fct = nn.CrossEntropyLoss() 68 | else: 69 | raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") 70 | 71 | # parameters initialization 72 | self._reset_parameters() 73 | 74 | def _reset_parameters(self): 75 | stdv = 1.0 / np.sqrt(self.embedding_size) 76 | for weight in self.parameters(): 77 | weight.data.uniform_(-stdv, stdv) 78 | 79 | def att_out(self, hidden, star_node, batch): 80 | star_node_repeat = torch.index_select(star_node, 0, batch) 81 | sim = (hidden * star_node_repeat).sum(dim=-1) 82 | sim = softmax(sim, batch) 83 | att_hidden = sim.unsqueeze(-1) * hidden 84 | output = global_add_pool(att_hidden, batch) 85 | 86 | return output 87 | 88 | def forward(self, x, edge_index, batch, alias_inputs, item_seq_len): 89 | mask = alias_inputs.gt(0) 90 | hidden = self.item_embedding(x) 91 | 92 | star_node = global_mean_pool(hidden, batch) 93 | for i in range(self.step): 94 | hidden = self.gnncell(hidden, edge_index) 95 | star_node_repeat = torch.index_select(star_node, 0, batch) 96 | sim = (hidden * star_node_repeat).sum(dim=-1, keepdim=True) / math.sqrt(self.embedding_size) 97 | alpha = torch.sigmoid(sim) 98 | hidden = (1 - alpha) * hidden + alpha * star_node_repeat 99 | star_node = self.att_out(hidden, star_node, batch) 100 | 101 | seq_hidden = hidden[alias_inputs] 102 | bs, item_num, _ = seq_hidden.shape 103 | pos_emb = self.pos_embedding.weight[:item_num] 104 | pos_emb = pos_emb.unsqueeze(0).expand(bs, -1, -1) 105 | seq_hidden = seq_hidden + pos_emb 106 | 107 | # fetch the last hidden state of last timestamp 108 | ht = self.gather_indexes(seq_hidden, item_seq_len - 1) 109 | q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1)) 110 | q2 = self.linear_two(seq_hidden) 111 | q3 = self.linear_three(star_node).view(star_node.shape[0], 1, star_node.shape[1]) 112 | 113 | alpha = self.linear_four(torch.sigmoid(q1 + q2 + q3)) 114 | a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1) 115 | seq_output = self.linear_transform(torch.cat([a, ht], dim=1)) 116 | return layer_norm(seq_output) 117 | 118 | def calculate_loss(self, interaction): 119 | x = interaction['x'] 120 | edge_index = interaction['edge_index'] 121 | batch = interaction['batch'] 122 | alias_inputs = interaction['alias_inputs'] 123 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 124 | seq_output = self.forward(x, edge_index, batch, alias_inputs, item_seq_len) 125 | pos_items = interaction[self.POS_ITEM_ID] 126 | if self.loss_type == 'BPR': 127 | neg_items = interaction[self.NEG_ITEM_ID] 128 | pos_items_emb = layer_norm(self.item_embedding(pos_items)) 129 | neg_items_emb = layer_norm(self.item_embedding(neg_items)) 130 | pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) * self.scale # [B] 131 | neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) * self.scale # [B] 132 | loss = self.loss_fct(pos_score, neg_score) 133 | return loss 134 | else: # self.loss_type = 'CE' 135 | test_item_emb = layer_norm(self.item_embedding.weight) 136 | logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) * self.scale 137 | loss = self.loss_fct(logits, pos_items) 138 | return loss 139 | 140 | def predict(self, interaction): 141 | test_item = interaction[self.ITEM_ID] 142 | x = interaction['x'] 143 | edge_index = interaction['edge_index'] 144 | batch = interaction['batch'] 145 | alias_inputs = interaction['alias_inputs'] 146 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 147 | seq_output = self.forward(x, edge_index, batch, alias_inputs, item_seq_len) 148 | test_item_emb = layer_norm(self.item_embedding(test_item)) 149 | scores = torch.mul(seq_output, test_item_emb).sum(dim=1) * self.scale # [B] 150 | return scores 151 | 152 | def full_sort_predict(self, interaction): 153 | x = interaction['x'] 154 | edge_index = interaction['edge_index'] 155 | batch = interaction['batch'] 156 | alias_inputs = interaction['alias_inputs'] 157 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 158 | seq_output = self.forward(x, edge_index, batch, alias_inputs, item_seq_len) 159 | test_items_emb = layer_norm(self.item_embedding.weight) 160 | scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) * self.scale # [B, n_items] 161 | return scores 162 | -------------------------------------------------------------------------------- /recbole_gnn/model/sequential_recommender/srgnn.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/3/7 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | r""" 6 | SRGNN 7 | ################################################ 8 | 9 | Reference: 10 | Shu Wu et al. "Session-based Recommendation with Graph Neural Networks." in AAAI 2019. 11 | 12 | Reference code: 13 | https://github.com/CRIPAC-DIG/SR-GNN 14 | 15 | """ 16 | import numpy as np 17 | import torch 18 | from torch import nn 19 | from recbole.model.loss import BPRLoss 20 | from recbole.model.abstract_recommender import SequentialRecommender 21 | 22 | from recbole_gnn.model.layers import SRGNNCell 23 | 24 | 25 | class SRGNN(SequentialRecommender): 26 | r"""SRGNN regards the conversation history as a directed graph. 27 | In addition to considering the connection between the item and the adjacent item, 28 | it also considers the connection with other interactive items. 29 | 30 | Such as: A example of a session sequence(eg:item1, item2, item3, item2, item4) and the connection matrix A 31 | 32 | Outgoing edges: 33 | === ===== ===== ===== ===== 34 | \ 1 2 3 4 35 | === ===== ===== ===== ===== 36 | 1 0 1 0 0 37 | 2 0 0 1/2 1/2 38 | 3 0 1 0 0 39 | 4 0 0 0 0 40 | === ===== ===== ===== ===== 41 | 42 | Incoming edges: 43 | === ===== ===== ===== ===== 44 | \ 1 2 3 4 45 | === ===== ===== ===== ===== 46 | 1 0 0 0 0 47 | 2 1/2 0 1/2 0 48 | 3 0 1 0 0 49 | 4 0 1 0 0 50 | === ===== ===== ===== ===== 51 | """ 52 | 53 | def __init__(self, config, dataset): 54 | super(SRGNN, self).__init__(config, dataset) 55 | 56 | # load parameters info 57 | self.embedding_size = config['embedding_size'] 58 | self.step = config['step'] 59 | self.device = config['device'] 60 | self.loss_type = config['loss_type'] 61 | 62 | # item embedding 63 | self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0) 64 | 65 | # define layers and loss 66 | self.gnncell = SRGNNCell(self.embedding_size) 67 | self.linear_one = nn.Linear(self.embedding_size, self.embedding_size) 68 | self.linear_two = nn.Linear(self.embedding_size, self.embedding_size) 69 | self.linear_three = nn.Linear(self.embedding_size, 1, bias=False) 70 | self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size) 71 | if self.loss_type == 'BPR': 72 | self.loss_fct = BPRLoss() 73 | elif self.loss_type == 'CE': 74 | self.loss_fct = nn.CrossEntropyLoss() 75 | else: 76 | raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") 77 | 78 | # parameters initialization 79 | self._reset_parameters() 80 | 81 | def _reset_parameters(self): 82 | stdv = 1.0 / np.sqrt(self.embedding_size) 83 | for weight in self.parameters(): 84 | weight.data.uniform_(-stdv, stdv) 85 | 86 | def forward(self, x, edge_index, alias_inputs, item_seq_len): 87 | mask = alias_inputs.gt(0) 88 | hidden = self.item_embedding(x) 89 | for i in range(self.step): 90 | hidden = self.gnncell(hidden, edge_index) 91 | 92 | seq_hidden = hidden[alias_inputs] 93 | # fetch the last hidden state of last timestamp 94 | ht = self.gather_indexes(seq_hidden, item_seq_len - 1) 95 | q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1)) 96 | q2 = self.linear_two(seq_hidden) 97 | 98 | alpha = self.linear_three(torch.sigmoid(q1 + q2)) 99 | a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1) 100 | seq_output = self.linear_transform(torch.cat([a, ht], dim=1)) 101 | return seq_output 102 | 103 | def calculate_loss(self, interaction): 104 | x = interaction['x'] 105 | edge_index = interaction['edge_index'] 106 | alias_inputs = interaction['alias_inputs'] 107 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 108 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) 109 | pos_items = interaction[self.POS_ITEM_ID] 110 | if self.loss_type == 'BPR': 111 | neg_items = interaction[self.NEG_ITEM_ID] 112 | pos_items_emb = self.item_embedding(pos_items) 113 | neg_items_emb = self.item_embedding(neg_items) 114 | pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B] 115 | neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B] 116 | loss = self.loss_fct(pos_score, neg_score) 117 | return loss 118 | else: # self.loss_type = 'CE' 119 | test_item_emb = self.item_embedding.weight 120 | logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) 121 | loss = self.loss_fct(logits, pos_items) 122 | return loss 123 | 124 | def predict(self, interaction): 125 | test_item = interaction[self.ITEM_ID] 126 | x = interaction['x'] 127 | edge_index = interaction['edge_index'] 128 | alias_inputs = interaction['alias_inputs'] 129 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 130 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) 131 | test_item_emb = self.item_embedding(test_item) 132 | scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] 133 | return scores 134 | 135 | def full_sort_predict(self, interaction): 136 | x = interaction['x'] 137 | edge_index = interaction['edge_index'] 138 | alias_inputs = interaction['alias_inputs'] 139 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 140 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) 141 | test_items_emb = self.item_embedding.weight 142 | scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items] 143 | return scores 144 | -------------------------------------------------------------------------------- /recbole_gnn/model/sequential_recommender/tagnn.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/3/17 2 | # @Author : Yupeng Hou 3 | # @Email : houyupeng@ruc.edu.cn 4 | 5 | r""" 6 | TAGNN 7 | ################################################ 8 | 9 | Reference: 10 | Feng Yu et al. "TAGNN: Target Attentive Graph Neural Networks for Session-based Recommendation." in SIGIR 2020 short. 11 | Implemented using PyTorch Geometric. 12 | 13 | Reference code: 14 | https://github.com/CRIPAC-DIG/TAGNN 15 | 16 | """ 17 | import numpy as np 18 | import torch 19 | from torch import nn 20 | import torch.nn.functional as F 21 | from recbole.model.abstract_recommender import SequentialRecommender 22 | 23 | from recbole_gnn.model.layers import SRGNNCell 24 | 25 | 26 | class TAGNN(SequentialRecommender): 27 | r"""TAGNN introduces target-aware attention and adaptively activates different user interests with respect to varied target items. 28 | """ 29 | 30 | def __init__(self, config, dataset): 31 | super(TAGNN, self).__init__(config, dataset) 32 | 33 | # load parameters info 34 | self.embedding_size = config['embedding_size'] 35 | self.step = config['step'] 36 | self.device = config['device'] 37 | self.loss_type = config['loss_type'] 38 | 39 | # item embedding 40 | self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0) 41 | 42 | # define layers and loss 43 | self.gnncell = SRGNNCell(self.embedding_size) 44 | self.linear_one = nn.Linear(self.embedding_size, self.embedding_size) 45 | self.linear_two = nn.Linear(self.embedding_size, self.embedding_size) 46 | self.linear_three = nn.Linear(self.embedding_size, 1, bias=False) 47 | self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size) 48 | self.linear_t = nn.Linear(self.embedding_size, self.embedding_size, bias=False) #target attention 49 | if self.loss_type == 'CE': 50 | self.loss_fct = nn.CrossEntropyLoss() 51 | else: 52 | raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") 53 | 54 | # parameters initialization 55 | self._reset_parameters() 56 | 57 | def _reset_parameters(self): 58 | stdv = 1.0 / np.sqrt(self.embedding_size) 59 | for weight in self.parameters(): 60 | weight.data.uniform_(-stdv, stdv) 61 | 62 | def forward(self, x, edge_index, alias_inputs, item_seq_len): 63 | mask = alias_inputs.gt(0) 64 | hidden = self.item_embedding(x) 65 | for i in range(self.step): 66 | hidden = self.gnncell(hidden, edge_index) 67 | 68 | seq_hidden = hidden[alias_inputs] 69 | # fetch the last hidden state of last timestamp 70 | ht = self.gather_indexes(seq_hidden, item_seq_len - 1) 71 | q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1)) 72 | q2 = self.linear_two(seq_hidden) 73 | 74 | alpha = self.linear_three(torch.sigmoid(q1 + q2)) 75 | alpha = F.softmax(alpha, 1) 76 | a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1) 77 | seq_output = self.linear_transform(torch.cat([a, ht], dim=1)) 78 | 79 | seq_hidden = seq_hidden * mask.view(mask.shape[0], -1, 1).float() 80 | qt = self.linear_t(seq_hidden) 81 | b = self.item_embedding.weight 82 | beta = F.softmax(b @ qt.transpose(1,2), -1) 83 | target = beta @ seq_hidden 84 | a = seq_output.view(ht.shape[0], 1, ht.shape[1]) # b,1,d 85 | a = a + target # b,n,d 86 | scores = torch.sum(a * b, -1) # b,n 87 | return scores 88 | 89 | def calculate_loss(self, interaction): 90 | x = interaction['x'] 91 | edge_index = interaction['edge_index'] 92 | alias_inputs = interaction['alias_inputs'] 93 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 94 | logits = self.forward(x, edge_index, alias_inputs, item_seq_len) 95 | pos_items = interaction[self.POS_ITEM_ID] 96 | loss = self.loss_fct(logits, pos_items) 97 | return loss 98 | 99 | def predict(self, interaction): 100 | pass 101 | 102 | def full_sort_predict(self, interaction): 103 | x = interaction['x'] 104 | edge_index = interaction['edge_index'] 105 | alias_inputs = interaction['alias_inputs'] 106 | item_seq_len = interaction[self.ITEM_SEQ_LEN] 107 | scores = self.forward(x, edge_index, alias_inputs, item_seq_len) 108 | return scores 109 | -------------------------------------------------------------------------------- /recbole_gnn/model/social_recommender/__init__.py: -------------------------------------------------------------------------------- 1 | from recbole_gnn.model.social_recommender.diffnet import DiffNet 2 | from recbole_gnn.model.social_recommender.mhcn import MHCN 3 | from recbole_gnn.model.social_recommender.sept import SEPT -------------------------------------------------------------------------------- /recbole_gnn/model/social_recommender/diffnet.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/3/15 2 | # @Author : Lanling Xu 3 | # @Email : xulanling_sherry@163.com 4 | 5 | r""" 6 | DiffNet 7 | ################################################ 8 | Reference: 9 | Le Wu et al. "A Neural Influence Diffusion Model for Social Recommendation." in SIGIR 2019. 10 | 11 | Reference code: 12 | https://github.com/PeiJieSun/diffnet 13 | """ 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | 19 | from recbole.model.init import xavier_uniform_initialization 20 | from recbole.model.loss import BPRLoss, EmbLoss 21 | from recbole.utils import InputType 22 | 23 | from recbole_gnn.model.abstract_recommender import SocialRecommender 24 | from recbole_gnn.model.layers import BipartiteGCNConv 25 | 26 | 27 | class DiffNet(SocialRecommender): 28 | r"""DiffNet is a deep influence propagation model to stimulate how users are influenced by the recursive social diffusion process for social recommendation. 29 | We implement the model following the original author with a pairwise training mode. 30 | """ 31 | input_type = InputType.PAIRWISE 32 | 33 | def __init__(self, config, dataset): 34 | super(DiffNet, self).__init__(config, dataset) 35 | 36 | # load dataset info 37 | self.edge_index, self.edge_weight = dataset.get_bipartite_inter_mat(row='user') 38 | self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device) 39 | 40 | self.net_edge_index, self.net_edge_weight = dataset.get_norm_net_adj_mat(row_norm=True) 41 | self.net_edge_index, self.net_edge_weight = self.net_edge_index.to(self.device), self.net_edge_weight.to(self.device) 42 | 43 | # load parameters info 44 | self.embedding_size = config['embedding_size'] # int type:the embedding size of DiffNet 45 | self.n_layers = config['n_layers'] # int type:the GCN layer num of DiffNet for social net 46 | self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization 47 | self.pretrained_review = config['pretrained_review'] # bool type:whether to load pre-trained review vectors of users and items 48 | 49 | # define layers and loss 50 | self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.embedding_size) 51 | self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.embedding_size) 52 | self.bipartite_gcn_conv = BipartiteGCNConv(dim=self.embedding_size) 53 | self.mf_loss = BPRLoss() 54 | self.reg_loss = EmbLoss() 55 | 56 | # storage variables for full sort evaluation acceleration 57 | self.restore_user_e = None 58 | self.restore_item_e = None 59 | 60 | # parameters initialization 61 | self.apply(xavier_uniform_initialization) 62 | self.other_parameter_name = ['restore_user_e', 'restore_item_e'] 63 | 64 | if self.pretrained_review: 65 | # handle review information, map the origin review into the new space 66 | self.user_review_embedding = nn.Embedding(self.n_users, self.embedding_size, padding_idx=0) 67 | self.user_review_embedding.weight.requires_grad = False 68 | self.user_review_embedding.weight.data.copy_(self.convertDistribution(dataset.user_feat['user_review_emb'])) 69 | 70 | self.item_review_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0) 71 | self.item_review_embedding.weight.requires_grad = False 72 | self.item_review_embedding.weight.data.copy_(self.convertDistribution(dataset.item_feat['item_review_emb'])) 73 | 74 | self.user_fusion_layer = nn.Linear(self.embedding_size, self.embedding_size) 75 | self.item_fusion_layer = nn.Linear(self.embedding_size, self.embedding_size) 76 | self.activation = nn.Sigmoid() 77 | 78 | def convertDistribution(self, x): 79 | mean, std = torch.mean(x), torch.std(x) 80 | y = (x - mean) * 0.2 / std 81 | return y 82 | 83 | def forward(self): 84 | user_embedding = self.user_embedding.weight 85 | final_item_embedding = self.item_embedding.weight 86 | 87 | if self.pretrained_review: 88 | user_reduce_dim_vector_matrix = self.activation(self.user_fusion_layer(self.user_review_embedding.weight)) 89 | item_reduce_dim_vector_matrix = self.activation(self.item_fusion_layer(self.item_review_embedding.weight)) 90 | 91 | user_review_vector_matrix = self.convertDistribution(user_reduce_dim_vector_matrix) 92 | item_review_vector_matrix = self.convertDistribution(item_reduce_dim_vector_matrix) 93 | 94 | user_embedding = user_embedding + user_review_vector_matrix 95 | final_item_embedding = final_item_embedding + item_review_vector_matrix 96 | 97 | user_embedding_from_consumed_items = self.bipartite_gcn_conv(x=(final_item_embedding, user_embedding), edge_index=self.edge_index.flip([0]), edge_weight=self.edge_weight, size=(self.n_items, self.n_users)) 98 | 99 | embeddings_list = [user_embedding] 100 | for layer_idx in range(self.n_layers): 101 | user_embedding = self.bipartite_gcn_conv((user_embedding, user_embedding), self.net_edge_index.flip([0]), self.net_edge_weight, size=(self.n_users, self.n_users)) 102 | embeddings_list.append(user_embedding) 103 | final_user_embedding = torch.stack(embeddings_list, dim=1) 104 | final_user_embedding = torch.sum(final_user_embedding, dim=1) + user_embedding_from_consumed_items 105 | 106 | return final_user_embedding, final_item_embedding 107 | 108 | def calculate_loss(self, interaction): 109 | # clear the storage variable when training 110 | if self.restore_user_e is not None or self.restore_item_e is not None: 111 | self.restore_user_e, self.restore_item_e = None, None 112 | 113 | user = interaction[self.USER_ID] 114 | pos_item = interaction[self.ITEM_ID] 115 | neg_item = interaction[self.NEG_ITEM_ID] 116 | 117 | user_all_embeddings, item_all_embeddings = self.forward() 118 | u_embeddings = user_all_embeddings[user] 119 | pos_embeddings = item_all_embeddings[pos_item] 120 | neg_embeddings = item_all_embeddings[neg_item] 121 | 122 | # calculate BPR Loss 123 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) 124 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) 125 | mf_loss = self.mf_loss(pos_scores, neg_scores) 126 | 127 | # calculate regularization Loss 128 | u_ego_embeddings = self.user_embedding(user) 129 | pos_ego_embeddings = self.item_embedding(pos_item) 130 | neg_ego_embeddings = self.item_embedding(neg_item) 131 | 132 | reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings) 133 | loss = mf_loss + self.reg_weight * reg_loss 134 | 135 | return loss 136 | 137 | def predict(self, interaction): 138 | user = interaction[self.USER_ID] 139 | item = interaction[self.ITEM_ID] 140 | 141 | user_all_embeddings, item_all_embeddings = self.forward() 142 | 143 | u_embeddings = user_all_embeddings[user] 144 | i_embeddings = item_all_embeddings[item] 145 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) 146 | return scores 147 | 148 | def full_sort_predict(self, interaction): 149 | user = interaction[self.USER_ID] 150 | if self.restore_user_e is None or self.restore_item_e is None: 151 | self.restore_user_e, self.restore_item_e = self.forward() 152 | # get user embedding from storage variable 153 | u_embeddings = self.restore_user_e[user] 154 | 155 | # dot with all item embedding to accelerate 156 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) 157 | 158 | return scores.view(-1) -------------------------------------------------------------------------------- /recbole_gnn/properties/model/DiffNet.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | n_layers: 2 3 | reg_weight: 1e-05 4 | pretrained_review: False -------------------------------------------------------------------------------- /recbole_gnn/properties/model/DirectAU.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | encoder: "MF" # "MF" or "lightGCN" 3 | gamma: 0.5 4 | weight_decay: 1e-6 5 | train_batch_size: 256 6 | 7 | # n_layers: 3 # needed for LightGCN -------------------------------------------------------------------------------- /recbole_gnn/properties/model/GCEGNN.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | leakyrelu_alpha: 0.2 3 | dropout_local: 0. 4 | dropout_global: 0.5 5 | dropout_gcn: 0. 6 | loss_type: CE 7 | gnn_transform: sess_graph 8 | 9 | # global 10 | build_global_graph: True 11 | sample_num: 12 12 | hop: 1 13 | -------------------------------------------------------------------------------- /recbole_gnn/properties/model/GCSAN.yaml: -------------------------------------------------------------------------------- 1 | n_layers: 1 2 | n_heads: 1 3 | hidden_size: 64 4 | inner_size: 256 5 | hidden_dropout_prob: 0.2 6 | attn_dropout_prob: 0.2 7 | hidden_act: 'gelu' 8 | layer_norm_eps: 1e-12 9 | initializer_range: 0.02 10 | step: 1 11 | weight: 0.6 12 | reg_weight: 5e-5 13 | loss_type: 'CE' 14 | gnn_transform: sess_graph 15 | -------------------------------------------------------------------------------- /recbole_gnn/properties/model/HMLET.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | n_layers: 4 3 | reg_weight: 1e-05 4 | require_pow: True 5 | gate_layer_ids: [2,3] 6 | gating_mlp_dims: [64,16,2] 7 | dropout_ratio: 0.2 8 | activation_function: elu 9 | 10 | warm_up_epochs: 50 11 | ori_temp: 0.7 12 | min_temp: 0.01 13 | gum_temp_decay: 0.005 14 | epoch_temp_decay: 1 15 | -------------------------------------------------------------------------------- /recbole_gnn/properties/model/LESSR.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | n_layers: 4 3 | batch_norm: True 4 | feat_drop: 0.2 5 | loss_type: CE 6 | gnn_transform: sess_graph 7 | -------------------------------------------------------------------------------- /recbole_gnn/properties/model/LightGCL.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 # (int) The embedding size of users and items. 2 | n_layers: 2 # (int) The number of layers in LightGCL. 3 | dropout: 0.0 # (float) The dropout ratio. 4 | temp: 0.8 # (float) The temperature in softmax. 5 | lambda1: 0.01 # (float) The hyperparameter to control the strengths of SSL. 6 | lambda2: 1e-05 # (float) The L2 regularization weight. 7 | q: 5 # (int) A slightly overestimated rank of the adjacency matrix. -------------------------------------------------------------------------------- /recbole_gnn/properties/model/LightGCN.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | n_layers: 2 3 | reg_weight: 1e-05 4 | require_pow: True -------------------------------------------------------------------------------- /recbole_gnn/properties/model/MHCN.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | n_layers: 2 3 | ssl_reg: 1e-05 4 | reg_weight: 1e-05 -------------------------------------------------------------------------------- /recbole_gnn/properties/model/NCL.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | n_layers: 3 3 | reg_weight: 1e-4 4 | 5 | ssl_temp: 0.1 6 | ssl_reg: 1e-7 7 | hyper_layers: 1 8 | 9 | alpha: 1 10 | 11 | proto_reg: 8e-8 12 | num_clusters: 1000 13 | 14 | m_step: 1 15 | warm_up_step: 20 16 | -------------------------------------------------------------------------------- /recbole_gnn/properties/model/NGCF.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | hidden_size_list: [64,64,64] 3 | node_dropout: 0.0 4 | message_dropout: 0.1 5 | reg_weight: 1e-5 -------------------------------------------------------------------------------- /recbole_gnn/properties/model/NISER.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | step: 1 3 | sigma: 16 4 | item_dropout: 0.1 5 | loss_type: 'CE' 6 | gnn_transform: sess_graph -------------------------------------------------------------------------------- /recbole_gnn/properties/model/SEPT.yaml: -------------------------------------------------------------------------------- 1 | warm_up_epochs: 100 2 | embedding_size: 64 3 | n_layers: 2 4 | drop_ratio: 0.3 5 | instance_cnt: 10 6 | reg_weight: 1e-05 7 | ssl_weight: 1e-07 8 | ssl_tau: 0.1 -------------------------------------------------------------------------------- /recbole_gnn/properties/model/SGL.yaml: -------------------------------------------------------------------------------- 1 | type: "ED" 2 | n_layers: 3 3 | ssl_tau: 0.5 4 | reg_weight: 1e-5 5 | ssl_weight: 0.05 6 | drop_ratio: 0.1 7 | embedding_size: 64 -------------------------------------------------------------------------------- /recbole_gnn/properties/model/SGNNHN.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | step: 6 3 | scale: 12 4 | loss_type: 'CE' 5 | gnn_transform: sess_graph 6 | -------------------------------------------------------------------------------- /recbole_gnn/properties/model/SRGNN.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | step: 1 3 | loss_type: 'CE' 4 | gnn_transform: sess_graph -------------------------------------------------------------------------------- /recbole_gnn/properties/model/SSL4REC.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | drop_ratio: 0.1 3 | tau: 0.1 4 | reg_weight: 1e-04 5 | ssl_weight: 1e-05 6 | require_pow: True -------------------------------------------------------------------------------- /recbole_gnn/properties/model/SimGCL.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | n_layers: 2 3 | reg_weight: 1e-4 4 | 5 | lambda: 0.5 6 | eps: 0.1 7 | temperature: 0.2 8 | -------------------------------------------------------------------------------- /recbole_gnn/properties/model/TAGNN.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | step: 1 3 | loss_type: 'CE' 4 | gnn_transform: sess_graph 5 | -------------------------------------------------------------------------------- /recbole_gnn/properties/model/XSimGCL.yaml: -------------------------------------------------------------------------------- 1 | embedding_size: 64 2 | n_layers: 2 3 | reg_weight: 0.0001 4 | 5 | lambda: 0.1 6 | eps: 0.2 7 | temperature: 0.2 8 | layer_cl: 1 9 | require_pow: True 10 | -------------------------------------------------------------------------------- /recbole_gnn/properties/quick_start_config/sequential_base.yaml: -------------------------------------------------------------------------------- 1 | train_neg_sample_args: ~ 2 | -------------------------------------------------------------------------------- /recbole_gnn/properties/quick_start_config/social_base.yaml: -------------------------------------------------------------------------------- 1 | NET_SOURCE_ID_FIELD: source_id 2 | NET_TARGET_ID_FIELD: target_id 3 | 4 | load_col: 5 | inter: ['user_id', 'item_id', 'rating', 'timestamp'] 6 | net: [source_id, target_id] 7 | 8 | filter_net_by_inter: True 9 | undirected_net: True -------------------------------------------------------------------------------- /recbole_gnn/quick_start.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging import getLogger 3 | from recbole.utils import init_logger, init_seed, set_color 4 | 5 | from recbole_gnn.config import Config 6 | from recbole_gnn.utils import create_dataset, data_preparation, get_model, get_trainer 7 | 8 | 9 | def run_recbole_gnn(model=None, dataset=None, config_file_list=None, config_dict=None, saved=True): 10 | r""" A fast running api, which includes the complete process of 11 | training and testing a model on a specified dataset 12 | Args: 13 | model (str, optional): Model name. Defaults to ``None``. 14 | dataset (str, optional): Dataset name. Defaults to ``None``. 15 | config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``. 16 | config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``. 17 | saved (bool, optional): Whether to save the model. Defaults to ``True``. 18 | """ 19 | # configurations initialization 20 | config = Config(model=model, dataset=dataset, config_file_list=config_file_list, config_dict=config_dict) 21 | try: 22 | assert config["enable_sparse"] in [True, False, None] 23 | except AssertionError: 24 | raise ValueError("Your config `enable_sparse` must be `True` or `False` or `None`") 25 | init_seed(config['seed'], config['reproducibility']) 26 | # logger initialization 27 | init_logger(config) 28 | logger = getLogger() 29 | 30 | logger.info(config) 31 | 32 | # dataset filtering 33 | dataset = create_dataset(config) 34 | logger.info(dataset) 35 | 36 | # dataset splitting 37 | train_data, valid_data, test_data = data_preparation(config, dataset) 38 | 39 | # model loading and initialization 40 | init_seed(config['seed'], config['reproducibility']) 41 | model = get_model(config['model'])(config, train_data.dataset).to(config['device']) 42 | logger.info(model) 43 | 44 | # trainer loading and initialization 45 | trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model) 46 | 47 | # model training 48 | best_valid_score, best_valid_result = trainer.fit( 49 | train_data, valid_data, saved=saved, show_progress=config['show_progress'] 50 | ) 51 | 52 | # model evaluation 53 | test_result = trainer.evaluate(test_data, load_best_model=saved, show_progress=config['show_progress']) 54 | 55 | logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}') 56 | logger.info(set_color('test result', 'yellow') + f': {test_result}') 57 | 58 | return { 59 | 'best_valid_score': best_valid_score, 60 | 'valid_score_bigger': config['valid_metric_bigger'], 61 | 'best_valid_result': best_valid_result, 62 | 'test_result': test_result 63 | } 64 | 65 | 66 | def objective_function(config_dict=None, config_file_list=None, saved=True): 67 | r""" The default objective_function used in HyperTuning 68 | 69 | Args: 70 | config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``. 71 | config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``. 72 | saved (bool, optional): Whether to save the model. Defaults to ``True``. 73 | """ 74 | 75 | config = Config(config_dict=config_dict, config_file_list=config_file_list) 76 | try: 77 | assert config["enable_sparse"] in [True, False, None] 78 | except AssertionError: 79 | raise ValueError("Your config `enable_sparse` must be `True` or `False` or `None`") 80 | init_seed(config['seed'], config['reproducibility']) 81 | logging.basicConfig(level=logging.ERROR) 82 | dataset = create_dataset(config) 83 | train_data, valid_data, test_data = data_preparation(config, dataset) 84 | init_seed(config['seed'], config['reproducibility']) 85 | model = get_model(config['model'])(config, train_data.dataset).to(config['device']) 86 | trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model) 87 | best_valid_score, best_valid_result = trainer.fit(train_data, valid_data, verbose=False, saved=saved) 88 | test_result = trainer.evaluate(test_data, load_best_model=saved) 89 | 90 | return { 91 | 'model': config['model'], 92 | 'best_valid_score': best_valid_score, 93 | 'valid_score_bigger': config['valid_metric_bigger'], 94 | 'best_valid_result': best_valid_result, 95 | 'test_result': test_result 96 | } 97 | -------------------------------------------------------------------------------- /recbole_gnn/trainer.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import math 3 | from torch.nn.utils.clip_grad import clip_grad_norm_ 4 | from tqdm import tqdm 5 | from recbole.trainer import Trainer 6 | from recbole.utils import early_stopping, dict2str, set_color, get_gpu_usage 7 | 8 | 9 | class NCLTrainer(Trainer): 10 | def __init__(self, config, model): 11 | super(NCLTrainer, self).__init__(config, model) 12 | 13 | self.num_m_step = config['m_step'] 14 | assert self.num_m_step is not None 15 | 16 | def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None): 17 | r"""Train the model based on the train data and the valid data. 18 | Args: 19 | train_data (DataLoader): the train data 20 | valid_data (DataLoader, optional): the valid data, default: None. 21 | If it's None, the early_stopping is invalid. 22 | verbose (bool, optional): whether to write training and evaluation information to logger, default: True 23 | saved (bool, optional): whether to save the model parameters, default: True 24 | show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``. 25 | callback_fn (callable): Optional callback function executed at end of epoch. 26 | Includes (epoch_idx, valid_score) input arguments. 27 | Returns: 28 | (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None) 29 | """ 30 | if saved and self.start_epoch >= self.epochs: 31 | self._save_checkpoint(-1) 32 | 33 | self.eval_collector.data_collect(train_data) 34 | 35 | for epoch_idx in range(self.start_epoch, self.epochs): 36 | 37 | # only differences from the original trainer 38 | if epoch_idx % self.num_m_step == 0: 39 | self.logger.info("Running E-step ! ") 40 | self.model.e_step() 41 | # train 42 | training_start_time = time() 43 | train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress) 44 | self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss 45 | training_end_time = time() 46 | train_loss_output = \ 47 | self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) 48 | if verbose: 49 | self.logger.info(train_loss_output) 50 | self._add_train_loss_to_tensorboard(epoch_idx, train_loss) 51 | 52 | # eval 53 | if self.eval_step <= 0 or not valid_data: 54 | if saved: 55 | self._save_checkpoint(epoch_idx) 56 | update_output = set_color('Saving current', 'blue') + ': %s' % self.saved_model_file 57 | if verbose: 58 | self.logger.info(update_output) 59 | continue 60 | if (epoch_idx + 1) % self.eval_step == 0: 61 | valid_start_time = time() 62 | valid_score, valid_result = self._valid_epoch(valid_data, show_progress=show_progress) 63 | self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping( 64 | valid_score, 65 | self.best_valid_score, 66 | self.cur_step, 67 | max_step=self.stopping_step, 68 | bigger=self.valid_metric_bigger 69 | ) 70 | valid_end_time = time() 71 | valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue') 72 | + ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \ 73 | (epoch_idx, valid_end_time - valid_start_time, valid_score) 74 | valid_result_output = set_color('valid result', 'blue') + ': \n' + dict2str(valid_result) 75 | if verbose: 76 | self.logger.info(valid_score_output) 77 | self.logger.info(valid_result_output) 78 | self.tensorboard.add_scalar('Vaild_score', valid_score, epoch_idx) 79 | 80 | if update_flag: 81 | if saved: 82 | self._save_checkpoint(epoch_idx) 83 | update_output = set_color('Saving current best', 'blue') + ': %s' % self.saved_model_file 84 | if verbose: 85 | self.logger.info(update_output) 86 | self.best_valid_result = valid_result 87 | 88 | if callback_fn: 89 | callback_fn(epoch_idx, valid_score) 90 | 91 | if stop_flag: 92 | stop_output = 'Finished training, best eval result in epoch %d' % \ 93 | (epoch_idx - self.cur_step * self.eval_step) 94 | if verbose: 95 | self.logger.info(stop_output) 96 | break 97 | self._add_hparam_to_tensorboard(self.best_valid_score) 98 | return self.best_valid_score, self.best_valid_result 99 | 100 | def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False): 101 | r"""Train the model in an epoch 102 | Args: 103 | train_data (DataLoader): The train data. 104 | epoch_idx (int): The current epoch id. 105 | loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be 106 | :attr:`self.model.calculate_loss`. Defaults to ``None``. 107 | show_progress (bool): Show the progress of training epoch. Defaults to ``False``. 108 | Returns: 109 | float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains 110 | multiple parts and the model return these multiple parts loss instead of the sum of loss, it will return a 111 | tuple which includes the sum of loss in each part. 112 | """ 113 | self.model.train() 114 | loss_func = loss_func or self.model.calculate_loss 115 | total_loss = None 116 | iter_data = ( 117 | tqdm( 118 | train_data, 119 | total=len(train_data), 120 | ncols=100, 121 | desc=set_color(f"Train {epoch_idx:>5}", 'pink'), 122 | ) if show_progress else train_data 123 | ) 124 | for batch_idx, interaction in enumerate(iter_data): 125 | interaction = interaction.to(self.device) 126 | self.optimizer.zero_grad() 127 | losses = loss_func(interaction) 128 | if isinstance(losses, tuple): 129 | if epoch_idx < self.config['warm_up_step']: 130 | losses = losses[:-1] 131 | loss = sum(losses) 132 | loss_tuple = tuple(per_loss.item() for per_loss in losses) 133 | total_loss = loss_tuple if total_loss is None else tuple(map(sum, zip(total_loss, loss_tuple))) 134 | else: 135 | loss = losses 136 | total_loss = losses.item() if total_loss is None else total_loss + losses.item() 137 | self._check_nan(loss) 138 | loss.backward() 139 | if self.clip_grad_norm: 140 | clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm) 141 | self.optimizer.step() 142 | if self.gpu_available and show_progress: 143 | iter_data.set_postfix_str(set_color('GPU RAM: ' + get_gpu_usage(self.device), 'yellow')) 144 | return total_loss 145 | 146 | 147 | class HMLETTrainer(Trainer): 148 | def __init__(self, config, model): 149 | super(HMLETTrainer, self).__init__(config, model) 150 | 151 | self.warm_up_epochs = config['warm_up_epochs'] 152 | self.ori_temp = config['ori_temp'] 153 | self.min_temp = config['min_temp'] 154 | self.gum_temp_decay = config['gum_temp_decay'] 155 | self.epoch_temp_decay = config['epoch_temp_decay'] 156 | 157 | def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False): 158 | if epoch_idx > self.warm_up_epochs: 159 | # Temp decay 160 | gum_temp = self.ori_temp * math.exp(-self.gum_temp_decay*(epoch_idx - self.warm_up_epochs)) 161 | self.model.gum_temp = max(gum_temp, self.min_temp) 162 | self.logger.info(f'Current gumbel softmax temperature: {self.model.gum_temp}') 163 | 164 | for gating in self.model.gating_nets: 165 | self.model._gating_freeze(gating, True) 166 | return super()._train_epoch(train_data, epoch_idx, loss_func, show_progress) 167 | 168 | 169 | class SEPTTrainer(Trainer): 170 | def __init__(self, config, model): 171 | super(SEPTTrainer, self).__init__(config, model) 172 | self.warm_up_epochs = config['warm_up_epochs'] 173 | 174 | def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False): 175 | if epoch_idx < self.warm_up_epochs: 176 | loss_func = self.model.calculate_rec_loss 177 | else: 178 | self.model.subgraph_construction() 179 | return super()._train_epoch(train_data, epoch_idx, loss_func, show_progress) -------------------------------------------------------------------------------- /recbole_gnn/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import importlib 4 | from logging import getLogger 5 | from recbole.data.utils import load_split_dataloaders, create_samplers, save_split_dataloaders 6 | from recbole.data.utils import create_dataset as create_recbole_dataset 7 | from recbole.data.utils import data_preparation as recbole_data_preparation 8 | from recbole.utils import set_color, Enum 9 | from recbole.utils import get_model as get_recbole_model 10 | from recbole.utils import get_trainer as get_recbole_trainer 11 | from recbole.utils.argument_list import dataset_arguments 12 | 13 | from recbole_gnn.data.dataloader import CustomizedTrainDataLoader, CustomizedNegSampleEvalDataLoader, CustomizedFullSortEvalDataLoader 14 | 15 | 16 | def create_dataset(config): 17 | """Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`. 18 | If :attr:`config['dataset_save_path']` file exists and 19 | its :attr:`config` of dataset is equal to current :attr:`config` of dataset. 20 | It will return the saved dataset in :attr:`config['dataset_save_path']`. 21 | Args: 22 | config (Config): An instance object of Config, used to record parameter information. 23 | Returns: 24 | Dataset: Constructed dataset. 25 | """ 26 | model_type = config['MODEL_TYPE'] 27 | dataset_module = importlib.import_module('recbole_gnn.data.dataset') 28 | gen_graph_module_path = '.'.join(['recbole_gnn.model.general_recommender', config['model'].lower()]) 29 | seq_module_path = '.'.join(['recbole_gnn.model.sequential_recommender', config['model'].lower()]) 30 | if hasattr(dataset_module, config['model'] + 'Dataset'): 31 | dataset_class = getattr(dataset_module, config['model'] + 'Dataset') 32 | elif importlib.util.find_spec(gen_graph_module_path, __name__): 33 | dataset_class = getattr(dataset_module, 'GeneralGraphDataset') 34 | elif importlib.util.find_spec(seq_module_path, __name__): 35 | dataset_class = getattr(dataset_module, 'SessionGraphDataset') 36 | elif model_type == ModelType.SOCIAL: 37 | dataset_class = getattr(dataset_module, 'SocialDataset') 38 | else: 39 | return create_recbole_dataset(config) 40 | 41 | default_file = os.path.join(config['checkpoint_dir'], f'{config["dataset"]}-{dataset_class.__name__}.pth') 42 | file = config['dataset_save_path'] or default_file 43 | if os.path.exists(file): 44 | with open(file, 'rb') as f: 45 | dataset = pickle.load(f) 46 | dataset_args_unchanged = True 47 | for arg in dataset_arguments + ['seed', 'repeatable']: 48 | if config[arg] != dataset.config[arg]: 49 | dataset_args_unchanged = False 50 | break 51 | if dataset_args_unchanged: 52 | logger = getLogger() 53 | logger.info(set_color('Load filtered dataset from', 'pink') + f': [{file}]') 54 | return dataset 55 | 56 | dataset = dataset_class(config) 57 | if config['save_dataset']: 58 | dataset.save() 59 | return dataset 60 | 61 | 62 | def get_model(model_name): 63 | r"""Automatically select model class based on model name 64 | Args: 65 | model_name (str): model name 66 | Returns: 67 | Recommender: model class 68 | """ 69 | model_submodule = [ 70 | 'general_recommender', 'sequential_recommender', 'social_recommender' 71 | ] 72 | 73 | model_file_name = model_name.lower() 74 | model_module = None 75 | for submodule in model_submodule: 76 | module_path = '.'.join(['recbole_gnn.model', submodule, model_file_name]) 77 | if importlib.util.find_spec(module_path, __name__): 78 | model_module = importlib.import_module(module_path, __name__) 79 | break 80 | 81 | if model_module is None: 82 | model_class = get_recbole_model(model_name) 83 | else: 84 | model_class = getattr(model_module, model_name) 85 | return model_class 86 | 87 | 88 | def _get_customized_dataloader(config, phase): 89 | if phase == 'train': 90 | return CustomizedTrainDataLoader 91 | else: 92 | eval_mode = config["eval_args"]["mode"] 93 | if eval_mode == 'full': 94 | return CustomizedFullSortEvalDataLoader 95 | else: 96 | return CustomizedNegSampleEvalDataLoader 97 | 98 | 99 | def data_preparation(config, dataset): 100 | """Split the dataset by :attr:`config['eval_args']` and create training, validation and test dataloader. 101 | Note: 102 | If we can load split dataloaders by :meth:`load_split_dataloaders`, we will not create new split dataloaders. 103 | Args: 104 | config (Config): An instance object of Config, used to record parameter information. 105 | dataset (Dataset): An instance object of Dataset, which contains all interaction records. 106 | Returns: 107 | tuple: 108 | - train_data (AbstractDataLoader): The dataloader for training. 109 | - valid_data (AbstractDataLoader): The dataloader for validation. 110 | - test_data (AbstractDataLoader): The dataloader for testing. 111 | """ 112 | seq_module_path = '.'.join(['recbole_gnn.model.sequential_recommender', config['model'].lower()]) 113 | if importlib.util.find_spec(seq_module_path, __name__): 114 | # Special condition for sequential models of RecBole-Graph 115 | dataloaders = load_split_dataloaders(config) 116 | if dataloaders is not None: 117 | train_data, valid_data, test_data = dataloaders 118 | else: 119 | built_datasets = dataset.build() 120 | train_dataset, valid_dataset, test_dataset = built_datasets 121 | train_sampler, valid_sampler, test_sampler = create_samplers(config, dataset, built_datasets) 122 | 123 | train_data = _get_customized_dataloader(config, 'train')(config, train_dataset, train_sampler, shuffle=True) 124 | valid_data = _get_customized_dataloader(config, 'evaluation')(config, valid_dataset, valid_sampler, shuffle=False) 125 | test_data = _get_customized_dataloader(config, 'evaluation')(config, test_dataset, test_sampler, shuffle=False) 126 | if config['save_dataloaders']: 127 | save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data)) 128 | 129 | logger = getLogger() 130 | logger.info( 131 | set_color('[Training]: ', 'pink') + set_color('train_batch_size', 'cyan') + ' = ' + 132 | set_color(f'[{config["train_batch_size"]}]', 'yellow') + set_color(' negative sampling', 'cyan') + ': ' + 133 | set_color(f'[{config["train_neg_sample_args"]}]', 'yellow') 134 | ) 135 | logger.info( 136 | set_color('[Evaluation]: ', 'pink') + set_color('eval_batch_size', 'cyan') + ' = ' + 137 | set_color(f'[{config["eval_batch_size"]}]', 'yellow') + set_color(' eval_args', 'cyan') + ': ' + 138 | set_color(f'[{config["eval_args"]}]', 'yellow') 139 | ) 140 | return train_data, valid_data, test_data 141 | else: 142 | return recbole_data_preparation(config, dataset) 143 | 144 | 145 | def get_trainer(model_type, model_name): 146 | r"""Automatically select trainer class based on model type and model name 147 | Args: 148 | model_type (ModelType): model type 149 | model_name (str): model name 150 | Returns: 151 | Trainer: trainer class 152 | """ 153 | try: 154 | return getattr(importlib.import_module('recbole_gnn.trainer'), model_name + 'Trainer') 155 | except AttributeError: 156 | return get_recbole_trainer(model_type, model_name) 157 | 158 | 159 | class ModelType(Enum): 160 | """Type of models. 161 | 162 | - ``Social``: Social-based Recommendation 163 | """ 164 | 165 | SOCIAL = 7 -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | ## General Model Results 2 | 3 | * [ml-1m](general/ml-1m.md) 4 | 5 | ## Sequential Model Results 6 | 7 | * [diginetica](sequential/diginetica.md) 8 | 9 | ## Social-aware Model Results 10 | 11 | * [lastfm](social/lastfm.md) 12 | -------------------------------------------------------------------------------- /results/general/ml-1m.md: -------------------------------------------------------------------------------- 1 | # Experimental Setting 2 | 3 | **Dataset:** [MovieLens-1M](https://grouplens.org/datasets/movielens/) 4 | 5 | **Filtering:** Remove interactions with a rating score of less than 3 6 | 7 | **Evaluation:** ratio-based 8:1:1, full sort 8 | 9 | **Metrics:** Recall@10, NGCG@10, MRR@10, Hit@10, Precision@10 10 | 11 | **Properties:** 12 | 13 | ```yaml 14 | # dataset config 15 | field_separator: "\t" 16 | seq_separator: " " 17 | USER_ID_FIELD: user_id 18 | ITEM_ID_FIELD: item_id 19 | RATING_FIELD: rating 20 | NEG_PREFIX: neg_ 21 | LABEL_FIELD: label 22 | load_col: 23 | inter: [user_id, item_id, rating] 24 | val_interval: 25 | rating: "[3,inf)" 26 | unused_col: 27 | inter: [rating] 28 | 29 | # training and evaluation 30 | epochs: 500 31 | train_batch_size: 4096 32 | valid_metric: MRR@10 33 | eval_batch_size: 4096000 34 | ``` 35 | 36 | For fairness, we restrict users' and items' embedding dimension as following. Please adjust the name of the corresponding args of different models. 37 | ``` 38 | embedding_size: 64 39 | ``` 40 | 41 | # Dataset Statistics 42 | 43 | | Dataset | #Users | #Items | #Interactions | Sparsity | 44 | | ---------- | ------ | ------ | ------------- | -------- | 45 | | ml-1m | 6,040 | 3,629 | 836,478 | 96.18% | 46 | 47 | # Evaluation Results 48 | 49 | | Method | Recall@10 | MRR@10 | NDCG@10 | Hit@10 | Precision@10 | 50 | |--------------|-----------|--------|---------|--------|--------------| 51 | | **BPR** | 0.1776 | 0.4187 | 0.2401 | 0.7199 | 0.1779 | 52 | | **NeuMF** | 0.1651 | 0.4020 | 0.2271 | 0.7029 | 0.1700 | 53 | | **NGCF** | 0.1814 | 0.4354 | 0.2508 | 0.7239 | 0.1850 | 54 | | **LightGCN** | 0.1861 | 0.4388 | 0.2538 | 0.7330 | 0.1863 | 55 | | **LightGCL** | 0.1867 | 0.4283 | 0.2479 | 0.7370 | 0.1815 | 56 | | **SGL** | 0.1889 | 0.4315 | 0.2505 | 0.7392 | 0.1843 | 57 | | **HMLET** | 0.1847 | 0.4297 | 0.2490 | 0.7305 | 0.1836 | 58 | | **NCL** | 0.2021 | 0.4599 | 0.2702 | 0.7565 | 0.1962 | 59 | | **SimGCL** | 0.2029 | 0.4550 | 0.2667 | 0.7640 | 0.1933 | 60 | | **XSimGCL** | 0.2116 | 0.4638 | 0.2750 | 0.7743 | 0.1987 | 61 | 62 | # Hyper-parameters 63 | 64 | | | Best hyper-parameters | Tuning range | 65 | |--------------|------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 66 | | **BPR** | learning_rate=0.001 | learning_rate choice [0.05, 0.02, 0.01, 0.005, 0.002, 0.001, 0.0005, 0.0002, 0.0001, 0.00005, 0.00002, 0.00001] | 67 | | **NeuMF** | learning_rate=0.0001
mlp_hidden_size=[32,16,8]
dropout_prob=0 | learning_rate choice [0.005, 0.002, 0.001, 0.0005, 0.0002, 0.0001, 0.00005]
mlp_hidden_size choice ['[64,64]', '[64,32]', '[64,32,16]','[32,16,8]']
dropout_prob choice [0, 0.1, 0.2] | 68 | | **NGCF** | learning_rate=0.0002
message_dropout=0.0
node_dropout=0.0 | learning_rate choice [0.001, 0.0005, 0.0002]
node_dropout choice [0.0, 0.1]
message_dropout choice [0.0, 0.1] | 69 | | **LightGCN** | learning_rate=0.002
n_layers=3
reg_weight=0.0001 | learning_rate choice [0.005, 0.002, 0.001]
n_layers choice [2, 3]
reg_weight choice [1e-4, 1e-5] | 70 | | **LightGCL** | learning_rate=0.001
n_layers=2
lambda1=0.0001
temp=2
lambda2=1e-7
dropout=0.1 | learning_rate choice [0.001]
n_layers choice [2, 3]
lambda1 choice [0.01, 0.005, 0.001, 0.0001, 1e-5, 1e-7]
temp choice [0.5, 0.8, 2, 3]
lambda2 choice [1e-4, 1e-5, 1e-7]
dropout choice [0.0, 0.1, 0.25] | 71 | | **SGL** | learning_rate=0.002
n_layers=3
reg_weight=0.0001
ssl_tau=0.5
drop_ratio=0.1
ssl_weight=0.005 | learning_rate choice [0.002]
n_layers choice [3]
reg_weight choice [1e-4]
ssl_tau choice [0.1, 0.5]
drop_ratio choice [0.1, 0.3]
ssl_weight choice [1e-5, 1e-6, 1e-7, 0.005, 0.01, 0.05] | 72 | | **HMLET** | learning_rate=0.002
n_layers=4
activation_function=leakyrelu | learning_rate choice [0.002, 0.001, 0.0005]
n_layers choice [3, 4]
activation_function choice ['elu', 'leakyrelu'] | 73 | | **NCL** | learning_rate=0.002
n_layers=3
reg_weight=0.0001
ssl_temp=0.1
ssl_reg=1e-06
hyper_layers=1
alpha=1.5 | learning_rate choice [0.002]
n_layers choice [3]
reg_weight choice [1e-4]
ssl_temp choice [0.1, 0.05]
ssl_reg choice [1e-7, 1e-6]
hyper_layers choice [1]
alpha choice [1, 0.8, 1.5] | 74 | | **SimGCL** | learning_rate=0.002
n_layers=2
reg_weight=0.0001
temperature=0.05
lambda=1e-5
eps=0.1 | learning_rate choice [0.002]
n_layers choice [2, 3]
reg_weight choice [1e-4]
temperature choice [0.05, 0.1, 0.2]
lambda choice [1e-5, 1e-6, 1e-7, 0.005, 0.01, 0.05]
eps choice [0.1, 0.2] | 75 | | **XSimGCL** | learning_rate=0.002
n_layers=2
reg_weight=0.0001
temperature=0.2
lambda=0.1
eps=0.2
layer_cl=1 | learning_rate choice [0.002]
n_layers choice [2, 3]
reg_weight choice [1e-4]
temperature choice [0.05, 0.1, 0.2]
lambda choice [1e-5, 1e-6, 1e-7, 1e-4, 0.005, 0.01, 0.05, 0.1]
eps choice [0.1, 0.2]
layer_cl choice [1] | 76 | -------------------------------------------------------------------------------- /results/sequential/diginetica.md: -------------------------------------------------------------------------------- 1 | # Experimental Setting 2 | 3 | **Dataset:** diginetica-not-merged 4 | 5 | **Filtering:** Remove users and items with less than 5 interactions 6 | 7 | **Evaluation:** leave one out, full sort 8 | 9 | **Metrics:** Recall@10, NGCG@10, MRR@10, Hit@10, Precision@10 10 | 11 | **Properties:** 12 | 13 | ```yaml 14 | # dataset config 15 | field_separator: "\t" 16 | seq_separator: " " 17 | USER_ID_FIELD: session_id 18 | ITEM_ID_FIELD: item_id 19 | TIME_FIELD: timestamp 20 | NEG_PREFIX: neg_ 21 | ITEM_LIST_LENGTH_FIELD: item_length 22 | LIST_SUFFIX: _list 23 | MAX_ITEM_LIST_LENGTH: 20 24 | POSITION_FIELD: position_id 25 | load_col: 26 | inter: [session_id, item_id, timestamp] 27 | user_inter_num_interval: "[5,inf)" 28 | item_inter_num_interval: "[5,inf)" 29 | 30 | # training and evaluation 31 | epochs: 500 32 | train_batch_size: 4096 33 | eval_batch_size: 2000 34 | valid_metric: MRR@10 35 | eval_args: 36 | split: {'LS':"valid_and_test"} 37 | mode: full 38 | order: TO 39 | train_neg_sample_args: ~ 40 | ``` 41 | 42 | For fairness, we restrict users' and items' embedding dimension as following. Please adjust the name of the corresponding args of different models. 43 | ``` 44 | embedding_size: 64 45 | ``` 46 | 47 | # Dataset Statistics 48 | 49 | | Dataset | #Users | #Items | #Interactions | Sparsity | 50 | | ---------- | ------ | ------ | ------------- | -------- | 51 | | diginetica | 72,014 | 29,454 | 580,490 | 99.97% | 52 | 53 | # Evaluation Results 54 | 55 | | Method | Recall@10 | MRR@10 | NDCG@10 | Hit@10 | Precision@10 | 56 | | -------------------- | --------- | ------ | ------- | ------ | ------------ | 57 | | **GRU4Rec** | 0.3691 | 0.1632 | 0.2114 | 0.3691 | 0.0369 | 58 | | **NARM** | 0.3801 | 0.1695 | 0.2188 | 0.3801 | 0.0380 | 59 | | **SASRec** | 0.4144 | 0.1857 | 0.2393 | 0.4144 | 0.0414 | 60 | | **SR-GNN** | 0.3881 | 0.1754 | 0.2253 | 0.3881 | 0.0388 | 61 | | **GC-SAN** | 0.4127 | 0.1881 | 0.2408 | 0.4127 | 0.0413 | 62 | | **NISER+** | 0.4144 | 0.1904 | 0.2430 | 0.4144 | 0.0414 | 63 | | **LESSR** | 0.3964 | 0.1763 | 0.2279 | 0.3964 | 0.0396 | 64 | | **TAGNN** | 0.3894 | 0.1763 | 0.2263 | 0.3894 | 0.0389 | 65 | | **GCE-GNN** | 0.4284 | 0.1961 | 0.2507 | 0.4284 | 0.0428 | 66 | | **SGNN-HN** | 0.4183 | 0.1877 | 0.2418 | 0.4183 | 0.0418 | 67 | 68 | # Hyper-parameters 69 | 70 | | | Best hyper-parameters | Tuning range | 71 | | -------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | 72 | | **GRU4Rec** | learning_rate=0.01
hidden_size=128
dropout_prob=0.3
num_layers=1 | learning_rate in [1e-2, 1e-3, 3e-3]
num_layers in [1, 2, 3]
hidden_size in [128]
dropout_prob in [0.1, 0.2, 0.3] | 73 | | **SASRec** | learning_rate=0.001
n_layers=2
attn_dropout_prob=0.2
hidden_dropout_prob=0.2 | learning_rate in [0.001, 0.0001]
n_layers in [1, 2]
hidden_dropout_prob in [0.2, 0.5]
attn_dropout_prob in [0.2, 0.5] | 74 | | **NARM** | learning_rate=0.001
hidden_size=128
n_layers=1
dropout_probs=[0.25, 0.5] | learning_rate in [0.001, 0.01, 0.03]
hidden_size in [128]
n_layers in [1, 2]
dropout_probs in ['[0.25,0.5]', '[0.2,0.2]', '[0.1,0.2]'] | 75 | | **SR-GNN** | learning_rate=0.001
step=1 | learning_rate in [0.01, 0.001, 0.0001]
step in [1, 2] | 76 | | **GC-SAN** | learning_rate=0.001
step=1 | learning_rate in [0.01, 0.001, 0.0001]
step in [1, 2] | 77 | | **NISER+** | learning_rate=0.001
sigma=16 | learning_rate in [0.01, 0.001, 0.003]
sigma in [10, 16, 20] | 78 | | **LESSR** | learning_rate=0.001
n_layers=4 | learning_rate in [0.01, 0.001, 0.003]
n_layers in [2, 4] | 79 | | **TAGNN** | learning_rate=0.001 | learning_rate in [0.01, 0.001, 0.003]
train_batch_size=512 | 80 | | **GCE-GNN** | learning_rate=0.001
dropout_global=0.5 | learning_rate in [0.01, 0.001, 0.003]
dropout_global in [0.2, 0.5] | 81 | | **SGNN-HN** | learning_rate=0.003
scale=12
step=2 | learning_rate in [0.01, 0.001, 0.003]
scale in [12, 16, 20]
step in [2, 4, 6] | 82 | -------------------------------------------------------------------------------- /results/social/lastfm.md: -------------------------------------------------------------------------------- 1 | # Experimental Setting 2 | 3 | **Dataset:** [LastFM](http://files.grouplens.org/datasets/hetrec2011/) 4 | 5 | > Note that datasets for social recommendation methods can be downloaded from [Social-Datasets](https://github.com/Sherry-XLL/Social-Datasets). 6 | 7 | **Filtering:** None 8 | 9 | **Evaluation:** ratio-based 8:1:1, full sort 10 | 11 | **Metrics:** Recall@10, NGCG@10, MRR@10, Hit@10, Precision@10 12 | 13 | **Properties:** 14 | 15 | ```yaml 16 | # dataset config 17 | field_separator: "\t" 18 | seq_separator: " " 19 | USER_ID_FIELD: user_id 20 | ITEM_ID_FIELD: artist_id 21 | NET_SOURCE_ID_FIELD: source_id 22 | NET_TARGET_ID_FIELD: target_id 23 | LABEL_FIELD: label 24 | NEG_PREFIX: neg_ 25 | load_col: 26 | inter: [user_id, artist_id] 27 | net: [source_id, target_id] 28 | 29 | # social network config 30 | filter_net_by_inter: True 31 | undirected_net: True 32 | 33 | # training and evaluation 34 | epochs: 5000 35 | train_batch_size: 4096 36 | eval_batch_size: 409600000 37 | valid_metric: NDCG@10 38 | stopping_step: 50 39 | ``` 40 | 41 | For fairness, we restrict users' and items' embedding dimension as following. Please adjust the name of the corresponding args of different models. 42 | ``` 43 | embedding_size: 64 44 | ``` 45 | 46 | # Dataset Statistics 47 | 48 | | Dataset | #Users | #Items | #Interactions | Sparsity | 49 | | ---------- | ------ | ------ | ------------- | -------- | 50 | | lastfm | 1,892 | 17,632 | 92,834 | 99.72% | 51 | 52 | # Evaluation Results 53 | 54 | | Method | Recall@10 | MRR@10 | NDCG@10 | Hit@10 | Precision@10 | 55 | | -------------------- | --------- | ------ | ------- | ------ | ------------ | 56 | | **BPR** | 0.1761 | 0.3026 | 0.1674 | 0.5573 | 0.0858 | 57 | | **NeuMF** | 0.1696 | 0.2924 | 0.1604 | 0.5456 | 0.0828 | 58 | | **NGCF** | 0.1960 | 0.3479 | 0.1898 | 0.6141 | 0.0961 | 59 | | **LightGCN** | 0.2064 | 0.3559 | 0.1972 | 0.6322 | 0.1009 | 60 | | **DiffNet** | 0.1757 | 0.3117 | 0.1694 | 0.5621 | 0.0857 | 61 | | **MHCN** | 0.2123 | 0.3782 | 0.2068 | 0.6523 | 0.1042 | 62 | | **SEPT** | 0.2127 | 0.3703 | 0.2057 | 0.6465 | 0.1044 | 63 | 64 | # Hyper-parameters 65 | 66 | | | Best hyper-parameters | Tuning range | 67 | | -------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | 68 | | **BPR** | learning_rate=0.0005 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001] | 69 | | **NeuMF** | learning_rate=0.0005
dropout_prob=0.1 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
dropout_prob in [0.1, 0.2, 0.3] | 70 | | **NGCF** | learning_rate=0.0005
hidden_size_list=[64,64,64] | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
hidden_size_list in ['[64]', '[64,64]', '[64,64,64]'] | 71 | | **LightGCN** | learning_rate=0.001
n_layers=3 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
n_layers in [1, 2, 3] | 72 | | **DiffNet** | learning_rate=0.0005
n_layers=1 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
n_layers in [1, 2, 3] | 73 | | **MHCN** | learning_rate=0.0005
n_layers=2
ssl_reg=1e-05 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
n_layers in [1, 2, 3]
ssl_reg in [1e-04, 1e-05, 1e-06] | 74 | | **SEPT** | learning_rate=0.0005
n_layers=2
ssl_weight=1e-07 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
n_layers in [1, 2, 3]
ssl_weight in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7] | 75 | -------------------------------------------------------------------------------- /run_hyper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from recbole.trainer import HyperTuning 4 | from recbole_gnn.quick_start import objective_function 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--config_files', type=str, default=None, help='fixed config files') 10 | parser.add_argument('--params_file', type=str, default=None, help='parameters file') 11 | parser.add_argument('--output_file', type=str, default='hyper_example.result', help='output file') 12 | args, _ = parser.parse_known_args() 13 | 14 | # plz set algo='exhaustive' to use exhaustive search, in this case, max_evals is auto set 15 | config_file_list = args.config_files.strip().split(' ') if args.config_files else None 16 | hp = HyperTuning(objective_function, algo='exhaustive', 17 | params_file=args.params_file, fixed_config_file_list=config_file_list) 18 | hp.run() 19 | hp.export_result(output_file=args.output_file) 20 | print('best params: ', hp.best_params) 21 | print('best result: ') 22 | print(hp.params2result[hp.params2str(hp.best_params)]) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /run_recbole_gnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from recbole_gnn.quick_start import run_recbole_gnn 4 | 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--model', '-m', type=str, default='BPR', help='name of models') 9 | parser.add_argument('--dataset', '-d', type=str, default='ml-100k', help='name of datasets') 10 | parser.add_argument('--config_files', type=str, default=None, help='config files') 11 | 12 | args, _ = parser.parse_known_args() 13 | 14 | config_file_list = args.config_files.strip().split(' ') if args.config_files else None 15 | run_recbole_gnn(model=args.model, dataset=args.dataset, config_file_list=config_file_list) 16 | -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | python -m pytest -v tests/test_model.py 5 | echo "model tests finished" 6 | -------------------------------------------------------------------------------- /tests/test_data/test/test.net: -------------------------------------------------------------------------------- 1 | source_id:token target_id:token 2 | 187 100 3 | 119 40 4 | 96 119 5 | 12 52 6 | 153 131 7 | 259 232 8 | 191 307 9 | 83 150 10 | 86 255 11 | 177 4 12 | 210 192 13 | 25 323 14 | 90 298 15 | 38 47 16 | 201 283 17 | 93 63 18 | 115 190 19 | 143 293 20 | 147 265 21 | 320 68 22 | 188 273 23 | 332 321 24 | 212 203 25 | 326 98 26 | 74 270 27 | 4 333 28 | 87 261 29 | 163 207 30 | 18 175 31 | 127 77 32 | 296 179 33 | 17 101 34 | 24 30 35 | 102 288 36 | 345 269 37 | 270 188 38 | 235 297 39 | 68 303 40 | 313 43 41 | 239 109 42 | 28 76 43 | 108 227 44 | 78 218 45 | 96 30 46 | 180 301 47 | 211 12 48 | 234 34 49 | 178 53 50 | 3 243 51 | 179 73 52 | 98 92 53 | 310 116 54 | 154 271 55 | 293 3 56 | 80 297 57 | 329 254 58 | 198 134 59 | 341 238 60 | 75 185 61 | 166 64 62 | 205 142 63 | 317 163 64 | 261 91 65 | 314 322 66 | 4 33 67 | 71 73 68 | 289 182 69 | 21 12 70 | 248 49 71 | 255 32 72 | 261 170 73 | 257 314 74 | 159 118 75 | 212 221 76 | 5 177 77 | 204 57 78 | 132 120 79 | 13 275 80 | 340 252 81 | 245 251 82 | 334 15 83 | 130 103 84 | 280 187 85 | 232 153 86 | 242 341 87 | 219 123 88 | 6 290 89 | 49 289 90 | 46 347 91 | 185 231 92 | 57 254 93 | 134 248 94 | 24 234 95 | 57 207 96 | 147 295 97 | 191 274 98 | 340 54 99 | 280 150 100 | 190 4 101 | 238 198 102 | 72 123 103 | 122 178 104 | 7 334 105 | 11 90 106 | 232 78 107 | 16 77 108 | 41 190 109 | 108 101 110 | 212 66 111 | 258 18 112 | 321 250 113 | 126 280 114 | 271 85 115 | 11 176 116 | 22 69 117 | 129 159 118 | 235 193 119 | 129 88 120 | 221 315 121 | 308 329 122 | 103 83 123 | 180 43 124 | 208 87 125 | 64 75 126 | 92 36 127 | 298 151 128 | 56 103 129 | 162 268 130 | 81 252 131 | 344 115 132 | 67 282 133 | 132 17 134 | 83 307 135 | 299 82 136 | 321 227 137 | 48 13 138 | 212 57 139 | 344 280 140 | 195 81 141 | 112 122 142 | 345 346 143 | 65 18 144 | 269 3 145 | 131 123 146 | 185 311 147 | 124 330 148 | 347 297 149 | 321 251 150 | 196 135 151 | 65 122 152 | 322 197 153 | 334 160 154 | 129 64 155 | 38 17 156 | 289 256 157 | 51 286 158 | 107 260 159 | 300 101 160 | 290 281 161 | 192 170 162 | 42 2 163 | 54 260 164 | 126 1 165 | 326 294 166 | 119 14 167 | 48 172 168 | 133 191 169 | 332 157 170 | 311 99 171 | 115 123 172 | 160 201 173 | 269 267 174 | 302 184 175 | 262 168 176 | 11 80 177 | 317 155 178 | 163 310 179 | 290 32 180 | 90 239 181 | 246 129 182 | 105 189 183 | 336 8 184 | 266 100 185 | 153 311 186 | 7 20 187 | 329 94 188 | 135 38 189 | 216 331 190 | 291 89 191 | 121 253 192 | 246 82 193 | 113 325 194 | 99 313 195 | 226 188 196 | 319 60 197 | 195 280 198 | 245 319 199 | 168 291 200 | 63 127 201 | 316 280 202 | 67 69 203 | 40 143 204 | 177 18 205 | 239 253 206 | 213 304 207 | 218 315 208 | 18 312 209 | 165 6 210 | 324 232 211 | 167 156 212 | 295 275 213 | 42 110 214 | 25 226 215 | 114 104 216 | 172 305 217 | 66 26 218 | 51 303 219 | 247 110 220 | 245 18 221 | 335 307 222 | 325 95 223 | 289 81 224 | 166 141 225 | 4 39 226 | 171 16 227 | 79 145 228 | 187 65 229 | 102 105 230 | 234 70 231 | 321 104 232 | 62 179 233 | 171 122 234 | 225 239 235 | 283 315 236 | 121 107 237 | 154 297 238 | 309 170 239 | 3 38 240 | 78 345 241 | 164 238 242 | 92 142 243 | 339 4 244 | 251 61 245 | 223 240 246 | 167 39 247 | 223 8 248 | 61 253 249 | 220 256 250 | 139 247 251 | 199 267 252 | 344 264 253 | 336 56 254 | 110 235 255 | 75 90 256 | 93 321 257 | 345 277 258 | 119 260 259 | 214 10 260 | 15 86 261 | 102 5 262 | 34 213 263 | 223 238 264 | 243 169 265 | 107 223 266 | 106 175 267 | 218 104 268 | 28 82 269 | 267 37 270 | 331 124 271 | 16 146 272 | 186 289 273 | 226 304 274 | 109 34 275 | 124 73 276 | 165 286 277 | 260 70 278 | 94 159 279 | 151 257 280 | 151 210 281 | 263 288 282 | 276 218 283 | 222 79 284 | 48 133 285 | 67 218 286 | 282 250 287 | 127 195 288 | 222 316 289 | 19 272 290 | 238 43 291 | 71 240 292 | 208 65 293 | 219 300 294 | 338 29 295 | 75 86 296 | 86 269 297 | 91 100 298 | 273 248 299 | 202 9 300 | 190 33 301 | 84 92 302 | 124 306 303 | 284 70 304 | 281 341 305 | 247 302 306 | 306 230 307 | 320 279 308 | 319 41 309 | 91 160 310 | 323 201 311 | 305 194 312 | 41 156 313 | 220 264 314 | 296 310 315 | 183 131 316 | 232 21 317 | 239 218 318 | 302 49 319 | 250 287 320 | 200 109 321 | 96 263 322 | 225 221 323 | 123 263 324 | 329 256 325 | 136 344 326 | 338 76 327 | 233 245 328 | 347 198 329 | 99 83 330 | 240 81 331 | 238 291 332 | 78 331 333 | 56 225 334 | 21 93 335 | 24 293 336 | 28 155 337 | 245 19 338 | 225 198 339 | 90 235 340 | 191 35 341 | 146 28 342 | 303 194 343 | 203 276 344 | 189 49 345 | 265 232 346 | 204 198 347 | 283 217 348 | 306 44 349 | 133 175 350 | 256 80 351 | 345 215 352 | 97 13 353 | 25 287 354 | 104 48 355 | 20 50 356 | 155 340 357 | 202 57 358 | 343 263 359 | 135 293 360 | 152 266 361 | 232 182 362 | 86 217 363 | 73 72 364 | 143 44 365 | 299 162 366 | 277 324 367 | 154 124 368 | 307 210 369 | 210 226 370 | 323 293 371 | 55 97 372 | 52 8 373 | 32 163 374 | 312 307 375 | 271 171 376 | 204 34 377 | 64 282 378 | 311 315 379 | 174 58 380 | 56 84 381 | 217 275 382 | 86 180 383 | 342 84 384 | 340 174 385 | 13 80 386 | 100 197 387 | 189 341 388 | 5 86 389 | 9 40 390 | 210 329 391 | 260 188 392 | 236 261 393 | 94 282 394 | 105 188 395 | 141 258 396 | 132 285 397 | 17 156 398 | 70 213 399 | 204 5 400 | 344 74 401 | 34 202 402 | 347 263 403 | 121 312 404 | 146 219 405 | 31 48 406 | 53 291 407 | 213 203 408 | 125 9 409 | 279 301 410 | 247 140 411 | 217 2 412 | 298 83 413 | 315 311 414 | 165 209 415 | 169 270 416 | 259 40 417 | 174 285 418 | 21 276 419 | 58 229 420 | 165 84 421 | 48 29 422 | 222 257 423 | 38 209 424 | 336 30 425 | 53 63 426 | 269 243 427 | 36 324 428 | 252 138 429 | 113 155 430 | 123 290 431 | 10 253 432 | 346 15 433 | 217 36 434 | 15 102 435 | 264 149 436 | 143 122 437 | 300 178 438 | 25 220 439 | 58 231 440 | 19 250 441 | 11 147 442 | 73 186 443 | 90 109 444 | 248 104 445 | 196 55 446 | 308 298 447 | 316 7 448 | 160 208 449 | 173 323 450 | 196 176 451 | 147 168 452 | 168 293 453 | 274 328 454 | 6 133 455 | 177 226 456 | 49 336 457 | 173 7 458 | 307 1 459 | 85 128 460 | 63 241 461 | 39 323 462 | 167 173 463 | 298 253 464 | 171 42 465 | 196 326 466 | 53 329 467 | 221 307 468 | 51 194 469 | 192 231 470 | 13 23 471 | 308 117 472 | 324 84 473 | 228 13 474 | 231 156 475 | 314 286 476 | 321 314 477 | 140 30 478 | 143 288 479 | 55 340 480 | 192 264 481 | 119 220 482 | 28 226 483 | 248 309 484 | 227 122 485 | 157 227 486 | 81 178 487 | 143 329 488 | 327 170 489 | 199 308 490 | 297 27 491 | 28 101 492 | 317 179 493 | 176 293 494 | 328 265 495 | 64 256 496 | 176 316 497 | 336 315 498 | 137 189 499 | 290 209 500 | 243 232 501 | 305 233 502 | 28 26 503 | 216 306 504 | 155 65 505 | 246 166 506 | 148 218 507 | 28 343 508 | 31 148 509 | 6 38 510 | 43 267 511 | 85 30 512 | 5 212 513 | 328 157 514 | 93 65 515 | 158 179 516 | 315 256 517 | 261 210 518 | 8 234 519 | 137 163 520 | 261 9 521 | 247 231 522 | 32 266 523 | 118 191 524 | 107 34 525 | 87 153 526 | 132 81 527 | 41 235 528 | 80 103 529 | 13 167 530 | 31 166 531 | 290 32 532 | 53 125 533 | 131 163 534 | 188 82 535 | 68 38 536 | 94 325 537 | 254 129 538 | 99 63 539 | 267 164 540 | 1 46 541 | 175 36 542 | 99 72 543 | 328 80 544 | 84 221 545 | 164 80 546 | 232 264 547 | 172 70 548 | 227 346 549 | 183 44 550 | 208 184 551 | 120 317 552 | 20 154 553 | 76 315 554 | 52 200 555 | 231 46 556 | 343 241 557 | 42 284 558 | 229 345 559 | 213 75 560 | 155 135 561 | 28 261 562 | 22 255 563 | 106 169 564 | 310 347 565 | 212 275 566 | 104 314 567 | 347 181 568 | 285 72 569 | 26 68 570 | 6 331 571 | 19 227 572 | 325 108 573 | 325 110 574 | 152 226 575 | 221 160 576 | 310 226 577 | 145 57 578 | 228 299 579 | 233 139 580 | 291 1 581 | 52 173 582 | 173 33 583 | 48 339 584 | 188 27 585 | 329 117 586 | 216 73 587 | 291 325 588 | 180 22 589 | 343 95 590 | 293 172 591 | 31 146 592 | 99 213 593 | 290 10 594 | 79 212 595 | 184 96 596 | 257 27 597 | 11 323 598 | 117 95 599 | 215 118 600 | 258 23 601 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | from recbole_gnn.quick_start import objective_function 5 | 6 | current_path = os.path.dirname(os.path.realpath(__file__)) 7 | config_file_list = [os.path.join(current_path, 'test_model.yaml')] 8 | 9 | 10 | def quick_test(config_dict): 11 | objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False) 12 | 13 | 14 | class TestGeneralRecommender(unittest.TestCase): 15 | def test_bpr(self): 16 | config_dict = { 17 | 'model': 'BPR', 18 | } 19 | quick_test(config_dict) 20 | 21 | def test_neumf(self): 22 | config_dict = { 23 | 'model': 'NeuMF', 24 | } 25 | quick_test(config_dict) 26 | 27 | def test_ngcf(self): 28 | config_dict = { 29 | 'model': 'NGCF', 30 | } 31 | quick_test(config_dict) 32 | 33 | def test_lightgcn(self): 34 | config_dict = { 35 | 'model': 'LightGCN', 36 | } 37 | quick_test(config_dict) 38 | 39 | def test_sgl(self): 40 | config_dict = { 41 | 'model': 'SGL', 42 | } 43 | quick_test(config_dict) 44 | 45 | def test_hmlet(self): 46 | config_dict = { 47 | 'model': 'HMLET', 48 | } 49 | quick_test(config_dict) 50 | 51 | def test_ncl(self): 52 | config_dict = { 53 | 'model': 'NCL', 54 | 'num_clusters': 10 55 | } 56 | quick_test(config_dict) 57 | 58 | def test_simgcl(self): 59 | config_dict = { 60 | 'model': 'SimGCL' 61 | } 62 | quick_test(config_dict) 63 | 64 | def test_xsimgcl(self): 65 | config_dict = { 66 | 'model': 'XSimGCL' 67 | } 68 | quick_test(config_dict) 69 | 70 | def test_lightgcl(self): 71 | config_dict = { 72 | 'model': 'LightGCL' 73 | } 74 | quick_test(config_dict) 75 | 76 | def test_directau(self): 77 | config_dict = { 78 | 'model': 'DirectAU' 79 | } 80 | quick_test(config_dict) 81 | 82 | def test_ssl4rec(self): 83 | config_dict = { 84 | 'model': 'SSL4REC' 85 | } 86 | quick_test(config_dict) 87 | 88 | 89 | class TestSequentialRecommender(unittest.TestCase): 90 | def test_gru4rec(self): 91 | config_dict = { 92 | 'model': 'GRU4Rec', 93 | } 94 | quick_test(config_dict) 95 | 96 | def test_narm(self): 97 | config_dict = { 98 | 'model': 'NARM', 99 | } 100 | quick_test(config_dict) 101 | 102 | def test_sasrec(self): 103 | config_dict = { 104 | 'model': 'SASRec', 105 | } 106 | quick_test(config_dict) 107 | 108 | def test_srgnn(self): 109 | config_dict = { 110 | 'model': 'SRGNN', 111 | } 112 | quick_test(config_dict) 113 | 114 | def test_srgnn_uni100(self): 115 | config_dict = { 116 | 'model': 'SRGNN', 117 | 'eval_args': { 118 | 'split': {'LS': "valid_and_test"}, 119 | 'mode': 'uni100', 120 | 'order': 'TO' 121 | } 122 | } 123 | quick_test(config_dict) 124 | 125 | def test_gcsan(self): 126 | config_dict = { 127 | 'model': 'GCSAN', 128 | } 129 | quick_test(config_dict) 130 | 131 | def test_niser(self): 132 | config_dict = { 133 | 'model': 'NISER', 134 | } 135 | quick_test(config_dict) 136 | 137 | def test_lessr(self): 138 | config_dict = { 139 | 'model': 'LESSR' 140 | } 141 | quick_test(config_dict) 142 | 143 | def test_tagnn(self): 144 | config_dict = { 145 | 'model': 'TAGNN' 146 | } 147 | quick_test(config_dict) 148 | 149 | def test_gcegnn(self): 150 | config_dict = { 151 | 'model': 'GCEGNN' 152 | } 153 | quick_test(config_dict) 154 | 155 | def test_sgnnhn(self): 156 | config_dict = { 157 | 'model': 'SGNNHN' 158 | } 159 | quick_test(config_dict) 160 | 161 | 162 | class TestSocialRecommender(unittest.TestCase): 163 | def test_diffnet(self): 164 | config_dict = { 165 | 'model': 'DiffNet', 166 | } 167 | quick_test(config_dict) 168 | 169 | def test_mhcn(self): 170 | config_dict = { 171 | 'model': 'MHCN', 172 | } 173 | quick_test(config_dict) 174 | 175 | def test_sept(self): 176 | config_dict = { 177 | 'model': 'SEPT', 178 | } 179 | quick_test(config_dict) 180 | 181 | 182 | if __name__ == '__main__': 183 | unittest.main() 184 | -------------------------------------------------------------------------------- /tests/test_model.yaml: -------------------------------------------------------------------------------- 1 | dataset: test 2 | epochs: 1 3 | state: ERROR 4 | data_path: tests/test_data/ 5 | 6 | # Atomic File Format 7 | field_separator: "\t" 8 | seq_separator: " " 9 | 10 | # Common Features 11 | USER_ID_FIELD: user_id 12 | ITEM_ID_FIELD: item_id 13 | RATING_FIELD: rating 14 | TIME_FIELD: timestamp 15 | seq_len: ~ 16 | # Label for Point-wise DataLoader 17 | LABEL_FIELD: label 18 | # NegSample Prefix for Pair-wise DataLoader 19 | NEG_PREFIX: neg_ 20 | # Sequential Model Needed 21 | ITEM_LIST_LENGTH_FIELD: item_length 22 | LIST_SUFFIX: _list 23 | MAX_ITEM_LIST_LENGTH: 50 24 | POSITION_FIELD: position_id 25 | # social network config 26 | NET_SOURCE_ID_FIELD: source_id 27 | NET_TARGET_ID_FIELD: target_id 28 | filter_net_by_inter: True 29 | undirected_net: True 30 | 31 | # Selectively Loading 32 | load_col: 33 | inter: [user_id, item_id, rating, timestamp] 34 | net: [source_id, target_id] 35 | 36 | unload_col: ~ 37 | 38 | # Preprocessing 39 | alias_of_user_id: ~ 40 | alias_of_item_id: ~ 41 | alias_of_entity_id: ~ 42 | alias_of_relation_id: ~ 43 | preload_weight: ~ 44 | normalize_field: ~ 45 | normalize_all: True 46 | --------------------------------------------------------------------------------