├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── bug_report_CN.md
│ ├── feature_request.md
│ └── feature_request_CN.md
└── workflows
│ └── python-package.yml
├── .gitignore
├── LICENSE
├── README.md
├── asset
├── arch.png
├── diginetica.png
├── ml-1m.png
└── recbole-gnn-logo.png
├── recbole_gnn
├── config.py
├── data
│ ├── __init__.py
│ ├── dataloader.py
│ ├── dataset.py
│ └── transform.py
├── model
│ ├── abstract_recommender.py
│ ├── general_recommender
│ │ ├── __init__.py
│ │ ├── directau.py
│ │ ├── hmlet.py
│ │ ├── lightgcl.py
│ │ ├── lightgcn.py
│ │ ├── ncl.py
│ │ ├── ngcf.py
│ │ ├── sgl.py
│ │ ├── simgcl.py
│ │ ├── ssl4rec.py
│ │ └── xsimgcl.py
│ ├── layers.py
│ ├── sequential_recommender
│ │ ├── __init__.py
│ │ ├── gcegnn.py
│ │ ├── gcsan.py
│ │ ├── lessr.py
│ │ ├── niser.py
│ │ ├── sgnnhn.py
│ │ ├── srgnn.py
│ │ └── tagnn.py
│ └── social_recommender
│ │ ├── __init__.py
│ │ ├── diffnet.py
│ │ ├── mhcn.py
│ │ └── sept.py
├── properties
│ ├── model
│ │ ├── DiffNet.yaml
│ │ ├── DirectAU.yaml
│ │ ├── GCEGNN.yaml
│ │ ├── GCSAN.yaml
│ │ ├── HMLET.yaml
│ │ ├── LESSR.yaml
│ │ ├── LightGCL.yaml
│ │ ├── LightGCN.yaml
│ │ ├── MHCN.yaml
│ │ ├── NCL.yaml
│ │ ├── NGCF.yaml
│ │ ├── NISER.yaml
│ │ ├── SEPT.yaml
│ │ ├── SGL.yaml
│ │ ├── SGNNHN.yaml
│ │ ├── SRGNN.yaml
│ │ ├── SSL4REC.yaml
│ │ ├── SimGCL.yaml
│ │ ├── TAGNN.yaml
│ │ └── XSimGCL.yaml
│ └── quick_start_config
│ │ ├── sequential_base.yaml
│ │ └── social_base.yaml
├── quick_start.py
├── trainer.py
└── utils.py
├── results
├── README.md
├── general
│ └── ml-1m.md
├── sequential
│ └── diginetica.md
└── social
│ └── lastfm.md
├── run_hyper.py
├── run_recbole_gnn.py
├── run_test.sh
└── tests
├── test_data
└── test
│ ├── test.inter
│ └── test.net
├── test_model.py
└── test_model.yaml
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: "[\U0001F41BBUG] Describe your problem in one sentence."
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. extra yaml file
16 | 2. your code
17 | 3. script for running
18 |
19 | **Expected behavior**
20 | A clear and concise description of what you expected to happen.
21 |
22 | **Screenshots**
23 | If applicable, add screenshots to help explain your problem.
24 |
25 | **Colab Links**
26 | If applicable, add links to Colab or other Jupyter laboratory platforms that can reproduce the bug.
27 |
28 | **Desktop (please complete the following information):**
29 | - OS: [e.g. Linux, macOS or Windows]
30 | - RecBole Version [e.g. 0.1.0]
31 | - Python Version [e.g. 3.79]
32 | - PyTorch Version [e.g. 1.60]
33 | - cudatoolkit Version [e.g. 9.2, none]
34 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report_CN.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug 报告
3 | about: 提交一份 bug 报告,帮助 RecBole-GNN 变得更好
4 | title: "[\U0001F41BBUG] 用一句话描述您的问题。"
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | **描述这个 bug**
11 | 对 bug 作一个清晰简明的描述。
12 |
13 | **如何复现**
14 | 复现这个 bug 的步骤:
15 | 1. 您引入的额外 yaml 文件
16 | 2. 您的代码
17 | 3. 您的运行脚本
18 |
19 | **预期**
20 | 对您的预期作清晰简明的描述。
21 |
22 | **屏幕截图**
23 | 添加屏幕截图以帮助解释您的问题。(可选)
24 |
25 | **链接**
26 | 添加能够复现 bug 的代码链接,如 Colab 或者其他在线 Jupyter 平台。(可选)
27 |
28 | **实验环境(请补全下列信息):**
29 | - 操作系统: [如 Linux, macOS 或 Windows]
30 | - RecBole 版本 [如 0.1.0]
31 | - Python 版本 [如 3.79]
32 | - PyTorch 版本 [如 1.60]
33 | - cudatoolkit 版本 [如 9.2, none]
34 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: "[\U0001F4A1SUG] Description of what you want to happen in one sentence"
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request_CN.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: 请求添加新功能
3 | about: 提出一个关于本项目新功能/新特性的建议
4 | title: "[\U0001F4A1SUG] 一句话描述您希望新增的功能或特性"
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | **您希望添加的功能是否与某个问题相关?**
11 | 关于这个问题的简洁清晰的描述,例如,当 [...] 时,我总是很沮丧。
12 |
13 | **描述您希望的解决方案**
14 | 关于解决方案的简洁清晰的描述。
15 |
16 | **描述您考虑的替代方案**
17 | 关于您考虑的,能实现这个功能的其他替代方案的简洁清晰的描述。
18 |
19 | **其他**
20 | 您可以添加其他任何的资料、链接或者屏幕截图,以帮助我们理解这个新功能。
21 |
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | name: RecBole-GNN tests
2 |
3 | # Controls when the action will run.
4 | on:
5 | # Triggers the workflow on push or pull request events but only for the master branch
6 | push:
7 | pull_request:
8 |
9 | # Allows you to run this workflow manually from the Actions tab
10 | workflow_dispatch:
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | matrix:
18 | python-version: [3.9]
19 | torch-version: [2.0.0]
20 | defaults:
21 | run:
22 | shell: bash -l {0}
23 |
24 | steps:
25 | - uses: actions/checkout@v2
26 | - name: Setup Miniconda
27 | uses: conda-incubator/setup-miniconda@v2
28 | with:
29 | python-version: ${{ matrix.python-version }}
30 | channels: conda-forge
31 | channel-priority: true
32 | auto-activate-base: true
33 | # install setuptools as a interim solution for bugs in PyTorch 1.10.2 (#69904)
34 | - name: Install dependencies
35 | run: |
36 | python -m pip install --upgrade pip
37 | pip install pytest
38 | pip install dgl
39 | pip install torch==${{ matrix.torch-version}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
40 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
41 | pip install recbole==1.1.1
42 | conda install -c conda-forge faiss-cpu
43 | # Use "python -m pytest" instead of "pytest" to fix imports
44 | - name: Test model
45 | run: |
46 | python -m pytest -v tests/test_model.py
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # RecBole
132 | log_tensorboard/
133 | saved/
134 | dataset/
135 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 RUCAIBox
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RecBole-GNN
2 |
3 | 
4 |
5 | -----
6 |
7 | *Updates*:
8 |
9 | * [Oct 29, 2023] Add [SSL4Rec](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/ssl4rec.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/76, by [@downeykking](https://github.com/downeykking))
10 | * [Oct 23, 2023] Add sparse tensor support, accelerating LightGCN & NGCF by ~5x, with 1/6 GPU memories. (https://github.com/RUCAIBox/RecBole-GNN/pull/75, by [@downeykking](https://github.com/downeykking))
11 | * [Oct 20, 2023] Add [DirectAU](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/directau.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/74, by [@downeykking](https://github.com/downeykking))
12 | * [Oct 16, 2023] Add [XSimGCL](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/xsimgcl.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/72, by [@downeykking](https://github.com/downeykking))
13 | * [Apr 12, 2023] Add [LightGCL](https://github.com/RUCAIBox/RecBole-GNN/blob/main/recbole_gnn/model/general_recommender/lightgcl.py). (https://github.com/RUCAIBox/RecBole-GNN/pull/63, by [@wending0417](https://github.com/wending0417))
14 | * [Oct 29, 2022] Adaptation to RecBole 1.1.1. (https://github.com/RUCAIBox/RecBole-GNN/pull/53)
15 | * [Jun 15, 2022] Add [MultiBehaviorDataset](https://github.com/RUCAIBox/RecBole-GNN/blob/8c61463451b294dce9af2d1939a5e054f7955e0f/recbole_gnn/data/dataset.py#L145). (https://github.com/RUCAIBox/RecBole-GNN/pull/43, by [@Tokkiu](https://github.com/Tokkiu))
16 |
17 | -----
18 |
19 | **RecBole-GNN** is a library built upon [PyTorch](https://pytorch.org) and [RecBole](https://github.com/RUCAIBox/RecBole) for reproducing and developing recommendation algorithms based on graph neural networks (GNNs). Our library includes algorithms covering three major categories:
20 | * **General Recommendation** with user-item interaction graphs;
21 | * **Sequential Recommendation** with session/sequence graphs;
22 | * **Social Recommendation** with social networks.
23 |
24 | 
25 |
26 | ## Highlights
27 |
28 | * **Easy-to-use and unified API**:
29 | Our library shares unified API and input (atomic files) as RecBole.
30 | * **Efficient and reusable graph processing**:
31 | We provide highly efficient and reusable basic datasets, dataloaders and layers for graph processing and learning.
32 | * **Extensive graph library**:
33 | Graph neural networks from widely-used library like [PyG](https://github.com/pyg-team/pytorch_geometric) are incorporated. Recently proposed graph algorithms can be easily equipped and compared with existing methods.
34 |
35 | ## Requirements
36 |
37 | ```
38 | recbole==1.1.1
39 | pyg>=2.0.4
40 | pytorch>=1.7.0
41 | python>=3.7.0
42 | ```
43 |
44 | > If you are using `recbole==1.0.1`, please refer to our `recbole1.0.1` branch [[link]](https://github.com/hyp1231/RecBole-GNN/tree/recbole1.0.1).
45 |
46 | ## Quick-Start
47 |
48 | With the source code, you can use the provided script for initial usage of our library:
49 |
50 | ```bash
51 | python run_recbole_gnn.py
52 | ```
53 |
54 | If you want to change the models or datasets, just run the script by setting additional command parameters:
55 |
56 | ```bash
57 | python run_recbole_gnn.py -m [model] -d [dataset]
58 | ```
59 |
60 | ## Implemented Models
61 |
62 | We list currently supported models according to category:
63 |
64 | **General Recommendation**:
65 |
66 | * **[NGCF](recbole_gnn/model/general_recommender/ngcf.py)** from Wang *et al.*: [Neural Graph Collaborative Filtering](https://arxiv.org/abs/1905.08108) (SIGIR 2019).
67 | * **[LightGCN](recbole_gnn/model/general_recommender/lightgcn.py)** from He *et al.*: [LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation](https://arxiv.org/abs/2002.02126) (SIGIR 2020).
68 | * **[SSL4Rec](recbole_gnn/model/general_recommender/ssl4rec.py)** from Yao *et al.*: [Self-supervised Learning for Large-scale Item Recommendations](https://arxiv.org/abs/2007.12865) (CIKM 2021).
69 | * **[SGL](recbole_gnn/model/general_recommender/sgl.py)** from Wu *et al.*: [Self-supervised Graph Learning for Recommendation](https://arxiv.org/abs/2010.10783) (SIGIR 2021).
70 | * **[HMLET](recbole_gnn/model/general_recommender/hmlet.py)** from Kong *et al.*: [Linear, or Non-Linear, That is the Question!](https://arxiv.org/abs/2111.07265) (WSDM 2022).
71 | * **[NCL](recbole_gnn/model/general_recommender/ncl.py)** from Lin *et al.*: [Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning](https://arxiv.org/abs/2202.06200) (TheWebConf 2022).
72 | * **[DirectAU](recbole_gnn/model/general_recommender/directau.py)** from Wang *et al.*: [Towards Representation Alignment and Uniformity in Collaborative Filtering](https://arxiv.org/abs/2206.12811) (KDD 2022).
73 | * **[SimGCL](recbole_gnn/model/general_recommender/simgcl.py)** from Yu *et al.*: [Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for Recommendation](https://arxiv.org/abs/2112.08679) (SIGIR 2022).
74 | * **[XSimGCL](recbole_gnn/model/general_recommender/xsimgcl.py)** from Yu *et al.*: [XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation](https://arxiv.org/abs/2209.02544) (TKDE 2023).
75 | * **[LightGCL](recbole_gnn/model/general_recommender/lightgcl.py)** from Cai *et al.*: [LightGCL: Simple Yet Effective Graph Contrastive Learning for Recommendation
76 | ](https://arxiv.org/abs/2302.08191) (ICLR 2023).
77 |
78 | **Sequential Recommendation**:
79 |
80 | * **[SR-GNN](recbole_gnn/model/sequential_recommender/srgnn.py)** from Wu *et al.*: [Session-based Recommendation with Graph Neural Networks](https://arxiv.org/abs/1811.00855) (AAAI 2019).
81 | * **[GC-SAN](recbole_gnn/model/sequential_recommender/gcsan.py)** from Xu *et al.*: [Graph Contextualized Self-Attention Network for Session-based Recommendation](https://www.ijcai.org/proceedings/2019/547) (IJCAI 2019).
82 | * **[NISER+](recbole_gnn/model/sequential_recommender/niser.py)** from Gupta *et al.*: [NISER: Normalized Item and Session Representations to Handle Popularity Bias](https://arxiv.org/abs/1909.04276) (GRLA, CIKM 2019 workshop).
83 | * **[LESSR](recbole_gnn/model/sequential_recommender/lessr.py)** from Chen *et al.*: [Handling Information Loss of Graph Neural Networks for Session-based Recommendation](https://dl.acm.org/doi/10.1145/3394486.3403170) (KDD 2020).
84 | * **[TAGNN](recbole_gnn/model/sequential_recommender/tagnn.py)** from Yu *et al.*: [TAGNN: Target Attentive Graph Neural Networks for Session-based Recommendation](https://arxiv.org/abs/2005.02844) (SIGIR 2020 short).
85 | * **[GCE-GNN](recbole_gnn/model/sequential_recommender/gcegnn.py)** from Wang *et al.*: [Global Context Enhanced Graph Neural Networks for Session-based Recommendation](https://arxiv.org/abs/2106.05081) (SIGIR 2020).
86 | * **[SGNN-HN](recbole_gnn/model/sequential_recommender/sgnnhn.py)** from Pan *et al.*: [Star Graph Neural Networks for Session-based Recommendation](https://dl.acm.org/doi/10.1145/3340531.3412014) (CIKM 2020).
87 |
88 | **Social Recommendation**:
89 |
90 | > Note that datasets for social recommendation methods can be downloaded from [Social-Datasets](https://github.com/Sherry-XLL/Social-Datasets).
91 |
92 | * **[DiffNet](recbole_gnn/model/social_recommender/diffnet.py)** from Wu *et al.*: [A Neural Influence Diffusion Model for Social Recommendation](https://arxiv.org/abs/1904.10322) (SIGIR 2019).
93 | * **[MHCN](recbole_gnn/model/social_recommender/mhcn.py)** from Yu *et al.*: [Self-Supervised Multi-Channel Hypergraph Convolutional Network for Social Recommendation](https://doi.org/10.1145/3442381.3449844) (WWW 2021).
94 | * **[SEPT](recbole_gnn/model/social_recommender/sept.py)** from Yu *et al.*: [Socially-Aware Self-Supervised Tri-Training for Recommendation](https://doi.org/10.1145/3447548.3467340) (KDD 2021).
95 |
96 | ## Result
97 |
98 | ### Leaderboard
99 |
100 | We carefully tune the hyper-parameters of the implemented models of each research field and release the corresponding leaderboards for reference:
101 |
102 | - **General** recommendation on `MovieLens-1M` dataset [[link]](results/general/ml-1m.md);
103 | - **Sequential** recommendation on `Diginetica` dataset [[link]](results/sequential/diginetica.md);
104 | - **Social** recommendation on `LastFM` dataset [[link]](results/social/lastfm.md);
105 |
106 | ### Efficiency
107 |
108 | With the sequential/session graphs preprocessing technique, as well as efficient GNN layers, we speed up the training process of our sequential recommenders a lot.
109 |
110 | 
111 |
112 | ## The Team
113 |
114 | RecBole-GNN is initially developed and maintained by members from [RUCAIBox](http://aibox.ruc.edu.cn/), the main developers are Yupeng Hou ([@hyp1231](https://github.com/hyp1231)), Lanling Xu ([@Sherry-XLL](https://github.com/Sherry-XLL)) and Changxin Tian ([@ChangxinTian](https://github.com/ChangxinTian)). We also thank Xinzhou ([@downeykking](https://github.com/downeykking)), Wanli ([@wending0417](https://github.com/wending0417)), and Jingqi ([@Tokkiu](https://github.com/Tokkiu)) for their great contribution! ❤️
115 |
116 | ## Acknowledgement
117 |
118 | The implementation is based on the open-source recommendation library [RecBole](https://github.com/RUCAIBox/RecBole). RecBole-GNN is part of [RecBole 2.0](https://github.com/RUCAIBox/RecBole2.0) now!
119 |
120 | Please cite the following paper as the reference if you use our code or processed datasets.
121 |
122 | ```bibtex
123 | @inproceedings{zhao2022recbole2,
124 | author={Wayne Xin Zhao and Yupeng Hou and Xingyu Pan and Chen Yang and Zeyu Zhang and Zihan Lin and Jingsen Zhang and Shuqing Bian and Jiakai Tang and Wenqi Sun and Yushuo Chen and Lanling Xu and Gaowei Zhang and Zhen Tian and Changxin Tian and Shanlei Mu and Xinyan Fan and Xu Chen and Ji-Rong Wen},
125 | title={RecBole 2.0: Towards a More Up-to-Date Recommendation Library},
126 | booktitle = {{CIKM}},
127 | year={2022}
128 | }
129 |
130 | @inproceedings{zhao2021recbole,
131 | author = {Wayne Xin Zhao and Shanlei Mu and Yupeng Hou and Zihan Lin and Yushuo Chen and Xingyu Pan and Kaiyuan Li and Yujie Lu and Hui Wang and Changxin Tian and Yingqian Min and Zhichao Feng and Xinyan Fan and Xu Chen and Pengfei Wang and Wendi Ji and Yaliang Li and Xiaoling Wang and Ji{-}Rong Wen},
132 | title = {RecBole: Towards a Unified, Comprehensive and Efficient Framework for Recommendation Algorithms},
133 | booktitle = {{CIKM}},
134 | pages = {4653--4664},
135 | publisher = {{ACM}},
136 | year = {2021}
137 | }
138 | ```
139 |
--------------------------------------------------------------------------------
/asset/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RUCAIBox/RecBole-GNN/632ef888589944c190ad8f449b49ca559618d4df/asset/arch.png
--------------------------------------------------------------------------------
/asset/diginetica.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RUCAIBox/RecBole-GNN/632ef888589944c190ad8f449b49ca559618d4df/asset/diginetica.png
--------------------------------------------------------------------------------
/asset/ml-1m.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RUCAIBox/RecBole-GNN/632ef888589944c190ad8f449b49ca559618d4df/asset/ml-1m.png
--------------------------------------------------------------------------------
/asset/recbole-gnn-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RUCAIBox/RecBole-GNN/632ef888589944c190ad8f449b49ca559618d4df/asset/recbole-gnn-logo.png
--------------------------------------------------------------------------------
/recbole_gnn/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import recbole
3 | from recbole.config.configurator import Config as RecBole_Config
4 | from recbole.utils import ModelType as RecBoleModelType
5 |
6 | from recbole_gnn.utils import get_model, ModelType
7 |
8 |
9 | class Config(RecBole_Config):
10 | def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=None):
11 | """
12 | Args:
13 | model (str/AbstractRecommender): the model name or the model class, default is None, if it is None, config
14 | will search the parameter 'model' from the external input as the model name or model class.
15 | dataset (str): the dataset name, default is None, if it is None, config will search the parameter 'dataset'
16 | from the external input as the dataset name.
17 | config_file_list (list of str): the external config file, it allows multiple config files, default is None.
18 | config_dict (dict): the external parameter dictionaries, default is None.
19 | """
20 | if recbole.__version__ == "1.1.1":
21 | self.compatibility_settings()
22 | super(Config, self).__init__(model, dataset, config_file_list, config_dict)
23 |
24 | def compatibility_settings(self):
25 | import numpy as np
26 | np.bool = np.bool_
27 | np.int = np.int_
28 | np.float = np.float_
29 | np.complex = np.complex_
30 | np.object = np.object_
31 | np.str = np.str_
32 | np.long = np.int_
33 | np.unicode = np.unicode_
34 |
35 | def _get_model_and_dataset(self, model, dataset):
36 |
37 | if model is None:
38 | try:
39 | model = self.external_config_dict['model']
40 | except KeyError:
41 | raise KeyError(
42 | 'model need to be specified in at least one of the these ways: '
43 | '[model variable, config file, config dict, command line] '
44 | )
45 | if not isinstance(model, str):
46 | final_model_class = model
47 | final_model = model.__name__
48 | else:
49 | final_model = model
50 | final_model_class = get_model(final_model)
51 |
52 | if dataset is None:
53 | try:
54 | final_dataset = self.external_config_dict['dataset']
55 | except KeyError:
56 | raise KeyError(
57 | 'dataset need to be specified in at least one of the these ways: '
58 | '[dataset variable, config file, config dict, command line] '
59 | )
60 | else:
61 | final_dataset = dataset
62 |
63 | return final_model, final_model_class, final_dataset
64 |
65 | def _load_internal_config_dict(self, model, model_class, dataset):
66 | super()._load_internal_config_dict(model, model_class, dataset)
67 | current_path = os.path.dirname(os.path.realpath(__file__))
68 | model_init_file = os.path.join(current_path, './properties/model/' + model + '.yaml')
69 | quick_start_config_path = os.path.join(current_path, './properties/quick_start_config/')
70 | sequential_base_init = os.path.join(quick_start_config_path, 'sequential_base.yaml')
71 | social_base_init = os.path.join(quick_start_config_path, 'social_base.yaml')
72 |
73 | if os.path.isfile(model_init_file):
74 | config_dict = self._update_internal_config_dict(model_init_file)
75 |
76 | self.internal_config_dict['MODEL_TYPE'] = model_class.type
77 | if self.internal_config_dict['MODEL_TYPE'] == RecBoleModelType.SEQUENTIAL:
78 | self._update_internal_config_dict(sequential_base_init)
79 | if self.internal_config_dict['MODEL_TYPE'] == ModelType.SOCIAL:
80 | self._update_internal_config_dict(social_base_init)
81 |
--------------------------------------------------------------------------------
/recbole_gnn/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RUCAIBox/RecBole-GNN/632ef888589944c190ad8f449b49ca559618d4df/recbole_gnn/data/__init__.py
--------------------------------------------------------------------------------
/recbole_gnn/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from recbole.data.interaction import cat_interactions
4 | from recbole.data.dataloader.general_dataloader import TrainDataLoader, NegSampleEvalDataLoader, FullSortEvalDataLoader
5 |
6 | from recbole_gnn.data.transform import gnn_construct_transform
7 |
8 |
9 | class CustomizedTrainDataLoader(TrainDataLoader):
10 | def __init__(self, config, dataset, sampler, shuffle=False):
11 | super().__init__(config, dataset, sampler, shuffle=shuffle)
12 | if config['gnn_transform'] is not None:
13 | self.transform = gnn_construct_transform(config)
14 |
15 |
16 | class CustomizedNegSampleEvalDataLoader(NegSampleEvalDataLoader):
17 | def __init__(self, config, dataset, sampler, shuffle=False):
18 | super().__init__(config, dataset, sampler, shuffle=shuffle)
19 | if config['gnn_transform'] is not None:
20 | self.transform = gnn_construct_transform(config)
21 |
22 | def collate_fn(self, index):
23 | index = np.array(index)
24 | if (
25 | self.neg_sample_args["distribution"] != "none"
26 | and self.neg_sample_args["sample_num"] != "none"
27 | ):
28 | uid_list = self.uid_list[index]
29 | data_list = []
30 | idx_list = []
31 | positive_u = []
32 | positive_i = torch.tensor([], dtype=torch.int64)
33 |
34 | for idx, uid in enumerate(uid_list):
35 | index = self.uid2index[uid]
36 | data_list.append(self._neg_sampling(self._dataset[index]))
37 | idx_list += [idx for i in range(self.uid2items_num[uid] * self.times)]
38 | positive_u += [idx for i in range(self.uid2items_num[uid])]
39 | positive_i = torch.cat(
40 | (positive_i, self._dataset[index][self.iid_field]), 0
41 | )
42 |
43 | cur_data = cat_interactions(data_list)
44 | idx_list = torch.from_numpy(np.array(idx_list)).long()
45 | positive_u = torch.from_numpy(np.array(positive_u)).long()
46 |
47 | return self.transform(self._dataset, cur_data), idx_list, positive_u, positive_i
48 | else:
49 | data = self._dataset[index]
50 | transformed_data = self.transform(self._dataset, data)
51 | cur_data = self._neg_sampling(transformed_data)
52 | return cur_data, None, None, None
53 |
54 |
55 | class CustomizedFullSortEvalDataLoader(FullSortEvalDataLoader):
56 | def __init__(self, config, dataset, sampler, shuffle=False):
57 | super().__init__(config, dataset, sampler, shuffle=shuffle)
58 | if config['gnn_transform'] is not None:
59 | self.transform = gnn_construct_transform(config)
60 |
--------------------------------------------------------------------------------
/recbole_gnn/data/transform.py:
--------------------------------------------------------------------------------
1 | from logging import getLogger
2 | import torch
3 | from torch.nn.utils.rnn import pad_sequence
4 | from recbole.data.interaction import Interaction
5 |
6 |
7 | def gnn_construct_transform(config):
8 | if config['gnn_transform'] is None:
9 | raise ValueError('config["gnn_transform"] is None but trying to construct transform.')
10 | str2transform = {
11 | 'sess_graph': SessionGraph,
12 | }
13 | return str2transform[config['gnn_transform']](config)
14 |
15 |
16 | class SessionGraph:
17 | def __init__(self, config):
18 | self.logger = getLogger()
19 | self.logger.info('SessionGraph Transform in DataLoader.')
20 |
21 | def __call__(self, dataset, interaction):
22 | graph_objs = dataset.graph_objs
23 | index = interaction['graph_idx']
24 | graph_batch = {
25 | k: [graph_objs[k][_.item()] for _ in index]
26 | for k in graph_objs
27 | }
28 | graph_batch['batch'] = []
29 |
30 | tot_node_num = torch.ones([1], dtype=torch.long)
31 | for i in range(index.shape[0]):
32 | for k in graph_batch:
33 | if 'edge_index' in k:
34 | graph_batch[k][i] = graph_batch[k][i] + tot_node_num
35 | if 'alias_inputs' in graph_batch:
36 | graph_batch['alias_inputs'][i] = graph_batch['alias_inputs'][i] + tot_node_num
37 | graph_batch['batch'].append(torch.full_like(graph_batch['x'][i], i))
38 | tot_node_num += graph_batch['x'][i].shape[0]
39 |
40 | if hasattr(dataset, 'node_attr'):
41 | node_attr = ['batch'] + dataset.node_attr
42 | else:
43 | node_attr = ['x', 'batch']
44 | for k in node_attr:
45 | graph_batch[k] = [torch.zeros([1], dtype=graph_batch[k][-1].dtype)] + graph_batch[k]
46 |
47 | for k in graph_batch:
48 | if k == 'alias_inputs':
49 | graph_batch[k] = pad_sequence(graph_batch[k], batch_first=True)
50 | else:
51 | graph_batch[k] = torch.cat(graph_batch[k], dim=-1)
52 |
53 | interaction.update(Interaction(graph_batch))
54 | return interaction
55 |
--------------------------------------------------------------------------------
/recbole_gnn/model/abstract_recommender.py:
--------------------------------------------------------------------------------
1 | from recbole.model.abstract_recommender import GeneralRecommender
2 | from recbole.utils import ModelType as RecBoleModelType
3 |
4 | from recbole_gnn.utils import ModelType
5 |
6 |
7 | class GeneralGraphRecommender(GeneralRecommender):
8 | """This is an abstract general graph recommender. All the general graph models should implement in this class.
9 | The base general graph recommender class provide the basic U-I graph dataset and parameters information.
10 | """
11 | type = RecBoleModelType.GENERAL
12 |
13 | def __init__(self, config, dataset):
14 | super(GeneralGraphRecommender, self).__init__(config, dataset)
15 | self.edge_index, self.edge_weight = dataset.get_norm_adj_mat(enable_sparse=config["enable_sparse"])
16 | self.use_sparse = config["enable_sparse"] and dataset.is_sparse
17 | if self.use_sparse:
18 | self.edge_index, self.edge_weight = self.edge_index.to(self.device), None
19 | else:
20 | self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device)
21 |
22 |
23 | class SocialRecommender(GeneralRecommender):
24 | """This is an abstract social recommender. All the social graph model should implement this class.
25 | The base social recommender class provide the basic social graph dataset and parameters information.
26 | """
27 | type = ModelType.SOCIAL
28 |
29 | def __init__(self, config, dataset):
30 | super(SocialRecommender, self).__init__(config, dataset)
31 |
--------------------------------------------------------------------------------
/recbole_gnn/model/general_recommender/__init__.py:
--------------------------------------------------------------------------------
1 | from recbole_gnn.model.general_recommender.lightgcn import LightGCN
2 | from recbole_gnn.model.general_recommender.hmlet import HMLET
3 | from recbole_gnn.model.general_recommender.ncl import NCL
4 | from recbole_gnn.model.general_recommender.ngcf import NGCF
5 | from recbole_gnn.model.general_recommender.sgl import SGL
6 | from recbole_gnn.model.general_recommender.lightgcl import LightGCL
7 | from recbole_gnn.model.general_recommender.simgcl import SimGCL
8 | from recbole_gnn.model.general_recommender.xsimgcl import XSimGCL
9 | from recbole_gnn.model.general_recommender.directau import DirectAU
10 | from recbole_gnn.model.general_recommender.ssl4rec import SSL4REC
11 |
--------------------------------------------------------------------------------
/recbole_gnn/model/general_recommender/directau.py:
--------------------------------------------------------------------------------
1 | # r"""
2 | # DiretAU
3 | # ################################################
4 | # Reference:
5 | # Chenyang Wang et al. "Towards Representation Alignment and Uniformity in Collaborative Filtering." in KDD 2022.
6 |
7 | # Reference code:
8 | # https://github.com/THUwangcy/DirectAU
9 | # """
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 | from recbole.model.init import xavier_normal_initialization
17 | from recbole.utils import InputType
18 | from recbole.model.general_recommender import BPR
19 | from recbole_gnn.model.general_recommender import LightGCN
20 |
21 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
22 |
23 |
24 | class DirectAU(GeneralGraphRecommender):
25 | input_type = InputType.PAIRWISE
26 |
27 | def __init__(self, config, dataset):
28 | super(DirectAU, self).__init__(config, dataset)
29 |
30 | # load parameters info
31 | self.embedding_size = config['embedding_size']
32 | self.gamma = config['gamma']
33 | self.encoder_name = config['encoder']
34 |
35 | # define encoder
36 | if self.encoder_name == 'MF':
37 | self.encoder = MFEncoder(config, dataset)
38 | elif self.encoder_name == 'LightGCN':
39 | self.encoder = LGCNEncoder(config, dataset)
40 | else:
41 | raise ValueError('Non-implemented Encoder.')
42 |
43 | # storage variables for full sort evaluation acceleration
44 | self.restore_user_e = None
45 | self.restore_item_e = None
46 |
47 | # parameters initialization
48 | self.apply(xavier_normal_initialization)
49 |
50 | def forward(self, user, item):
51 | user_e, item_e = self.encoder(user, item)
52 | return F.normalize(user_e, dim=-1), F.normalize(item_e, dim=-1)
53 |
54 | @staticmethod
55 | def alignment(x, y, alpha=2):
56 | return (x - y).norm(p=2, dim=1).pow(alpha).mean()
57 |
58 | @staticmethod
59 | def uniformity(x, t=2):
60 | return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()
61 |
62 | def calculate_loss(self, interaction):
63 | if self.restore_user_e is not None or self.restore_item_e is not None:
64 | self.restore_user_e, self.restore_item_e = None, None
65 |
66 | user = interaction[self.USER_ID]
67 | item = interaction[self.ITEM_ID]
68 |
69 | user_e, item_e = self.forward(user, item)
70 | align = self.alignment(user_e, item_e)
71 | uniform = self.gamma * (self.uniformity(user_e) + self.uniformity(item_e)) / 2
72 |
73 | return align, uniform
74 |
75 | def predict(self, interaction):
76 | user = interaction[self.USER_ID]
77 | item = interaction[self.ITEM_ID]
78 | user_e = self.user_embedding(user)
79 | item_e = self.item_embedding(item)
80 | return torch.mul(user_e, item_e).sum(dim=1)
81 |
82 | def full_sort_predict(self, interaction):
83 | user = interaction[self.USER_ID]
84 | if self.encoder_name == 'LightGCN':
85 | if self.restore_user_e is None or self.restore_item_e is None:
86 | self.restore_user_e, self.restore_item_e = self.encoder.get_all_embeddings()
87 | user_e = self.restore_user_e[user]
88 | all_item_e = self.restore_item_e
89 | else:
90 | user_e = self.encoder.user_embedding(user)
91 | all_item_e = self.encoder.item_embedding.weight
92 | score = torch.matmul(user_e, all_item_e.transpose(0, 1))
93 | return score.view(-1)
94 |
95 |
96 | class MFEncoder(BPR):
97 | def __init__(self, config, dataset):
98 | super(MFEncoder, self).__init__(config, dataset)
99 |
100 | def forward(self, user_id, item_id):
101 | return super().forward(user_id, item_id)
102 |
103 | def get_all_embeddings(self):
104 | user_embeddings = self.user_embedding.weight
105 | item_embeddings = self.item_embedding.weight
106 | return user_embeddings, item_embeddings
107 |
108 |
109 | class LGCNEncoder(LightGCN):
110 | def __init__(self, config, dataset):
111 | super(LGCNEncoder, self).__init__(config, dataset)
112 |
113 | def forward(self, user_id, item_id):
114 | user_all_embeddings, item_all_embeddings = self.get_all_embeddings()
115 | u_embed = user_all_embeddings[user_id]
116 | i_embed = item_all_embeddings[item_id]
117 | return u_embed, i_embed
118 |
119 | def get_all_embeddings(self):
120 | return super().forward()
121 |
--------------------------------------------------------------------------------
/recbole_gnn/model/general_recommender/hmlet.py:
--------------------------------------------------------------------------------
1 | # @Time : 2022/3/21
2 | # @Author : Yupeng Hou
3 | # @Email : houyupeng@ruc.edu.cn
4 |
5 | r"""
6 | HMLET
7 | ################################################
8 | Reference:
9 | Taeyong Kong et al. "Linear, or Non-Linear, That is the Question!." in WSDM 2022.
10 |
11 | Reference code:
12 | https://github.com/qbxlvnf11/HMLET
13 | """
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | from recbole.model.init import xavier_uniform_initialization
19 | from recbole.model.loss import BPRLoss, EmbLoss
20 | from recbole.model.layers import activation_layer
21 | from recbole.utils import InputType
22 |
23 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
24 | from recbole_gnn.model.layers import LightGCNConv
25 |
26 |
27 | class Gating_Net(nn.Module):
28 | def __init__(self, embedding_dim, mlp_dims, dropout_p):
29 | super(Gating_Net, self).__init__()
30 | self.embedding_dim = embedding_dim
31 |
32 | fc_layers = []
33 | for i in range(len(mlp_dims)):
34 | if i == 0:
35 | fc = nn.Linear(embedding_dim*2, mlp_dims[i])
36 | fc_layers.append(fc)
37 | else:
38 | fc = nn.Linear(mlp_dims[i-1], mlp_dims[i])
39 | fc_layers.append(fc)
40 | if i != len(mlp_dims) - 1:
41 | fc_layers.append(nn.BatchNorm1d(mlp_dims[i]))
42 | fc_layers.append(nn.Dropout(p=dropout_p))
43 | fc_layers.append(nn.ReLU(inplace=True))
44 | self.mlp = nn.Sequential(*fc_layers)
45 |
46 | def gumbel_softmax(self, logits, temperature, hard):
47 | """Sample from the Gumbel-Softmax distribution and optionally discretize.
48 | Args:
49 | logits: [batch_size, n_class] unnormalized log-probs
50 | temperature: non-negative scalar
51 | hard: if True, take argmax, but differentiate w.r.t. soft sample y
52 | Returns:
53 | [batch_size, n_class] sample from the Gumbel-Softmax distribution.
54 | If hard=True, then the returned sample will be one-hot, otherwise it will
55 | be a probabilitiy distribution that sums to 1 across classes
56 | """
57 | y = self.gumbel_softmax_sample(logits, temperature) ## (0.6, 0.2, 0.1,..., 0.11)
58 | if hard:
59 | k = logits.size(1) # k is numb of classes
60 | # y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype) ## (1, 0, 0, ..., 0)
61 | y_hard = torch.eq(y, torch.max(y, dim=1, keepdim=True)[0]).type_as(y)
62 | y = (y_hard - y).detach() + y
63 | return y
64 |
65 | def gumbel_softmax_sample(self, logits, temperature):
66 | """ Draw a sample from the Gumbel-Softmax distribution"""
67 | noise = self.sample_gumbel(logits)
68 | y = (logits + noise) / temperature
69 | return F.softmax(y, dim=1)
70 |
71 | def sample_gumbel(self, logits):
72 | """Sample from Gumbel(0, 1)"""
73 | noise = torch.rand(logits.size())
74 | eps = 1e-20
75 | noise.add_(eps).log_().neg_()
76 | noise.add_(eps).log_().neg_()
77 | return torch.Tensor(noise.float()).to(logits.device)
78 |
79 | def forward(self, feature, temperature, hard):
80 | x = self.mlp(feature)
81 | out = self.gumbel_softmax(x, temperature, hard)
82 | out_value = out.unsqueeze(2)
83 | gating_out = out_value.repeat(1, 1, self.embedding_dim)
84 | return gating_out
85 |
86 |
87 | class HMLET(GeneralGraphRecommender):
88 | r"""HMLET combines both linear and non-linear propagation layers for general recommendation and yields better performance.
89 | """
90 | input_type = InputType.PAIRWISE
91 |
92 | def __init__(self, config, dataset):
93 | super(HMLET, self).__init__(config, dataset)
94 |
95 | # load parameters info
96 | self.latent_dim = config['embedding_size'] # int type:the embedding size of lightGCN
97 | self.n_layers = config['n_layers'] # int type:the layer num of lightGCN
98 | self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization
99 | self.require_pow = config['require_pow'] # bool type: whether to require pow when regularization
100 | self.gate_layer_ids = config['gate_layer_ids'] # list type: layer ids for non-linear gating
101 | self.gating_mlp_dims = config['gating_mlp_dims'] # list type: list of mlp dimensions in gating module
102 | self.dropout_ratio = config['dropout_ratio'] # dropout ratio for mlp in gating module
103 | self.gum_temp = config['ori_temp']
104 | self.logger.info(f'Model initialization, gumbel softmax temperature: {self.gum_temp}')
105 |
106 | # define layers and loss
107 | self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
108 | self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
109 | self.gcn_conv = LightGCNConv(dim=self.latent_dim)
110 | self.activation = nn.ELU() if config['activation_function'] == 'elu' else activation_layer(config['activation_function'])
111 | self.gating_nets = nn.ModuleList([
112 | Gating_Net(self.latent_dim, self.gating_mlp_dims, self.dropout_ratio) for _ in range(len(self.gate_layer_ids))
113 | ])
114 |
115 | self.mf_loss = BPRLoss()
116 | self.reg_loss = EmbLoss()
117 |
118 | # storage variables for full sort evaluation acceleration
119 | self.restore_user_e = None
120 | self.restore_item_e = None
121 |
122 | # parameters initialization
123 | self.apply(xavier_uniform_initialization)
124 | self.other_parameter_name = ['restore_user_e', 'restore_item_e', 'gum_temp']
125 |
126 | for gating in self.gating_nets:
127 | self._gating_freeze(gating, False)
128 |
129 | def _gating_freeze(self, model, freeze_flag):
130 | for name, child in model.named_children():
131 | for param in child.parameters():
132 | param.requires_grad = freeze_flag
133 |
134 | def __choosing_one(self, features, gumbel_out):
135 | feature = torch.sum(torch.mul(features, gumbel_out), dim=1) # batch x embedding_dim (or batch x embedding_dim x layer_num)
136 | return feature
137 |
138 | def __where(self, idx, lst):
139 | for i in range(len(lst)):
140 | if lst[i] == idx:
141 | return i
142 | raise ValueError(f'{idx} not in {lst}.')
143 |
144 | def get_ego_embeddings(self):
145 | r"""Get the embedding of users and items and combine to an embedding matrix.
146 | Returns:
147 | Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim]
148 | """
149 | user_embeddings = self.user_embedding.weight
150 | item_embeddings = self.item_embedding.weight
151 | ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
152 | return ego_embeddings
153 |
154 | def forward(self):
155 | all_embeddings = self.get_ego_embeddings()
156 | embeddings_list = [all_embeddings]
157 | non_lin_emb_list = [all_embeddings]
158 |
159 | for layer_idx in range(self.n_layers):
160 | linear_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
161 | if layer_idx not in self.gate_layer_ids:
162 | all_embeddings = linear_embeddings
163 | else:
164 | non_lin_id = self.__where(layer_idx, self.gate_layer_ids)
165 | last_non_lin_emb = non_lin_emb_list[non_lin_id]
166 | non_lin_embeddings = self.activation(self.gcn_conv(last_non_lin_emb, self.edge_index, self.edge_weight))
167 | stack_embeddings = torch.stack([linear_embeddings, non_lin_embeddings], dim=1)
168 | concat_embeddings = torch.cat((linear_embeddings, non_lin_embeddings), dim=-1)
169 | gumbel_out = self.gating_nets[non_lin_id](concat_embeddings, self.gum_temp, not self.training)
170 | all_embeddings = self.__choosing_one(stack_embeddings, gumbel_out)
171 | non_lin_emb_list.append(all_embeddings)
172 | embeddings_list.append(all_embeddings)
173 | hmlet_all_embeddings = torch.stack(embeddings_list, dim=1)
174 | hmlet_all_embeddings = torch.mean(hmlet_all_embeddings, dim=1)
175 |
176 | user_all_embeddings, item_all_embeddings = torch.split(hmlet_all_embeddings, [self.n_users, self.n_items])
177 | return user_all_embeddings, item_all_embeddings
178 |
179 | def calculate_loss(self, interaction):
180 | # clear the storage variable when training
181 | if self.restore_user_e is not None or self.restore_item_e is not None:
182 | self.restore_user_e, self.restore_item_e = None, None
183 |
184 | user = interaction[self.USER_ID]
185 | pos_item = interaction[self.ITEM_ID]
186 | neg_item = interaction[self.NEG_ITEM_ID]
187 |
188 | user_all_embeddings, item_all_embeddings = self.forward()
189 | u_embeddings = user_all_embeddings[user]
190 | pos_embeddings = item_all_embeddings[pos_item]
191 | neg_embeddings = item_all_embeddings[neg_item]
192 |
193 | # calculate BPR Loss
194 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
195 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
196 | mf_loss = self.mf_loss(pos_scores, neg_scores)
197 |
198 | # calculate regularization Loss
199 | u_ego_embeddings = self.user_embedding(user)
200 | pos_ego_embeddings = self.item_embedding(pos_item)
201 | neg_ego_embeddings = self.item_embedding(neg_item)
202 |
203 | reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)
204 | loss = mf_loss + self.reg_weight * reg_loss
205 |
206 | return loss
207 |
208 | def predict(self, interaction):
209 | user = interaction[self.USER_ID]
210 | item = interaction[self.ITEM_ID]
211 |
212 | user_all_embeddings, item_all_embeddings = self.forward()
213 |
214 | u_embeddings = user_all_embeddings[user]
215 | i_embeddings = item_all_embeddings[item]
216 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
217 | return scores
218 |
219 | def full_sort_predict(self, interaction):
220 | user = interaction[self.USER_ID]
221 | if self.restore_user_e is None or self.restore_item_e is None:
222 | self.restore_user_e, self.restore_item_e = self.forward()
223 | # get user embedding from storage variable
224 | u_embeddings = self.restore_user_e[user]
225 |
226 | # dot with all item embedding to accelerate
227 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))
228 |
229 | return scores.view(-1)
--------------------------------------------------------------------------------
/recbole_gnn/model/general_recommender/lightgcl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2023/04/12
3 | # @Author : Wanli Yang
4 | # @Email : 2013774@mail.nankai.edu.cn
5 |
6 | r"""
7 | LightGCL
8 | ################################################
9 | Reference:
10 | Xuheng Cai et al. "LightGCL: Simple Yet Effective Graph Contrastive Learning for Recommendation" in ICLR 2023.
11 |
12 | Reference code:
13 | https://github.com/HKUDS/LightGCL
14 | """
15 |
16 | import numpy as np
17 | import scipy.sparse as sp
18 | import torch
19 | import torch.nn as nn
20 | from recbole.model.abstract_recommender import GeneralRecommender
21 | from recbole.model.init import xavier_uniform_initialization
22 | from recbole.model.loss import EmbLoss
23 | from recbole.utils import InputType
24 | import torch.nn.functional as F
25 |
26 |
27 | class LightGCL(GeneralRecommender):
28 | r"""LightGCL is a GCN-based recommender model.
29 |
30 | LightGCL guides graph augmentation by singular value decomposition (SVD) to not only
31 | distill the useful information of user-item interactions but also inject the global
32 | collaborative context into the representation alignment of contrastive learning.
33 |
34 | We implement the model following the original author with a pairwise training mode.
35 | """
36 | input_type = InputType.PAIRWISE
37 |
38 | def __init__(self, config, dataset):
39 | super(LightGCL, self).__init__(config, dataset)
40 | self._user = dataset.inter_feat[dataset.uid_field]
41 | self._item = dataset.inter_feat[dataset.iid_field]
42 |
43 | # load parameters info
44 | self.embed_dim = config["embedding_size"]
45 | self.n_layers = config["n_layers"]
46 | self.dropout = config["dropout"]
47 | self.temp = config["temp"]
48 | self.lambda_1 = config["lambda1"]
49 | self.lambda_2 = config["lambda2"]
50 | self.q = config["q"]
51 | self.act = nn.LeakyReLU(0.5)
52 | self.reg_loss = EmbLoss()
53 |
54 | # get the normalized adjust matrix
55 | self.adj_norm = self.coo2tensor(self.create_adjust_matrix())
56 |
57 | # perform svd reconstruction
58 | svd_u, s, svd_v = torch.svd_lowrank(self.adj_norm, q=self.q)
59 | self.u_mul_s = svd_u @ (torch.diag(s))
60 | self.v_mul_s = svd_v @ (torch.diag(s))
61 | del s
62 | self.ut = svd_u.T
63 | self.vt = svd_v.T
64 |
65 | self.E_u_0 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.n_users, self.embed_dim)))
66 | self.E_i_0 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.n_items, self.embed_dim)))
67 | self.E_u_list = [None] * (self.n_layers + 1)
68 | self.E_i_list = [None] * (self.n_layers + 1)
69 | self.E_u_list[0] = self.E_u_0
70 | self.E_i_list[0] = self.E_i_0
71 | self.Z_u_list = [None] * (self.n_layers + 1)
72 | self.Z_i_list = [None] * (self.n_layers + 1)
73 | self.G_u_list = [None] * (self.n_layers + 1)
74 | self.G_i_list = [None] * (self.n_layers + 1)
75 | self.G_u_list[0] = self.E_u_0
76 | self.G_i_list[0] = self.E_i_0
77 |
78 | self.E_u = None
79 | self.E_i = None
80 | self.restore_user_e = None
81 | self.restore_item_e = None
82 |
83 | self.apply(xavier_uniform_initialization)
84 | self.other_parameter_name = ['restore_user_e', 'restore_item_e']
85 |
86 | def create_adjust_matrix(self):
87 | r"""Get the normalized interaction matrix of users and items.
88 |
89 | Returns:
90 | coo_matrix of the normalized interaction matrix.
91 | """
92 | ratings = np.ones_like(self._user, dtype=np.float32)
93 | matrix = sp.csr_matrix(
94 | (ratings, (self._user, self._item)),
95 | shape=(self.n_users, self.n_items),
96 | ).tocoo()
97 | rowD = np.squeeze(np.array(matrix.sum(1)), axis=1)
98 | colD = np.squeeze(np.array(matrix.sum(0)), axis=0)
99 | for i in range(len(matrix.data)):
100 | matrix.data[i] = matrix.data[i] / pow(rowD[matrix.row[i]] * colD[matrix.col[i]], 0.5)
101 | return matrix
102 |
103 | def coo2tensor(self, matrix: sp.coo_matrix):
104 | r"""Convert coo_matrix to tensor.
105 |
106 | Args:
107 | matrix (scipy.coo_matrix): Sparse matrix to be converted.
108 |
109 | Returns:
110 | torch.sparse.FloatTensor: Transformed sparse matrix.
111 | """
112 | indices = torch.from_numpy(
113 | np.vstack((matrix.row, matrix.col)).astype(np.int64))
114 | values = torch.from_numpy(matrix.data)
115 | shape = torch.Size(matrix.shape)
116 | x = torch.sparse.FloatTensor(indices, values, shape).coalesce().to(self.device)
117 | return x
118 |
119 | def sparse_dropout(self, matrix, dropout):
120 | if dropout == 0.0:
121 | return matrix
122 | indices = matrix.indices()
123 | values = F.dropout(matrix.values(), p=dropout)
124 | size = matrix.size()
125 | return torch.sparse.FloatTensor(indices, values, size)
126 |
127 | def forward(self):
128 | for layer in range(1, self.n_layers + 1):
129 | # GNN propagation
130 | self.Z_u_list[layer] = torch.spmm(self.sparse_dropout(self.adj_norm, self.dropout),
131 | self.E_i_list[layer - 1])
132 | self.Z_i_list[layer] = torch.spmm(self.sparse_dropout(self.adj_norm, self.dropout).transpose(0, 1),
133 | self.E_u_list[layer - 1])
134 | # aggregate
135 | self.E_u_list[layer] = self.Z_u_list[layer]
136 | self.E_i_list[layer] = self.Z_i_list[layer]
137 |
138 | # aggregate across layer
139 | self.E_u = sum(self.E_u_list)
140 | self.E_i = sum(self.E_i_list)
141 |
142 | return self.E_u, self.E_i
143 |
144 | def calculate_loss(self, interaction):
145 | if self.restore_user_e is not None or self.restore_item_e is not None:
146 | self.restore_user_e, self.restore_item_e = None, None
147 |
148 | user_list = interaction[self.USER_ID]
149 | pos_item_list = interaction[self.ITEM_ID]
150 | neg_item_list = interaction[self.NEG_ITEM_ID]
151 | E_u_norm, E_i_norm = self.forward()
152 | bpr_loss = self.calc_bpr_loss(E_u_norm, E_i_norm, user_list, pos_item_list, neg_item_list)
153 | ssl_loss = self.calc_ssl_loss(E_u_norm, E_i_norm, user_list, pos_item_list)
154 | total_loss = bpr_loss + ssl_loss
155 | return total_loss
156 |
157 | def calc_bpr_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list, neg_item_list):
158 | r"""Calculate the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss.
159 |
160 | Args:
161 | E_u_norm (torch.Tensor): Ego embedding of all users after forwarding.
162 | E_i_norm (torch.Tensor): Ego embedding of all items after forwarding.
163 | user_list (torch.Tensor): List of the user.
164 | pos_item_list (torch.Tensor): List of positive examples.
165 | neg_item_list (torch.Tensor): List of negative examples.
166 |
167 | Returns:
168 | torch.Tensor: Loss of BPR tasks and parameter regularization.
169 | """
170 | u_e = E_u_norm[user_list]
171 | pi_e = E_i_norm[pos_item_list]
172 | ni_e = E_i_norm[neg_item_list]
173 | pos_scores = torch.mul(u_e, pi_e).sum(dim=1)
174 | neg_scores = torch.mul(u_e, ni_e).sum(dim=1)
175 | loss1 = -(pos_scores - neg_scores).sigmoid().log().mean()
176 |
177 | # reg loss
178 | loss_reg = 0
179 | for param in self.parameters():
180 | loss_reg += param.norm(2).square()
181 | loss_reg *= self.lambda_2
182 | return loss1 + loss_reg
183 |
184 | def calc_ssl_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list):
185 | r"""Calculate the loss of self-supervised tasks.
186 |
187 | Args:
188 | E_u_norm (torch.Tensor): Ego embedding of all users in the original graph after forwarding.
189 | E_i_norm (torch.Tensor): Ego embedding of all items in the original graph after forwarding.
190 | user_list (torch.Tensor): List of the user.
191 | pos_item_list (torch.Tensor): List of positive examples.
192 |
193 | Returns:
194 | torch.Tensor: Loss of self-supervised tasks.
195 | """
196 | # calculate G_u_norm&G_i_norm
197 | for layer in range(1, self.n_layers + 1):
198 | # svd_adj propagation
199 | vt_ei = self.vt @ self.E_i_list[layer - 1]
200 | self.G_u_list[layer] = self.u_mul_s @ vt_ei
201 | ut_eu = self.ut @ self.E_u_list[layer - 1]
202 | self.G_i_list[layer] = self.v_mul_s @ ut_eu
203 |
204 | # aggregate across layer
205 | G_u_norm = sum(self.G_u_list)
206 | G_i_norm = sum(self.G_i_list)
207 |
208 | neg_score = torch.log(torch.exp(G_u_norm[user_list] @ E_u_norm.T / self.temp).sum(1) + 1e-8).mean()
209 | neg_score += torch.log(torch.exp(G_i_norm[pos_item_list] @ E_i_norm.T / self.temp).sum(1) + 1e-8).mean()
210 | pos_score = (torch.clamp((G_u_norm[user_list] * E_u_norm[user_list]).sum(1) / self.temp, -5.0, 5.0)).mean() + (
211 | torch.clamp((G_i_norm[pos_item_list] * E_i_norm[pos_item_list]).sum(1) / self.temp, -5.0, 5.0)).mean()
212 | ssl_loss = -pos_score + neg_score
213 | return self.lambda_1 * ssl_loss
214 |
215 | def predict(self, interaction):
216 | if self.restore_user_e is None or self.restore_item_e is None:
217 | self.restore_user_e, self.restore_item_e = self.forward()
218 | user = self.restore_user_e[interaction[self.USER_ID]]
219 | item = self.restore_item_e[interaction[self.ITEM_ID]]
220 | return torch.sum(user * item, dim=1)
221 |
222 | def full_sort_predict(self, interaction):
223 | if self.restore_user_e is None or self.restore_item_e is None:
224 | self.restore_user_e, self.restore_item_e = self.forward()
225 | user = self.restore_user_e[interaction[self.USER_ID]]
226 | return user.matmul(self.restore_item_e.T)
227 |
--------------------------------------------------------------------------------
/recbole_gnn/model/general_recommender/lightgcn.py:
--------------------------------------------------------------------------------
1 | # @Time : 2022/3/8
2 | # @Author : Lanling Xu
3 | # @Email : xulanling_sherry@163.com
4 |
5 | r"""
6 | LightGCN
7 | ################################################
8 | Reference:
9 | Xiangnan He et al. "LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation." in SIGIR 2020.
10 |
11 | Reference code:
12 | https://github.com/kuandeng/LightGCN
13 | """
14 |
15 | import numpy as np
16 | import torch
17 |
18 | from recbole.model.init import xavier_uniform_initialization
19 | from recbole.model.loss import BPRLoss, EmbLoss
20 | from recbole.utils import InputType
21 |
22 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
23 | from recbole_gnn.model.layers import LightGCNConv
24 |
25 |
26 | class LightGCN(GeneralGraphRecommender):
27 | r"""LightGCN is a GCN-based recommender model, implemented via PyG.
28 | LightGCN includes only the most essential component in GCN — neighborhood aggregation — for
29 | collaborative filtering. Specifically, LightGCN learns user and item embeddings by linearly
30 | propagating them on the user-item interaction graph, and uses the weighted sum of the embeddings
31 | learned at all layers as the final embedding.
32 | We implement the model following the original author with a pairwise training mode.
33 | """
34 | input_type = InputType.PAIRWISE
35 |
36 | def __init__(self, config, dataset):
37 | super(LightGCN, self).__init__(config, dataset)
38 |
39 | # load parameters info
40 | self.latent_dim = config['embedding_size'] # int type:the embedding size of lightGCN
41 | self.n_layers = config['n_layers'] # int type:the layer num of lightGCN
42 | self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization
43 | self.require_pow = config['require_pow'] # bool type: whether to require pow when regularization
44 |
45 | # define layers and loss
46 | self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
47 | self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
48 | self.gcn_conv = LightGCNConv(dim=self.latent_dim)
49 | self.mf_loss = BPRLoss()
50 | self.reg_loss = EmbLoss()
51 |
52 | # storage variables for full sort evaluation acceleration
53 | self.restore_user_e = None
54 | self.restore_item_e = None
55 |
56 | # parameters initialization
57 | self.apply(xavier_uniform_initialization)
58 | self.other_parameter_name = ['restore_user_e', 'restore_item_e']
59 |
60 | def get_ego_embeddings(self):
61 | r"""Get the embedding of users and items and combine to an embedding matrix.
62 | Returns:
63 | Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim]
64 | """
65 | user_embeddings = self.user_embedding.weight
66 | item_embeddings = self.item_embedding.weight
67 | ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
68 | return ego_embeddings
69 |
70 | def forward(self):
71 | all_embeddings = self.get_ego_embeddings()
72 | embeddings_list = [all_embeddings]
73 |
74 | for layer_idx in range(self.n_layers):
75 | all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
76 | embeddings_list.append(all_embeddings)
77 | lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
78 | lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)
79 |
80 | user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
81 | return user_all_embeddings, item_all_embeddings
82 |
83 | def calculate_loss(self, interaction):
84 | # clear the storage variable when training
85 | if self.restore_user_e is not None or self.restore_item_e is not None:
86 | self.restore_user_e, self.restore_item_e = None, None
87 |
88 | user = interaction[self.USER_ID]
89 | pos_item = interaction[self.ITEM_ID]
90 | neg_item = interaction[self.NEG_ITEM_ID]
91 |
92 | user_all_embeddings, item_all_embeddings = self.forward()
93 | u_embeddings = user_all_embeddings[user]
94 | pos_embeddings = item_all_embeddings[pos_item]
95 | neg_embeddings = item_all_embeddings[neg_item]
96 |
97 | # calculate BPR Loss
98 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
99 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
100 | mf_loss = self.mf_loss(pos_scores, neg_scores)
101 |
102 | # calculate regularization Loss
103 | u_ego_embeddings = self.user_embedding(user)
104 | pos_ego_embeddings = self.item_embedding(pos_item)
105 | neg_ego_embeddings = self.item_embedding(neg_item)
106 |
107 | reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)
108 | loss = mf_loss + self.reg_weight * reg_loss
109 |
110 | return loss
111 |
112 | def predict(self, interaction):
113 | user = interaction[self.USER_ID]
114 | item = interaction[self.ITEM_ID]
115 |
116 | user_all_embeddings, item_all_embeddings = self.forward()
117 |
118 | u_embeddings = user_all_embeddings[user]
119 | i_embeddings = item_all_embeddings[item]
120 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
121 | return scores
122 |
123 | def full_sort_predict(self, interaction):
124 | user = interaction[self.USER_ID]
125 | if self.restore_user_e is None or self.restore_item_e is None:
126 | self.restore_user_e, self.restore_item_e = self.forward()
127 | # get user embedding from storage variable
128 | u_embeddings = self.restore_user_e[user]
129 |
130 | # dot with all item embedding to accelerate
131 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))
132 |
133 | return scores.view(-1)
134 |
--------------------------------------------------------------------------------
/recbole_gnn/model/general_recommender/ncl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | r"""
3 | NCL
4 | ################################################
5 | Reference:
6 | Zihan Lin*, Changxin Tian*, Yupeng Hou*, Wayne Xin Zhao. "Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning." in WWW 2022.
7 | """
8 |
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | from recbole.model.init import xavier_uniform_initialization
13 | from recbole.model.loss import BPRLoss, EmbLoss
14 | from recbole.utils import InputType
15 |
16 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
17 | from recbole_gnn.model.layers import LightGCNConv
18 |
19 |
20 | class NCL(GeneralGraphRecommender):
21 | input_type = InputType.PAIRWISE
22 |
23 | def __init__(self, config, dataset):
24 | super(NCL, self).__init__(config, dataset)
25 |
26 | # load parameters info
27 | self.latent_dim = config['embedding_size'] # int type: the embedding size of the base model
28 | self.n_layers = config['n_layers'] # int type: the layer num of the base model
29 | self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization
30 |
31 | self.ssl_temp = config['ssl_temp']
32 | self.ssl_reg = config['ssl_reg']
33 | self.hyper_layers = config['hyper_layers']
34 |
35 | self.alpha = config['alpha']
36 |
37 | self.proto_reg = config['proto_reg']
38 | self.k = config['num_clusters']
39 |
40 | # define layers and loss
41 | self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
42 | self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
43 | self.gcn_conv = LightGCNConv(dim=self.latent_dim)
44 | self.mf_loss = BPRLoss()
45 | self.reg_loss = EmbLoss()
46 |
47 | # storage variables for full sort evaluation acceleration
48 | self.restore_user_e = None
49 | self.restore_item_e = None
50 |
51 | # parameters initialization
52 | self.apply(xavier_uniform_initialization)
53 | self.other_parameter_name = ['restore_user_e', 'restore_item_e']
54 |
55 | self.user_centroids = None
56 | self.user_2cluster = None
57 | self.item_centroids = None
58 | self.item_2cluster = None
59 |
60 | def e_step(self):
61 | user_embeddings = self.user_embedding.weight.detach().cpu().numpy()
62 | item_embeddings = self.item_embedding.weight.detach().cpu().numpy()
63 | self.user_centroids, self.user_2cluster = self.run_kmeans(user_embeddings)
64 | self.item_centroids, self.item_2cluster = self.run_kmeans(item_embeddings)
65 |
66 | def run_kmeans(self, x):
67 | """Run K-means algorithm to get k clusters of the input tensor x
68 | """
69 | import faiss
70 | kmeans = faiss.Kmeans(d=self.latent_dim, k=self.k, gpu=True)
71 | kmeans.train(x)
72 | cluster_cents = kmeans.centroids
73 |
74 | _, I = kmeans.index.search(x, 1)
75 |
76 | # convert to cuda Tensors for broadcast
77 | centroids = torch.Tensor(cluster_cents).to(self.device)
78 | centroids = F.normalize(centroids, p=2, dim=1)
79 |
80 | node2cluster = torch.LongTensor(I).squeeze().to(self.device)
81 | return centroids, node2cluster
82 |
83 | def get_ego_embeddings(self):
84 | r"""Get the embedding of users and items and combine to an embedding matrix.
85 | Returns:
86 | Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim]
87 | """
88 | user_embeddings = self.user_embedding.weight
89 | item_embeddings = self.item_embedding.weight
90 | ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
91 | return ego_embeddings
92 |
93 | def forward(self):
94 | all_embeddings = self.get_ego_embeddings()
95 | embeddings_list = [all_embeddings]
96 | for layer_idx in range(max(self.n_layers, self.hyper_layers * 2)):
97 | all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
98 | embeddings_list.append(all_embeddings)
99 |
100 | lightgcn_all_embeddings = torch.stack(embeddings_list[:self.n_layers + 1], dim=1)
101 | lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)
102 |
103 | user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
104 | return user_all_embeddings, item_all_embeddings, embeddings_list
105 |
106 | def ProtoNCE_loss(self, node_embedding, user, item):
107 | user_embeddings_all, item_embeddings_all = torch.split(node_embedding, [self.n_users, self.n_items])
108 |
109 | user_embeddings = user_embeddings_all[user] # [B, e]
110 | norm_user_embeddings = F.normalize(user_embeddings)
111 |
112 | user2cluster = self.user_2cluster[user] # [B,]
113 | user2centroids = self.user_centroids[user2cluster] # [B, e]
114 | pos_score_user = torch.mul(norm_user_embeddings, user2centroids).sum(dim=1)
115 | pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
116 | ttl_score_user = torch.matmul(norm_user_embeddings, self.user_centroids.transpose(0, 1))
117 | ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)
118 |
119 | proto_nce_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()
120 |
121 | item_embeddings = item_embeddings_all[item]
122 | norm_item_embeddings = F.normalize(item_embeddings)
123 |
124 | item2cluster = self.item_2cluster[item] # [B, ]
125 | item2centroids = self.item_centroids[item2cluster] # [B, e]
126 | pos_score_item = torch.mul(norm_item_embeddings, item2centroids).sum(dim=1)
127 | pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
128 | ttl_score_item = torch.matmul(norm_item_embeddings, self.item_centroids.transpose(0, 1))
129 | ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)
130 | proto_nce_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()
131 |
132 | proto_nce_loss = self.proto_reg * (proto_nce_loss_user + proto_nce_loss_item)
133 | return proto_nce_loss
134 |
135 | def ssl_layer_loss(self, current_embedding, previous_embedding, user, item):
136 | current_user_embeddings, current_item_embeddings = torch.split(current_embedding, [self.n_users, self.n_items])
137 | previous_user_embeddings_all, previous_item_embeddings_all = torch.split(previous_embedding, [self.n_users, self.n_items])
138 |
139 | current_user_embeddings = current_user_embeddings[user]
140 | previous_user_embeddings = previous_user_embeddings_all[user]
141 | norm_user_emb1 = F.normalize(current_user_embeddings)
142 | norm_user_emb2 = F.normalize(previous_user_embeddings)
143 | norm_all_user_emb = F.normalize(previous_user_embeddings_all)
144 | pos_score_user = torch.mul(norm_user_emb1, norm_user_emb2).sum(dim=1)
145 | ttl_score_user = torch.matmul(norm_user_emb1, norm_all_user_emb.transpose(0, 1))
146 | pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
147 | ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)
148 |
149 | ssl_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()
150 |
151 | current_item_embeddings = current_item_embeddings[item]
152 | previous_item_embeddings = previous_item_embeddings_all[item]
153 | norm_item_emb1 = F.normalize(current_item_embeddings)
154 | norm_item_emb2 = F.normalize(previous_item_embeddings)
155 | norm_all_item_emb = F.normalize(previous_item_embeddings_all)
156 | pos_score_item = torch.mul(norm_item_emb1, norm_item_emb2).sum(dim=1)
157 | ttl_score_item = torch.matmul(norm_item_emb1, norm_all_item_emb.transpose(0, 1))
158 | pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
159 | ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)
160 |
161 | ssl_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()
162 |
163 | ssl_loss = self.ssl_reg * (ssl_loss_user + self.alpha * ssl_loss_item)
164 | return ssl_loss
165 |
166 | def calculate_loss(self, interaction):
167 | # clear the storage variable when training
168 | if self.restore_user_e is not None or self.restore_item_e is not None:
169 | self.restore_user_e, self.restore_item_e = None, None
170 |
171 | user = interaction[self.USER_ID]
172 | pos_item = interaction[self.ITEM_ID]
173 | neg_item = interaction[self.NEG_ITEM_ID]
174 |
175 | user_all_embeddings, item_all_embeddings, embeddings_list = self.forward()
176 |
177 | center_embedding = embeddings_list[0]
178 | context_embedding = embeddings_list[self.hyper_layers * 2]
179 |
180 | ssl_loss = self.ssl_layer_loss(context_embedding, center_embedding, user, pos_item)
181 | proto_loss = self.ProtoNCE_loss(center_embedding, user, pos_item)
182 |
183 | u_embeddings = user_all_embeddings[user]
184 | pos_embeddings = item_all_embeddings[pos_item]
185 | neg_embeddings = item_all_embeddings[neg_item]
186 |
187 | # calculate BPR Loss
188 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
189 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
190 |
191 | mf_loss = self.mf_loss(pos_scores, neg_scores)
192 |
193 | u_ego_embeddings = self.user_embedding(user)
194 | pos_ego_embeddings = self.item_embedding(pos_item)
195 | neg_ego_embeddings = self.item_embedding(neg_item)
196 |
197 | reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings)
198 |
199 | return mf_loss + self.reg_weight * reg_loss, ssl_loss, proto_loss
200 |
201 | def predict(self, interaction):
202 | user = interaction[self.USER_ID]
203 | item = interaction[self.ITEM_ID]
204 |
205 | user_all_embeddings, item_all_embeddings, embeddings_list = self.forward()
206 |
207 | u_embeddings = user_all_embeddings[user]
208 | i_embeddings = item_all_embeddings[item]
209 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
210 | return scores
211 |
212 | def full_sort_predict(self, interaction):
213 | user = interaction[self.USER_ID]
214 | if self.restore_user_e is None or self.restore_item_e is None:
215 | self.restore_user_e, self.restore_item_e, embedding_list = self.forward()
216 | # get user embedding from storage variable
217 | u_embeddings = self.restore_user_e[user]
218 |
219 | # dot with all item embedding to accelerate
220 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))
221 |
222 | return scores.view(-1)
223 |
--------------------------------------------------------------------------------
/recbole_gnn/model/general_recommender/ngcf.py:
--------------------------------------------------------------------------------
1 | # @Time : 2022/3/8
2 | # @Author : Changxin Tian
3 | # @Email : cx.tian@outlook.com
4 | r"""
5 | NGCF
6 | ################################################
7 | Reference:
8 | Xiang Wang et al. "Neural Graph Collaborative Filtering." in SIGIR 2019.
9 |
10 | Reference code:
11 | https://github.com/xiangwang1223/neural_graph_collaborative_filtering
12 |
13 | """
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | from torch_geometric.utils import dropout_adj
19 |
20 | from recbole.model.init import xavier_normal_initialization
21 | from recbole.model.loss import BPRLoss, EmbLoss
22 | from recbole.utils import InputType
23 |
24 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
25 | from recbole_gnn.model.layers import BiGNNConv
26 |
27 |
28 | class NGCF(GeneralGraphRecommender):
29 | r"""NGCF is a model that incorporate GNN for recommendation.
30 | We implement the model following the original author with a pairwise training mode.
31 | """
32 | input_type = InputType.PAIRWISE
33 |
34 | def __init__(self, config, dataset):
35 | super(NGCF, self).__init__(config, dataset)
36 |
37 | # load parameters info
38 | self.embedding_size = config['embedding_size']
39 | self.hidden_size_list = config['hidden_size_list']
40 | self.hidden_size_list = [self.embedding_size] + self.hidden_size_list
41 | self.node_dropout = config['node_dropout']
42 | self.message_dropout = config['message_dropout']
43 | self.reg_weight = config['reg_weight']
44 |
45 | # define layers and loss
46 | self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
47 | self.item_embedding = nn.Embedding(self.n_items, self.embedding_size)
48 | self.GNNlayers = torch.nn.ModuleList()
49 | for input_size, output_size in zip(self.hidden_size_list[:-1], self.hidden_size_list[1:]):
50 | self.GNNlayers.append(BiGNNConv(input_size, output_size))
51 | self.mf_loss = BPRLoss()
52 | self.reg_loss = EmbLoss()
53 |
54 | # storage variables for full sort evaluation acceleration
55 | self.restore_user_e = None
56 | self.restore_item_e = None
57 |
58 | # parameters initialization
59 | self.apply(xavier_normal_initialization)
60 | self.other_parameter_name = ['restore_user_e', 'restore_item_e']
61 |
62 | def get_ego_embeddings(self):
63 | r"""Get the embedding of users and items and combine to an embedding matrix.
64 |
65 | Returns:
66 | Tensor of the embedding matrix. Shape of (n_items+n_users, embedding_dim)
67 | """
68 | user_embeddings = self.user_embedding.weight
69 | item_embeddings = self.item_embedding.weight
70 | ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
71 | return ego_embeddings
72 |
73 | def forward(self):
74 | if self.node_dropout == 0:
75 | edge_index, edge_weight = self.edge_index, self.edge_weight
76 | else:
77 | edge_index, edge_weight = self.edge_index, self.edge_weight
78 | if self.use_sparse:
79 | row, col, edge_weight = edge_index.t().coo()
80 | edge_index = torch.stack([row, col], 0)
81 | edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight,
82 | p=self.node_dropout, training=self.training)
83 | from torch_sparse import SparseTensor
84 | edge_index = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight,
85 | sparse_sizes=(self.n_users + self.n_items, self.n_users + self.n_items))
86 | edge_index = edge_index.t()
87 | edge_weight = None
88 | else:
89 | edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight,
90 | p=self.node_dropout, training=self.training)
91 |
92 | all_embeddings = self.get_ego_embeddings()
93 | embeddings_list = [all_embeddings]
94 | for gnn in self.GNNlayers:
95 | all_embeddings = gnn(all_embeddings, edge_index, edge_weight)
96 | all_embeddings = nn.LeakyReLU(negative_slope=0.2)(all_embeddings)
97 | all_embeddings = nn.Dropout(self.message_dropout)(all_embeddings)
98 | all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
99 | embeddings_list += [all_embeddings] # storage output embedding of each layer
100 | ngcf_all_embeddings = torch.cat(embeddings_list, dim=1)
101 |
102 | user_all_embeddings, item_all_embeddings = torch.split(ngcf_all_embeddings, [self.n_users, self.n_items])
103 |
104 | return user_all_embeddings, item_all_embeddings
105 |
106 | def calculate_loss(self, interaction):
107 | # clear the storage variable when training
108 | if self.restore_user_e is not None or self.restore_item_e is not None:
109 | self.restore_user_e, self.restore_item_e = None, None
110 |
111 | user = interaction[self.USER_ID]
112 | pos_item = interaction[self.ITEM_ID]
113 | neg_item = interaction[self.NEG_ITEM_ID]
114 |
115 | user_all_embeddings, item_all_embeddings = self.forward()
116 | u_embeddings = user_all_embeddings[user]
117 | pos_embeddings = item_all_embeddings[pos_item]
118 | neg_embeddings = item_all_embeddings[neg_item]
119 |
120 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
121 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
122 | mf_loss = self.mf_loss(pos_scores, neg_scores) # calculate BPR Loss
123 |
124 | reg_loss = self.reg_loss(u_embeddings, pos_embeddings, neg_embeddings) # L2 regularization of embeddings
125 |
126 | return mf_loss + self.reg_weight * reg_loss
127 |
128 | def predict(self, interaction):
129 | user = interaction[self.USER_ID]
130 | item = interaction[self.ITEM_ID]
131 |
132 | user_all_embeddings, item_all_embeddings = self.forward()
133 |
134 | u_embeddings = user_all_embeddings[user]
135 | i_embeddings = item_all_embeddings[item]
136 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
137 | return scores
138 |
139 | def full_sort_predict(self, interaction):
140 | user = interaction[self.USER_ID]
141 | if self.restore_user_e is None or self.restore_item_e is None:
142 | self.restore_user_e, self.restore_item_e = self.forward()
143 | # get user embedding from storage variable
144 | u_embeddings = self.restore_user_e[user]
145 |
146 | # dot with all item embedding to accelerate
147 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))
148 |
149 | return scores.view(-1)
150 |
--------------------------------------------------------------------------------
/recbole_gnn/model/general_recommender/sgl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/3/8
3 | # @Author : Changxin Tian
4 | # @Email : cx.tian@outlook.com
5 | r"""
6 | SGL
7 | ################################################
8 | Reference:
9 | Jiancan Wu et al. "SGL: Self-supervised Graph Learning for Recommendation" in SIGIR 2021.
10 |
11 | Reference code:
12 | https://github.com/wujcan/SGL
13 | """
14 |
15 | import numpy as np
16 | import torch
17 | import torch.nn.functional as F
18 | from torch_geometric.utils import degree
19 | from torch_geometric.nn.conv.gcn_conv import gcn_norm
20 |
21 | from recbole.model.init import xavier_uniform_initialization
22 | from recbole.model.loss import EmbLoss
23 | from recbole.utils import InputType
24 |
25 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
26 | from recbole_gnn.model.layers import LightGCNConv
27 |
28 |
29 | class SGL(GeneralGraphRecommender):
30 | r"""SGL is a GCN-based recommender model.
31 |
32 | SGL supplements the classical supervised task of recommendation with an auxiliary
33 | self supervised task, which reinforces node representation learning via self-
34 | discrimination.Specifically,SGL generates multiple views of a node, maximizing the
35 | agreement between different views of the same node compared to that of other nodes.
36 | SGL devises three operators to generate the views — node dropout, edge dropout, and
37 | random walk — that change the graph structure in different manners.
38 |
39 | We implement the model following the original author with a pairwise training mode.
40 | """
41 | input_type = InputType.PAIRWISE
42 |
43 | def __init__(self, config, dataset):
44 | super(SGL, self).__init__(config, dataset)
45 |
46 | # load parameters info
47 | self.latent_dim = config["embedding_size"]
48 | self.n_layers = int(config["n_layers"])
49 | self.aug_type = config["type"]
50 | self.drop_ratio = config["drop_ratio"]
51 | self.ssl_tau = config["ssl_tau"]
52 | self.reg_weight = config["reg_weight"]
53 | self.ssl_weight = config["ssl_weight"]
54 |
55 | self._user = dataset.inter_feat[dataset.uid_field]
56 | self._item = dataset.inter_feat[dataset.iid_field]
57 | self.dataset = dataset
58 |
59 | # define layers and loss
60 | self.user_embedding = torch.nn.Embedding(self.n_users, self.latent_dim)
61 | self.item_embedding = torch.nn.Embedding(self.n_items, self.latent_dim)
62 | self.gcn_conv = LightGCNConv(dim=self.latent_dim)
63 | self.reg_loss = EmbLoss()
64 |
65 | # storage variables for full sort evaluation acceleration
66 | self.restore_user_e = None
67 | self.restore_item_e = None
68 |
69 | # parameters initialization
70 | self.apply(xavier_uniform_initialization)
71 | self.other_parameter_name = ['restore_user_e', 'restore_item_e']
72 |
73 | def train(self, mode: bool = True):
74 | r"""Override train method of base class. The subgraph is reconstructed each time it is called.
75 |
76 | """
77 | T = super().train(mode=mode)
78 | if mode:
79 | self.graph_construction()
80 | return T
81 |
82 | def graph_construction(self):
83 | r"""Devise three operators to generate the views — node dropout, edge dropout, and random walk of a node.
84 |
85 | """
86 | if self.aug_type == "ND" or self.aug_type == "ED":
87 | self.sub_graph1 = [self.random_graph_augment()] * self.n_layers
88 | self.sub_graph2 = [self.random_graph_augment()] * self.n_layers
89 | elif self.aug_type == "RW":
90 | self.sub_graph1 = [self.random_graph_augment() for _ in range(self.n_layers)]
91 | self.sub_graph2 = [self.random_graph_augment() for _ in range(self.n_layers)]
92 |
93 | def random_graph_augment(self):
94 | def rand_sample(high, size=None, replace=True):
95 | return np.random.choice(np.arange(high), size=size, replace=replace)
96 |
97 | if self.aug_type == "ND":
98 | drop_user = rand_sample(self.n_users, size=int(self.n_users * self.drop_ratio), replace=False)
99 | drop_item = rand_sample(self.n_items, size=int(self.n_items * self.drop_ratio), replace=False)
100 |
101 | mask = np.isin(self._user.numpy(), drop_user)
102 | mask |= np.isin(self._item.numpy(), drop_item)
103 | keep = np.where(~mask)
104 |
105 | row = self._user[keep]
106 | col = self._item[keep] + self.n_users
107 |
108 | elif self.aug_type == "ED" or self.aug_type == "RW":
109 | keep = rand_sample(len(self._user), size=int(len(self._user) * (1 - self.drop_ratio)), replace=False)
110 | row = self._user[keep]
111 | col = self._item[keep] + self.n_users
112 |
113 | edge_index1 = torch.stack([row, col])
114 | edge_index2 = torch.stack([col, row])
115 | edge_index = torch.cat([edge_index1, edge_index2], dim=1)
116 | edge_weight = torch.ones(edge_index.size(1))
117 | num_nodes = self.n_users + self.n_items
118 |
119 | if self.use_sparse:
120 | adj_t = self.dataset.edge_index_to_adj_t(edge_index, edge_weight, num_nodes, num_nodes)
121 | adj_t = gcn_norm(adj_t, None, num_nodes, add_self_loops=False)
122 | return adj_t.to(self.device), None
123 |
124 | edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=False)
125 |
126 | return edge_index.to(self.device), edge_weight.to(self.device)
127 |
128 | def forward(self, graph=None):
129 | all_embeddings = torch.cat([self.user_embedding.weight, self.item_embedding.weight])
130 | embeddings_list = [all_embeddings]
131 |
132 | if graph is None: # for the original graph
133 | for _ in range(self.n_layers):
134 | all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
135 | embeddings_list.append(all_embeddings)
136 | else: # for the augmented graph
137 | for graph_edge_index, graph_edge_weight in graph:
138 | all_embeddings = self.gcn_conv(all_embeddings, graph_edge_index, graph_edge_weight)
139 | embeddings_list.append(all_embeddings)
140 |
141 | embeddings_list = torch.stack(embeddings_list, dim=1)
142 | embeddings_list = torch.mean(embeddings_list, dim=1, keepdim=False)
143 | user_all_embeddings, item_all_embeddings = torch.split(embeddings_list, [self.n_users, self.n_items], dim=0)
144 |
145 | return user_all_embeddings, item_all_embeddings
146 |
147 | def calc_bpr_loss(self, user_emd, item_emd, user_list, pos_item_list, neg_item_list):
148 | r"""Calculate the the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss.
149 |
150 | Args:
151 | user_emd (torch.Tensor): Ego embedding of all users after forwarding.
152 | item_emd (torch.Tensor): Ego embedding of all items after forwarding.
153 | user_list (torch.Tensor): List of the user.
154 | pos_item_list (torch.Tensor): List of positive examples.
155 | neg_item_list (torch.Tensor): List of negative examples.
156 |
157 | Returns:
158 | torch.Tensor: Loss of BPR tasks and parameter regularization.
159 | """
160 | u_e = user_emd[user_list]
161 | pi_e = item_emd[pos_item_list]
162 | ni_e = item_emd[neg_item_list]
163 | p_scores = torch.mul(u_e, pi_e).sum(dim=1)
164 | n_scores = torch.mul(u_e, ni_e).sum(dim=1)
165 |
166 | l1 = torch.sum(-F.logsigmoid(p_scores - n_scores))
167 |
168 | u_e_p = self.user_embedding(user_list)
169 | pi_e_p = self.item_embedding(pos_item_list)
170 | ni_e_p = self.item_embedding(neg_item_list)
171 |
172 | l2 = self.reg_loss(u_e_p, pi_e_p, ni_e_p)
173 |
174 | return l1 + l2 * self.reg_weight
175 |
176 | def calc_ssl_loss(self, user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2):
177 | r"""Calculate the loss of self-supervised tasks.
178 |
179 | Args:
180 | user_list (torch.Tensor): List of the user.
181 | pos_item_list (torch.Tensor): List of positive examples.
182 | user_sub1 (torch.Tensor): Ego embedding of all users in the first subgraph after forwarding.
183 | user_sub2 (torch.Tensor): Ego embedding of all users in the second subgraph after forwarding.
184 | item_sub1 (torch.Tensor): Ego embedding of all items in the first subgraph after forwarding.
185 | item_sub2 (torch.Tensor): Ego embedding of all items in the second subgraph after forwarding.
186 |
187 | Returns:
188 | torch.Tensor: Loss of self-supervised tasks.
189 | """
190 |
191 | u_emd1 = F.normalize(user_sub1[user_list], dim=1)
192 | u_emd2 = F.normalize(user_sub2[user_list], dim=1)
193 | all_user2 = F.normalize(user_sub2, dim=1)
194 | v1 = torch.sum(u_emd1 * u_emd2, dim=1)
195 | v2 = u_emd1.matmul(all_user2.T)
196 | v1 = torch.exp(v1 / self.ssl_tau)
197 | v2 = torch.sum(torch.exp(v2 / self.ssl_tau), dim=1)
198 | ssl_user = -torch.sum(torch.log(v1 / v2))
199 |
200 | i_emd1 = F.normalize(item_sub1[pos_item_list], dim=1)
201 | i_emd2 = F.normalize(item_sub2[pos_item_list], dim=1)
202 | all_item2 = F.normalize(item_sub2, dim=1)
203 | v3 = torch.sum(i_emd1 * i_emd2, dim=1)
204 | v4 = i_emd1.matmul(all_item2.T)
205 | v3 = torch.exp(v3 / self.ssl_tau)
206 | v4 = torch.sum(torch.exp(v4 / self.ssl_tau), dim=1)
207 | ssl_item = -torch.sum(torch.log(v3 / v4))
208 |
209 | return (ssl_item + ssl_user) * self.ssl_weight
210 |
211 | def calculate_loss(self, interaction):
212 | if self.restore_user_e is not None or self.restore_item_e is not None:
213 | self.restore_user_e, self.restore_item_e = None, None
214 |
215 | user_list = interaction[self.USER_ID]
216 | pos_item_list = interaction[self.ITEM_ID]
217 | neg_item_list = interaction[self.NEG_ITEM_ID]
218 |
219 | user_emd, item_emd = self.forward()
220 | user_sub1, item_sub1 = self.forward(self.sub_graph1)
221 | user_sub2, item_sub2 = self.forward(self.sub_graph2)
222 |
223 | total_loss = self.calc_bpr_loss(user_emd, item_emd, user_list, pos_item_list, neg_item_list) + \
224 | self.calc_ssl_loss(user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2)
225 | return total_loss
226 |
227 | def predict(self, interaction):
228 | if self.restore_user_e is None or self.restore_item_e is None:
229 | self.restore_user_e, self.restore_item_e = self.forward()
230 |
231 | user = self.restore_user_e[interaction[self.USER_ID]]
232 | item = self.restore_item_e[interaction[self.ITEM_ID]]
233 | return torch.sum(user * item, dim=1)
234 |
235 | def full_sort_predict(self, interaction):
236 | if self.restore_user_e is None or self.restore_item_e is None:
237 | self.restore_user_e, self.restore_item_e = self.forward()
238 |
239 | user = self.restore_user_e[interaction[self.USER_ID]]
240 | return user.matmul(self.restore_item_e.T)
241 |
--------------------------------------------------------------------------------
/recbole_gnn/model/general_recommender/simgcl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | r"""
3 | SimGCL
4 | ################################################
5 | Reference:
6 | Junliang Yu, Hongzhi Yin, Xin Xia, Tong Chen, Lizhen Cui, Quoc Viet Hung Nguyen. "Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for Recommendation." in SIGIR 2022.
7 | """
8 |
9 |
10 | import torch
11 | import torch.nn.functional as F
12 |
13 | from recbole_gnn.model.general_recommender import LightGCN
14 |
15 |
16 | class SimGCL(LightGCN):
17 | def __init__(self, config, dataset):
18 | super(SimGCL, self).__init__(config, dataset)
19 |
20 | self.cl_rate = config['lambda']
21 | self.eps = config['eps']
22 | self.temperature = config['temperature']
23 |
24 | def forward(self, perturbed=False):
25 | all_embs = self.get_ego_embeddings()
26 | embeddings_list = []
27 |
28 | for layer_idx in range(self.n_layers):
29 | all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight)
30 | if perturbed:
31 | random_noise = torch.rand_like(all_embs, device=all_embs.device)
32 | all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps
33 | embeddings_list.append(all_embs)
34 | lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
35 | lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)
36 |
37 | user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
38 | return user_all_embeddings, item_all_embeddings
39 |
40 | def calculate_cl_loss(self, x1, x2):
41 | x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
42 | pos_score = (x1 * x2).sum(dim=-1)
43 | pos_score = torch.exp(pos_score / self.temperature)
44 | ttl_score = torch.matmul(x1, x2.transpose(0, 1))
45 | ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1)
46 | return -torch.log(pos_score / ttl_score).sum()
47 |
48 | def calculate_loss(self, interaction):
49 | loss = super().calculate_loss(interaction)
50 |
51 | user = torch.unique(interaction[self.USER_ID])
52 | pos_item = torch.unique(interaction[self.ITEM_ID])
53 |
54 | perturbed_user_embs_1, perturbed_item_embs_1 = self.forward(perturbed=True)
55 | perturbed_user_embs_2, perturbed_item_embs_2 = self.forward(perturbed=True)
56 |
57 | user_cl_loss = self.calculate_cl_loss(perturbed_user_embs_1[user], perturbed_user_embs_2[user])
58 | item_cl_loss = self.calculate_cl_loss(perturbed_item_embs_1[pos_item], perturbed_item_embs_2[pos_item])
59 |
60 | return loss + self.cl_rate * (user_cl_loss + item_cl_loss)
61 |
--------------------------------------------------------------------------------
/recbole_gnn/model/general_recommender/ssl4rec.py:
--------------------------------------------------------------------------------
1 | r"""
2 | SSL4REC
3 | ################################################
4 | Reference:
5 | Tiansheng Yao et al. "Self-supervised Learning for Large-scale Item Recommendations." in CIKM 2021.
6 |
7 | Reference code:
8 | https://github.com/Coder-Yu/SELFRec/model/graph/SSL4Rec.py
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 | from recbole.model.loss import EmbLoss
16 | from recbole.utils import InputType
17 |
18 | from recbole.model.init import xavier_uniform_initialization
19 | from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
20 |
21 |
22 | class SSL4REC(GeneralGraphRecommender):
23 | input_type = InputType.PAIRWISE
24 |
25 | def __init__(self, config, dataset):
26 | super(SSL4REC, self).__init__(config, dataset)
27 |
28 | # load parameters info
29 | self.tau = config["tau"]
30 | self.reg_weight = config["reg_weight"]
31 | self.cl_rate = config["ssl_weight"]
32 | self.require_pow = config["require_pow"]
33 |
34 | self.reg_loss = EmbLoss()
35 |
36 | self.encoder = DNN_Encoder(config, dataset)
37 |
38 | # storage variables for full sort evaluation acceleration
39 | self.restore_user_e = None
40 | self.restore_item_e = None
41 |
42 | # parameters initialization
43 | self.apply(xavier_uniform_initialization)
44 | self.other_parameter_name = ['restore_user_e', 'restore_item_e']
45 |
46 | def forward(self, user, item):
47 | user_e, item_e = self.encoder(user, item)
48 | return user_e, item_e
49 |
50 | def calculate_batch_softmax_loss(self, user_emb, item_emb, temperature):
51 | user_emb, item_emb = F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1)
52 | pos_score = (user_emb * item_emb).sum(dim=-1)
53 | pos_score = torch.exp(pos_score / temperature)
54 | ttl_score = torch.matmul(user_emb, item_emb.transpose(0, 1))
55 | ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)
56 | loss = -torch.log(pos_score / ttl_score + 10e-6)
57 | return torch.mean(loss)
58 |
59 | def calculate_loss(self, interaction):
60 | # clear the storage variable when training
61 | if self.restore_user_e is not None or self.restore_item_e is not None:
62 | self.restore_user_e, self.restore_item_e = None, None
63 |
64 | user = interaction[self.USER_ID]
65 | pos_item = interaction[self.ITEM_ID]
66 |
67 | user_embeddings, item_embeddings = self.forward(user, pos_item)
68 |
69 | rec_loss = self.calculate_batch_softmax_loss(user_embeddings, item_embeddings, self.tau)
70 | cl_loss = self.encoder.calculate_cl_loss(pos_item)
71 | reg_loss = self.reg_loss(user_embeddings, item_embeddings, require_pow=self.require_pow)
72 |
73 | loss = rec_loss + self.cl_rate * cl_loss + self.reg_weight * reg_loss
74 |
75 | return loss
76 |
77 | def predict(self, interaction):
78 | user = interaction[self.USER_ID]
79 | item = interaction[self.ITEM_ID]
80 |
81 | user_embeddings, item_embeddings = self.forward(user, item)
82 |
83 | u_embeddings = user_embeddings[user]
84 | i_embeddings = item_embeddings[item]
85 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
86 | return scores
87 |
88 | def full_sort_predict(self, interaction):
89 | user = interaction[self.USER_ID]
90 | if self.restore_user_e is None or self.restore_item_e is None:
91 | self.restore_user_e, self.restore_item_e = self.forward(torch.arange(
92 | self.n_users, device=self.device), torch.arange(self.n_items, device=self.device))
93 | # get user embedding from storage variable
94 | u_embeddings = self.restore_user_e[user]
95 |
96 | # dot with all item embedding to accelerate
97 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))
98 |
99 | return scores.view(-1)
100 |
101 |
102 | class DNN_Encoder(nn.Module):
103 | def __init__(self, config, dataset):
104 | super(DNN_Encoder, self).__init__()
105 |
106 | self.emb_size = config["embedding_size"]
107 | self.drop_ratio = config["drop_ratio"]
108 | self.tau = config["tau"]
109 |
110 | self.USER_ID = config["USER_ID_FIELD"]
111 | self.ITEM_ID = config["ITEM_ID_FIELD"]
112 | self.n_users = dataset.num(self.USER_ID)
113 | self.n_items = dataset.num(self.ITEM_ID)
114 |
115 | self.user_tower = nn.Sequential(
116 | nn.Linear(self.emb_size, 1024),
117 | nn.ReLU(True),
118 | nn.Linear(1024, 128),
119 | nn.Tanh()
120 | )
121 | self.item_tower = nn.Sequential(
122 | nn.Linear(self.emb_size, 1024),
123 | nn.ReLU(True),
124 | nn.Linear(1024, 128),
125 | nn.Tanh()
126 | )
127 | self.dropout = nn.Dropout(self.drop_ratio)
128 |
129 | self.initial_user_emb = nn.Embedding(self.n_users, self.emb_size)
130 | self.initial_item_emb = nn.Embedding(self.n_items, self.emb_size)
131 | self.reset_parameters()
132 |
133 | def reset_parameters(self):
134 | nn.init.xavier_uniform_(self.initial_user_emb.weight)
135 | nn.init.xavier_uniform_(self.initial_item_emb.weight)
136 |
137 | def forward(self, q, x):
138 | q_emb = self.initial_user_emb(q)
139 | i_emb = self.initial_item_emb(x)
140 |
141 | q_emb = self.user_tower(q_emb)
142 | i_emb = self.item_tower(i_emb)
143 |
144 | return q_emb, i_emb
145 |
146 | def item_encoding(self, x):
147 | i_emb = self.initial_item_emb(x)
148 | i1_emb = self.dropout(i_emb)
149 | i2_emb = self.dropout(i_emb)
150 |
151 | i1_emb = self.item_tower(i1_emb)
152 | i2_emb = self.item_tower(i2_emb)
153 |
154 | return i1_emb, i2_emb
155 |
156 | def calculate_cl_loss(self, idx):
157 | x1, x2 = self.item_encoding(idx)
158 | x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
159 | pos_score = (x1 * x2).sum(dim=-1)
160 | pos_score = torch.exp(pos_score / self.tau)
161 | ttl_score = torch.matmul(x1, x2.transpose(0, 1))
162 | ttl_score = torch.exp(ttl_score / self.tau).sum(dim=1)
163 | return -torch.log(pos_score / ttl_score).mean()
164 |
--------------------------------------------------------------------------------
/recbole_gnn/model/general_recommender/xsimgcl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | r"""
3 | XSimGCL
4 | ################################################
5 | Reference:
6 | Junliang Yu, Xin Xia, Tong Chen, Lizhen Cui, Nguyen Quoc Viet Hung, Hongzhi Yin. "XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation" in TKDE 2023.
7 |
8 | Reference code:
9 | https://github.com/Coder-Yu/SELFRec/blob/main/model/graph/XSimGCL.py
10 | """
11 |
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from recbole_gnn.model.general_recommender import LightGCN
17 |
18 |
19 | class XSimGCL(LightGCN):
20 | def __init__(self, config, dataset):
21 | super(XSimGCL, self).__init__(config, dataset)
22 |
23 | self.cl_rate = config['lambda']
24 | self.eps = config['eps']
25 | self.temperature = config['temperature']
26 | self.layer_cl = config['layer_cl']
27 |
28 | def forward(self, perturbed=False):
29 | all_embs = self.get_ego_embeddings()
30 | all_embs_cl = all_embs
31 | embeddings_list = []
32 |
33 | for layer_idx in range(self.n_layers):
34 | all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight)
35 | if perturbed:
36 | random_noise = torch.rand_like(all_embs, device=all_embs.device)
37 | all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps
38 | embeddings_list.append(all_embs)
39 | if layer_idx == self.layer_cl - 1:
40 | all_embs_cl = all_embs
41 | lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
42 | lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)
43 |
44 | user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
45 | user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embs_cl, [self.n_users, self.n_items])
46 | if perturbed:
47 | return user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl
48 | return user_all_embeddings, item_all_embeddings
49 |
50 | def calculate_cl_loss(self, x1, x2):
51 | x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
52 | pos_score = (x1 * x2).sum(dim=-1)
53 | pos_score = torch.exp(pos_score / self.temperature)
54 | ttl_score = torch.matmul(x1, x2.transpose(0, 1))
55 | ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1)
56 | return -torch.log(pos_score / ttl_score).mean()
57 |
58 | def calculate_loss(self, interaction):
59 | # clear the storage variable when training
60 | if self.restore_user_e is not None or self.restore_item_e is not None:
61 | self.restore_user_e, self.restore_item_e = None, None
62 |
63 | user = interaction[self.USER_ID]
64 | pos_item = interaction[self.ITEM_ID]
65 | neg_item = interaction[self.NEG_ITEM_ID]
66 |
67 | user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl = self.forward(perturbed=True)
68 | u_embeddings = user_all_embeddings[user]
69 | pos_embeddings = item_all_embeddings[pos_item]
70 | neg_embeddings = item_all_embeddings[neg_item]
71 |
72 | # calculate BPR Loss
73 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
74 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
75 | mf_loss = self.mf_loss(pos_scores, neg_scores)
76 |
77 | # calculate regularization Loss
78 | u_ego_embeddings = self.user_embedding(user)
79 | pos_ego_embeddings = self.item_embedding(pos_item)
80 | neg_ego_embeddings = self.item_embedding(neg_item)
81 | reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)
82 |
83 | user = torch.unique(interaction[self.USER_ID])
84 | pos_item = torch.unique(interaction[self.ITEM_ID])
85 |
86 | # calculate CL Loss
87 | user_cl_loss = self.calculate_cl_loss(user_all_embeddings[user], user_all_embeddings_cl[user])
88 | item_cl_loss = self.calculate_cl_loss(item_all_embeddings[pos_item], item_all_embeddings_cl[pos_item])
89 |
90 | return mf_loss, self.reg_weight * reg_loss, self.cl_rate * (user_cl_loss + item_cl_loss)
91 |
--------------------------------------------------------------------------------
/recbole_gnn/model/layers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch_geometric.nn import MessagePassing
5 | from torch_sparse import matmul
6 |
7 |
8 | class LightGCNConv(MessagePassing):
9 | def __init__(self, dim):
10 | super(LightGCNConv, self).__init__(aggr='add')
11 | self.dim = dim
12 |
13 | def forward(self, x, edge_index, edge_weight):
14 | return self.propagate(edge_index, x=x, edge_weight=edge_weight)
15 |
16 | def message(self, x_j, edge_weight):
17 | return edge_weight.view(-1, 1) * x_j
18 |
19 | def message_and_aggregate(self, adj_t, x):
20 | return matmul(adj_t, x, reduce=self.aggr)
21 |
22 | def __repr__(self):
23 | return '{}({})'.format(self.__class__.__name__, self.dim)
24 |
25 |
26 | class BipartiteGCNConv(MessagePassing):
27 | def __init__(self, dim):
28 | super(BipartiteGCNConv, self).__init__(aggr='add')
29 | self.dim = dim
30 |
31 | def forward(self, x, edge_index, edge_weight, size):
32 | return self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size)
33 |
34 | def message(self, x_j, edge_weight):
35 | return edge_weight.view(-1, 1) * x_j
36 |
37 | def __repr__(self):
38 | return '{}({})'.format(self.__class__.__name__, self.dim)
39 |
40 |
41 | class BiGNNConv(MessagePassing):
42 | r"""Propagate a layer of Bi-interaction GNN
43 |
44 | .. math::
45 | output = (L+I)EW_1 + LE \otimes EW_2
46 | """
47 |
48 | def __init__(self, in_channels, out_channels):
49 | super().__init__(aggr='add')
50 | self.in_channels, self.out_channels = in_channels, out_channels
51 | self.lin1 = torch.nn.Linear(in_features=in_channels, out_features=out_channels)
52 | self.lin2 = torch.nn.Linear(in_features=in_channels, out_features=out_channels)
53 |
54 | def forward(self, x, edge_index, edge_weight):
55 | x_prop = self.propagate(edge_index, x=x, edge_weight=edge_weight)
56 | x_trans = self.lin1(x_prop + x)
57 | x_inter = self.lin2(torch.mul(x_prop, x))
58 | return x_trans + x_inter
59 |
60 | def message(self, x_j, edge_weight):
61 | return edge_weight.view(-1, 1) * x_j
62 |
63 | def message_and_aggregate(self, adj_t, x):
64 | return matmul(adj_t, x, reduce=self.aggr)
65 |
66 | def __repr__(self):
67 | return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels)
68 |
69 |
70 | class SRGNNConv(MessagePassing):
71 | def __init__(self, dim):
72 | # mean aggregation to incorporate weight naturally
73 | super(SRGNNConv, self).__init__(aggr='mean')
74 |
75 | self.lin = torch.nn.Linear(dim, dim)
76 |
77 | def forward(self, x, edge_index):
78 | x = self.lin(x)
79 | return self.propagate(edge_index, x=x)
80 |
81 |
82 | class SRGNNCell(nn.Module):
83 | def __init__(self, dim):
84 | super(SRGNNCell, self).__init__()
85 |
86 | self.dim = dim
87 | self.incomming_conv = SRGNNConv(dim)
88 | self.outcomming_conv = SRGNNConv(dim)
89 |
90 | self.lin_ih = nn.Linear(2 * dim, 3 * dim)
91 | self.lin_hh = nn.Linear(dim, 3 * dim)
92 |
93 | self._reset_parameters()
94 |
95 | def forward(self, hidden, edge_index):
96 | input_in = self.incomming_conv(hidden, edge_index)
97 | reversed_edge_index = torch.flip(edge_index, dims=[0])
98 | input_out = self.outcomming_conv(hidden, reversed_edge_index)
99 | inputs = torch.cat([input_in, input_out], dim=-1)
100 |
101 | gi = self.lin_ih(inputs)
102 | gh = self.lin_hh(hidden)
103 | i_r, i_i, i_n = gi.chunk(3, -1)
104 | h_r, h_i, h_n = gh.chunk(3, -1)
105 | reset_gate = torch.sigmoid(i_r + h_r)
106 | input_gate = torch.sigmoid(i_i + h_i)
107 | new_gate = torch.tanh(i_n + reset_gate * h_n)
108 | hy = (1 - input_gate) * hidden + input_gate * new_gate
109 | return hy
110 |
111 | def _reset_parameters(self):
112 | stdv = 1.0 / np.sqrt(self.dim)
113 | for weight in self.parameters():
114 | weight.data.uniform_(-stdv, stdv)
115 |
--------------------------------------------------------------------------------
/recbole_gnn/model/sequential_recommender/__init__.py:
--------------------------------------------------------------------------------
1 | from recbole_gnn.model.sequential_recommender.gcegnn import GCEGNN
2 | from recbole_gnn.model.sequential_recommender.gcsan import GCSAN
3 | from recbole_gnn.model.sequential_recommender.lessr import LESSR
4 | from recbole_gnn.model.sequential_recommender.niser import NISER
5 | from recbole_gnn.model.sequential_recommender.sgnnhn import SGNNHN
6 | from recbole_gnn.model.sequential_recommender.srgnn import SRGNN
7 | from recbole_gnn.model.sequential_recommender.tagnn import TAGNN
8 |
--------------------------------------------------------------------------------
/recbole_gnn/model/sequential_recommender/gcsan.py:
--------------------------------------------------------------------------------
1 | # @Time : 2022/3/7
2 | # @Author : Yupeng Hou
3 | # @Email : houyupeng@ruc.edu.cn
4 |
5 | r"""
6 | GCSAN
7 | ################################################
8 |
9 | Reference:
10 | Chengfeng Xu et al. "Graph Contextualized Self-Attention Network for Session-based Recommendation." in IJCAI 2019.
11 |
12 | """
13 |
14 | import torch
15 | from torch import nn
16 | from recbole.model.layers import TransformerEncoder
17 | from recbole.model.loss import EmbLoss, BPRLoss
18 | from recbole.model.abstract_recommender import SequentialRecommender
19 |
20 | from recbole_gnn.model.layers import SRGNNCell
21 |
22 |
23 | class GCSAN(SequentialRecommender):
24 | r"""GCSAN captures rich local dependencies via graph neural network,
25 | and learns long-range dependencies by applying the self-attention mechanism.
26 |
27 | Note:
28 |
29 | In the original paper, the attention mechanism in the self-attention layer is a single head,
30 | for the reusability of the project code, we use a unified transformer component.
31 | According to the experimental results, we only applied regularization to embedding.
32 | """
33 |
34 | def __init__(self, config, dataset):
35 | super(GCSAN, self).__init__(config, dataset)
36 |
37 | # load parameters info
38 | self.n_layers = config['n_layers']
39 | self.n_heads = config['n_heads']
40 | self.hidden_size = config['hidden_size'] # same as embedding_size
41 | self.inner_size = config['inner_size'] # the dimensionality in feed-forward layer
42 | self.hidden_dropout_prob = config['hidden_dropout_prob']
43 | self.attn_dropout_prob = config['attn_dropout_prob']
44 | self.hidden_act = config['hidden_act']
45 | self.layer_norm_eps = config['layer_norm_eps']
46 |
47 | self.step = config['step']
48 | self.device = config['device']
49 | self.weight = config['weight']
50 | self.reg_weight = config['reg_weight']
51 | self.loss_type = config['loss_type']
52 | self.initializer_range = config['initializer_range']
53 |
54 | # item embedding
55 | self.item_embedding = nn.Embedding(self.n_items, self.hidden_size, padding_idx=0)
56 |
57 | # define layers and loss
58 | self.gnncell = SRGNNCell(self.hidden_size)
59 | self.self_attention = TransformerEncoder(
60 | n_layers=self.n_layers,
61 | n_heads=self.n_heads,
62 | hidden_size=self.hidden_size,
63 | inner_size=self.inner_size,
64 | hidden_dropout_prob=self.hidden_dropout_prob,
65 | attn_dropout_prob=self.attn_dropout_prob,
66 | hidden_act=self.hidden_act,
67 | layer_norm_eps=self.layer_norm_eps
68 | )
69 | self.reg_loss = EmbLoss()
70 | if self.loss_type == 'BPR':
71 | self.loss_fct = BPRLoss()
72 | elif self.loss_type == 'CE':
73 | self.loss_fct = nn.CrossEntropyLoss()
74 | else:
75 | raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
76 |
77 | # parameters initialization
78 | self.apply(self._init_weights)
79 |
80 | def _init_weights(self, module):
81 | """ Initialize the weights """
82 | if isinstance(module, (nn.Linear, nn.Embedding)):
83 | # Slightly different from the TF version which uses truncated_normal for initialization
84 | # cf https://github.com/pytorch/pytorch/pull/5617
85 | module.weight.data.normal_(mean=0.0, std=self.initializer_range)
86 | elif isinstance(module, nn.LayerNorm):
87 | module.bias.data.zero_()
88 | module.weight.data.fill_(1.0)
89 | if isinstance(module, nn.Linear) and module.bias is not None:
90 | module.bias.data.zero_()
91 |
92 | def get_attention_mask(self, item_seq):
93 | """Generate left-to-right uni-directional attention mask for multi-head attention."""
94 | attention_mask = (item_seq > 0).long()
95 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64
96 | # mask for left-to-right unidirectional
97 | max_len = attention_mask.size(-1)
98 | attn_shape = (1, max_len, max_len)
99 | subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8
100 | subsequent_mask = (subsequent_mask == 0).unsqueeze(1)
101 | subsequent_mask = subsequent_mask.long().to(item_seq.device)
102 |
103 | extended_attention_mask = extended_attention_mask * subsequent_mask
104 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
105 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
106 | return extended_attention_mask
107 |
108 | def forward(self, x, edge_index, alias_inputs, item_seq_len):
109 | hidden = self.item_embedding(x)
110 | for i in range(self.step):
111 | hidden = self.gnncell(hidden, edge_index)
112 |
113 | seq_hidden = hidden[alias_inputs]
114 | # fetch the last hidden state of last timestamp
115 | ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
116 |
117 | attention_mask = self.get_attention_mask(alias_inputs)
118 | outputs = self.self_attention(seq_hidden, attention_mask, output_all_encoded_layers=True)
119 | output = outputs[-1]
120 | at = self.gather_indexes(output, item_seq_len - 1)
121 | seq_output = self.weight * at + (1 - self.weight) * ht
122 | return seq_output
123 |
124 | def calculate_loss(self, interaction):
125 | x = interaction['x']
126 | edge_index = interaction['edge_index']
127 | alias_inputs = interaction['alias_inputs']
128 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
129 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
130 | pos_items = interaction[self.POS_ITEM_ID]
131 | if self.loss_type == 'BPR':
132 | neg_items = interaction[self.NEG_ITEM_ID]
133 | pos_items_emb = self.item_embedding(pos_items)
134 | neg_items_emb = self.item_embedding(neg_items)
135 | pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
136 | neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
137 | loss = self.loss_fct(pos_score, neg_score)
138 | else: # self.loss_type = 'CE'
139 | test_item_emb = self.item_embedding.weight
140 | logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
141 | loss = self.loss_fct(logits, pos_items)
142 | reg_loss = self.reg_loss(self.item_embedding.weight)
143 | total_loss = loss + self.reg_weight * reg_loss
144 | return total_loss
145 |
146 | def predict(self, interaction):
147 | test_item = interaction[self.ITEM_ID]
148 | x = interaction['x']
149 | edge_index = interaction['edge_index']
150 | alias_inputs = interaction['alias_inputs']
151 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
152 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
153 | test_item_emb = self.item_embedding(test_item)
154 | scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
155 | return scores
156 |
157 | def full_sort_predict(self, interaction):
158 | x = interaction['x']
159 | edge_index = interaction['edge_index']
160 | alias_inputs = interaction['alias_inputs']
161 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
162 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
163 | test_items_emb = self.item_embedding.weight
164 | scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items]
165 | return scores
166 |
--------------------------------------------------------------------------------
/recbole_gnn/model/sequential_recommender/lessr.py:
--------------------------------------------------------------------------------
1 | # @Time : 2022/3/11
2 | # @Author : Yupeng Hou
3 | # @Email : houyupeng@ruc.edu.cn
4 |
5 | r"""
6 | LESSR
7 | ################################################
8 |
9 | Reference:
10 | Tianwen Chen and Raymond Chi-Wing Wong. "Handling Information Loss of Graph Neural Networks for Session-based Recommendation." in KDD 2020.
11 |
12 | Reference code:
13 | https://github.com/twchen/lessr
14 |
15 | """
16 |
17 | import torch
18 | from torch import nn
19 | from torch_geometric.utils import softmax
20 | from torch_geometric.nn import global_add_pool
21 | from recbole.model.abstract_recommender import SequentialRecommender
22 |
23 |
24 | class EOPA(nn.Module):
25 | def __init__(
26 | self, input_dim, output_dim, batch_norm=True, feat_drop=0.0, activation=None
27 | ):
28 | super().__init__()
29 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
30 | self.feat_drop = nn.Dropout(feat_drop)
31 | self.gru = nn.GRU(input_dim, input_dim, batch_first=True)
32 | self.fc_self = nn.Linear(input_dim, output_dim, bias=False)
33 | self.fc_neigh = nn.Linear(input_dim, output_dim, bias=False)
34 | self.activation = activation
35 |
36 | def reducer(self, nodes):
37 | m = nodes.mailbox['m'] # (num_nodes, deg, d)
38 | # m[i]: the messages passed to the i-th node with in-degree equal to 'deg'
39 | # the order of messages follows the order of incoming edges
40 | # since the edges are sorted by occurrence time when the EOP multigraph is built
41 | # the messages are in the order required by EOPA
42 | _, hn = self.gru(m) # hn: (1, num_nodes, d)
43 | return {'neigh': hn.squeeze(0)}
44 |
45 | def forward(self, mg, feat):
46 | import dgl.function as fn
47 |
48 | with mg.local_scope():
49 | if self.batch_norm is not None:
50 | feat = self.batch_norm(feat)
51 | mg.ndata['ft'] = self.feat_drop(feat)
52 | if mg.number_of_edges() > 0:
53 | mg.update_all(fn.copy_u('ft', 'm'), self.reducer)
54 | neigh = mg.ndata['neigh']
55 | rst = self.fc_self(feat) + self.fc_neigh(neigh)
56 | else:
57 | rst = self.fc_self(feat)
58 | if self.activation is not None:
59 | rst = self.activation(rst)
60 | return rst
61 |
62 |
63 | class SGAT(nn.Module):
64 | def __init__(
65 | self,
66 | input_dim,
67 | hidden_dim,
68 | output_dim,
69 | batch_norm=True,
70 | feat_drop=0.0,
71 | activation=None,
72 | ):
73 | super().__init__()
74 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
75 | self.feat_drop = nn.Dropout(feat_drop)
76 | self.fc_q = nn.Linear(input_dim, hidden_dim, bias=True)
77 | self.fc_k = nn.Linear(input_dim, hidden_dim, bias=False)
78 | self.fc_v = nn.Linear(input_dim, output_dim, bias=False)
79 | self.fc_e = nn.Linear(hidden_dim, 1, bias=False)
80 | self.activation = activation
81 |
82 | def forward(self, sg, feat):
83 | import dgl.ops as F
84 |
85 | if self.batch_norm is not None:
86 | feat = self.batch_norm(feat)
87 | feat = self.feat_drop(feat)
88 | q = self.fc_q(feat)
89 | k = self.fc_k(feat)
90 | v = self.fc_v(feat)
91 | e = F.u_add_v(sg, q, k)
92 | e = self.fc_e(torch.sigmoid(e))
93 | a = F.edge_softmax(sg, e)
94 | rst = F.u_mul_e_sum(sg, v, a)
95 | if self.activation is not None:
96 | rst = self.activation(rst)
97 | return rst
98 |
99 |
100 | class AttnReadout(nn.Module):
101 | def __init__(
102 | self,
103 | input_dim,
104 | hidden_dim,
105 | output_dim,
106 | batch_norm=True,
107 | feat_drop=0.0,
108 | activation=None,
109 | ):
110 | super().__init__()
111 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
112 | self.feat_drop = nn.Dropout(feat_drop)
113 | self.fc_u = nn.Linear(input_dim, hidden_dim, bias=False)
114 | self.fc_v = nn.Linear(input_dim, hidden_dim, bias=True)
115 | self.fc_e = nn.Linear(hidden_dim, 1, bias=False)
116 | self.fc_out = (
117 | nn.Linear(input_dim, output_dim, bias=False)
118 | if output_dim != input_dim else None
119 | )
120 | self.activation = activation
121 |
122 | def forward(self, g, feat, last_nodes, batch):
123 | if self.batch_norm is not None:
124 | feat = self.batch_norm(feat)
125 | feat = self.feat_drop(feat)
126 | feat_u = self.fc_u(feat)
127 | feat_v = self.fc_v(feat[last_nodes])
128 | feat_v = torch.index_select(feat_v, dim=0, index=batch)
129 | e = self.fc_e(torch.sigmoid(feat_u + feat_v))
130 | alpha = softmax(e, batch)
131 | feat_norm = feat * alpha
132 | rst = global_add_pool(feat_norm, batch)
133 | if self.fc_out is not None:
134 | rst = self.fc_out(rst)
135 | if self.activation is not None:
136 | rst = self.activation(rst)
137 | return rst
138 |
139 |
140 | class LESSR(SequentialRecommender):
141 | r"""LESSR analyzes the information losses when constructing session graphs,
142 | and emphasises lossy session encoding problem and the ineffective long-range dependency capturing problem.
143 | To solve the first problem, authors propose a lossless encoding scheme and an edge-order preserving aggregation layer.
144 | To solve the second problem, authors propose a shortcut graph attention layer that effectively captures long-range dependencies.
145 |
146 | Note:
147 | We follow the original implementation, which requires DGL package.
148 | We find it difficult to implement these functions via PyG, so we remain them.
149 | If you would like to test this model, please install DGL.
150 | """
151 |
152 | def __init__(self, config, dataset):
153 | super().__init__(config, dataset)
154 |
155 | embedding_dim = config['embedding_size']
156 | self.num_layers = config['n_layers']
157 | batch_norm = config['batch_norm']
158 | feat_drop = config['feat_drop']
159 | self.loss_type = config['loss_type']
160 |
161 | self.item_embedding = nn.Embedding(self.n_items, embedding_dim, max_norm=1)
162 | self.layers = nn.ModuleList()
163 | input_dim = embedding_dim
164 | for i in range(self.num_layers):
165 | if i % 2 == 0:
166 | layer = EOPA(
167 | input_dim,
168 | embedding_dim,
169 | batch_norm=batch_norm,
170 | feat_drop=feat_drop,
171 | activation=nn.PReLU(embedding_dim),
172 | )
173 | else:
174 | layer = SGAT(
175 | input_dim,
176 | embedding_dim,
177 | embedding_dim,
178 | batch_norm=batch_norm,
179 | feat_drop=feat_drop,
180 | activation=nn.PReLU(embedding_dim),
181 | )
182 | input_dim += embedding_dim
183 | self.layers.append(layer)
184 | self.readout = AttnReadout(
185 | input_dim,
186 | embedding_dim,
187 | embedding_dim,
188 | batch_norm=batch_norm,
189 | feat_drop=feat_drop,
190 | activation=nn.PReLU(embedding_dim),
191 | )
192 | input_dim += embedding_dim
193 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
194 | self.feat_drop = nn.Dropout(feat_drop)
195 | self.fc_sr = nn.Linear(input_dim, embedding_dim, bias=False)
196 |
197 | if self.loss_type == 'CE':
198 | self.loss_fct = nn.CrossEntropyLoss()
199 | else:
200 | raise NotImplementedError("Make sure 'loss_type' in ['CE']!")
201 |
202 | def forward(self, x, edge_index_EOP, edge_index_shortcut, batch, is_last):
203 | import dgl
204 |
205 | mg = dgl.graph((edge_index_EOP[0], edge_index_EOP[1]), num_nodes=batch.shape[0])
206 | sg = dgl.graph((edge_index_shortcut[0], edge_index_shortcut[1]), num_nodes=batch.shape[0])
207 |
208 | feat = self.item_embedding(x)
209 | for i, layer in enumerate(self.layers):
210 | if i % 2 == 0:
211 | out = layer(mg, feat)
212 | else:
213 | out = layer(sg, feat)
214 | feat = torch.cat([out, feat], dim=1)
215 | sr_g = self.readout(mg, feat, is_last, batch)
216 | sr_l = feat[is_last]
217 | sr = torch.cat([sr_l, sr_g], dim=1)
218 | if self.batch_norm is not None:
219 | sr = self.batch_norm(sr)
220 | sr = self.fc_sr(self.feat_drop(sr))
221 | return sr
222 |
223 | def calculate_loss(self, interaction):
224 | x = interaction['x']
225 | edge_index_EOP = interaction['edge_index_EOP']
226 | edge_index_shortcut = interaction['edge_index_shortcut']
227 | batch = interaction['batch']
228 | is_last = interaction['is_last']
229 | seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last)
230 | pos_items = interaction[self.POS_ITEM_ID]
231 | test_item_emb = self.item_embedding.weight
232 | logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
233 | loss = self.loss_fct(logits, pos_items)
234 | return loss
235 |
236 | def predict(self, interaction):
237 | test_item = interaction[self.ITEM_ID]
238 | x = interaction['x']
239 | edge_index_EOP = interaction['edge_index_EOP']
240 | edge_index_shortcut = interaction['edge_index_shortcut']
241 | batch = interaction['batch']
242 | is_last = interaction['is_last']
243 | seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last)
244 | test_item_emb = self.item_embedding(test_item)
245 | scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
246 | return scores
247 |
248 | def full_sort_predict(self, interaction):
249 | x = interaction['x']
250 | edge_index_EOP = interaction['edge_index_EOP']
251 | edge_index_shortcut = interaction['edge_index_shortcut']
252 | batch = interaction['batch']
253 | is_last = interaction['is_last']
254 | seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last)
255 | test_items_emb = self.item_embedding.weight
256 | scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items]
257 | return scores
258 |
--------------------------------------------------------------------------------
/recbole_gnn/model/sequential_recommender/niser.py:
--------------------------------------------------------------------------------
1 | # @Time : 2022/3/7
2 | # @Author : Yupeng Hou
3 | # @Email : houyupeng@ruc.edu.cn
4 |
5 | r"""
6 | NISER
7 | ################################################
8 |
9 | Reference:
10 | Priyanka Gupta et al. "NISER: Normalized Item and Session Representations to Handle Popularity Bias." in CIKM 2019 GRLA workshop.
11 |
12 | """
13 | import numpy as np
14 | import torch
15 | from torch import nn
16 | import torch.nn.functional as F
17 | from recbole.model.loss import BPRLoss
18 | from recbole.model.abstract_recommender import SequentialRecommender
19 |
20 | from recbole_gnn.model.layers import SRGNNCell
21 |
22 |
23 | class NISER(SequentialRecommender):
24 | r"""NISER+ is a GNN-based model that normalizes session and item embeddings to handle popularity bias.
25 | """
26 |
27 | def __init__(self, config, dataset):
28 | super(NISER, self).__init__(config, dataset)
29 |
30 | # load parameters info
31 | self.embedding_size = config['embedding_size']
32 | self.step = config['step']
33 | self.device = config['device']
34 | self.loss_type = config['loss_type']
35 | self.sigma = config['sigma']
36 | self.max_seq_length = dataset.field2seqlen[self.ITEM_SEQ]
37 |
38 | # item embedding
39 | self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
40 | self.pos_embedding = nn.Embedding(self.max_seq_length, self.embedding_size)
41 | self.item_dropout = nn.Dropout(config['item_dropout'])
42 |
43 | # define layers and loss
44 | self.gnncell = SRGNNCell(self.embedding_size)
45 | self.linear_one = nn.Linear(self.embedding_size, self.embedding_size)
46 | self.linear_two = nn.Linear(self.embedding_size, self.embedding_size)
47 | self.linear_three = nn.Linear(self.embedding_size, 1, bias=False)
48 | self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size)
49 | if self.loss_type == 'BPR':
50 | self.loss_fct = BPRLoss()
51 | elif self.loss_type == 'CE':
52 | self.loss_fct = nn.CrossEntropyLoss()
53 | else:
54 | raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
55 |
56 | # parameters initialization
57 | self._reset_parameters()
58 |
59 | def _reset_parameters(self):
60 | stdv = 1.0 / np.sqrt(self.embedding_size)
61 | for weight in self.parameters():
62 | weight.data.uniform_(-stdv, stdv)
63 |
64 | def forward(self, x, edge_index, alias_inputs, item_seq_len):
65 | mask = alias_inputs.gt(0)
66 | hidden = self.item_embedding(x)
67 | # Dropout in NISER+
68 | hidden = self.item_dropout(hidden)
69 | # Normalize item embeddings
70 | hidden = F.normalize(hidden, dim=-1)
71 | for i in range(self.step):
72 | hidden = self.gnncell(hidden, edge_index)
73 |
74 | seq_hidden = hidden[alias_inputs]
75 | batch_size = seq_hidden.shape[0]
76 | pos_emb = self.pos_embedding.weight[:seq_hidden.shape[1]]
77 | pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1)
78 | seq_hidden = seq_hidden + pos_emb
79 | # fetch the last hidden state of last timestamp
80 | ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
81 | q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1))
82 | q2 = self.linear_two(seq_hidden)
83 |
84 | alpha = self.linear_three(torch.sigmoid(q1 + q2))
85 | a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1)
86 | seq_output = self.linear_transform(torch.cat([a, ht], dim=1))
87 | # Normalize session embeddings
88 | seq_output = F.normalize(seq_output, dim=-1)
89 | return seq_output
90 |
91 | def calculate_loss(self, interaction):
92 | x = interaction['x']
93 | edge_index = interaction['edge_index']
94 | alias_inputs = interaction['alias_inputs']
95 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
96 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
97 | pos_items = interaction[self.POS_ITEM_ID]
98 | if self.loss_type == 'BPR':
99 | neg_items = interaction[self.NEG_ITEM_ID]
100 | pos_items_emb = F.normalize(self.item_embedding(pos_items), dim=-1)
101 | neg_items_emb = F.normalize(self.item_embedding(neg_items), dim=-1)
102 | pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
103 | neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
104 | loss = self.loss_fct(self.sigma * pos_score, self.sigma * neg_score)
105 | return loss
106 | else: # self.loss_type = 'CE'
107 | test_item_emb = F.normalize(self.item_embedding.weight, dim=-1)
108 | logits = self.sigma * torch.matmul(seq_output, test_item_emb.transpose(0, 1))
109 | loss = self.loss_fct(logits, pos_items)
110 | return loss
111 |
112 | def predict(self, interaction):
113 | test_item = interaction[self.ITEM_ID]
114 | x = interaction['x']
115 | edge_index = interaction['edge_index']
116 | alias_inputs = interaction['alias_inputs']
117 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
118 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
119 | test_item_emb = F.normalize(self.item_embedding(test_item), dim=-1)
120 | scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
121 | return scores
122 |
123 | def full_sort_predict(self, interaction):
124 | x = interaction['x']
125 | edge_index = interaction['edge_index']
126 | alias_inputs = interaction['alias_inputs']
127 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
128 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
129 | test_items_emb = F.normalize(self.item_embedding.weight, dim=-1)
130 | scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items]
131 | return scores
132 |
--------------------------------------------------------------------------------
/recbole_gnn/model/sequential_recommender/sgnnhn.py:
--------------------------------------------------------------------------------
1 | # @Time : 2022/3/28
2 | # @Author : Yupeng Hou
3 | # @Email : houyupeng@ruc.edu.cn
4 |
5 | r"""
6 | SRGNN
7 | ################################################
8 |
9 | Reference:
10 | Zhiqiang Pan et al. "Star Graph Neural Networks for Session-based Recommendation." in CIKM 2020.
11 |
12 | Reference code:
13 | https://bitbucket.org/nudtpanzq/sgnn-hn
14 |
15 | """
16 |
17 | import math
18 | import numpy as np
19 | import torch
20 | from torch import nn
21 | from torch_geometric.nn import global_mean_pool, global_add_pool
22 | from torch_geometric.utils import softmax
23 | from recbole.model.abstract_recommender import SequentialRecommender
24 | from recbole.model.loss import BPRLoss
25 |
26 | from recbole_gnn.model.layers import SRGNNCell
27 |
28 |
29 | def layer_norm(x):
30 | ave_x = torch.mean(x, -1).unsqueeze(-1)
31 | x = x - ave_x
32 | norm_x = torch.sqrt(torch.sum(x**2, -1)).unsqueeze(-1)
33 | y = x / norm_x
34 | return y
35 |
36 |
37 | class SGNNHN(SequentialRecommender):
38 | r"""SGNN-HN applies a star graph neural network to model the complex transition relationship between items in an ongoing session.
39 | To avoid overfitting, it applies highway networks to adaptively select embeddings from item representations.
40 | """
41 |
42 | def __init__(self, config, dataset):
43 | super(SGNNHN, self).__init__(config, dataset)
44 |
45 | # load parameters info
46 | self.embedding_size = config['embedding_size']
47 | self.step = config['step']
48 | self.device = config['device']
49 | self.loss_type = config['loss_type']
50 | self.scale = config['scale']
51 |
52 | # item embedding
53 | self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
54 | self.max_seq_length = dataset.field2seqlen[self.ITEM_SEQ]
55 | self.pos_embedding = nn.Embedding(self.max_seq_length, self.embedding_size)
56 |
57 | # define layers and loss
58 | self.gnncell = SRGNNCell(self.embedding_size)
59 | self.linear_one = nn.Linear(self.embedding_size, self.embedding_size)
60 | self.linear_two = nn.Linear(self.embedding_size, self.embedding_size)
61 | self.linear_three = nn.Linear(self.embedding_size, self.embedding_size)
62 | self.linear_four = nn.Linear(self.embedding_size, 1, bias=False)
63 | self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size)
64 | if self.loss_type == 'BPR':
65 | self.loss_fct = BPRLoss()
66 | elif self.loss_type == 'CE':
67 | self.loss_fct = nn.CrossEntropyLoss()
68 | else:
69 | raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
70 |
71 | # parameters initialization
72 | self._reset_parameters()
73 |
74 | def _reset_parameters(self):
75 | stdv = 1.0 / np.sqrt(self.embedding_size)
76 | for weight in self.parameters():
77 | weight.data.uniform_(-stdv, stdv)
78 |
79 | def att_out(self, hidden, star_node, batch):
80 | star_node_repeat = torch.index_select(star_node, 0, batch)
81 | sim = (hidden * star_node_repeat).sum(dim=-1)
82 | sim = softmax(sim, batch)
83 | att_hidden = sim.unsqueeze(-1) * hidden
84 | output = global_add_pool(att_hidden, batch)
85 |
86 | return output
87 |
88 | def forward(self, x, edge_index, batch, alias_inputs, item_seq_len):
89 | mask = alias_inputs.gt(0)
90 | hidden = self.item_embedding(x)
91 |
92 | star_node = global_mean_pool(hidden, batch)
93 | for i in range(self.step):
94 | hidden = self.gnncell(hidden, edge_index)
95 | star_node_repeat = torch.index_select(star_node, 0, batch)
96 | sim = (hidden * star_node_repeat).sum(dim=-1, keepdim=True) / math.sqrt(self.embedding_size)
97 | alpha = torch.sigmoid(sim)
98 | hidden = (1 - alpha) * hidden + alpha * star_node_repeat
99 | star_node = self.att_out(hidden, star_node, batch)
100 |
101 | seq_hidden = hidden[alias_inputs]
102 | bs, item_num, _ = seq_hidden.shape
103 | pos_emb = self.pos_embedding.weight[:item_num]
104 | pos_emb = pos_emb.unsqueeze(0).expand(bs, -1, -1)
105 | seq_hidden = seq_hidden + pos_emb
106 |
107 | # fetch the last hidden state of last timestamp
108 | ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
109 | q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1))
110 | q2 = self.linear_two(seq_hidden)
111 | q3 = self.linear_three(star_node).view(star_node.shape[0], 1, star_node.shape[1])
112 |
113 | alpha = self.linear_four(torch.sigmoid(q1 + q2 + q3))
114 | a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1)
115 | seq_output = self.linear_transform(torch.cat([a, ht], dim=1))
116 | return layer_norm(seq_output)
117 |
118 | def calculate_loss(self, interaction):
119 | x = interaction['x']
120 | edge_index = interaction['edge_index']
121 | batch = interaction['batch']
122 | alias_inputs = interaction['alias_inputs']
123 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
124 | seq_output = self.forward(x, edge_index, batch, alias_inputs, item_seq_len)
125 | pos_items = interaction[self.POS_ITEM_ID]
126 | if self.loss_type == 'BPR':
127 | neg_items = interaction[self.NEG_ITEM_ID]
128 | pos_items_emb = layer_norm(self.item_embedding(pos_items))
129 | neg_items_emb = layer_norm(self.item_embedding(neg_items))
130 | pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) * self.scale # [B]
131 | neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) * self.scale # [B]
132 | loss = self.loss_fct(pos_score, neg_score)
133 | return loss
134 | else: # self.loss_type = 'CE'
135 | test_item_emb = layer_norm(self.item_embedding.weight)
136 | logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) * self.scale
137 | loss = self.loss_fct(logits, pos_items)
138 | return loss
139 |
140 | def predict(self, interaction):
141 | test_item = interaction[self.ITEM_ID]
142 | x = interaction['x']
143 | edge_index = interaction['edge_index']
144 | batch = interaction['batch']
145 | alias_inputs = interaction['alias_inputs']
146 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
147 | seq_output = self.forward(x, edge_index, batch, alias_inputs, item_seq_len)
148 | test_item_emb = layer_norm(self.item_embedding(test_item))
149 | scores = torch.mul(seq_output, test_item_emb).sum(dim=1) * self.scale # [B]
150 | return scores
151 |
152 | def full_sort_predict(self, interaction):
153 | x = interaction['x']
154 | edge_index = interaction['edge_index']
155 | batch = interaction['batch']
156 | alias_inputs = interaction['alias_inputs']
157 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
158 | seq_output = self.forward(x, edge_index, batch, alias_inputs, item_seq_len)
159 | test_items_emb = layer_norm(self.item_embedding.weight)
160 | scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) * self.scale # [B, n_items]
161 | return scores
162 |
--------------------------------------------------------------------------------
/recbole_gnn/model/sequential_recommender/srgnn.py:
--------------------------------------------------------------------------------
1 | # @Time : 2022/3/7
2 | # @Author : Yupeng Hou
3 | # @Email : houyupeng@ruc.edu.cn
4 |
5 | r"""
6 | SRGNN
7 | ################################################
8 |
9 | Reference:
10 | Shu Wu et al. "Session-based Recommendation with Graph Neural Networks." in AAAI 2019.
11 |
12 | Reference code:
13 | https://github.com/CRIPAC-DIG/SR-GNN
14 |
15 | """
16 | import numpy as np
17 | import torch
18 | from torch import nn
19 | from recbole.model.loss import BPRLoss
20 | from recbole.model.abstract_recommender import SequentialRecommender
21 |
22 | from recbole_gnn.model.layers import SRGNNCell
23 |
24 |
25 | class SRGNN(SequentialRecommender):
26 | r"""SRGNN regards the conversation history as a directed graph.
27 | In addition to considering the connection between the item and the adjacent item,
28 | it also considers the connection with other interactive items.
29 |
30 | Such as: A example of a session sequence(eg:item1, item2, item3, item2, item4) and the connection matrix A
31 |
32 | Outgoing edges:
33 | === ===== ===== ===== =====
34 | \ 1 2 3 4
35 | === ===== ===== ===== =====
36 | 1 0 1 0 0
37 | 2 0 0 1/2 1/2
38 | 3 0 1 0 0
39 | 4 0 0 0 0
40 | === ===== ===== ===== =====
41 |
42 | Incoming edges:
43 | === ===== ===== ===== =====
44 | \ 1 2 3 4
45 | === ===== ===== ===== =====
46 | 1 0 0 0 0
47 | 2 1/2 0 1/2 0
48 | 3 0 1 0 0
49 | 4 0 1 0 0
50 | === ===== ===== ===== =====
51 | """
52 |
53 | def __init__(self, config, dataset):
54 | super(SRGNN, self).__init__(config, dataset)
55 |
56 | # load parameters info
57 | self.embedding_size = config['embedding_size']
58 | self.step = config['step']
59 | self.device = config['device']
60 | self.loss_type = config['loss_type']
61 |
62 | # item embedding
63 | self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
64 |
65 | # define layers and loss
66 | self.gnncell = SRGNNCell(self.embedding_size)
67 | self.linear_one = nn.Linear(self.embedding_size, self.embedding_size)
68 | self.linear_two = nn.Linear(self.embedding_size, self.embedding_size)
69 | self.linear_three = nn.Linear(self.embedding_size, 1, bias=False)
70 | self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size)
71 | if self.loss_type == 'BPR':
72 | self.loss_fct = BPRLoss()
73 | elif self.loss_type == 'CE':
74 | self.loss_fct = nn.CrossEntropyLoss()
75 | else:
76 | raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
77 |
78 | # parameters initialization
79 | self._reset_parameters()
80 |
81 | def _reset_parameters(self):
82 | stdv = 1.0 / np.sqrt(self.embedding_size)
83 | for weight in self.parameters():
84 | weight.data.uniform_(-stdv, stdv)
85 |
86 | def forward(self, x, edge_index, alias_inputs, item_seq_len):
87 | mask = alias_inputs.gt(0)
88 | hidden = self.item_embedding(x)
89 | for i in range(self.step):
90 | hidden = self.gnncell(hidden, edge_index)
91 |
92 | seq_hidden = hidden[alias_inputs]
93 | # fetch the last hidden state of last timestamp
94 | ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
95 | q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1))
96 | q2 = self.linear_two(seq_hidden)
97 |
98 | alpha = self.linear_three(torch.sigmoid(q1 + q2))
99 | a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1)
100 | seq_output = self.linear_transform(torch.cat([a, ht], dim=1))
101 | return seq_output
102 |
103 | def calculate_loss(self, interaction):
104 | x = interaction['x']
105 | edge_index = interaction['edge_index']
106 | alias_inputs = interaction['alias_inputs']
107 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
108 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
109 | pos_items = interaction[self.POS_ITEM_ID]
110 | if self.loss_type == 'BPR':
111 | neg_items = interaction[self.NEG_ITEM_ID]
112 | pos_items_emb = self.item_embedding(pos_items)
113 | neg_items_emb = self.item_embedding(neg_items)
114 | pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
115 | neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
116 | loss = self.loss_fct(pos_score, neg_score)
117 | return loss
118 | else: # self.loss_type = 'CE'
119 | test_item_emb = self.item_embedding.weight
120 | logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
121 | loss = self.loss_fct(logits, pos_items)
122 | return loss
123 |
124 | def predict(self, interaction):
125 | test_item = interaction[self.ITEM_ID]
126 | x = interaction['x']
127 | edge_index = interaction['edge_index']
128 | alias_inputs = interaction['alias_inputs']
129 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
130 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
131 | test_item_emb = self.item_embedding(test_item)
132 | scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
133 | return scores
134 |
135 | def full_sort_predict(self, interaction):
136 | x = interaction['x']
137 | edge_index = interaction['edge_index']
138 | alias_inputs = interaction['alias_inputs']
139 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
140 | seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
141 | test_items_emb = self.item_embedding.weight
142 | scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items]
143 | return scores
144 |
--------------------------------------------------------------------------------
/recbole_gnn/model/sequential_recommender/tagnn.py:
--------------------------------------------------------------------------------
1 | # @Time : 2022/3/17
2 | # @Author : Yupeng Hou
3 | # @Email : houyupeng@ruc.edu.cn
4 |
5 | r"""
6 | TAGNN
7 | ################################################
8 |
9 | Reference:
10 | Feng Yu et al. "TAGNN: Target Attentive Graph Neural Networks for Session-based Recommendation." in SIGIR 2020 short.
11 | Implemented using PyTorch Geometric.
12 |
13 | Reference code:
14 | https://github.com/CRIPAC-DIG/TAGNN
15 |
16 | """
17 | import numpy as np
18 | import torch
19 | from torch import nn
20 | import torch.nn.functional as F
21 | from recbole.model.abstract_recommender import SequentialRecommender
22 |
23 | from recbole_gnn.model.layers import SRGNNCell
24 |
25 |
26 | class TAGNN(SequentialRecommender):
27 | r"""TAGNN introduces target-aware attention and adaptively activates different user interests with respect to varied target items.
28 | """
29 |
30 | def __init__(self, config, dataset):
31 | super(TAGNN, self).__init__(config, dataset)
32 |
33 | # load parameters info
34 | self.embedding_size = config['embedding_size']
35 | self.step = config['step']
36 | self.device = config['device']
37 | self.loss_type = config['loss_type']
38 |
39 | # item embedding
40 | self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
41 |
42 | # define layers and loss
43 | self.gnncell = SRGNNCell(self.embedding_size)
44 | self.linear_one = nn.Linear(self.embedding_size, self.embedding_size)
45 | self.linear_two = nn.Linear(self.embedding_size, self.embedding_size)
46 | self.linear_three = nn.Linear(self.embedding_size, 1, bias=False)
47 | self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size)
48 | self.linear_t = nn.Linear(self.embedding_size, self.embedding_size, bias=False) #target attention
49 | if self.loss_type == 'CE':
50 | self.loss_fct = nn.CrossEntropyLoss()
51 | else:
52 | raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
53 |
54 | # parameters initialization
55 | self._reset_parameters()
56 |
57 | def _reset_parameters(self):
58 | stdv = 1.0 / np.sqrt(self.embedding_size)
59 | for weight in self.parameters():
60 | weight.data.uniform_(-stdv, stdv)
61 |
62 | def forward(self, x, edge_index, alias_inputs, item_seq_len):
63 | mask = alias_inputs.gt(0)
64 | hidden = self.item_embedding(x)
65 | for i in range(self.step):
66 | hidden = self.gnncell(hidden, edge_index)
67 |
68 | seq_hidden = hidden[alias_inputs]
69 | # fetch the last hidden state of last timestamp
70 | ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
71 | q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1))
72 | q2 = self.linear_two(seq_hidden)
73 |
74 | alpha = self.linear_three(torch.sigmoid(q1 + q2))
75 | alpha = F.softmax(alpha, 1)
76 | a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1)
77 | seq_output = self.linear_transform(torch.cat([a, ht], dim=1))
78 |
79 | seq_hidden = seq_hidden * mask.view(mask.shape[0], -1, 1).float()
80 | qt = self.linear_t(seq_hidden)
81 | b = self.item_embedding.weight
82 | beta = F.softmax(b @ qt.transpose(1,2), -1)
83 | target = beta @ seq_hidden
84 | a = seq_output.view(ht.shape[0], 1, ht.shape[1]) # b,1,d
85 | a = a + target # b,n,d
86 | scores = torch.sum(a * b, -1) # b,n
87 | return scores
88 |
89 | def calculate_loss(self, interaction):
90 | x = interaction['x']
91 | edge_index = interaction['edge_index']
92 | alias_inputs = interaction['alias_inputs']
93 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
94 | logits = self.forward(x, edge_index, alias_inputs, item_seq_len)
95 | pos_items = interaction[self.POS_ITEM_ID]
96 | loss = self.loss_fct(logits, pos_items)
97 | return loss
98 |
99 | def predict(self, interaction):
100 | pass
101 |
102 | def full_sort_predict(self, interaction):
103 | x = interaction['x']
104 | edge_index = interaction['edge_index']
105 | alias_inputs = interaction['alias_inputs']
106 | item_seq_len = interaction[self.ITEM_SEQ_LEN]
107 | scores = self.forward(x, edge_index, alias_inputs, item_seq_len)
108 | return scores
109 |
--------------------------------------------------------------------------------
/recbole_gnn/model/social_recommender/__init__.py:
--------------------------------------------------------------------------------
1 | from recbole_gnn.model.social_recommender.diffnet import DiffNet
2 | from recbole_gnn.model.social_recommender.mhcn import MHCN
3 | from recbole_gnn.model.social_recommender.sept import SEPT
--------------------------------------------------------------------------------
/recbole_gnn/model/social_recommender/diffnet.py:
--------------------------------------------------------------------------------
1 | # @Time : 2022/3/15
2 | # @Author : Lanling Xu
3 | # @Email : xulanling_sherry@163.com
4 |
5 | r"""
6 | DiffNet
7 | ################################################
8 | Reference:
9 | Le Wu et al. "A Neural Influence Diffusion Model for Social Recommendation." in SIGIR 2019.
10 |
11 | Reference code:
12 | https://github.com/PeiJieSun/diffnet
13 | """
14 |
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 |
19 | from recbole.model.init import xavier_uniform_initialization
20 | from recbole.model.loss import BPRLoss, EmbLoss
21 | from recbole.utils import InputType
22 |
23 | from recbole_gnn.model.abstract_recommender import SocialRecommender
24 | from recbole_gnn.model.layers import BipartiteGCNConv
25 |
26 |
27 | class DiffNet(SocialRecommender):
28 | r"""DiffNet is a deep influence propagation model to stimulate how users are influenced by the recursive social diffusion process for social recommendation.
29 | We implement the model following the original author with a pairwise training mode.
30 | """
31 | input_type = InputType.PAIRWISE
32 |
33 | def __init__(self, config, dataset):
34 | super(DiffNet, self).__init__(config, dataset)
35 |
36 | # load dataset info
37 | self.edge_index, self.edge_weight = dataset.get_bipartite_inter_mat(row='user')
38 | self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device)
39 |
40 | self.net_edge_index, self.net_edge_weight = dataset.get_norm_net_adj_mat(row_norm=True)
41 | self.net_edge_index, self.net_edge_weight = self.net_edge_index.to(self.device), self.net_edge_weight.to(self.device)
42 |
43 | # load parameters info
44 | self.embedding_size = config['embedding_size'] # int type:the embedding size of DiffNet
45 | self.n_layers = config['n_layers'] # int type:the GCN layer num of DiffNet for social net
46 | self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization
47 | self.pretrained_review = config['pretrained_review'] # bool type:whether to load pre-trained review vectors of users and items
48 |
49 | # define layers and loss
50 | self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.embedding_size)
51 | self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.embedding_size)
52 | self.bipartite_gcn_conv = BipartiteGCNConv(dim=self.embedding_size)
53 | self.mf_loss = BPRLoss()
54 | self.reg_loss = EmbLoss()
55 |
56 | # storage variables for full sort evaluation acceleration
57 | self.restore_user_e = None
58 | self.restore_item_e = None
59 |
60 | # parameters initialization
61 | self.apply(xavier_uniform_initialization)
62 | self.other_parameter_name = ['restore_user_e', 'restore_item_e']
63 |
64 | if self.pretrained_review:
65 | # handle review information, map the origin review into the new space
66 | self.user_review_embedding = nn.Embedding(self.n_users, self.embedding_size, padding_idx=0)
67 | self.user_review_embedding.weight.requires_grad = False
68 | self.user_review_embedding.weight.data.copy_(self.convertDistribution(dataset.user_feat['user_review_emb']))
69 |
70 | self.item_review_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
71 | self.item_review_embedding.weight.requires_grad = False
72 | self.item_review_embedding.weight.data.copy_(self.convertDistribution(dataset.item_feat['item_review_emb']))
73 |
74 | self.user_fusion_layer = nn.Linear(self.embedding_size, self.embedding_size)
75 | self.item_fusion_layer = nn.Linear(self.embedding_size, self.embedding_size)
76 | self.activation = nn.Sigmoid()
77 |
78 | def convertDistribution(self, x):
79 | mean, std = torch.mean(x), torch.std(x)
80 | y = (x - mean) * 0.2 / std
81 | return y
82 |
83 | def forward(self):
84 | user_embedding = self.user_embedding.weight
85 | final_item_embedding = self.item_embedding.weight
86 |
87 | if self.pretrained_review:
88 | user_reduce_dim_vector_matrix = self.activation(self.user_fusion_layer(self.user_review_embedding.weight))
89 | item_reduce_dim_vector_matrix = self.activation(self.item_fusion_layer(self.item_review_embedding.weight))
90 |
91 | user_review_vector_matrix = self.convertDistribution(user_reduce_dim_vector_matrix)
92 | item_review_vector_matrix = self.convertDistribution(item_reduce_dim_vector_matrix)
93 |
94 | user_embedding = user_embedding + user_review_vector_matrix
95 | final_item_embedding = final_item_embedding + item_review_vector_matrix
96 |
97 | user_embedding_from_consumed_items = self.bipartite_gcn_conv(x=(final_item_embedding, user_embedding), edge_index=self.edge_index.flip([0]), edge_weight=self.edge_weight, size=(self.n_items, self.n_users))
98 |
99 | embeddings_list = [user_embedding]
100 | for layer_idx in range(self.n_layers):
101 | user_embedding = self.bipartite_gcn_conv((user_embedding, user_embedding), self.net_edge_index.flip([0]), self.net_edge_weight, size=(self.n_users, self.n_users))
102 | embeddings_list.append(user_embedding)
103 | final_user_embedding = torch.stack(embeddings_list, dim=1)
104 | final_user_embedding = torch.sum(final_user_embedding, dim=1) + user_embedding_from_consumed_items
105 |
106 | return final_user_embedding, final_item_embedding
107 |
108 | def calculate_loss(self, interaction):
109 | # clear the storage variable when training
110 | if self.restore_user_e is not None or self.restore_item_e is not None:
111 | self.restore_user_e, self.restore_item_e = None, None
112 |
113 | user = interaction[self.USER_ID]
114 | pos_item = interaction[self.ITEM_ID]
115 | neg_item = interaction[self.NEG_ITEM_ID]
116 |
117 | user_all_embeddings, item_all_embeddings = self.forward()
118 | u_embeddings = user_all_embeddings[user]
119 | pos_embeddings = item_all_embeddings[pos_item]
120 | neg_embeddings = item_all_embeddings[neg_item]
121 |
122 | # calculate BPR Loss
123 | pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
124 | neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
125 | mf_loss = self.mf_loss(pos_scores, neg_scores)
126 |
127 | # calculate regularization Loss
128 | u_ego_embeddings = self.user_embedding(user)
129 | pos_ego_embeddings = self.item_embedding(pos_item)
130 | neg_ego_embeddings = self.item_embedding(neg_item)
131 |
132 | reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings)
133 | loss = mf_loss + self.reg_weight * reg_loss
134 |
135 | return loss
136 |
137 | def predict(self, interaction):
138 | user = interaction[self.USER_ID]
139 | item = interaction[self.ITEM_ID]
140 |
141 | user_all_embeddings, item_all_embeddings = self.forward()
142 |
143 | u_embeddings = user_all_embeddings[user]
144 | i_embeddings = item_all_embeddings[item]
145 | scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
146 | return scores
147 |
148 | def full_sort_predict(self, interaction):
149 | user = interaction[self.USER_ID]
150 | if self.restore_user_e is None or self.restore_item_e is None:
151 | self.restore_user_e, self.restore_item_e = self.forward()
152 | # get user embedding from storage variable
153 | u_embeddings = self.restore_user_e[user]
154 |
155 | # dot with all item embedding to accelerate
156 | scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))
157 |
158 | return scores.view(-1)
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/DiffNet.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | n_layers: 2
3 | reg_weight: 1e-05
4 | pretrained_review: False
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/DirectAU.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | encoder: "MF" # "MF" or "lightGCN"
3 | gamma: 0.5
4 | weight_decay: 1e-6
5 | train_batch_size: 256
6 |
7 | # n_layers: 3 # needed for LightGCN
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/GCEGNN.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | leakyrelu_alpha: 0.2
3 | dropout_local: 0.
4 | dropout_global: 0.5
5 | dropout_gcn: 0.
6 | loss_type: CE
7 | gnn_transform: sess_graph
8 |
9 | # global
10 | build_global_graph: True
11 | sample_num: 12
12 | hop: 1
13 |
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/GCSAN.yaml:
--------------------------------------------------------------------------------
1 | n_layers: 1
2 | n_heads: 1
3 | hidden_size: 64
4 | inner_size: 256
5 | hidden_dropout_prob: 0.2
6 | attn_dropout_prob: 0.2
7 | hidden_act: 'gelu'
8 | layer_norm_eps: 1e-12
9 | initializer_range: 0.02
10 | step: 1
11 | weight: 0.6
12 | reg_weight: 5e-5
13 | loss_type: 'CE'
14 | gnn_transform: sess_graph
15 |
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/HMLET.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | n_layers: 4
3 | reg_weight: 1e-05
4 | require_pow: True
5 | gate_layer_ids: [2,3]
6 | gating_mlp_dims: [64,16,2]
7 | dropout_ratio: 0.2
8 | activation_function: elu
9 |
10 | warm_up_epochs: 50
11 | ori_temp: 0.7
12 | min_temp: 0.01
13 | gum_temp_decay: 0.005
14 | epoch_temp_decay: 1
15 |
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/LESSR.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | n_layers: 4
3 | batch_norm: True
4 | feat_drop: 0.2
5 | loss_type: CE
6 | gnn_transform: sess_graph
7 |
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/LightGCL.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64 # (int) The embedding size of users and items.
2 | n_layers: 2 # (int) The number of layers in LightGCL.
3 | dropout: 0.0 # (float) The dropout ratio.
4 | temp: 0.8 # (float) The temperature in softmax.
5 | lambda1: 0.01 # (float) The hyperparameter to control the strengths of SSL.
6 | lambda2: 1e-05 # (float) The L2 regularization weight.
7 | q: 5 # (int) A slightly overestimated rank of the adjacency matrix.
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/LightGCN.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | n_layers: 2
3 | reg_weight: 1e-05
4 | require_pow: True
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/MHCN.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | n_layers: 2
3 | ssl_reg: 1e-05
4 | reg_weight: 1e-05
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/NCL.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | n_layers: 3
3 | reg_weight: 1e-4
4 |
5 | ssl_temp: 0.1
6 | ssl_reg: 1e-7
7 | hyper_layers: 1
8 |
9 | alpha: 1
10 |
11 | proto_reg: 8e-8
12 | num_clusters: 1000
13 |
14 | m_step: 1
15 | warm_up_step: 20
16 |
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/NGCF.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | hidden_size_list: [64,64,64]
3 | node_dropout: 0.0
4 | message_dropout: 0.1
5 | reg_weight: 1e-5
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/NISER.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | step: 1
3 | sigma: 16
4 | item_dropout: 0.1
5 | loss_type: 'CE'
6 | gnn_transform: sess_graph
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/SEPT.yaml:
--------------------------------------------------------------------------------
1 | warm_up_epochs: 100
2 | embedding_size: 64
3 | n_layers: 2
4 | drop_ratio: 0.3
5 | instance_cnt: 10
6 | reg_weight: 1e-05
7 | ssl_weight: 1e-07
8 | ssl_tau: 0.1
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/SGL.yaml:
--------------------------------------------------------------------------------
1 | type: "ED"
2 | n_layers: 3
3 | ssl_tau: 0.5
4 | reg_weight: 1e-5
5 | ssl_weight: 0.05
6 | drop_ratio: 0.1
7 | embedding_size: 64
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/SGNNHN.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | step: 6
3 | scale: 12
4 | loss_type: 'CE'
5 | gnn_transform: sess_graph
6 |
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/SRGNN.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | step: 1
3 | loss_type: 'CE'
4 | gnn_transform: sess_graph
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/SSL4REC.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | drop_ratio: 0.1
3 | tau: 0.1
4 | reg_weight: 1e-04
5 | ssl_weight: 1e-05
6 | require_pow: True
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/SimGCL.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | n_layers: 2
3 | reg_weight: 1e-4
4 |
5 | lambda: 0.5
6 | eps: 0.1
7 | temperature: 0.2
8 |
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/TAGNN.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | step: 1
3 | loss_type: 'CE'
4 | gnn_transform: sess_graph
5 |
--------------------------------------------------------------------------------
/recbole_gnn/properties/model/XSimGCL.yaml:
--------------------------------------------------------------------------------
1 | embedding_size: 64
2 | n_layers: 2
3 | reg_weight: 0.0001
4 |
5 | lambda: 0.1
6 | eps: 0.2
7 | temperature: 0.2
8 | layer_cl: 1
9 | require_pow: True
10 |
--------------------------------------------------------------------------------
/recbole_gnn/properties/quick_start_config/sequential_base.yaml:
--------------------------------------------------------------------------------
1 | train_neg_sample_args: ~
2 |
--------------------------------------------------------------------------------
/recbole_gnn/properties/quick_start_config/social_base.yaml:
--------------------------------------------------------------------------------
1 | NET_SOURCE_ID_FIELD: source_id
2 | NET_TARGET_ID_FIELD: target_id
3 |
4 | load_col:
5 | inter: ['user_id', 'item_id', 'rating', 'timestamp']
6 | net: [source_id, target_id]
7 |
8 | filter_net_by_inter: True
9 | undirected_net: True
--------------------------------------------------------------------------------
/recbole_gnn/quick_start.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from logging import getLogger
3 | from recbole.utils import init_logger, init_seed, set_color
4 |
5 | from recbole_gnn.config import Config
6 | from recbole_gnn.utils import create_dataset, data_preparation, get_model, get_trainer
7 |
8 |
9 | def run_recbole_gnn(model=None, dataset=None, config_file_list=None, config_dict=None, saved=True):
10 | r""" A fast running api, which includes the complete process of
11 | training and testing a model on a specified dataset
12 | Args:
13 | model (str, optional): Model name. Defaults to ``None``.
14 | dataset (str, optional): Dataset name. Defaults to ``None``.
15 | config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``.
16 | config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``.
17 | saved (bool, optional): Whether to save the model. Defaults to ``True``.
18 | """
19 | # configurations initialization
20 | config = Config(model=model, dataset=dataset, config_file_list=config_file_list, config_dict=config_dict)
21 | try:
22 | assert config["enable_sparse"] in [True, False, None]
23 | except AssertionError:
24 | raise ValueError("Your config `enable_sparse` must be `True` or `False` or `None`")
25 | init_seed(config['seed'], config['reproducibility'])
26 | # logger initialization
27 | init_logger(config)
28 | logger = getLogger()
29 |
30 | logger.info(config)
31 |
32 | # dataset filtering
33 | dataset = create_dataset(config)
34 | logger.info(dataset)
35 |
36 | # dataset splitting
37 | train_data, valid_data, test_data = data_preparation(config, dataset)
38 |
39 | # model loading and initialization
40 | init_seed(config['seed'], config['reproducibility'])
41 | model = get_model(config['model'])(config, train_data.dataset).to(config['device'])
42 | logger.info(model)
43 |
44 | # trainer loading and initialization
45 | trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)
46 |
47 | # model training
48 | best_valid_score, best_valid_result = trainer.fit(
49 | train_data, valid_data, saved=saved, show_progress=config['show_progress']
50 | )
51 |
52 | # model evaluation
53 | test_result = trainer.evaluate(test_data, load_best_model=saved, show_progress=config['show_progress'])
54 |
55 | logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}')
56 | logger.info(set_color('test result', 'yellow') + f': {test_result}')
57 |
58 | return {
59 | 'best_valid_score': best_valid_score,
60 | 'valid_score_bigger': config['valid_metric_bigger'],
61 | 'best_valid_result': best_valid_result,
62 | 'test_result': test_result
63 | }
64 |
65 |
66 | def objective_function(config_dict=None, config_file_list=None, saved=True):
67 | r""" The default objective_function used in HyperTuning
68 |
69 | Args:
70 | config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``.
71 | config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``.
72 | saved (bool, optional): Whether to save the model. Defaults to ``True``.
73 | """
74 |
75 | config = Config(config_dict=config_dict, config_file_list=config_file_list)
76 | try:
77 | assert config["enable_sparse"] in [True, False, None]
78 | except AssertionError:
79 | raise ValueError("Your config `enable_sparse` must be `True` or `False` or `None`")
80 | init_seed(config['seed'], config['reproducibility'])
81 | logging.basicConfig(level=logging.ERROR)
82 | dataset = create_dataset(config)
83 | train_data, valid_data, test_data = data_preparation(config, dataset)
84 | init_seed(config['seed'], config['reproducibility'])
85 | model = get_model(config['model'])(config, train_data.dataset).to(config['device'])
86 | trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)
87 | best_valid_score, best_valid_result = trainer.fit(train_data, valid_data, verbose=False, saved=saved)
88 | test_result = trainer.evaluate(test_data, load_best_model=saved)
89 |
90 | return {
91 | 'model': config['model'],
92 | 'best_valid_score': best_valid_score,
93 | 'valid_score_bigger': config['valid_metric_bigger'],
94 | 'best_valid_result': best_valid_result,
95 | 'test_result': test_result
96 | }
97 |
--------------------------------------------------------------------------------
/recbole_gnn/trainer.py:
--------------------------------------------------------------------------------
1 | from time import time
2 | import math
3 | from torch.nn.utils.clip_grad import clip_grad_norm_
4 | from tqdm import tqdm
5 | from recbole.trainer import Trainer
6 | from recbole.utils import early_stopping, dict2str, set_color, get_gpu_usage
7 |
8 |
9 | class NCLTrainer(Trainer):
10 | def __init__(self, config, model):
11 | super(NCLTrainer, self).__init__(config, model)
12 |
13 | self.num_m_step = config['m_step']
14 | assert self.num_m_step is not None
15 |
16 | def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None):
17 | r"""Train the model based on the train data and the valid data.
18 | Args:
19 | train_data (DataLoader): the train data
20 | valid_data (DataLoader, optional): the valid data, default: None.
21 | If it's None, the early_stopping is invalid.
22 | verbose (bool, optional): whether to write training and evaluation information to logger, default: True
23 | saved (bool, optional): whether to save the model parameters, default: True
24 | show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``.
25 | callback_fn (callable): Optional callback function executed at end of epoch.
26 | Includes (epoch_idx, valid_score) input arguments.
27 | Returns:
28 | (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
29 | """
30 | if saved and self.start_epoch >= self.epochs:
31 | self._save_checkpoint(-1)
32 |
33 | self.eval_collector.data_collect(train_data)
34 |
35 | for epoch_idx in range(self.start_epoch, self.epochs):
36 |
37 | # only differences from the original trainer
38 | if epoch_idx % self.num_m_step == 0:
39 | self.logger.info("Running E-step ! ")
40 | self.model.e_step()
41 | # train
42 | training_start_time = time()
43 | train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
44 | self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss
45 | training_end_time = time()
46 | train_loss_output = \
47 | self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
48 | if verbose:
49 | self.logger.info(train_loss_output)
50 | self._add_train_loss_to_tensorboard(epoch_idx, train_loss)
51 |
52 | # eval
53 | if self.eval_step <= 0 or not valid_data:
54 | if saved:
55 | self._save_checkpoint(epoch_idx)
56 | update_output = set_color('Saving current', 'blue') + ': %s' % self.saved_model_file
57 | if verbose:
58 | self.logger.info(update_output)
59 | continue
60 | if (epoch_idx + 1) % self.eval_step == 0:
61 | valid_start_time = time()
62 | valid_score, valid_result = self._valid_epoch(valid_data, show_progress=show_progress)
63 | self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
64 | valid_score,
65 | self.best_valid_score,
66 | self.cur_step,
67 | max_step=self.stopping_step,
68 | bigger=self.valid_metric_bigger
69 | )
70 | valid_end_time = time()
71 | valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue')
72 | + ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \
73 | (epoch_idx, valid_end_time - valid_start_time, valid_score)
74 | valid_result_output = set_color('valid result', 'blue') + ': \n' + dict2str(valid_result)
75 | if verbose:
76 | self.logger.info(valid_score_output)
77 | self.logger.info(valid_result_output)
78 | self.tensorboard.add_scalar('Vaild_score', valid_score, epoch_idx)
79 |
80 | if update_flag:
81 | if saved:
82 | self._save_checkpoint(epoch_idx)
83 | update_output = set_color('Saving current best', 'blue') + ': %s' % self.saved_model_file
84 | if verbose:
85 | self.logger.info(update_output)
86 | self.best_valid_result = valid_result
87 |
88 | if callback_fn:
89 | callback_fn(epoch_idx, valid_score)
90 |
91 | if stop_flag:
92 | stop_output = 'Finished training, best eval result in epoch %d' % \
93 | (epoch_idx - self.cur_step * self.eval_step)
94 | if verbose:
95 | self.logger.info(stop_output)
96 | break
97 | self._add_hparam_to_tensorboard(self.best_valid_score)
98 | return self.best_valid_score, self.best_valid_result
99 |
100 | def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
101 | r"""Train the model in an epoch
102 | Args:
103 | train_data (DataLoader): The train data.
104 | epoch_idx (int): The current epoch id.
105 | loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be
106 | :attr:`self.model.calculate_loss`. Defaults to ``None``.
107 | show_progress (bool): Show the progress of training epoch. Defaults to ``False``.
108 | Returns:
109 | float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains
110 | multiple parts and the model return these multiple parts loss instead of the sum of loss, it will return a
111 | tuple which includes the sum of loss in each part.
112 | """
113 | self.model.train()
114 | loss_func = loss_func or self.model.calculate_loss
115 | total_loss = None
116 | iter_data = (
117 | tqdm(
118 | train_data,
119 | total=len(train_data),
120 | ncols=100,
121 | desc=set_color(f"Train {epoch_idx:>5}", 'pink'),
122 | ) if show_progress else train_data
123 | )
124 | for batch_idx, interaction in enumerate(iter_data):
125 | interaction = interaction.to(self.device)
126 | self.optimizer.zero_grad()
127 | losses = loss_func(interaction)
128 | if isinstance(losses, tuple):
129 | if epoch_idx < self.config['warm_up_step']:
130 | losses = losses[:-1]
131 | loss = sum(losses)
132 | loss_tuple = tuple(per_loss.item() for per_loss in losses)
133 | total_loss = loss_tuple if total_loss is None else tuple(map(sum, zip(total_loss, loss_tuple)))
134 | else:
135 | loss = losses
136 | total_loss = losses.item() if total_loss is None else total_loss + losses.item()
137 | self._check_nan(loss)
138 | loss.backward()
139 | if self.clip_grad_norm:
140 | clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm)
141 | self.optimizer.step()
142 | if self.gpu_available and show_progress:
143 | iter_data.set_postfix_str(set_color('GPU RAM: ' + get_gpu_usage(self.device), 'yellow'))
144 | return total_loss
145 |
146 |
147 | class HMLETTrainer(Trainer):
148 | def __init__(self, config, model):
149 | super(HMLETTrainer, self).__init__(config, model)
150 |
151 | self.warm_up_epochs = config['warm_up_epochs']
152 | self.ori_temp = config['ori_temp']
153 | self.min_temp = config['min_temp']
154 | self.gum_temp_decay = config['gum_temp_decay']
155 | self.epoch_temp_decay = config['epoch_temp_decay']
156 |
157 | def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
158 | if epoch_idx > self.warm_up_epochs:
159 | # Temp decay
160 | gum_temp = self.ori_temp * math.exp(-self.gum_temp_decay*(epoch_idx - self.warm_up_epochs))
161 | self.model.gum_temp = max(gum_temp, self.min_temp)
162 | self.logger.info(f'Current gumbel softmax temperature: {self.model.gum_temp}')
163 |
164 | for gating in self.model.gating_nets:
165 | self.model._gating_freeze(gating, True)
166 | return super()._train_epoch(train_data, epoch_idx, loss_func, show_progress)
167 |
168 |
169 | class SEPTTrainer(Trainer):
170 | def __init__(self, config, model):
171 | super(SEPTTrainer, self).__init__(config, model)
172 | self.warm_up_epochs = config['warm_up_epochs']
173 |
174 | def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
175 | if epoch_idx < self.warm_up_epochs:
176 | loss_func = self.model.calculate_rec_loss
177 | else:
178 | self.model.subgraph_construction()
179 | return super()._train_epoch(train_data, epoch_idx, loss_func, show_progress)
--------------------------------------------------------------------------------
/recbole_gnn/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import importlib
4 | from logging import getLogger
5 | from recbole.data.utils import load_split_dataloaders, create_samplers, save_split_dataloaders
6 | from recbole.data.utils import create_dataset as create_recbole_dataset
7 | from recbole.data.utils import data_preparation as recbole_data_preparation
8 | from recbole.utils import set_color, Enum
9 | from recbole.utils import get_model as get_recbole_model
10 | from recbole.utils import get_trainer as get_recbole_trainer
11 | from recbole.utils.argument_list import dataset_arguments
12 |
13 | from recbole_gnn.data.dataloader import CustomizedTrainDataLoader, CustomizedNegSampleEvalDataLoader, CustomizedFullSortEvalDataLoader
14 |
15 |
16 | def create_dataset(config):
17 | """Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`.
18 | If :attr:`config['dataset_save_path']` file exists and
19 | its :attr:`config` of dataset is equal to current :attr:`config` of dataset.
20 | It will return the saved dataset in :attr:`config['dataset_save_path']`.
21 | Args:
22 | config (Config): An instance object of Config, used to record parameter information.
23 | Returns:
24 | Dataset: Constructed dataset.
25 | """
26 | model_type = config['MODEL_TYPE']
27 | dataset_module = importlib.import_module('recbole_gnn.data.dataset')
28 | gen_graph_module_path = '.'.join(['recbole_gnn.model.general_recommender', config['model'].lower()])
29 | seq_module_path = '.'.join(['recbole_gnn.model.sequential_recommender', config['model'].lower()])
30 | if hasattr(dataset_module, config['model'] + 'Dataset'):
31 | dataset_class = getattr(dataset_module, config['model'] + 'Dataset')
32 | elif importlib.util.find_spec(gen_graph_module_path, __name__):
33 | dataset_class = getattr(dataset_module, 'GeneralGraphDataset')
34 | elif importlib.util.find_spec(seq_module_path, __name__):
35 | dataset_class = getattr(dataset_module, 'SessionGraphDataset')
36 | elif model_type == ModelType.SOCIAL:
37 | dataset_class = getattr(dataset_module, 'SocialDataset')
38 | else:
39 | return create_recbole_dataset(config)
40 |
41 | default_file = os.path.join(config['checkpoint_dir'], f'{config["dataset"]}-{dataset_class.__name__}.pth')
42 | file = config['dataset_save_path'] or default_file
43 | if os.path.exists(file):
44 | with open(file, 'rb') as f:
45 | dataset = pickle.load(f)
46 | dataset_args_unchanged = True
47 | for arg in dataset_arguments + ['seed', 'repeatable']:
48 | if config[arg] != dataset.config[arg]:
49 | dataset_args_unchanged = False
50 | break
51 | if dataset_args_unchanged:
52 | logger = getLogger()
53 | logger.info(set_color('Load filtered dataset from', 'pink') + f': [{file}]')
54 | return dataset
55 |
56 | dataset = dataset_class(config)
57 | if config['save_dataset']:
58 | dataset.save()
59 | return dataset
60 |
61 |
62 | def get_model(model_name):
63 | r"""Automatically select model class based on model name
64 | Args:
65 | model_name (str): model name
66 | Returns:
67 | Recommender: model class
68 | """
69 | model_submodule = [
70 | 'general_recommender', 'sequential_recommender', 'social_recommender'
71 | ]
72 |
73 | model_file_name = model_name.lower()
74 | model_module = None
75 | for submodule in model_submodule:
76 | module_path = '.'.join(['recbole_gnn.model', submodule, model_file_name])
77 | if importlib.util.find_spec(module_path, __name__):
78 | model_module = importlib.import_module(module_path, __name__)
79 | break
80 |
81 | if model_module is None:
82 | model_class = get_recbole_model(model_name)
83 | else:
84 | model_class = getattr(model_module, model_name)
85 | return model_class
86 |
87 |
88 | def _get_customized_dataloader(config, phase):
89 | if phase == 'train':
90 | return CustomizedTrainDataLoader
91 | else:
92 | eval_mode = config["eval_args"]["mode"]
93 | if eval_mode == 'full':
94 | return CustomizedFullSortEvalDataLoader
95 | else:
96 | return CustomizedNegSampleEvalDataLoader
97 |
98 |
99 | def data_preparation(config, dataset):
100 | """Split the dataset by :attr:`config['eval_args']` and create training, validation and test dataloader.
101 | Note:
102 | If we can load split dataloaders by :meth:`load_split_dataloaders`, we will not create new split dataloaders.
103 | Args:
104 | config (Config): An instance object of Config, used to record parameter information.
105 | dataset (Dataset): An instance object of Dataset, which contains all interaction records.
106 | Returns:
107 | tuple:
108 | - train_data (AbstractDataLoader): The dataloader for training.
109 | - valid_data (AbstractDataLoader): The dataloader for validation.
110 | - test_data (AbstractDataLoader): The dataloader for testing.
111 | """
112 | seq_module_path = '.'.join(['recbole_gnn.model.sequential_recommender', config['model'].lower()])
113 | if importlib.util.find_spec(seq_module_path, __name__):
114 | # Special condition for sequential models of RecBole-Graph
115 | dataloaders = load_split_dataloaders(config)
116 | if dataloaders is not None:
117 | train_data, valid_data, test_data = dataloaders
118 | else:
119 | built_datasets = dataset.build()
120 | train_dataset, valid_dataset, test_dataset = built_datasets
121 | train_sampler, valid_sampler, test_sampler = create_samplers(config, dataset, built_datasets)
122 |
123 | train_data = _get_customized_dataloader(config, 'train')(config, train_dataset, train_sampler, shuffle=True)
124 | valid_data = _get_customized_dataloader(config, 'evaluation')(config, valid_dataset, valid_sampler, shuffle=False)
125 | test_data = _get_customized_dataloader(config, 'evaluation')(config, test_dataset, test_sampler, shuffle=False)
126 | if config['save_dataloaders']:
127 | save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data))
128 |
129 | logger = getLogger()
130 | logger.info(
131 | set_color('[Training]: ', 'pink') + set_color('train_batch_size', 'cyan') + ' = ' +
132 | set_color(f'[{config["train_batch_size"]}]', 'yellow') + set_color(' negative sampling', 'cyan') + ': ' +
133 | set_color(f'[{config["train_neg_sample_args"]}]', 'yellow')
134 | )
135 | logger.info(
136 | set_color('[Evaluation]: ', 'pink') + set_color('eval_batch_size', 'cyan') + ' = ' +
137 | set_color(f'[{config["eval_batch_size"]}]', 'yellow') + set_color(' eval_args', 'cyan') + ': ' +
138 | set_color(f'[{config["eval_args"]}]', 'yellow')
139 | )
140 | return train_data, valid_data, test_data
141 | else:
142 | return recbole_data_preparation(config, dataset)
143 |
144 |
145 | def get_trainer(model_type, model_name):
146 | r"""Automatically select trainer class based on model type and model name
147 | Args:
148 | model_type (ModelType): model type
149 | model_name (str): model name
150 | Returns:
151 | Trainer: trainer class
152 | """
153 | try:
154 | return getattr(importlib.import_module('recbole_gnn.trainer'), model_name + 'Trainer')
155 | except AttributeError:
156 | return get_recbole_trainer(model_type, model_name)
157 |
158 |
159 | class ModelType(Enum):
160 | """Type of models.
161 |
162 | - ``Social``: Social-based Recommendation
163 | """
164 |
165 | SOCIAL = 7
--------------------------------------------------------------------------------
/results/README.md:
--------------------------------------------------------------------------------
1 | ## General Model Results
2 |
3 | * [ml-1m](general/ml-1m.md)
4 |
5 | ## Sequential Model Results
6 |
7 | * [diginetica](sequential/diginetica.md)
8 |
9 | ## Social-aware Model Results
10 |
11 | * [lastfm](social/lastfm.md)
12 |
--------------------------------------------------------------------------------
/results/general/ml-1m.md:
--------------------------------------------------------------------------------
1 | # Experimental Setting
2 |
3 | **Dataset:** [MovieLens-1M](https://grouplens.org/datasets/movielens/)
4 |
5 | **Filtering:** Remove interactions with a rating score of less than 3
6 |
7 | **Evaluation:** ratio-based 8:1:1, full sort
8 |
9 | **Metrics:** Recall@10, NGCG@10, MRR@10, Hit@10, Precision@10
10 |
11 | **Properties:**
12 |
13 | ```yaml
14 | # dataset config
15 | field_separator: "\t"
16 | seq_separator: " "
17 | USER_ID_FIELD: user_id
18 | ITEM_ID_FIELD: item_id
19 | RATING_FIELD: rating
20 | NEG_PREFIX: neg_
21 | LABEL_FIELD: label
22 | load_col:
23 | inter: [user_id, item_id, rating]
24 | val_interval:
25 | rating: "[3,inf)"
26 | unused_col:
27 | inter: [rating]
28 |
29 | # training and evaluation
30 | epochs: 500
31 | train_batch_size: 4096
32 | valid_metric: MRR@10
33 | eval_batch_size: 4096000
34 | ```
35 |
36 | For fairness, we restrict users' and items' embedding dimension as following. Please adjust the name of the corresponding args of different models.
37 | ```
38 | embedding_size: 64
39 | ```
40 |
41 | # Dataset Statistics
42 |
43 | | Dataset | #Users | #Items | #Interactions | Sparsity |
44 | | ---------- | ------ | ------ | ------------- | -------- |
45 | | ml-1m | 6,040 | 3,629 | 836,478 | 96.18% |
46 |
47 | # Evaluation Results
48 |
49 | | Method | Recall@10 | MRR@10 | NDCG@10 | Hit@10 | Precision@10 |
50 | |--------------|-----------|--------|---------|--------|--------------|
51 | | **BPR** | 0.1776 | 0.4187 | 0.2401 | 0.7199 | 0.1779 |
52 | | **NeuMF** | 0.1651 | 0.4020 | 0.2271 | 0.7029 | 0.1700 |
53 | | **NGCF** | 0.1814 | 0.4354 | 0.2508 | 0.7239 | 0.1850 |
54 | | **LightGCN** | 0.1861 | 0.4388 | 0.2538 | 0.7330 | 0.1863 |
55 | | **LightGCL** | 0.1867 | 0.4283 | 0.2479 | 0.7370 | 0.1815 |
56 | | **SGL** | 0.1889 | 0.4315 | 0.2505 | 0.7392 | 0.1843 |
57 | | **HMLET** | 0.1847 | 0.4297 | 0.2490 | 0.7305 | 0.1836 |
58 | | **NCL** | 0.2021 | 0.4599 | 0.2702 | 0.7565 | 0.1962 |
59 | | **SimGCL** | 0.2029 | 0.4550 | 0.2667 | 0.7640 | 0.1933 |
60 | | **XSimGCL** | 0.2116 | 0.4638 | 0.2750 | 0.7743 | 0.1987 |
61 |
62 | # Hyper-parameters
63 |
64 | | | Best hyper-parameters | Tuning range |
65 | |--------------|------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
66 | | **BPR** | learning_rate=0.001 | learning_rate choice [0.05, 0.02, 0.01, 0.005, 0.002, 0.001, 0.0005, 0.0002, 0.0001, 0.00005, 0.00002, 0.00001] |
67 | | **NeuMF** | learning_rate=0.0001
mlp_hidden_size=[32,16,8]
dropout_prob=0 | learning_rate choice [0.005, 0.002, 0.001, 0.0005, 0.0002, 0.0001, 0.00005]
mlp_hidden_size choice ['[64,64]', '[64,32]', '[64,32,16]','[32,16,8]']
dropout_prob choice [0, 0.1, 0.2] |
68 | | **NGCF** | learning_rate=0.0002
message_dropout=0.0
node_dropout=0.0 | learning_rate choice [0.001, 0.0005, 0.0002]
node_dropout choice [0.0, 0.1]
message_dropout choice [0.0, 0.1] |
69 | | **LightGCN** | learning_rate=0.002
n_layers=3
reg_weight=0.0001 | learning_rate choice [0.005, 0.002, 0.001]
n_layers choice [2, 3]
reg_weight choice [1e-4, 1e-5] |
70 | | **LightGCL** | learning_rate=0.001
n_layers=2
lambda1=0.0001
temp=2
lambda2=1e-7
dropout=0.1 | learning_rate choice [0.001]
n_layers choice [2, 3]
lambda1 choice [0.01, 0.005, 0.001, 0.0001, 1e-5, 1e-7]
temp choice [0.5, 0.8, 2, 3]
lambda2 choice [1e-4, 1e-5, 1e-7]
dropout choice [0.0, 0.1, 0.25] |
71 | | **SGL** | learning_rate=0.002
n_layers=3
reg_weight=0.0001
ssl_tau=0.5
drop_ratio=0.1
ssl_weight=0.005 | learning_rate choice [0.002]
n_layers choice [3]
reg_weight choice [1e-4]
ssl_tau choice [0.1, 0.5]
drop_ratio choice [0.1, 0.3]
ssl_weight choice [1e-5, 1e-6, 1e-7, 0.005, 0.01, 0.05] |
72 | | **HMLET** | learning_rate=0.002
n_layers=4
activation_function=leakyrelu | learning_rate choice [0.002, 0.001, 0.0005]
n_layers choice [3, 4]
activation_function choice ['elu', 'leakyrelu'] |
73 | | **NCL** | learning_rate=0.002
n_layers=3
reg_weight=0.0001
ssl_temp=0.1
ssl_reg=1e-06
hyper_layers=1
alpha=1.5 | learning_rate choice [0.002]
n_layers choice [3]
reg_weight choice [1e-4]
ssl_temp choice [0.1, 0.05]
ssl_reg choice [1e-7, 1e-6]
hyper_layers choice [1]
alpha choice [1, 0.8, 1.5] |
74 | | **SimGCL** | learning_rate=0.002
n_layers=2
reg_weight=0.0001
temperature=0.05
lambda=1e-5
eps=0.1 | learning_rate choice [0.002]
n_layers choice [2, 3]
reg_weight choice [1e-4]
temperature choice [0.05, 0.1, 0.2]
lambda choice [1e-5, 1e-6, 1e-7, 0.005, 0.01, 0.05]
eps choice [0.1, 0.2] |
75 | | **XSimGCL** | learning_rate=0.002
n_layers=2
reg_weight=0.0001
temperature=0.2
lambda=0.1
eps=0.2
layer_cl=1 | learning_rate choice [0.002]
n_layers choice [2, 3]
reg_weight choice [1e-4]
temperature choice [0.05, 0.1, 0.2]
lambda choice [1e-5, 1e-6, 1e-7, 1e-4, 0.005, 0.01, 0.05, 0.1]
eps choice [0.1, 0.2]
layer_cl choice [1] |
76 |
--------------------------------------------------------------------------------
/results/sequential/diginetica.md:
--------------------------------------------------------------------------------
1 | # Experimental Setting
2 |
3 | **Dataset:** diginetica-not-merged
4 |
5 | **Filtering:** Remove users and items with less than 5 interactions
6 |
7 | **Evaluation:** leave one out, full sort
8 |
9 | **Metrics:** Recall@10, NGCG@10, MRR@10, Hit@10, Precision@10
10 |
11 | **Properties:**
12 |
13 | ```yaml
14 | # dataset config
15 | field_separator: "\t"
16 | seq_separator: " "
17 | USER_ID_FIELD: session_id
18 | ITEM_ID_FIELD: item_id
19 | TIME_FIELD: timestamp
20 | NEG_PREFIX: neg_
21 | ITEM_LIST_LENGTH_FIELD: item_length
22 | LIST_SUFFIX: _list
23 | MAX_ITEM_LIST_LENGTH: 20
24 | POSITION_FIELD: position_id
25 | load_col:
26 | inter: [session_id, item_id, timestamp]
27 | user_inter_num_interval: "[5,inf)"
28 | item_inter_num_interval: "[5,inf)"
29 |
30 | # training and evaluation
31 | epochs: 500
32 | train_batch_size: 4096
33 | eval_batch_size: 2000
34 | valid_metric: MRR@10
35 | eval_args:
36 | split: {'LS':"valid_and_test"}
37 | mode: full
38 | order: TO
39 | train_neg_sample_args: ~
40 | ```
41 |
42 | For fairness, we restrict users' and items' embedding dimension as following. Please adjust the name of the corresponding args of different models.
43 | ```
44 | embedding_size: 64
45 | ```
46 |
47 | # Dataset Statistics
48 |
49 | | Dataset | #Users | #Items | #Interactions | Sparsity |
50 | | ---------- | ------ | ------ | ------------- | -------- |
51 | | diginetica | 72,014 | 29,454 | 580,490 | 99.97% |
52 |
53 | # Evaluation Results
54 |
55 | | Method | Recall@10 | MRR@10 | NDCG@10 | Hit@10 | Precision@10 |
56 | | -------------------- | --------- | ------ | ------- | ------ | ------------ |
57 | | **GRU4Rec** | 0.3691 | 0.1632 | 0.2114 | 0.3691 | 0.0369 |
58 | | **NARM** | 0.3801 | 0.1695 | 0.2188 | 0.3801 | 0.0380 |
59 | | **SASRec** | 0.4144 | 0.1857 | 0.2393 | 0.4144 | 0.0414 |
60 | | **SR-GNN** | 0.3881 | 0.1754 | 0.2253 | 0.3881 | 0.0388 |
61 | | **GC-SAN** | 0.4127 | 0.1881 | 0.2408 | 0.4127 | 0.0413 |
62 | | **NISER+** | 0.4144 | 0.1904 | 0.2430 | 0.4144 | 0.0414 |
63 | | **LESSR** | 0.3964 | 0.1763 | 0.2279 | 0.3964 | 0.0396 |
64 | | **TAGNN** | 0.3894 | 0.1763 | 0.2263 | 0.3894 | 0.0389 |
65 | | **GCE-GNN** | 0.4284 | 0.1961 | 0.2507 | 0.4284 | 0.0428 |
66 | | **SGNN-HN** | 0.4183 | 0.1877 | 0.2418 | 0.4183 | 0.0418 |
67 |
68 | # Hyper-parameters
69 |
70 | | | Best hyper-parameters | Tuning range |
71 | | -------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
72 | | **GRU4Rec** | learning_rate=0.01
hidden_size=128
dropout_prob=0.3
num_layers=1 | learning_rate in [1e-2, 1e-3, 3e-3]
num_layers in [1, 2, 3]
hidden_size in [128]
dropout_prob in [0.1, 0.2, 0.3] |
73 | | **SASRec** | learning_rate=0.001
n_layers=2
attn_dropout_prob=0.2
hidden_dropout_prob=0.2 | learning_rate in [0.001, 0.0001]
n_layers in [1, 2]
hidden_dropout_prob in [0.2, 0.5]
attn_dropout_prob in [0.2, 0.5] |
74 | | **NARM** | learning_rate=0.001
hidden_size=128
n_layers=1
dropout_probs=[0.25, 0.5] | learning_rate in [0.001, 0.01, 0.03]
hidden_size in [128]
n_layers in [1, 2]
dropout_probs in ['[0.25,0.5]', '[0.2,0.2]', '[0.1,0.2]'] |
75 | | **SR-GNN** | learning_rate=0.001
step=1 | learning_rate in [0.01, 0.001, 0.0001]
step in [1, 2] |
76 | | **GC-SAN** | learning_rate=0.001
step=1 | learning_rate in [0.01, 0.001, 0.0001]
step in [1, 2] |
77 | | **NISER+** | learning_rate=0.001
sigma=16 | learning_rate in [0.01, 0.001, 0.003]
sigma in [10, 16, 20] |
78 | | **LESSR** | learning_rate=0.001
n_layers=4 | learning_rate in [0.01, 0.001, 0.003]
n_layers in [2, 4] |
79 | | **TAGNN** | learning_rate=0.001 | learning_rate in [0.01, 0.001, 0.003]
train_batch_size=512 |
80 | | **GCE-GNN** | learning_rate=0.001
dropout_global=0.5 | learning_rate in [0.01, 0.001, 0.003]
dropout_global in [0.2, 0.5] |
81 | | **SGNN-HN** | learning_rate=0.003
scale=12
step=2 | learning_rate in [0.01, 0.001, 0.003]
scale in [12, 16, 20]
step in [2, 4, 6] |
82 |
--------------------------------------------------------------------------------
/results/social/lastfm.md:
--------------------------------------------------------------------------------
1 | # Experimental Setting
2 |
3 | **Dataset:** [LastFM](http://files.grouplens.org/datasets/hetrec2011/)
4 |
5 | > Note that datasets for social recommendation methods can be downloaded from [Social-Datasets](https://github.com/Sherry-XLL/Social-Datasets).
6 |
7 | **Filtering:** None
8 |
9 | **Evaluation:** ratio-based 8:1:1, full sort
10 |
11 | **Metrics:** Recall@10, NGCG@10, MRR@10, Hit@10, Precision@10
12 |
13 | **Properties:**
14 |
15 | ```yaml
16 | # dataset config
17 | field_separator: "\t"
18 | seq_separator: " "
19 | USER_ID_FIELD: user_id
20 | ITEM_ID_FIELD: artist_id
21 | NET_SOURCE_ID_FIELD: source_id
22 | NET_TARGET_ID_FIELD: target_id
23 | LABEL_FIELD: label
24 | NEG_PREFIX: neg_
25 | load_col:
26 | inter: [user_id, artist_id]
27 | net: [source_id, target_id]
28 |
29 | # social network config
30 | filter_net_by_inter: True
31 | undirected_net: True
32 |
33 | # training and evaluation
34 | epochs: 5000
35 | train_batch_size: 4096
36 | eval_batch_size: 409600000
37 | valid_metric: NDCG@10
38 | stopping_step: 50
39 | ```
40 |
41 | For fairness, we restrict users' and items' embedding dimension as following. Please adjust the name of the corresponding args of different models.
42 | ```
43 | embedding_size: 64
44 | ```
45 |
46 | # Dataset Statistics
47 |
48 | | Dataset | #Users | #Items | #Interactions | Sparsity |
49 | | ---------- | ------ | ------ | ------------- | -------- |
50 | | lastfm | 1,892 | 17,632 | 92,834 | 99.72% |
51 |
52 | # Evaluation Results
53 |
54 | | Method | Recall@10 | MRR@10 | NDCG@10 | Hit@10 | Precision@10 |
55 | | -------------------- | --------- | ------ | ------- | ------ | ------------ |
56 | | **BPR** | 0.1761 | 0.3026 | 0.1674 | 0.5573 | 0.0858 |
57 | | **NeuMF** | 0.1696 | 0.2924 | 0.1604 | 0.5456 | 0.0828 |
58 | | **NGCF** | 0.1960 | 0.3479 | 0.1898 | 0.6141 | 0.0961 |
59 | | **LightGCN** | 0.2064 | 0.3559 | 0.1972 | 0.6322 | 0.1009 |
60 | | **DiffNet** | 0.1757 | 0.3117 | 0.1694 | 0.5621 | 0.0857 |
61 | | **MHCN** | 0.2123 | 0.3782 | 0.2068 | 0.6523 | 0.1042 |
62 | | **SEPT** | 0.2127 | 0.3703 | 0.2057 | 0.6465 | 0.1044 |
63 |
64 | # Hyper-parameters
65 |
66 | | | Best hyper-parameters | Tuning range |
67 | | -------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
68 | | **BPR** | learning_rate=0.0005 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001] |
69 | | **NeuMF** | learning_rate=0.0005
dropout_prob=0.1 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
dropout_prob in [0.1, 0.2, 0.3] |
70 | | **NGCF** | learning_rate=0.0005
hidden_size_list=[64,64,64] | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
hidden_size_list in ['[64]', '[64,64]', '[64,64,64]'] |
71 | | **LightGCN** | learning_rate=0.001
n_layers=3 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
n_layers in [1, 2, 3] |
72 | | **DiffNet** | learning_rate=0.0005
n_layers=1 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
n_layers in [1, 2, 3] |
73 | | **MHCN** | learning_rate=0.0005
n_layers=2
ssl_reg=1e-05 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
n_layers in [1, 2, 3]
ssl_reg in [1e-04, 1e-05, 1e-06] |
74 | | **SEPT** | learning_rate=0.0005
n_layers=2
ssl_weight=1e-07 | learning_rate in [0.01, 0.005, 0.001, 0.0005, 0.0001]
n_layers in [1, 2, 3]
ssl_weight in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7] |
75 |
--------------------------------------------------------------------------------
/run_hyper.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from recbole.trainer import HyperTuning
4 | from recbole_gnn.quick_start import objective_function
5 |
6 |
7 | def main():
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--config_files', type=str, default=None, help='fixed config files')
10 | parser.add_argument('--params_file', type=str, default=None, help='parameters file')
11 | parser.add_argument('--output_file', type=str, default='hyper_example.result', help='output file')
12 | args, _ = parser.parse_known_args()
13 |
14 | # plz set algo='exhaustive' to use exhaustive search, in this case, max_evals is auto set
15 | config_file_list = args.config_files.strip().split(' ') if args.config_files else None
16 | hp = HyperTuning(objective_function, algo='exhaustive',
17 | params_file=args.params_file, fixed_config_file_list=config_file_list)
18 | hp.run()
19 | hp.export_result(output_file=args.output_file)
20 | print('best params: ', hp.best_params)
21 | print('best result: ')
22 | print(hp.params2result[hp.params2str(hp.best_params)])
23 |
24 |
25 | if __name__ == '__main__':
26 | main()
27 |
--------------------------------------------------------------------------------
/run_recbole_gnn.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from recbole_gnn.quick_start import run_recbole_gnn
4 |
5 |
6 | if __name__ == '__main__':
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--model', '-m', type=str, default='BPR', help='name of models')
9 | parser.add_argument('--dataset', '-d', type=str, default='ml-100k', help='name of datasets')
10 | parser.add_argument('--config_files', type=str, default=None, help='config files')
11 |
12 | args, _ = parser.parse_known_args()
13 |
14 | config_file_list = args.config_files.strip().split(' ') if args.config_files else None
15 | run_recbole_gnn(model=args.model, dataset=args.dataset, config_file_list=config_file_list)
16 |
--------------------------------------------------------------------------------
/run_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | python -m pytest -v tests/test_model.py
5 | echo "model tests finished"
6 |
--------------------------------------------------------------------------------
/tests/test_data/test/test.net:
--------------------------------------------------------------------------------
1 | source_id:token target_id:token
2 | 187 100
3 | 119 40
4 | 96 119
5 | 12 52
6 | 153 131
7 | 259 232
8 | 191 307
9 | 83 150
10 | 86 255
11 | 177 4
12 | 210 192
13 | 25 323
14 | 90 298
15 | 38 47
16 | 201 283
17 | 93 63
18 | 115 190
19 | 143 293
20 | 147 265
21 | 320 68
22 | 188 273
23 | 332 321
24 | 212 203
25 | 326 98
26 | 74 270
27 | 4 333
28 | 87 261
29 | 163 207
30 | 18 175
31 | 127 77
32 | 296 179
33 | 17 101
34 | 24 30
35 | 102 288
36 | 345 269
37 | 270 188
38 | 235 297
39 | 68 303
40 | 313 43
41 | 239 109
42 | 28 76
43 | 108 227
44 | 78 218
45 | 96 30
46 | 180 301
47 | 211 12
48 | 234 34
49 | 178 53
50 | 3 243
51 | 179 73
52 | 98 92
53 | 310 116
54 | 154 271
55 | 293 3
56 | 80 297
57 | 329 254
58 | 198 134
59 | 341 238
60 | 75 185
61 | 166 64
62 | 205 142
63 | 317 163
64 | 261 91
65 | 314 322
66 | 4 33
67 | 71 73
68 | 289 182
69 | 21 12
70 | 248 49
71 | 255 32
72 | 261 170
73 | 257 314
74 | 159 118
75 | 212 221
76 | 5 177
77 | 204 57
78 | 132 120
79 | 13 275
80 | 340 252
81 | 245 251
82 | 334 15
83 | 130 103
84 | 280 187
85 | 232 153
86 | 242 341
87 | 219 123
88 | 6 290
89 | 49 289
90 | 46 347
91 | 185 231
92 | 57 254
93 | 134 248
94 | 24 234
95 | 57 207
96 | 147 295
97 | 191 274
98 | 340 54
99 | 280 150
100 | 190 4
101 | 238 198
102 | 72 123
103 | 122 178
104 | 7 334
105 | 11 90
106 | 232 78
107 | 16 77
108 | 41 190
109 | 108 101
110 | 212 66
111 | 258 18
112 | 321 250
113 | 126 280
114 | 271 85
115 | 11 176
116 | 22 69
117 | 129 159
118 | 235 193
119 | 129 88
120 | 221 315
121 | 308 329
122 | 103 83
123 | 180 43
124 | 208 87
125 | 64 75
126 | 92 36
127 | 298 151
128 | 56 103
129 | 162 268
130 | 81 252
131 | 344 115
132 | 67 282
133 | 132 17
134 | 83 307
135 | 299 82
136 | 321 227
137 | 48 13
138 | 212 57
139 | 344 280
140 | 195 81
141 | 112 122
142 | 345 346
143 | 65 18
144 | 269 3
145 | 131 123
146 | 185 311
147 | 124 330
148 | 347 297
149 | 321 251
150 | 196 135
151 | 65 122
152 | 322 197
153 | 334 160
154 | 129 64
155 | 38 17
156 | 289 256
157 | 51 286
158 | 107 260
159 | 300 101
160 | 290 281
161 | 192 170
162 | 42 2
163 | 54 260
164 | 126 1
165 | 326 294
166 | 119 14
167 | 48 172
168 | 133 191
169 | 332 157
170 | 311 99
171 | 115 123
172 | 160 201
173 | 269 267
174 | 302 184
175 | 262 168
176 | 11 80
177 | 317 155
178 | 163 310
179 | 290 32
180 | 90 239
181 | 246 129
182 | 105 189
183 | 336 8
184 | 266 100
185 | 153 311
186 | 7 20
187 | 329 94
188 | 135 38
189 | 216 331
190 | 291 89
191 | 121 253
192 | 246 82
193 | 113 325
194 | 99 313
195 | 226 188
196 | 319 60
197 | 195 280
198 | 245 319
199 | 168 291
200 | 63 127
201 | 316 280
202 | 67 69
203 | 40 143
204 | 177 18
205 | 239 253
206 | 213 304
207 | 218 315
208 | 18 312
209 | 165 6
210 | 324 232
211 | 167 156
212 | 295 275
213 | 42 110
214 | 25 226
215 | 114 104
216 | 172 305
217 | 66 26
218 | 51 303
219 | 247 110
220 | 245 18
221 | 335 307
222 | 325 95
223 | 289 81
224 | 166 141
225 | 4 39
226 | 171 16
227 | 79 145
228 | 187 65
229 | 102 105
230 | 234 70
231 | 321 104
232 | 62 179
233 | 171 122
234 | 225 239
235 | 283 315
236 | 121 107
237 | 154 297
238 | 309 170
239 | 3 38
240 | 78 345
241 | 164 238
242 | 92 142
243 | 339 4
244 | 251 61
245 | 223 240
246 | 167 39
247 | 223 8
248 | 61 253
249 | 220 256
250 | 139 247
251 | 199 267
252 | 344 264
253 | 336 56
254 | 110 235
255 | 75 90
256 | 93 321
257 | 345 277
258 | 119 260
259 | 214 10
260 | 15 86
261 | 102 5
262 | 34 213
263 | 223 238
264 | 243 169
265 | 107 223
266 | 106 175
267 | 218 104
268 | 28 82
269 | 267 37
270 | 331 124
271 | 16 146
272 | 186 289
273 | 226 304
274 | 109 34
275 | 124 73
276 | 165 286
277 | 260 70
278 | 94 159
279 | 151 257
280 | 151 210
281 | 263 288
282 | 276 218
283 | 222 79
284 | 48 133
285 | 67 218
286 | 282 250
287 | 127 195
288 | 222 316
289 | 19 272
290 | 238 43
291 | 71 240
292 | 208 65
293 | 219 300
294 | 338 29
295 | 75 86
296 | 86 269
297 | 91 100
298 | 273 248
299 | 202 9
300 | 190 33
301 | 84 92
302 | 124 306
303 | 284 70
304 | 281 341
305 | 247 302
306 | 306 230
307 | 320 279
308 | 319 41
309 | 91 160
310 | 323 201
311 | 305 194
312 | 41 156
313 | 220 264
314 | 296 310
315 | 183 131
316 | 232 21
317 | 239 218
318 | 302 49
319 | 250 287
320 | 200 109
321 | 96 263
322 | 225 221
323 | 123 263
324 | 329 256
325 | 136 344
326 | 338 76
327 | 233 245
328 | 347 198
329 | 99 83
330 | 240 81
331 | 238 291
332 | 78 331
333 | 56 225
334 | 21 93
335 | 24 293
336 | 28 155
337 | 245 19
338 | 225 198
339 | 90 235
340 | 191 35
341 | 146 28
342 | 303 194
343 | 203 276
344 | 189 49
345 | 265 232
346 | 204 198
347 | 283 217
348 | 306 44
349 | 133 175
350 | 256 80
351 | 345 215
352 | 97 13
353 | 25 287
354 | 104 48
355 | 20 50
356 | 155 340
357 | 202 57
358 | 343 263
359 | 135 293
360 | 152 266
361 | 232 182
362 | 86 217
363 | 73 72
364 | 143 44
365 | 299 162
366 | 277 324
367 | 154 124
368 | 307 210
369 | 210 226
370 | 323 293
371 | 55 97
372 | 52 8
373 | 32 163
374 | 312 307
375 | 271 171
376 | 204 34
377 | 64 282
378 | 311 315
379 | 174 58
380 | 56 84
381 | 217 275
382 | 86 180
383 | 342 84
384 | 340 174
385 | 13 80
386 | 100 197
387 | 189 341
388 | 5 86
389 | 9 40
390 | 210 329
391 | 260 188
392 | 236 261
393 | 94 282
394 | 105 188
395 | 141 258
396 | 132 285
397 | 17 156
398 | 70 213
399 | 204 5
400 | 344 74
401 | 34 202
402 | 347 263
403 | 121 312
404 | 146 219
405 | 31 48
406 | 53 291
407 | 213 203
408 | 125 9
409 | 279 301
410 | 247 140
411 | 217 2
412 | 298 83
413 | 315 311
414 | 165 209
415 | 169 270
416 | 259 40
417 | 174 285
418 | 21 276
419 | 58 229
420 | 165 84
421 | 48 29
422 | 222 257
423 | 38 209
424 | 336 30
425 | 53 63
426 | 269 243
427 | 36 324
428 | 252 138
429 | 113 155
430 | 123 290
431 | 10 253
432 | 346 15
433 | 217 36
434 | 15 102
435 | 264 149
436 | 143 122
437 | 300 178
438 | 25 220
439 | 58 231
440 | 19 250
441 | 11 147
442 | 73 186
443 | 90 109
444 | 248 104
445 | 196 55
446 | 308 298
447 | 316 7
448 | 160 208
449 | 173 323
450 | 196 176
451 | 147 168
452 | 168 293
453 | 274 328
454 | 6 133
455 | 177 226
456 | 49 336
457 | 173 7
458 | 307 1
459 | 85 128
460 | 63 241
461 | 39 323
462 | 167 173
463 | 298 253
464 | 171 42
465 | 196 326
466 | 53 329
467 | 221 307
468 | 51 194
469 | 192 231
470 | 13 23
471 | 308 117
472 | 324 84
473 | 228 13
474 | 231 156
475 | 314 286
476 | 321 314
477 | 140 30
478 | 143 288
479 | 55 340
480 | 192 264
481 | 119 220
482 | 28 226
483 | 248 309
484 | 227 122
485 | 157 227
486 | 81 178
487 | 143 329
488 | 327 170
489 | 199 308
490 | 297 27
491 | 28 101
492 | 317 179
493 | 176 293
494 | 328 265
495 | 64 256
496 | 176 316
497 | 336 315
498 | 137 189
499 | 290 209
500 | 243 232
501 | 305 233
502 | 28 26
503 | 216 306
504 | 155 65
505 | 246 166
506 | 148 218
507 | 28 343
508 | 31 148
509 | 6 38
510 | 43 267
511 | 85 30
512 | 5 212
513 | 328 157
514 | 93 65
515 | 158 179
516 | 315 256
517 | 261 210
518 | 8 234
519 | 137 163
520 | 261 9
521 | 247 231
522 | 32 266
523 | 118 191
524 | 107 34
525 | 87 153
526 | 132 81
527 | 41 235
528 | 80 103
529 | 13 167
530 | 31 166
531 | 290 32
532 | 53 125
533 | 131 163
534 | 188 82
535 | 68 38
536 | 94 325
537 | 254 129
538 | 99 63
539 | 267 164
540 | 1 46
541 | 175 36
542 | 99 72
543 | 328 80
544 | 84 221
545 | 164 80
546 | 232 264
547 | 172 70
548 | 227 346
549 | 183 44
550 | 208 184
551 | 120 317
552 | 20 154
553 | 76 315
554 | 52 200
555 | 231 46
556 | 343 241
557 | 42 284
558 | 229 345
559 | 213 75
560 | 155 135
561 | 28 261
562 | 22 255
563 | 106 169
564 | 310 347
565 | 212 275
566 | 104 314
567 | 347 181
568 | 285 72
569 | 26 68
570 | 6 331
571 | 19 227
572 | 325 108
573 | 325 110
574 | 152 226
575 | 221 160
576 | 310 226
577 | 145 57
578 | 228 299
579 | 233 139
580 | 291 1
581 | 52 173
582 | 173 33
583 | 48 339
584 | 188 27
585 | 329 117
586 | 216 73
587 | 291 325
588 | 180 22
589 | 343 95
590 | 293 172
591 | 31 146
592 | 99 213
593 | 290 10
594 | 79 212
595 | 184 96
596 | 257 27
597 | 11 323
598 | 117 95
599 | 215 118
600 | 258 23
601 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 |
4 | from recbole_gnn.quick_start import objective_function
5 |
6 | current_path = os.path.dirname(os.path.realpath(__file__))
7 | config_file_list = [os.path.join(current_path, 'test_model.yaml')]
8 |
9 |
10 | def quick_test(config_dict):
11 | objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False)
12 |
13 |
14 | class TestGeneralRecommender(unittest.TestCase):
15 | def test_bpr(self):
16 | config_dict = {
17 | 'model': 'BPR',
18 | }
19 | quick_test(config_dict)
20 |
21 | def test_neumf(self):
22 | config_dict = {
23 | 'model': 'NeuMF',
24 | }
25 | quick_test(config_dict)
26 |
27 | def test_ngcf(self):
28 | config_dict = {
29 | 'model': 'NGCF',
30 | }
31 | quick_test(config_dict)
32 |
33 | def test_lightgcn(self):
34 | config_dict = {
35 | 'model': 'LightGCN',
36 | }
37 | quick_test(config_dict)
38 |
39 | def test_sgl(self):
40 | config_dict = {
41 | 'model': 'SGL',
42 | }
43 | quick_test(config_dict)
44 |
45 | def test_hmlet(self):
46 | config_dict = {
47 | 'model': 'HMLET',
48 | }
49 | quick_test(config_dict)
50 |
51 | def test_ncl(self):
52 | config_dict = {
53 | 'model': 'NCL',
54 | 'num_clusters': 10
55 | }
56 | quick_test(config_dict)
57 |
58 | def test_simgcl(self):
59 | config_dict = {
60 | 'model': 'SimGCL'
61 | }
62 | quick_test(config_dict)
63 |
64 | def test_xsimgcl(self):
65 | config_dict = {
66 | 'model': 'XSimGCL'
67 | }
68 | quick_test(config_dict)
69 |
70 | def test_lightgcl(self):
71 | config_dict = {
72 | 'model': 'LightGCL'
73 | }
74 | quick_test(config_dict)
75 |
76 | def test_directau(self):
77 | config_dict = {
78 | 'model': 'DirectAU'
79 | }
80 | quick_test(config_dict)
81 |
82 | def test_ssl4rec(self):
83 | config_dict = {
84 | 'model': 'SSL4REC'
85 | }
86 | quick_test(config_dict)
87 |
88 |
89 | class TestSequentialRecommender(unittest.TestCase):
90 | def test_gru4rec(self):
91 | config_dict = {
92 | 'model': 'GRU4Rec',
93 | }
94 | quick_test(config_dict)
95 |
96 | def test_narm(self):
97 | config_dict = {
98 | 'model': 'NARM',
99 | }
100 | quick_test(config_dict)
101 |
102 | def test_sasrec(self):
103 | config_dict = {
104 | 'model': 'SASRec',
105 | }
106 | quick_test(config_dict)
107 |
108 | def test_srgnn(self):
109 | config_dict = {
110 | 'model': 'SRGNN',
111 | }
112 | quick_test(config_dict)
113 |
114 | def test_srgnn_uni100(self):
115 | config_dict = {
116 | 'model': 'SRGNN',
117 | 'eval_args': {
118 | 'split': {'LS': "valid_and_test"},
119 | 'mode': 'uni100',
120 | 'order': 'TO'
121 | }
122 | }
123 | quick_test(config_dict)
124 |
125 | def test_gcsan(self):
126 | config_dict = {
127 | 'model': 'GCSAN',
128 | }
129 | quick_test(config_dict)
130 |
131 | def test_niser(self):
132 | config_dict = {
133 | 'model': 'NISER',
134 | }
135 | quick_test(config_dict)
136 |
137 | def test_lessr(self):
138 | config_dict = {
139 | 'model': 'LESSR'
140 | }
141 | quick_test(config_dict)
142 |
143 | def test_tagnn(self):
144 | config_dict = {
145 | 'model': 'TAGNN'
146 | }
147 | quick_test(config_dict)
148 |
149 | def test_gcegnn(self):
150 | config_dict = {
151 | 'model': 'GCEGNN'
152 | }
153 | quick_test(config_dict)
154 |
155 | def test_sgnnhn(self):
156 | config_dict = {
157 | 'model': 'SGNNHN'
158 | }
159 | quick_test(config_dict)
160 |
161 |
162 | class TestSocialRecommender(unittest.TestCase):
163 | def test_diffnet(self):
164 | config_dict = {
165 | 'model': 'DiffNet',
166 | }
167 | quick_test(config_dict)
168 |
169 | def test_mhcn(self):
170 | config_dict = {
171 | 'model': 'MHCN',
172 | }
173 | quick_test(config_dict)
174 |
175 | def test_sept(self):
176 | config_dict = {
177 | 'model': 'SEPT',
178 | }
179 | quick_test(config_dict)
180 |
181 |
182 | if __name__ == '__main__':
183 | unittest.main()
184 |
--------------------------------------------------------------------------------
/tests/test_model.yaml:
--------------------------------------------------------------------------------
1 | dataset: test
2 | epochs: 1
3 | state: ERROR
4 | data_path: tests/test_data/
5 |
6 | # Atomic File Format
7 | field_separator: "\t"
8 | seq_separator: " "
9 |
10 | # Common Features
11 | USER_ID_FIELD: user_id
12 | ITEM_ID_FIELD: item_id
13 | RATING_FIELD: rating
14 | TIME_FIELD: timestamp
15 | seq_len: ~
16 | # Label for Point-wise DataLoader
17 | LABEL_FIELD: label
18 | # NegSample Prefix for Pair-wise DataLoader
19 | NEG_PREFIX: neg_
20 | # Sequential Model Needed
21 | ITEM_LIST_LENGTH_FIELD: item_length
22 | LIST_SUFFIX: _list
23 | MAX_ITEM_LIST_LENGTH: 50
24 | POSITION_FIELD: position_id
25 | # social network config
26 | NET_SOURCE_ID_FIELD: source_id
27 | NET_TARGET_ID_FIELD: target_id
28 | filter_net_by_inter: True
29 | undirected_net: True
30 |
31 | # Selectively Loading
32 | load_col:
33 | inter: [user_id, item_id, rating, timestamp]
34 | net: [source_id, target_id]
35 |
36 | unload_col: ~
37 |
38 | # Preprocessing
39 | alias_of_user_id: ~
40 | alias_of_item_id: ~
41 | alias_of_entity_id: ~
42 | alias_of_relation_id: ~
43 | preload_weight: ~
44 | normalize_field: ~
45 | normalize_all: True
46 |
--------------------------------------------------------------------------------