├── .coveragerc
├── .github
└── workflows
│ ├── ci.yml
│ ├── ci_serving.yml
│ ├── ci_win_mac.yml
│ ├── compatibility-tensorflow1.yml
│ ├── compatibility-tensorflow2.yml
│ ├── rust.yml
│ ├── test-install.yml
│ └── wheels.yml
├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── _config.yml
├── codecov.yml
├── distributed
├── spark
│ ├── pom.xml
│ ├── spark-recommender.iml
│ └── src
│ │ └── main
│ │ └── scala
│ │ ├── TestScala.scala
│ │ └── com
│ │ └── libreco
│ │ ├── data
│ │ └── DataSplitter.scala
│ │ ├── evaluate
│ │ ├── EvalClassifier.scala
│ │ ├── EvalRecommender.scala
│ │ └── EvalRegressor.scala
│ │ ├── example
│ │ ├── AlsExample.scala
│ │ ├── ClassifierExample.scala
│ │ └── RegressorExample.scala
│ │ ├── feature
│ │ ├── FeatureEngineering.scala
│ │ └── MultiHotEncoder.scala
│ │ ├── model
│ │ ├── Classifier.scala
│ │ ├── Recommender.scala
│ │ └── Regressor.scala
│ │ └── utils
│ │ ├── Context.scala
│ │ ├── FilterNAs.scala
│ │ └── ItemNameConverter.scala
└── youtube_distributed.py
├── docker
├── Dockerfile
└── README.md
├── docs
├── Makefile
├── make.bat
├── md_doc
│ ├── autoint_feature.jpg
│ ├── implementation_details.md
│ ├── python_serving_guide.md
│ ├── rust_serving_guide.md
│ └── user_guide.md
├── requirements.txt
└── source
│ ├── _static
│ ├── autoint_feature.jpg
│ └── css
│ │ └── custom.css
│ ├── api
│ ├── algorithms
│ │ ├── als.rst
│ │ ├── autoint.rst
│ │ ├── bases.rst
│ │ ├── bpr.rst
│ │ ├── caser.rst
│ │ ├── deepfm.rst
│ │ ├── deepwalk.rst
│ │ ├── din.rst
│ │ ├── fm.rst
│ │ ├── graphsage.rst
│ │ ├── graphsage_dgl.rst
│ │ ├── index.rst
│ │ ├── item2vec.rst
│ │ ├── item_cf.rst
│ │ ├── item_cf_rs.rst
│ │ ├── lightgcn.rst
│ │ ├── ncf.rst
│ │ ├── ngcf.rst
│ │ ├── pinsage.rst
│ │ ├── pinsage_dgl.rst
│ │ ├── rnn4rec.rst
│ │ ├── sim.rst
│ │ ├── svd.rst
│ │ ├── svdpp.rst
│ │ ├── swing.rst
│ │ ├── transformer.rst
│ │ ├── two_tower.rst
│ │ ├── user_cf.rst
│ │ ├── user_cf_rs.rst
│ │ ├── wavenet.rst
│ │ ├── wide_deep.rst
│ │ ├── youtube_ranking.rst
│ │ └── youtube_retrieval.rst
│ ├── data
│ │ ├── data_info.rst
│ │ ├── dataset.rst
│ │ ├── index.rst
│ │ ├── split.rst
│ │ └── transformed.rst
│ ├── evaluation.rst
│ └── serialization.rst
│ ├── conf.py
│ ├── index.rst
│ ├── installation.rst
│ ├── internal
│ ├── data_info.rst
│ ├── implementation_details.rst
│ └── index.rst
│ ├── serving_guide
│ ├── online.rst
│ ├── python.rst
│ └── rust.rst
│ ├── tutorial.rst
│ └── user_guide
│ ├── data_processing.rst
│ ├── embedding.rst
│ ├── evaluation_save_load.rst
│ ├── feature_engineering.rst
│ ├── index.rst
│ ├── model_retrain.rst
│ ├── model_train.rst
│ └── recommendation.rst
├── examples
├── changing_feature_example.py
├── feat_example.py
├── feat_ranking_example.py
├── feat_rating_example.py
├── knn_embedding_example.py
├── model_retrain_example.py
├── multi_sparse_example.py
├── multi_sparse_processing_example.py
├── pure_example.py
├── pure_ranking_example.py
├── pure_rating_example.py
├── sample_data
│ ├── sample_movielens_genre.csv
│ ├── sample_movielens_merged.csv
│ └── sample_movielens_rating.dat
├── save_load_example.py
├── seq_example.py
├── split_data_example.py
└── tutorial.ipynb
├── libreco
├── __init__.py
├── algorithms
│ ├── __init__.py
│ ├── _als.pyx
│ ├── _bpr.pyx
│ ├── als.py
│ ├── autoint.py
│ ├── bpr.py
│ ├── caser.py
│ ├── deepfm.py
│ ├── deepwalk.py
│ ├── din.py
│ ├── fm.py
│ ├── graphsage.py
│ ├── graphsage_dgl.py
│ ├── item2vec.py
│ ├── item_cf.py
│ ├── item_cf_rs.py
│ ├── lightgcn.py
│ ├── ncf.py
│ ├── ngcf.py
│ ├── pinsage.py
│ ├── pinsage_dgl.py
│ ├── rnn4rec.py
│ ├── sim.py
│ ├── svd.py
│ ├── svdpp.py
│ ├── swing.py
│ ├── torch_modules
│ │ ├── __init__.py
│ │ ├── graphsage_module.py
│ │ ├── lightgcn_module.py
│ │ ├── ngcf_module.py
│ │ └── pinsage_module.py
│ ├── transformer.py
│ ├── two_tower.py
│ ├── user_cf.py
│ ├── user_cf_rs.py
│ ├── wave_net.py
│ ├── wide_deep.py
│ ├── youtube_ranking.py
│ └── youtube_retrieval.py
├── bases
│ ├── __init__.py
│ ├── base.py
│ ├── cf_base.py
│ ├── cf_base_rs.py
│ ├── dyn_embed_base.py
│ ├── embed_base.py
│ ├── gensim_base.py
│ ├── meta.py
│ ├── sage_base.py
│ └── tf_base.py
├── batch
│ ├── __init__.py
│ ├── batch_data.py
│ ├── batch_unit.py
│ ├── collators.py
│ ├── enums.py
│ ├── sequence.py
│ └── tf_feed_dicts.py
├── data
│ ├── __init__.py
│ ├── consumed.py
│ ├── data_info.py
│ ├── dataset.py
│ ├── processing.py
│ ├── split.py
│ └── transformed.py
├── evaluation
│ ├── __init__.py
│ ├── computation.py
│ ├── evaluate.py
│ └── metrics.py
├── feature
│ ├── __init__.py
│ ├── column_mapping.py
│ ├── multi_sparse.py
│ ├── sparse.py
│ ├── ssl.py
│ ├── unique.py
│ └── update.py
├── graph
│ ├── __init__.py
│ ├── from_dgl.py
│ ├── inference.py
│ ├── message.py
│ └── neighbor_walk.py
├── layers
│ ├── __init__.py
│ ├── activation.py
│ ├── attention.py
│ ├── convolutional.py
│ ├── dense.py
│ ├── embedding.py
│ ├── normalization.py
│ ├── recurrent.py
│ └── transformer.py
├── prediction
│ ├── __init__.py
│ ├── predict.py
│ └── preprocess.py
├── recommendation
│ ├── __init__.py
│ ├── cold_start.py
│ ├── preprocess.py
│ ├── ranking.py
│ └── recommend.py
├── sampling
│ ├── __init__.py
│ ├── negatives.py
│ └── random_walks.py
├── tfops
│ ├── __init__.py
│ ├── configs.py
│ ├── features.py
│ ├── loss.py
│ ├── rebuild.py
│ ├── variables.py
│ └── version.py
├── torchops
│ ├── __init__.py
│ ├── configs.py
│ ├── loss.py
│ └── rebuild.py
├── training
│ ├── __init__.py
│ ├── dispatch.py
│ ├── tf_trainer.py
│ ├── torch_trainer.py
│ └── trainer.py
└── utils
│ ├── __init__.py
│ ├── _similarities.pyx
│ ├── constants.py
│ ├── exception.py
│ ├── initializers.py
│ ├── misc.py
│ ├── save_load.py
│ ├── similarities.py
│ ├── sparse.py
│ └── validate.py
├── libserving
├── .dockerignore
├── Dockerfile-py
├── Dockerfile-rs
├── __init__.py
├── actix_serving
│ ├── Cargo.lock
│ ├── Cargo.toml
│ ├── build.rs
│ ├── proto
│ │ ├── recommend.proto
│ │ ├── tensorflow
│ │ │ └── core
│ │ │ │ └── framework
│ │ │ │ ├── resource_handle.proto
│ │ │ │ ├── tensor.proto
│ │ │ │ ├── tensor_shape.proto
│ │ │ │ └── types.proto
│ │ └── tensorflow_serving
│ │ │ └── apis
│ │ │ ├── model.proto
│ │ │ ├── predict.proto
│ │ │ └── prediction_service.proto
│ ├── rustfmt.toml
│ ├── src
│ │ ├── bin
│ │ │ ├── benchmark.rs
│ │ │ ├── realtime.rs
│ │ │ ├── realtime_grpc_client.rs
│ │ │ └── realtime_grpc_server.rs
│ │ ├── embed_deploy.rs
│ │ ├── knn_deploy.rs
│ │ ├── lib.rs
│ │ ├── main.rs
│ │ ├── online_deploy.rs
│ │ ├── online_deploy_grpc.rs
│ │ ├── tf_deploy.rs
│ │ └── utils
│ │ │ ├── common.rs
│ │ │ ├── constants.rs
│ │ │ ├── errors.rs
│ │ │ ├── faiss.rs
│ │ │ ├── features.rs
│ │ │ ├── mod.rs
│ │ │ └── redis_ops.rs
│ └── tests
│ │ ├── common
│ │ └── mod.rs
│ │ ├── embed.rs
│ │ ├── knn.rs
│ │ └── tf.rs
├── crate-index-config
├── docker-compose-py.yml
├── docker-compose-rs.yml
├── docker-compose-tf-serving.yml
├── request.py
├── sanic_serving
│ ├── __init__.py
│ ├── benchmark.py
│ ├── common.py
│ ├── embed_deploy.py
│ ├── knn_deploy.py
│ ├── online_deploy.py
│ └── tf_deploy.py
└── serialization
│ ├── __init__.py
│ ├── common.py
│ ├── embed.py
│ ├── knn.py
│ ├── online.py
│ ├── redis.py
│ └── tfmodel.py
├── pyproject.toml
├── python-package-conda.yml
├── requirements-dev.txt
├── requirements-serving.txt
├── requirements.txt
├── rust
├── .gitignore
├── Cargo.toml
├── README.md
├── pyproject.toml
├── recfarm
│ └── __init__.py
├── rustfmt.toml
└── src
│ ├── graph.rs
│ ├── incremental.rs
│ ├── inference.rs
│ ├── item_cf.rs
│ ├── lib.rs
│ ├── ordering.rs
│ ├── serialization.rs
│ ├── similarities.rs
│ ├── sparse.rs
│ ├── swing.rs
│ ├── user_cf.rs
│ └── utils.rs
├── setup.cfg
├── setup.py
└── tests
├── __init__.py
├── compatibility_test.py
├── conftest.py
├── models
├── __init__.py
├── test_als.py
├── test_autoint.py
├── test_base.py
├── test_bpr.py
├── test_caser.py
├── test_deepfm.py
├── test_deepwalk.py
├── test_din.py
├── test_fm.py
├── test_graphsage.py
├── test_graphsage_dgl.py
├── test_item2vec.py
├── test_item_cf.py
├── test_item_cf_rs.py
├── test_lightgcn.py
├── test_ncf.py
├── test_ngcf.py
├── test_pinsage.py
├── test_pinsage_dgl.py
├── test_rnn4rec.py
├── test_sim.py
├── test_svd.py
├── test_svdpp.py
├── test_swing.py
├── test_transformer.py
├── test_two_tower.py
├── test_user_cf.py
├── test_user_cf_rs.py
├── test_wave_net.py
├── test_wide_deep.py
├── test_youtube_ranking.py
├── test_youtube_retrieval.py
└── utils_tf.py
├── retrain
├── __init__.py
├── test_als_retrain.py
├── test_gensim_model_retrain.py
├── test_rs_cf_retrain.py
├── test_rs_swing_retrain.py
├── test_tfmodel_retrain_feat.py
├── test_tfmodel_retrain_pure.py
├── test_thmodel_retrain_feat.py
├── test_thmodel_retrain_feat_dgl.py
├── test_thmodel_retrain_pure.py
└── test_two_tower_retrain.py
├── sample_data
├── sample_movielens_genre.csv
├── sample_movielens_merged.csv
└── sample_movielens_rating.dat
├── serving
├── __init__.py
├── conftest.py
├── mock_tf_server.py
├── setup_coverage.sh
├── subprocess_coverage_setup.py
├── test_embed_serving.py
├── test_faiss_index.py
├── test_knn_serving.py
├── test_online_serving.py
├── test_serialization.py
└── test_tf_serving.py
├── test_collators.py
├── test_consumed.py
├── test_data.py
├── test_dgl.py
├── test_feature.py
├── test_initializers.py
├── test_knn_embed.py
├── test_misc.py
├── test_multi_sparse_processing.py
├── test_multiprocessing_seeds.py
├── test_rank_reco.py
├── test_similarities.py
├── test_split_data.py
├── test_tf_layers.py
├── utils_data.py
├── utils_metrics.py
├── utils_multi_sparse_models.py
├── utils_pred.py
├── utils_reco.py
└── utils_save_load.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | parallel = True
3 |
4 | concurrency = multiprocessing
5 |
6 | source =
7 | libreco/
8 | libserving/serialization/
9 | libserving/sanic_serving/
10 |
11 | omit =
12 | libreco/utils/exception.py
13 | libreco/utils/sampling.py
14 | libserving/sanic_serving/benchmark.py
15 |
16 | [report]
17 | exclude_lines =
18 | pragma: no cover
19 | raise AssertionError
20 | raise NameError
21 | raise NotImplementedError
22 | raise OSError.*
23 | raise ValueError
24 | raise SanicException.*
25 | except .*redis.*
26 | except \(ImportError, ModuleNotFoundError\):
27 | except ValidationError.
28 | if __name__ == .__main__.:
29 | @(abc\.)?abstractmethod
30 |
31 | precision = 2
32 |
33 | show_missing = True
34 |
35 | skip_empty = True
36 |
37 | [html]
38 | directory = html-coverage-report
39 |
40 | title = LibRecommender Coverage Report
41 |
42 | [xml]
43 | output = coverage.xml
44 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | branches:
9 | - master
10 | # Manual run
11 | workflow_dispatch:
12 |
13 | jobs:
14 | build:
15 | name: testing
16 | runs-on: ${{ matrix.os }}
17 | strategy:
18 | fail-fast: false
19 | matrix:
20 | os: [ubuntu-22.04]
21 | python-version: [3.7, 3.8, 3.9, '3.10', '3.11']
22 |
23 | steps:
24 | - uses: actions/checkout@v4
25 | - name: Set up Python ${{ matrix.python-version }}
26 | uses: actions/setup-python@v5
27 | with:
28 | python-version: ${{ matrix.python-version }}
29 | cache: 'pip'
30 |
31 | - name: Display Python version
32 | run: python -c "import sys; print(sys.version)"
33 |
34 | - name: Install dependencies
35 | run: |
36 | python -m pip install -U pip wheel setuptools
37 | python -m pip install -r requirements-dev.txt
38 | python -m pip install -e .
39 |
40 | - name: Lint with flake8
41 | run: |
42 | flake8 libreco/ libserving/ tests/ examples/
43 |
44 | - name: Lint with ruff
45 | run: |
46 | ruff check libreco/ libserving/ tests/ examples/
47 | if: matrix.python-version != '3.6'
48 |
49 | - name: Test
50 | run: |
51 | python -m pip install pytest
52 | python -m pytest tests/ --ignore="tests/serving"
53 | if: matrix.python-version != '3.10'
54 |
55 | - name: Test with coverage
56 | run: |
57 | bash tests/serving/setup_coverage.sh
58 | coverage --version && coverage erase
59 | coverage run -m pytest tests/ --ignore="tests/serving"
60 | coverage combine && coverage report
61 | coverage xml
62 | if: matrix.python-version == '3.10'
63 |
64 | - name: Upload coverage to Codecov
65 | uses: codecov/codecov-action@v4
66 | with:
67 | file: ./coverage.xml
68 | flags: CI
69 | name: python${{ matrix.python-version }}-test
70 | token: ${{ secrets.CODECOV_TOKEN }}
71 | fail_ci_if_error: false
72 | verbose: true
73 | if: matrix.python-version == '3.10'
74 |
75 | - name: Upload coverage to Codacy
76 | uses: codacy/codacy-coverage-reporter-action@v1
77 | with:
78 | project-token: ${{ secrets.CODACY_PROJECT_TOKEN }}
79 | coverage-reports: ./coverage.xml
80 | if: matrix.python-version == '3.10'
81 |
--------------------------------------------------------------------------------
/.github/workflows/ci_serving.yml:
--------------------------------------------------------------------------------
1 | name: CI-Serving
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | branches:
9 | - master
10 | # Manual run
11 | workflow_dispatch:
12 |
13 | jobs:
14 | testing:
15 | runs-on: ${{ matrix.os }}
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | os: [ubuntu-22.04]
20 | python-version: [3.9, '3.10', '3.11']
21 |
22 | steps:
23 | - uses: actions/checkout@v4
24 | - name: Set up Python ${{ matrix.python-version }}
25 | uses: actions/setup-python@v5
26 | with:
27 | python-version: ${{ matrix.python-version }}
28 | cache: 'pip'
29 |
30 | - name: Display Python version
31 | run: python -c "import sys; print(sys.version)"
32 |
33 | - name: Install dependencies
34 | run: |
35 | python -m pip install -U pip wheel setuptools
36 | python -m pip install numpy>=1.19.5
37 | python -m pip install "scipy>=1.2.1,<1.13.0"
38 | python -m pip install pandas>=1.0.0
39 | python -m pip install scikit-learn>=0.20.0
40 | python -m pip install "tensorflow>=1.15.0,<2.16.0"
41 | python -m pip install torch>=1.10.0
42 | python -m pip install gensim>=4.0.0
43 | python -m pip install tqdm
44 | python -m pip install recfarm
45 | python -m pip install -r requirements-serving.txt
46 | python -m pip install -e .
47 |
48 | - name: Set up Redis
49 | uses: shogo82148/actions-setup-redis@v1
50 | with:
51 | redis-version: '7.x'
52 |
53 | - name: Test Redis
54 | run: redis-cli ping
55 |
56 | - name: Test
57 | run: |
58 | python -m pip install pytest
59 | python -m pytest tests/serving
60 | if: matrix.python-version != '3.10'
61 |
62 | - name: Test with coverage
63 | run: |
64 | python -m pip install pytest coverage
65 | bash tests/serving/setup_coverage.sh
66 | coverage --version && coverage erase
67 | coverage run -m pytest tests/serving
68 | coverage combine && coverage report
69 | coverage xml
70 | if: matrix.python-version == '3.10'
71 |
72 | - name: Upload coverage to Codecov
73 | uses: codecov/codecov-action@v4
74 | with:
75 | file: ./coverage.xml
76 | flags: CI
77 | name: python${{ matrix.python-version }}-serving
78 | token: ${{ secrets.CODECOV_TOKEN }}
79 | fail_ci_if_error: false
80 | verbose: true
81 | if: matrix.python-version == '3.10'
82 |
--------------------------------------------------------------------------------
/.github/workflows/ci_win_mac.yml:
--------------------------------------------------------------------------------
1 | name: CI-Windows-macOS
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - master
7 | # Manual run
8 | workflow_dispatch:
9 |
10 | jobs:
11 | testing:
12 | runs-on: ${{ matrix.os }}
13 | strategy:
14 | fail-fast: false
15 | matrix:
16 | os: [macos-latest, windows-latest]
17 | python-version: [3.8, '3.10']
18 |
19 | steps:
20 | - uses: actions/checkout@v4
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v5
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | cache: 'pip'
26 |
27 | - name: Display Python version
28 | run: python -c "import sys; print(sys.version)"
29 |
30 | - name: Install dependencies
31 | run: |
32 | python -m pip install -U pip wheel setuptools
33 | python -m pip install numpy>=1.19.5
34 | python -m pip install "scipy>=1.2.1,<1.13.0"
35 | python -m pip install pandas>=1.0.0
36 | python -m pip install scikit-learn>=0.20.0
37 | python -m pip install "tensorflow>=1.15.0,<2.16.0"
38 | python -m pip install torch>=1.10.0
39 | python -m pip install "smart_open<7.0.0"
40 | python -m pip install gensim>=4.0.0
41 | python -m pip install tqdm
42 | python -m pip install -e .
43 |
44 | - name: Install DGL on Windows
45 | run: python -m pip install 'dgl<=1.1.0' -f https://data.dgl.ai/wheels/repo.html
46 | if: matrix.os == 'windows-latest'
47 |
48 | - name: Install DGL on macOS
49 | run: |
50 | python -m pip install 'dgl<2.0.0' -f https://data.dgl.ai/wheels/repo.html
51 | if: matrix.os == 'macos-latest'
52 |
53 | - name: Install dataclasses
54 | run: |
55 | python -m pip install dataclasses
56 | if: matrix.python-version == '3.6'
57 |
58 | - name: Install recfarm
59 | run: |
60 | python -m pip install recfarm
61 | if: matrix.python-version != '3.6'
62 |
63 | - name: Test with pytest
64 | run: |
65 | python -m pip install pytest
66 | python -m pytest tests/ --ignore="tests/serving"
67 |
--------------------------------------------------------------------------------
/.github/workflows/rust.yml:
--------------------------------------------------------------------------------
1 | name: Rust
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | permissions:
7 | contents: read
8 |
9 | jobs:
10 | linux:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v4
14 | - uses: actions/setup-python@v5
15 | with:
16 | python-version: '3.10'
17 | - name: Build wheels
18 | uses: PyO3/maturin-action@v1
19 | with:
20 | target: x86_64
21 | working-directory: ./rust
22 | args: --release --out dist
23 | sccache: 'true'
24 | manylinux: auto
25 | rust-toolchain: stable
26 | - name: Upload wheels
27 | uses: actions/upload-artifact@v4
28 | with:
29 | name: wheels-linux
30 | path: ./rust/dist
31 |
32 | windows:
33 | runs-on: windows-latest
34 | steps:
35 | - uses: actions/checkout@v4
36 | - uses: actions/setup-python@v5
37 | with:
38 | python-version: '3.10'
39 | - name: Build wheels
40 | uses: PyO3/maturin-action@v1
41 | with:
42 | target: x86_64
43 | args: --release --out dist --manifest-path rust/Cargo.toml
44 | sccache: 'true'
45 | rust-toolchain: stable
46 | - name: Upload wheels
47 | uses: actions/upload-artifact@v4
48 | with:
49 | name: wheels-windows
50 | path: dist
51 |
52 | macos:
53 | runs-on: macos-latest
54 | steps:
55 | - uses: actions/checkout@v4
56 | - uses: actions/setup-python@v5
57 | with:
58 | python-version: '3.10'
59 | - name: Build wheels
60 | uses: PyO3/maturin-action@v1
61 | with:
62 | target: x86_64
63 | args: --release --out dist --manifest-path rust/Cargo.toml
64 | sccache: 'true'
65 | rust-toolchain: stable
66 | - name: Upload wheels
67 | uses: actions/upload-artifact@v4
68 | with:
69 | name: wheels-macos
70 | path: dist
71 |
72 | sdist:
73 | runs-on: ubuntu-latest
74 | steps:
75 | - uses: actions/checkout@v4
76 | - name: Build sdist
77 | uses: PyO3/maturin-action@v1
78 | with:
79 | command: sdist
80 | args: --out dist --manifest-path rust/Cargo.toml
81 | - name: Upload sdist
82 | uses: actions/upload-artifact@v4
83 | with:
84 | name: sdist
85 | path: dist
86 |
--------------------------------------------------------------------------------
/.github/workflows/test-install.yml:
--------------------------------------------------------------------------------
1 | name: test install library
2 |
3 | on:
4 | # Manual run
5 | workflow_dispatch:
6 |
7 | jobs:
8 | build:
9 | name: test install
10 | runs-on: ${{ matrix.os }}
11 | strategy:
12 | fail-fast: false
13 | matrix:
14 | os: [ubuntu-20.04, windows-latest, macos-12]
15 | python-version: [3.6, 3.7, 3.8, 3.9, '3.10', '3.11']
16 |
17 | steps:
18 | - uses: actions/checkout@v4
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v5
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 |
24 | - name: Display Python version
25 | run: python -c "import sys; print(sys.version)"
26 |
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install $(sed "/nmslib*/d;/dgl*/d;/dataclasses*/d" requirements.txt | tr -d ' ')
31 | python -m pip install Cython==0.29.37
32 | python -m pip install -e .
33 |
34 | - name: Install dataclasses
35 | run: |
36 | python -m pip install dataclasses
37 | if: matrix.python-version == '3.6'
38 |
39 | - name: Test install
40 | run: python -c "import libreco; print('libreco --', libreco.__version__)"
41 |
42 | - name: Test install dgl
43 | run: |
44 | python -m pip install 'dgl<2.0.0' -f https://data.dgl.ai/wheels/repo.html
45 | python -c "import dgl; print('dgl --', dgl.__version__)"
46 |
47 | - name: Test running on Linux
48 | run: |
49 | cp tests/sample_data/sample_movielens_rating.dat /home/runner/work/
50 | cp tests/compatibility_test.py /home/runner/work/ && cd /home/runner/work/
51 | python -m compatibility_test
52 | if: matrix.os == 'ubuntu-20.04'
53 |
54 | - name: Test running on Windows
55 | run: |
56 | copy D:\a\LibRecommender\LibRecommender\tests\sample_data\sample_movielens_rating.dat D:\a\
57 | copy D:\a\LibRecommender\LibRecommender\tests\compatibility_test.py D:\a\ && Set-Location -Path "D:\a\"
58 | python -m compatibility_test
59 | if: matrix.os == 'windows-latest'
60 |
61 | - name: Test running on macOS
62 | run: |
63 | cp tests/sample_data/sample_movielens_rating.dat /Users/runner/work/
64 | cp tests/compatibility_test.py /Users/runner/work/ && cd /Users/runner/work/
65 | python -m compatibility_test
66 | if: matrix.os == 'macos-latest'
67 |
68 | - name: Test install nmslib
69 | run: |
70 | python -m pip install nmslib
71 | python -c "import nmslib; print('nmslib --', nmslib.__version__)"
72 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # .idea folder
10 | .idea/
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | version: 2
6 |
7 | build:
8 | os: ubuntu-22.04
9 | tools:
10 | python: "3.10"
11 |
12 | sphinx:
13 | configuration: docs/source/conf.py
14 |
15 | python:
16 | install:
17 | - requirements: docs/requirements.txt
18 | - method: pip
19 | path: .
20 | # system_packages: true
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright 2023 massquantity
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the “Software”), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE..
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include LICENSE
3 | include requirements.txt requirements-serving.txt
4 | recursive-include libreco *.c *.cpp *.pyx
5 | recursive-include libserving *.py *.sh
6 | recursive-exclude examples *
7 | recursive-exclude distributed *
8 | recursive-exclude tests *
9 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-hacker
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | codecov:
2 | notify:
3 | after_n_builds: 1
4 |
5 | coverage:
6 | status:
7 | project:
8 | default:
9 | target: auto
10 | threshold: 1%
11 | patch:
12 | default:
13 | target: '50'
14 | informational: true
15 |
16 | ignore:
17 | - "docs/*"
18 | - "docker/*"
19 |
--------------------------------------------------------------------------------
/distributed/spark/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | com.massquantity
8 | spark-recommender
9 | 1.0-SNAPSHOT
10 |
11 |
12 |
13 | org.apache.maven.plugins
14 | maven-compiler-plugin
15 |
16 | 1.8
17 | 1.8
18 |
19 |
20 |
21 |
22 |
23 |
24 | UTF-8
25 | 1.8
26 | 2.3.3
27 | 2.11
28 | 1.8
29 | 1.8
30 |
31 |
32 |
33 |
34 | junit
35 | junit
36 | 4.13.1
37 | test
38 |
39 |
40 |
41 | org.apache.spark
42 | spark-core_${scala.version}
43 | ${spark.version}
44 |
45 |
46 |
47 | org.apache.spark
48 | spark-sql_${scala.version}
49 | ${spark.version}
50 |
51 |
52 |
53 | org.apache.spark
54 | spark-mllib_${scala.version}
55 | ${spark.version}
56 |
57 |
58 |
59 |
60 |
61 | alimaven
62 | aliyun maven
63 | http://maven.aliyun.com/nexus/content/groups/public/
64 |
65 | true
66 |
67 |
68 | false
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/distributed/spark/src/main/scala/TestScala.scala:
--------------------------------------------------------------------------------
1 | import org.apache.log4j.{Level, Logger}
2 | import org.apache.spark.ml.{Pipeline, Transformer}
3 | import org.apache.spark.ml.param.ParamMap
4 | import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCols}
5 | import org.apache.spark.ml.util.DefaultParamsWritable
6 | import org.apache.spark.sql.functions.{array_contains, col, split}
7 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
8 | import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
9 | import org.apache.spark.{SparkConf, SparkContext}
10 | import org.apache.spark.ml.util.Identifiable
11 |
12 |
13 | object TestScala {
14 | def main(args: Array[String]): Unit = {
15 | printf("test scala...")
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/distributed/spark/src/main/scala/com/libreco/example/AlsExample.scala:
--------------------------------------------------------------------------------
1 | package com.libreco.example
2 |
3 | import org.apache.spark.sql.functions.{round, sum, lit}
4 | import com.libreco.utils.{Context, ItemNameConverter}
5 | import com.libreco.data.DataSplitter
6 | import com.libreco.model.Recommender
7 |
8 | import scala.util.Random
9 |
10 |
11 | object AlsExample extends Context {
12 | import spark.implicits._
13 |
14 | def main(args: Array[String]): Unit = {
15 |
16 | val dataPath = this.getClass.getResource("/ml-1m/ratings.dat").toString
17 | val splitter = new DataSplitter()
18 | var data = spark.read.textFile(dataPath)
19 | .map(splitter.parseRating("::"))
20 | .toDF("user", "item", "rating", "timestamp")
21 | .withColumn("label", lit(1))
22 | data.show(4, truncate = false)
23 | // data.columns.foreach(x => println(s"$x -> ${data.filter(data(x).isNull).count}"))
24 | // data.groupBy("rating").count().orderBy($"count".desc)
25 | // .withColumn("percent", round($"count" / sum("count").over(), 4)).show()
26 |
27 | data = data.sample(withReplacement = false, 0.1) // sample 0.1
28 |
29 | val model = new Recommender()
30 | time(model.train(data), "Training")
31 | val transformedData = model.transform(data)
32 | transformedData.show(4, truncate = false)
33 |
34 | val movieMap = ItemNameConverter.getId2ItemName()
35 | val rec = model.recommendForUsers(data, 10, movieMap)
36 | rec.show(20, truncate = false)
37 |
38 | // val model = new Recommender()
39 | // time(model.train(data, evaluate = true, num = 10), "Evaluating")
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/distributed/spark/src/main/scala/com/libreco/example/RegressorExample.scala:
--------------------------------------------------------------------------------
1 | package com.libreco.example
2 |
3 | import org.apache.spark.ml.linalg.SparseVector
4 | import com.libreco.utils.{Context, ItemNameConverter}
5 | import com.libreco.model.Regressor
6 |
7 |
8 | object RegressorExample extends Context {
9 | import spark.implicits._
10 |
11 | def main(args: Array[String]): Unit = {
12 | val movieNameConverter = ItemNameConverter.getItemName()
13 | val userPath = this.getClass.getResource("/ml-1m/users.dat").toString
14 | val moviePath = this.getClass.getResource("/ml-1m/movies.dat").toString
15 | val ratingPath = this.getClass.getResource("/ml-1m/ratings.dat").toString
16 |
17 | val users = spark.read.textFile(userPath)
18 | .selectExpr("split(value, '::') as col")
19 | .selectExpr(
20 | "cast(col[0] as int) as user",
21 | "cast(col[1] as string) as sex",
22 | "cast(col[2] as int) as age",
23 | "cast(col[3] as int) as occupation")
24 | val items = spark.read.textFile(moviePath)
25 | .selectExpr("split(value, '::') as col")
26 | .selectExpr(
27 | "cast(col[0] as int) as item",
28 | "cast(col[1] as string) as movie",
29 | "cast(col[2] as string) as genre")
30 | .withColumn("movieName", movieNameConverter($"movie"))
31 | .drop($"movie")
32 | .withColumnRenamed("movieName", "movie")
33 | .select("item", "movie", "genre")
34 | var ratings = spark.read.textFile(ratingPath)
35 | .selectExpr("split(value, '::') as col")
36 | .selectExpr(
37 | "cast(col[0] as int) as user",
38 | "cast(col[1] as int) as item",
39 | "cast(col[2] as int) as rating",
40 | "cast(col[3] as long) as timestamp")
41 |
42 | ratings = ratings.sample(withReplacement = false, 0.1)
43 | val temp = ratings.join(users, Seq("user"), "left")
44 | val data = temp.join(items, Seq("item"), "left")
45 | data.show(4, truncate = false)
46 |
47 | // val model = new Regressor(Some("glr"))
48 | // time(model.train(data, evaluate = false), "Training")
49 | // val transformedData = model.transform(data)
50 | // transformedData.show(4, truncate = false)
51 |
52 | val model = new Regressor(Some("gbdt"))
53 | time(model.train(data, evaluate = true, debug = true), "Evaluating")
54 | }
55 | }
56 |
57 |
58 |
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/distributed/spark/src/main/scala/com/libreco/model/Recommender.scala:
--------------------------------------------------------------------------------
1 | package com.libreco.model
2 |
3 | import org.apache.spark.sql.DataFrame
4 | import org.apache.spark.ml.recommendation.{ALS, ALSModel}
5 | import org.apache.spark.sql.functions.{coalesce, typedLit}
6 | import com.libreco.utils.Context
7 | import com.libreco.evaluate.EvalRecommender
8 |
9 | import scala.collection.Map
10 |
11 | class Recommender extends Serializable with Context{
12 | import spark.implicits._
13 | var model: ALSModel = _
14 |
15 | def train(df: DataFrame, evaluate: Boolean = false, num: Int = 10): Unit = {
16 | df.cache()
17 | if (evaluate) {
18 | val evalModel = new EvalRecommender(num, "ndcg")
19 | evalModel.eval(df)
20 | }
21 | else {
22 | val als = new ALS()
23 | .setMaxIter(20)
24 | .setRegParam(0.01)
25 | .setUserCol("user")
26 | .setItemCol("item")
27 | .setRank(50)
28 | .setImplicitPrefs(true)
29 | .setRatingCol("label")
30 | model = als.fit(df)
31 | model.setColdStartStrategy("drop")
32 | }
33 | df.unpersist()
34 | }
35 |
36 | def transform(df: DataFrame): DataFrame = {
37 | model.transform(df)
38 | }
39 |
40 | def recommendForUsers(df: DataFrame,
41 | num: Int,
42 | ItemNameMap: Map[Int, String] = Map.empty): DataFrame = {
43 |
44 | val rec = model.recommendForUserSubset(df, num)
45 | .selectExpr("user", "explode(recommendations) as predAndProb")
46 | .select("user", "predAndProb.*").toDF("user", "item", "prob") // pred means specific item
47 |
48 | val nameMapCol = typedLit(ItemNameMap)
49 | rec.withColumn("name", coalesce(nameMapCol($"item")))
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/distributed/spark/src/main/scala/com/libreco/model/Regressor.scala:
--------------------------------------------------------------------------------
1 | package com.libreco.model
2 |
3 | import org.apache.spark.sql.DataFrame
4 | import com.libreco.feature.FeatureEngineering
5 | import com.libreco.evaluate.EvalRegressor
6 | import org.apache.spark.ml.regression.{GBTRegressor, GeneralizedLinearRegression}
7 | import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage}
8 |
9 | import scala.util.Random
10 |
11 | class Regressor(algo: Option[String] = Some("gbdt")) extends Serializable {
12 | var pipelineModel: PipelineModel = _
13 |
14 | def train(df: DataFrame, evaluate: Boolean = false, debug: Boolean = false): Unit = {
15 | val prePipelineStages: Array[PipelineStage] = FeatureEngineering.preProcessPipeline(df)
16 | if (debug) {
17 | val pipeline = new Pipeline().setStages(prePipelineStages)
18 | pipelineModel = pipeline.fit(df)
19 | val transformed = pipelineModel.transform(df)
20 | transformed.show(4, truncate = false)
21 | }
22 | if (evaluate) {
23 | val evalModel = new EvalRegressor(algo, prePipelineStages)
24 | evalModel.eval(df)
25 | }
26 | else {
27 | algo match {
28 | case Some("gbdt") =>
29 | val model = new GBTRegressor()
30 | .setFeaturesCol("featureVector")
31 | .setLabelCol("rating")
32 | .setPredictionCol("pred")
33 | .setFeatureSubsetStrategy("auto")
34 | .setMaxDepth(3)
35 | .setMaxIter(20)
36 | .setStepSize(0.01)
37 | .setSubsamplingRate(0.8)
38 | .setSeed(Random.nextLong())
39 | val pipelineStages = prePipelineStages ++ Array(model)
40 | val pipeline = new Pipeline().setStages(pipelineStages)
41 | pipelineModel = pipeline.fit(df)
42 |
43 | case Some("glr") =>
44 | val model = new GeneralizedLinearRegression()
45 | .setFeaturesCol("featureVector")
46 | .setLabelCol("rating")
47 | .setPredictionCol("pred")
48 | .setFamily("gaussian")
49 | .setLink("identity")
50 | .setRegParam(0.0)
51 | val pipelineStages = prePipelineStages ++ Array(model)
52 | val pipeline = new Pipeline().setStages(pipelineStages)
53 | pipelineModel = pipeline.fit(df)
54 |
55 | case _ =>
56 | println("Model muse either be GBDTRegressor or GeneralizedLinearRegression")
57 | System.exit(1)
58 | }
59 | }
60 | }
61 |
62 | def transform(dataset: DataFrame): DataFrame = {
63 | pipelineModel.transform(dataset)
64 | }
65 | }
66 |
67 |
68 |
--------------------------------------------------------------------------------
/distributed/spark/src/main/scala/com/libreco/utils/Context.scala:
--------------------------------------------------------------------------------
1 | package com.libreco.utils
2 |
3 | import org.apache.log4j.{Level, Logger}
4 | import org.apache.spark.SparkConf
5 | import org.apache.spark.sql.SparkSession
6 |
7 | trait Context {
8 | Logger.getLogger("org").setLevel(Level.ERROR)
9 | Logger.getLogger("com").setLevel(Level.ERROR)
10 |
11 | lazy val sparkConf: SparkConf = new SparkConf()
12 | .setAppName("Spark Recommender")
13 | .setMaster("local[*]")
14 | // .set("spark.core.max", "4")
15 |
16 | lazy val spark: SparkSession = SparkSession
17 | .builder()
18 | .config(sparkConf)
19 | .getOrCreate()
20 |
21 | def time[T](block: => T, info: String): T = {
22 | val t0 = System.nanoTime()
23 | val result = block
24 | val t1 = System.nanoTime()
25 | println(f"$info time: ${(t1 - t0) / 1e9d}%.2fs")
26 | result
27 | }
28 | }
--------------------------------------------------------------------------------
/distributed/spark/src/main/scala/com/libreco/utils/FilterNAs.scala:
--------------------------------------------------------------------------------
1 | package com.libreco.utils
2 |
3 | import org.apache.spark.sql.{DataFrame, Dataset, Column}
4 | import org.apache.spark.sql.functions.col
5 |
6 |
7 | object FilterNAs {
8 | def filter(data: DataFrame): DataFrame = {
9 | println(s"find and filter NAs for each column...")
10 | data.columns.foreach(col => println(s"$col -> ${data.filter(data(col).isNull).count}"))
11 | val allCols: Array[Column] = data.columns.map(col) // col is a function to get Column
12 | val nullFilter: Column = allCols.map(_.isNotNull).reduce(_ && _)
13 | data.select(allCols: _*).filter(nullFilter)
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/distributed/spark/src/main/scala/com/libreco/utils/ItemNameConverter.scala:
--------------------------------------------------------------------------------
1 | package com.libreco.utils
2 |
3 | import org.apache.spark.sql.expressions.UserDefinedFunction
4 | import org.apache.spark.sql.{DataFrame, Dataset}
5 | import org.apache.spark.sql.functions.udf
6 |
7 | import scala.collection.Map
8 | import scala.util.matching.Regex
9 |
10 |
11 | object ItemNameConverter extends Context{
12 | import spark.implicits._
13 |
14 | def getId2ItemName(): Map[Int, String] = {
15 | val itemDataPath = this.getClass.getResource("/ml-1m/movies.dat").toString
16 | val itemData: Dataset[String] = spark.read.textFile(itemDataPath)
17 |
18 | itemData.flatMap { line: String =>
19 | val Array(id, movieName, _*): Array[String] = line.split("::")
20 | if (id.isEmpty) {
21 | None
22 | } else {
23 | val pattern = new Regex("(.+)(\\(\\d+\\))")
24 | val name = for (m <- pattern.findFirstMatchIn(movieName)) yield m.group(1)
25 | Some(id.toInt, name.mkString)
26 | }
27 | }.collect().toMap
28 | }
29 |
30 | def getItemName(): UserDefinedFunction = {
31 | val itemDataPath = this.getClass.getResource("/ml-1m/movies.dat").toString
32 | val itemData: Dataset[String] = spark.read.textFile(itemDataPath)
33 |
34 | val itemMap = itemData.map { line: String =>
35 | val Array(_, movieName, _*): Array[String] = line.split("::")
36 | val pattern = new Regex("(.+)(\\(\\d+\\))")
37 | val name = for (m <- pattern.findFirstMatchIn(movieName)) yield m.group(1)
38 | (movieName, name.mkString)
39 | }.collect().toMap
40 | udf((origName: String) => itemMap(origName))
41 | }
42 |
43 | def main(args: Array[String]): Unit = {
44 | val mm = getId2ItemName()
45 | for (i <- 1 to 5) {
46 | println(s"Id2ItemName: $i -> ${mm(i)}")
47 | }
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.9-slim
2 |
3 | WORKDIR /root
4 |
5 | RUN apt-get update && \
6 | apt-get install --no-install-recommends -y gcc g++ && \
7 | apt-get clean && \
8 | rm -rf /var/lib/apt/lists/*
9 |
10 | ADD ../requirements.txt /root
11 | RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
12 | RUN pip install --no-cache-dir LibRecommender
13 | RUN pip install --no-cache-dir jupyterlab==3.5.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
14 |
15 | RUN jupyter server --generate-config --allow-root
16 | # password generated based on https://jupyter-notebook.readthedocs.io/en/stable/config.html
17 | RUN echo "c.ServerApp.password = 'argon2:\$argon2id\$v=19\$m=10240,t=10,p=8\$1xV3ym3i6fh/Y9WrkfOfag\$pbATSK3YAwGw1GqdzGqhCw'" >> /root/.jupyter/jupyter_server_config.py
18 | RUN echo "c.ServerApp.ip = '0.0.0.0'" >> /root/.jupyter/jupyter_server_config.py
19 | RUN echo "c.ServerApp.port = 8888" >> /root/.jupyter/jupyter_server_config.py
20 |
21 | ADD ../examples /root/examples
22 |
23 | EXPOSE 8888
24 |
25 | CMD ["jupyter", "lab", "--allow-root", "--notebook-dir=/root", "--no-browser"]
26 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | # LibRecommender in Docker
2 |
3 | Users can run [JupyterLab](https://jupyterlab.readthedocs.io/en/stable/) in a docker container and use the library without installing the package.
4 |
5 | 1. Pull image from Docker Hub
6 | ```shell
7 | $ docker pull massquantity/librecommender:latest
8 | ```
9 | 2. Start a docker container by running following command:
10 | ```shell
11 | $ docker run --rm -p 8889:8888 massquantity/librecommender:latest
12 | ```
13 | This command exposes the container port 8888 to 8889 on your machine. Feel free to change 8889 to any port you want,
14 | but make sure it is available.
15 |
16 | Or if you want to use your own data on your machine, try following command:
17 | ```shell
18 | $ docker run --rm -p 8889:8888 -v $(pwd):/root/data:ro massquantity/librecommender:latest
19 | ```
20 | The `-v` flag mounts the current directory to `/root/data` in the container, and the `ro` option means readonly.
21 | You can change `(pwd)` to the directory you want. For more information see [Use bind mounts](https://docs.docker.com/storage/bind-mounts/)
22 |
23 | 3. Open the JupyterLab in a browser with `http://localhost:8889`
24 |
25 | 4. Enter `LibRecommender` as the password.
26 |
27 | 5. The `examples` folder in the repository has been included in the container, so one can use the magic command in the notebook to run some example scripts:
28 | ```shell
29 | cd examples
30 | %run pure_ranking_example.py
31 | ```
32 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/md_doc/autoint_feature.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/docs/md_doc/autoint_feature.jpg
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==6.1.3
2 | sphinx_copybutton==0.5.2
3 | sphinx-inline-tabs==2022.1.2b11
4 | furo==2022.12.7
5 | numpy==1.23.4
6 | cython==0.29.30
7 | scipy==1.8.1
8 | pandas==1.4.3
9 | scikit-learn==1.1.1
10 | tensorflow-cpu==2.10.1
11 | torch==1.11.0 --index-url https://download.pytorch.org/whl/cpu
12 | gensim>=4.0.0
13 | tqdm==4.64.0
14 | ujson==5.4.0
15 | redis==4.3.4
16 |
--------------------------------------------------------------------------------
/docs/source/_static/autoint_feature.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/docs/source/_static/autoint_feature.jpg
--------------------------------------------------------------------------------
/docs/source/api/algorithms/als.rst:
--------------------------------------------------------------------------------
1 | ALS
2 | ---
3 |
4 | .. autoclass:: libreco.algorithms.ALS
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/autoint.rst:
--------------------------------------------------------------------------------
1 | AutoInt
2 | -------
3 |
4 | .. autoclass:: libreco.algorithms.AutoInt
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/bases.rst:
--------------------------------------------------------------------------------
1 | Base Classes
2 | ------------
3 |
4 | .. autoclass:: libreco.bases.Base
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
9 | .. autoclass:: libreco.bases.EmbedBase
10 | :members:
11 | :inherited-members:
12 | :show-inheritance:
13 |
14 | .. autoclass:: libreco.bases.TfBase
15 | :members:
16 | :inherited-members:
17 | :show-inheritance:
18 |
19 | .. autoclass:: libreco.bases.CfBase
20 | :members:
21 | :inherited-members:
22 | :show-inheritance:
23 |
24 | .. autoclass:: libreco.bases.RsCfBase
25 | :members:
26 | :inherited-members:
27 | :show-inheritance:
28 |
29 | .. autoclass:: libreco.bases.GensimBase
30 | :members:
31 | :inherited-members:
32 | :show-inheritance:
33 |
34 | .. autoclass:: libreco.bases.SageBase
35 | :members:
36 | :inherited-members:
37 | :show-inheritance:
38 |
39 | .. autoclass:: libreco.bases.DynEmbedBase
40 | :members:
41 | :inherited-members:
42 | :show-inheritance:
43 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/bpr.rst:
--------------------------------------------------------------------------------
1 | BPR
2 | ---
3 |
4 | .. autoclass:: libreco.algorithms.BPR
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/caser.rst:
--------------------------------------------------------------------------------
1 | Caser
2 | -----
3 |
4 | .. autoclass:: libreco.algorithms.Caser
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/deepfm.rst:
--------------------------------------------------------------------------------
1 | DeepFM
2 | ------
3 |
4 | .. autoclass:: libreco.algorithms.DeepFM
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/deepwalk.rst:
--------------------------------------------------------------------------------
1 | DeepWalk
2 | --------
3 |
4 | .. autoclass:: libreco.algorithms.DeepWalk
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/din.rst:
--------------------------------------------------------------------------------
1 | DIN
2 | ---
3 |
4 | .. autoclass:: libreco.algorithms.DIN
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/fm.rst:
--------------------------------------------------------------------------------
1 | FM
2 | --
3 |
4 | .. autoclass:: libreco.algorithms.FM
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/graphsage.rst:
--------------------------------------------------------------------------------
1 | GraphSage
2 | ---------
3 |
4 | .. autoclass:: libreco.algorithms.GraphSage
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/graphsage_dgl.rst:
--------------------------------------------------------------------------------
1 | GraphSageDGL
2 | ------------
3 |
4 | .. autoclass:: libreco.algorithms.GraphSageDGL
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 | :exclude-members: transform_blocks
9 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/index.rst:
--------------------------------------------------------------------------------
1 | Algorithms
2 | ==========
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | bases
8 | user_cf
9 | user_cf_rs
10 | item_cf
11 | item_cf_rs
12 | svd
13 | svdpp
14 | als
15 | ncf
16 | bpr
17 | wide_deep
18 | fm
19 | deepfm
20 | youtube_retrieval
21 | youtube_ranking
22 | autoint
23 | din
24 | item2vec
25 | rnn4rec
26 | caser
27 | wavenet
28 | deepwalk
29 | ngcf
30 | lightgcn
31 | graphsage
32 | graphsage_dgl
33 | pinsage
34 | pinsage_dgl
35 | two_tower
36 | transformer
37 | sim
38 | swing
39 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/item2vec.rst:
--------------------------------------------------------------------------------
1 | Item2Vec
2 | --------
3 |
4 | .. autoclass:: libreco.algorithms.Item2Vec
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/item_cf.rst:
--------------------------------------------------------------------------------
1 | ItemCF
2 | ------
3 |
4 | .. autoclass:: libreco.algorithms.ItemCF
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/item_cf_rs.rst:
--------------------------------------------------------------------------------
1 | RsItemCF
2 | --------
3 |
4 | .. autoclass:: libreco.algorithms.RsItemCF
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/lightgcn.rst:
--------------------------------------------------------------------------------
1 | LightGCN
2 | --------
3 |
4 | .. autoclass:: libreco.algorithms.LightGCN
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/ncf.rst:
--------------------------------------------------------------------------------
1 | NCF
2 | ---
3 |
4 | .. autoclass:: libreco.algorithms.NCF
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/ngcf.rst:
--------------------------------------------------------------------------------
1 | NGCF
2 | ----
3 |
4 | .. autoclass:: libreco.algorithms.NGCF
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/pinsage.rst:
--------------------------------------------------------------------------------
1 | PinSage
2 | -------
3 |
4 | .. autoclass:: libreco.algorithms.PinSage
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/pinsage_dgl.rst:
--------------------------------------------------------------------------------
1 | PinSageDGL
2 | ----------
3 |
4 | .. autoclass:: libreco.algorithms.PinSageDGL
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 | :exclude-members: transform_blocks
9 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/rnn4rec.rst:
--------------------------------------------------------------------------------
1 | RNN4Rec
2 | -------
3 |
4 | .. autoclass:: libreco.algorithms.RNN4Rec
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/sim.rst:
--------------------------------------------------------------------------------
1 | SIM
2 | ---
3 |
4 | .. autoclass:: libreco.algorithms.SIM
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/svd.rst:
--------------------------------------------------------------------------------
1 | SVD
2 | ---
3 |
4 | .. autoclass:: libreco.algorithms.SVD
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/svdpp.rst:
--------------------------------------------------------------------------------
1 | SVD++
2 | -----
3 |
4 | .. autoclass:: libreco.algorithms.SVDpp
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/swing.rst:
--------------------------------------------------------------------------------
1 | Swing
2 | -----
3 |
4 | .. autoclass:: libreco.algorithms.Swing
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/transformer.rst:
--------------------------------------------------------------------------------
1 | Transformer
2 | -----------
3 |
4 | .. autoclass:: libreco.algorithms.Transformer
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/two_tower.rst:
--------------------------------------------------------------------------------
1 | TwoTower
2 | --------
3 |
4 | .. autoclass:: libreco.algorithms.TwoTower
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/user_cf.rst:
--------------------------------------------------------------------------------
1 | UserCF
2 | ------
3 |
4 | .. autoclass:: libreco.algorithms.UserCF
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/user_cf_rs.rst:
--------------------------------------------------------------------------------
1 | RsUserCF
2 | --------
3 |
4 | .. autoclass:: libreco.algorithms.RsUserCF
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/wavenet.rst:
--------------------------------------------------------------------------------
1 | WaveNet
2 | -------
3 |
4 | .. autoclass:: libreco.algorithms.WaveNet
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/wide_deep.rst:
--------------------------------------------------------------------------------
1 | Wide & Deep
2 | -----------
3 |
4 | .. autoclass:: libreco.algorithms.WideDeep
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/youtube_ranking.rst:
--------------------------------------------------------------------------------
1 | YouTubeRanking
2 | --------------
3 |
4 | .. autoclass:: libreco.algorithms.YouTubeRanking
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/algorithms/youtube_retrieval.rst:
--------------------------------------------------------------------------------
1 | YouTubeRetrieval
2 | ----------------
3 |
4 | .. autoclass:: libreco.algorithms.YouTubeRetrieval
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/data/data_info.rst:
--------------------------------------------------------------------------------
1 | DataInfo
2 | ========
3 |
4 | .. autoclass:: libreco.data.DataInfo
5 | :members:
6 | :special-members: __repr__
7 |
8 | .. autoclass:: libreco.data.MultiSparseInfo
9 |
--------------------------------------------------------------------------------
/docs/source/api/data/dataset.rst:
--------------------------------------------------------------------------------
1 | Dataset
2 | =======
3 |
4 | .. automodule:: libreco.data.dataset
5 | :members:
6 | :inherited-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/api/data/index.rst:
--------------------------------------------------------------------------------
1 | Data
2 | ====
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | dataset
8 | data_info
9 | split
10 | transformed
--------------------------------------------------------------------------------
/docs/source/api/data/split.rst:
--------------------------------------------------------------------------------
1 | Split
2 | =====
3 |
4 | .. autofunction:: libreco.data.random_split
5 |
6 | .. autofunction:: libreco.data.split_by_ratio
7 |
8 | .. autofunction:: libreco.data.split_by_num
9 |
10 | .. autofunction:: libreco.data.split_by_ratio_chrono
11 |
12 | .. autofunction:: libreco.data.split_by_num_chrono
13 |
14 | .. autofunction:: libreco.data.split_multi_value
15 |
--------------------------------------------------------------------------------
/docs/source/api/data/transformed.rst:
--------------------------------------------------------------------------------
1 | TransformedSet
2 | ==============
3 |
4 | .. autoclass:: libreco.data.TransformedSet
5 | :members:
6 | :special-members: __len__
7 |
8 | .. autoclass:: libreco.data.TransformedEvalSet
9 | :members:
10 | :special-members: __len__
11 |
--------------------------------------------------------------------------------
/docs/source/api/evaluation.rst:
--------------------------------------------------------------------------------
1 | Evaluation
2 | ==========
3 |
4 | .. automodule:: libreco.evaluation
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/source/api/serialization.rst:
--------------------------------------------------------------------------------
1 | Serialization
2 | =============
3 |
4 | .. autofunction:: libserving.serialization.save_knn
5 |
6 | .. autofunction:: libserving.serialization.save_embed
7 |
8 | .. autofunction:: libserving.serialization.save_tf
9 |
10 | .. autofunction:: libserving.serialization.save_online
11 |
12 | .. autofunction:: libserving.serialization.knn2redis
13 |
14 | .. autofunction:: libserving.serialization.embed2redis
15 |
16 | .. autofunction:: libserving.serialization.tf2redis
17 |
18 | .. autofunction:: libserving.serialization.online2redis
19 |
--------------------------------------------------------------------------------
/docs/source/internal/data_info.rst:
--------------------------------------------------------------------------------
1 | Data Info
2 | =========
3 |
4 | The :class:`~libreco.data.DataInfo` object stores almost all the useful information from the original data.
5 | We admit there may be too much information in this object, but for the ease of use of the library,
6 | we've decided not to split it.
7 | So almost every model has a ``data_info`` attribute that is used to make recommendations.
8 | Additionally, when saving and loading a model, the corresponding *DataInfo* should also be saved and loaded.
9 |
10 | When using a ``feat`` model, the :class:`~libreco.data.DataInfo` object stores the unique features of
11 | all users/items in the training data. However, if a user/item has different categories or values
12 | in the training data (which may be unlikely if the data is clean :)), only the last one will be stored.
13 | For example, if in one sample a user's age is 20, and in another sample this user's age becomes 25,
14 | then only 25 will be kept. So here we basically assume the data is always sorted by time,
15 | and you should do so if it doesn't.
16 |
17 | Therefore, when you call ``model.predict(user=..., item=...)`` or ``model.recommend_user(user=...)``
18 | for a feat model, the model will use the stored feature information in DataInfo.
19 |
20 | The :class:`~libreco.data.DataInfo` object also stores users' consumed items, which can be useful in sequence models
21 | and ``unconsumed`` sampler.
22 |
23 | Changing User/Item Features
24 | ---------------------------
25 | It is also possible to change the unique user/item feature values stored in *DataInfo*,
26 | then the new features would be used in prediction and recommendation.
27 |
28 | .. code-block:: python3
29 |
30 | >>> data_info.assign_user_features(user_data=data)
31 | >>> data_info.assign_item_features(item_data=data)
32 |
33 | The passed ``data`` argument is a ``pandas.DataFrame`` that contains the user/item information.
34 | Be careful with this assign operation if you are not sure if the features in ``data`` are useful.
35 |
36 | .. SeeAlso::
37 |
38 | `changing_feature_example.py `_
39 |
--------------------------------------------------------------------------------
/docs/source/internal/index.rst:
--------------------------------------------------------------------------------
1 | Internal
2 | ========
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | implementation_details
8 | data_info
9 |
--------------------------------------------------------------------------------
/docs/source/user_guide/embedding.rst:
--------------------------------------------------------------------------------
1 | Embedding
2 | =========
3 |
4 | According to the `algorithm list `_,
5 | there are some algorithms that can generate user and item embeddings after training.
6 | So LibRecommender provides public APIs to get them:
7 |
8 | .. code-block:: python3
9 |
10 | >>> model = RNN4Rec(task="ranking", ...)
11 | >>> model.fit(train_data, ...)
12 | >>> model.get_user_embedding(user=1) # get user embedding for user 1
13 | >>> model.get_item_embedding(item=2) # get item embedding for item 2
14 |
15 | One can also search for similar users/items based on embeddings. By default,
16 | we use `nmslib `_ to do approximate similarity
17 | searching since it's generally fast, but some people may find it difficult
18 | to build and install the library, especially on Windows platform or Python >= 3.10.
19 | So one can fall back to numpy similarity calculation if nmslib is not available.
20 |
21 | .. code-block:: python3
22 |
23 | >>> model = RNN4Rec(task="ranking", ...)
24 | >>> model.fit(train_data, ...)
25 | >>> model.init_knn(approximate=True, sim_type="cosine")
26 | >>> model.search_knn_users(user=1, k=3)
27 | >>> model.search_knn_items(item=2, k=3)
28 |
29 | Before searching, one should call :func:`~libreco.bases.EmbedBase.init_knn` to initialize the index.
30 | Set ``approximate=True`` if you can use nmslib, otherwise set ``approximate=False``.
31 | The ``sim_type`` parameter should either be ``cosine`` or ``inner-product``.
32 |
33 |
34 | Dynamic Embedding Generation
35 | ----------------------------
36 | It is also common to generate user embeddings based on features or behavior sequences.
37 | Once the user embedding has been generated, you can use it to perform similarity search with all the item embeddings.
38 |
39 | This can be useful in the cold-start scenario, so LibRecommender provides API for dynamic user embeddings:
40 |
41 | .. code-block:: python3
42 |
43 | >>> model = RNN4Rec(task="ranking", norm_embed=True, ...)
44 | >>> model.fit(train_data, ...)
45 | >>> user_embed = model.dyn_user_embedding(user=1, seq=[0, 10])
46 |
47 | >>> model2 = YouTubeRetrieval(task="ranking", norm_embed=False, ...)
48 | >>> model2.fit(train_data, ...)
49 | >>> user_embed = model2.dyn_user_embedding(user="cold user", user_feats={"sex": "F"}, seq=[0, 10])
50 |
51 | .. SeeAlso::
52 |
53 | `knn_embedding_example.py `_
54 |
55 |
--------------------------------------------------------------------------------
/docs/source/user_guide/evaluation_save_load.rst:
--------------------------------------------------------------------------------
1 | Evaluation & Save/Load
2 | ======================
3 |
4 | Evaluate During Training
5 | ------------------------
6 |
7 | The standard procedure in LibRecommender is evaluating during training.
8 | However, for some complex models doing full evaluation on eval data can be very
9 | time-consuming, so you can specify some evaluation parameters to speed this up.
10 |
11 | The default value of ``eval_batch_size`` is 8192, and you can use a higher value if
12 | you have enough machine or GPU memory. On the contrary, if you encounter memory error during
13 | evaluation, try reducing ``eval_batch_size``.
14 |
15 | The ``eval_user_num`` parameter controls how many users to use in evaluation.
16 | By default, it is ``None``, which uses all the users in eval data.
17 | You can use a smaller value if the evaluation is slow, and this will sample ``eval_user_num``
18 | users randomly from eval data.
19 |
20 | .. code-block:: python3
21 |
22 | model.fit(
23 | train_data,
24 | verbose=2,
25 | shuffle=True,
26 | eval_data=eval_data,
27 | metrics=metrics,
28 | k=10, # parameter of metrics, e.g. recall at k, ndcg at k
29 | eval_batch_size=8192,
30 | eval_user_num=100,
31 | )
32 |
33 |
34 | Evaluate After Training
35 | -----------------------
36 |
37 | After the training, one can use the :func:`~libreco.evaluation.evaluate` function to
38 | evaluate on test data directly.
39 |
40 | Note that if your evaluation data(typically in :class:`pandas.DataFrame` format) **is implicit and only contains positive label**,
41 | then negative sampling is needed by passing ``neg_sampling=True``:
42 |
43 | .. literalinclude:: ../../../examples/save_load_example.py
44 | :caption: From file `examples/save_load_example.py `_
45 | :name: save_load_example.py
46 | :lines: 85-94
47 |
48 | Save/Load Model
49 | ---------------
50 |
51 | In general, we may want to save/load a model for two reasons:
52 |
53 | 1. Save the model, then load it to make some predictions and recommendations. This is called inference.
54 | 2. Save the model, then load it to retrain the model when we get some new data.
55 |
56 | The ``save/load`` API mainly deal with the first one, and the retraining problem is quite
57 | different, which will be covered in the :doc:`model_retrain`.
58 | When making predictions and recommendations, it may be unnecessary to save all the model
59 | variables. So one can pass ``inference_only=True`` to only save the essential model part.
60 |
61 | After loading the model, one can also evaluate the model directly,
62 | see `save_load_example.py `_ for typical usages.
63 |
--------------------------------------------------------------------------------
/docs/source/user_guide/index.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | User Guide
4 | ==========
5 |
6 | The purpose of this guide is to illustrate some main features that LibRecommender provides.
7 | Example usages are all listed in `examples `_ folder.
8 |
9 | .. toctree::
10 | :maxdepth: 1
11 |
12 | data_processing
13 | feature_engineering
14 | model_train
15 | evaluation_save_load
16 | recommendation
17 | embedding
18 | model_retrain
19 |
--------------------------------------------------------------------------------
/examples/feat_example.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from libreco.algorithms import YouTubeRanking
4 | from libreco.data import DatasetFeat, split_by_ratio_chrono
5 |
6 | if __name__ == "__main__":
7 | data = pd.read_csv("sample_data/sample_movielens_merged.csv", sep=",", header=0)
8 |
9 | # split into train and test data based on time
10 | train_data, test_data = split_by_ratio_chrono(data, test_size=0.2)
11 |
12 | # specify complete columns information
13 | sparse_col = ["sex", "occupation", "genre1", "genre2", "genre3"]
14 | dense_col = ["age"]
15 | user_col = ["sex", "age", "occupation"]
16 | item_col = ["genre1", "genre2", "genre3"]
17 |
18 | train_data, data_info = DatasetFeat.build_trainset(
19 | train_data, user_col, item_col, sparse_col, dense_col
20 | )
21 | test_data = DatasetFeat.build_testset(test_data)
22 | print(data_info) # n_users: 5953, n_items: 3209, data density: 0.4213 %
23 |
24 | ytb_ranking = YouTubeRanking(
25 | task="ranking",
26 | data_info=data_info,
27 | embed_size=16,
28 | n_epochs=3,
29 | lr=1e-4,
30 | batch_size=512,
31 | use_bn=True,
32 | hidden_units=(128, 64, 32),
33 | )
34 | ytb_ranking.fit(
35 | train_data,
36 | neg_sampling=True, # sample negative items train and eval data
37 | verbose=2,
38 | shuffle=True,
39 | eval_data=test_data,
40 | metrics=["loss", "roc_auc", "precision", "recall", "map", "ndcg"],
41 | )
42 |
43 | # predict preference of user 2211 to item 110
44 | print("prediction: ", ytb_ranking.predict(user=2211, item=110))
45 | # recommend 7 items for user 2211
46 | print("recommendation: ", ytb_ranking.recommend_user(user=2211, n_rec=7))
47 |
48 | # cold-start prediction
49 | print(
50 | "cold prediction: ",
51 | ytb_ranking.predict(user="ccc", item="not item", cold_start="average"),
52 | )
53 | # cold-start recommendation
54 | print(
55 | "cold recommendation: ",
56 | ytb_ranking.recommend_user(user="are we good?", n_rec=7, cold_start="popular"),
57 | )
58 |
--------------------------------------------------------------------------------
/examples/knn_embedding_example.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from libreco.algorithms import RNN4Rec
4 | from libreco.data import DatasetPure, split_by_ratio_chrono
5 | from libreco.utils.misc import colorize
6 |
7 | try:
8 | import nmslib # noqa: F401
9 |
10 | approximate = True
11 | print_str = "using `nmslib` for similarity search"
12 | print(f"{colorize(print_str, 'cyan')}")
13 | except (ImportError, ModuleNotFoundError):
14 | approximate = False
15 | print_str = "failed to import `nmslib`, using `numpy` for similarity search"
16 | print(f"{colorize(print_str, 'cyan')}")
17 |
18 |
19 | if __name__ == "__main__":
20 | data = pd.read_csv(
21 | "sample_data/sample_movielens_rating.dat",
22 | sep="::",
23 | names=["user", "item", "label", "time"],
24 | )
25 |
26 | train_data, eval_data = split_by_ratio_chrono(data, test_size=0.2)
27 | train_data, data_info = DatasetPure.build_trainset(train_data)
28 | eval_data = DatasetPure.build_evalset(eval_data)
29 |
30 | rnn = RNN4Rec(
31 | "ranking",
32 | data_info,
33 | rnn_type="lstm",
34 | loss_type="cross_entropy",
35 | embed_size=16,
36 | n_epochs=2,
37 | lr=0.001,
38 | lr_decay=False,
39 | hidden_units=16,
40 | reg=None,
41 | batch_size=2048,
42 | num_neg=1,
43 | dropout_rate=None,
44 | recent_num=10,
45 | tf_sess_config=None,
46 | )
47 | rnn.fit(train_data, neg_sampling=True, verbose=2)
48 |
49 | # `sim_type` should either be `cosine` or `inner-product`
50 | rnn.init_knn(approximate=approximate, sim_type="cosine")
51 | print("embedding for user 1: ", rnn.get_user_embedding(user=1))
52 | print("embedding for item 2: ", rnn.get_item_embedding(item=2))
53 | print()
54 |
55 | print(" 3 most similar users for user 1: ", rnn.search_knn_users(user=1, k=3))
56 | print(" 3 most similar items for item 2: ", rnn.search_knn_items(item=2, k=3))
57 | print()
58 |
59 | user_embed = rnn.dyn_user_embedding(user=1, seq=[0, 10, 100])
60 | print("generate embedding for user 1: ", user_embed)
61 |
--------------------------------------------------------------------------------
/examples/multi_sparse_example.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from libreco.algorithms import DeepFM
4 | from libreco.data import DatasetFeat, split_by_ratio_chrono
5 |
6 | if __name__ == "__main__":
7 | data = pd.read_csv("sample_data/sample_movielens_merged.csv", sep=",", header=0)
8 | train_data, eval_data = split_by_ratio_chrono(data, test_size=0.2)
9 |
10 | # specify complete columns information
11 | sparse_col = ["sex", "occupation"]
12 | multi_sparse_col = [["genre1", "genre2", "genre3"]] # should be list of list
13 | dense_col = ["age"]
14 | user_col = ["sex", "age", "occupation"]
15 | item_col = ["genre1", "genre2", "genre3"]
16 |
17 | train_data, data_info = DatasetFeat.build_trainset(
18 | train_data=train_data,
19 | user_col=user_col,
20 | item_col=item_col,
21 | sparse_col=sparse_col,
22 | dense_col=dense_col,
23 | multi_sparse_col=multi_sparse_col,
24 | pad_val=["missing"], # specify padding value
25 | )
26 | eval_data = DatasetFeat.build_testset(eval_data)
27 | print(data_info)
28 |
29 | deepfm = DeepFM(
30 | "ranking",
31 | data_info,
32 | embed_size=16,
33 | n_epochs=2,
34 | lr=1e-4,
35 | lr_decay=False,
36 | reg=None,
37 | batch_size=2048,
38 | num_neg=1,
39 | use_bn=False,
40 | dropout_rate=None,
41 | hidden_units=(128, 64, 32),
42 | tf_sess_config=None,
43 | multi_sparse_combiner="sqrtn", # specify multi_sparse combiner
44 | )
45 |
46 | deepfm.fit(
47 | train_data,
48 | neg_sampling=True,
49 | verbose=2,
50 | shuffle=True,
51 | eval_data=eval_data,
52 | metrics=[
53 | "loss",
54 | "balanced_accuracy",
55 | "roc_auc",
56 | "pr_auc",
57 | "precision",
58 | "recall",
59 | "map",
60 | "ndcg",
61 | ],
62 | )
63 |
64 | print("prediction: ", deepfm.predict(user=1, item=2333))
65 | print("recommendation: ", deepfm.recommend_user(user=1, n_rec=7))
66 |
--------------------------------------------------------------------------------
/examples/pure_example.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from libreco.algorithms import LightGCN # pure data, algorithm LightGCN
4 | from libreco.data import DatasetPure, random_split
5 | from libreco.evaluation import evaluate
6 |
7 | if __name__ == "__main__":
8 | data = pd.read_csv(
9 | "sample_data/sample_movielens_rating.dat",
10 | sep="::",
11 | names=["user", "item", "label", "time"],
12 | )
13 |
14 | # split whole data into three folds for training, evaluating and testing
15 | train_data, eval_data, test_data = random_split(data, multi_ratios=[0.8, 0.1, 0.1])
16 |
17 | train_data, data_info = DatasetPure.build_trainset(train_data)
18 | eval_data = DatasetPure.build_evalset(eval_data)
19 | test_data = DatasetPure.build_testset(test_data)
20 | print(data_info) # n_users: 5894, n_items: 3253, data sparsity: 0.4172 %
21 |
22 | lightgcn = LightGCN(
23 | task="ranking",
24 | data_info=data_info,
25 | loss_type="bpr",
26 | embed_size=16,
27 | n_epochs=3,
28 | lr=1e-3,
29 | batch_size=2048,
30 | num_neg=1,
31 | device="cuda",
32 | )
33 | # monitor metrics on eval_data during training
34 | lightgcn.fit(
35 | train_data,
36 | neg_sampling=True, # sample negative items for train and eval data
37 | verbose=2,
38 | eval_data=eval_data,
39 | metrics=["loss", "roc_auc", "precision", "recall", "ndcg"],
40 | )
41 |
42 | # do final evaluation on test data
43 | print(
44 | "evaluate_result: ",
45 | evaluate(
46 | model=lightgcn,
47 | data=test_data,
48 | neg_sampling=True, # sample negative items for test data
49 | metrics=["loss", "roc_auc", "precision", "recall", "ndcg"],
50 | ),
51 | )
52 | # predict preference of user 2211 to item 110
53 | print("prediction: ", lightgcn.predict(user=2211, item=110))
54 | # recommend 7 items for user 2211
55 | print("recommendation: ", lightgcn.recommend_user(user=2211, n_rec=7))
56 |
57 | # cold-start prediction
58 | print(
59 | "cold prediction: ",
60 | lightgcn.predict(user="ccc", item="not item", cold_start="average"),
61 | )
62 | # cold-start recommendation
63 | print(
64 | "cold recommendation: ",
65 | lightgcn.recommend_user(user="are we good?", n_rec=7, cold_start="popular"),
66 | )
67 |
--------------------------------------------------------------------------------
/examples/seq_example.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import pandas as pd
4 |
5 | from libreco.algorithms import DIN
6 | from libreco.data import DatasetFeat
7 |
8 | if __name__ == "__main__":
9 | start_time = time.perf_counter()
10 | data = pd.read_csv("sample_data/sample_movielens_merged.csv", sep=",", header=0)
11 |
12 | # specify complete columns information
13 | sparse_col = ["sex", "occupation", "genre1", "genre2", "genre3"]
14 | dense_col = ["age"]
15 | user_col = ["sex", "age", "occupation"]
16 | item_col = ["genre1", "genre2", "genre3"]
17 |
18 | train_data, data_info = DatasetFeat.build_trainset(
19 | data, user_col, item_col, sparse_col, dense_col
20 | )
21 |
22 | din = DIN(
23 | "ranking",
24 | data_info,
25 | loss_type="focal",
26 | embed_size=16,
27 | n_epochs=1,
28 | lr=3e-3,
29 | lr_decay=False,
30 | reg=None,
31 | batch_size=64,
32 | sampler="popular",
33 | num_neg=1,
34 | use_bn=True,
35 | hidden_units=(110, 32),
36 | recent_num=10,
37 | tf_sess_config=None,
38 | use_tf_attention=True,
39 | )
40 | din.fit(train_data, neg_sampling=True, verbose=2, shuffle=True, eval_data=None)
41 |
42 | print(
43 | "feat recommendation: ",
44 | din.recommend_user(user=4617, n_rec=7, user_feats={"sex": "F", "age": 3}),
45 | )
46 | print(
47 | "seq recommendation1: ",
48 | din.recommend_user(
49 | user=4617,
50 | n_rec=7,
51 | seq=[4, 0, 1, 222, "cold item", 222, 12, 1213, 1197, 1193],
52 | ),
53 | )
54 | print(
55 | "seq recommendation2: ",
56 | din.recommend_user(
57 | user=4617,
58 | n_rec=7,
59 | seq=["cold item", 1270, 2161, 110, 3827, 12, 34, 1273, 1589],
60 | ),
61 | )
62 | print(
63 | "feat & seq recommendation1: ",
64 | din.recommend_user(
65 | user=1, n_rec=7, user_feats={"sex": "F", "age": 3}, seq=[4, 0, 1, 222]
66 | ),
67 | )
68 | print(
69 | "feat & seq recommendation2: ",
70 | din.recommend_user(
71 | user=1, n_rec=7, user_feats={"sex": "M", "age": 33}, seq=[4, 0, 337, 1497]
72 | ),
73 | )
74 |
--------------------------------------------------------------------------------
/examples/split_data_example.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import pandas as pd
4 |
5 | from libreco.data import (
6 | random_split,
7 | split_by_num,
8 | split_by_num_chrono,
9 | split_by_ratio,
10 | split_by_ratio_chrono,
11 | )
12 |
13 | if __name__ == "__main__":
14 | start_time = time.perf_counter()
15 | data = pd.read_csv("sample_data/sample_movielens_merged.csv", sep=",", header=0)
16 |
17 | train_data, eval_data, test_data = random_split(data, multi_ratios=[0.8, 0.1, 0.1])
18 |
19 | train_data2, eval_data2 = split_by_ratio(data, test_size=0.2)
20 | print(train_data2.shape, eval_data2.shape)
21 |
22 | train_data3, eval_data3 = split_by_num(data, test_size=1)
23 | print(train_data3.shape, eval_data3.shape)
24 |
25 | train_data4, eval_data4 = split_by_ratio_chrono(data, test_size=0.2)
26 | print(train_data4.shape, eval_data4.shape)
27 |
28 | train_data5, eval_data5 = split_by_num_chrono(data, test_size=1)
29 | print(train_data5.shape, eval_data5.shape)
30 |
--------------------------------------------------------------------------------
/libreco/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.5.1"
2 |
--------------------------------------------------------------------------------
/libreco/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | from .als import ALS
2 | from .autoint import AutoInt
3 | from .bpr import BPR
4 | from .caser import Caser
5 | from .deepfm import DeepFM
6 | from .deepwalk import DeepWalk
7 | from .din import DIN
8 | from .fm import FM
9 | from .graphsage import GraphSage
10 | from .graphsage_dgl import GraphSageDGL
11 | from .item2vec import Item2Vec
12 | from .item_cf import ItemCF
13 | from .item_cf_rs import RsItemCF
14 | from .lightgcn import LightGCN
15 | from .ncf import NCF
16 | from .ngcf import NGCF
17 | from .pinsage import PinSage
18 | from .pinsage_dgl import PinSageDGL
19 | from .rnn4rec import RNN4Rec
20 | from .sim import SIM
21 | from .svd import SVD
22 | from .svdpp import SVDpp
23 | from .swing import Swing
24 | from .transformer import Transformer
25 | from .two_tower import TwoTower
26 | from .user_cf import UserCF
27 | from .user_cf_rs import RsUserCF
28 | from .wave_net import WaveNet
29 | from .wide_deep import WideDeep
30 | from .youtube_ranking import YouTubeRanking
31 | from .youtube_retrieval import YouTubeRetrieval
32 |
33 | __all__ = [
34 | "UserCF",
35 | "RsUserCF",
36 | "ItemCF",
37 | "RsItemCF",
38 | "SVD",
39 | "SVDpp",
40 | "ALS",
41 | "BPR",
42 | "NCF",
43 | "YouTubeRetrieval",
44 | "YouTubeRanking",
45 | "FM",
46 | "WideDeep",
47 | "DeepFM",
48 | "AutoInt",
49 | "DIN",
50 | "RNN4Rec",
51 | "Caser",
52 | "WaveNet",
53 | "Item2Vec",
54 | "DeepWalk",
55 | "NGCF",
56 | "LightGCN",
57 | "PinSage",
58 | "PinSageDGL",
59 | "GraphSage",
60 | "GraphSageDGL",
61 | "TwoTower",
62 | "Transformer",
63 | "SIM",
64 | "Swing",
65 | ]
66 |
--------------------------------------------------------------------------------
/libreco/algorithms/item2vec.py:
--------------------------------------------------------------------------------
1 | """Implementation of Item2Vec."""
2 | from gensim.models import Word2Vec
3 | from tqdm import tqdm
4 |
5 | from ..bases import GensimBase
6 |
7 |
8 | class Item2Vec(GensimBase):
9 | """*Item2Vec* algorithm.
10 |
11 | .. WARNING::
12 | Item2Vec can only use in ``ranking`` task.
13 |
14 | Parameters
15 | ----------
16 | task : {'ranking'}
17 | Recommendation task. See :ref:`Task`.
18 | data_info : :class:`~libreco.data.DataInfo` object
19 | Object that contains useful information for training and inference.
20 | embed_size: int, default: 16
21 | Vector size of embeddings.
22 | norm_embed : bool, default: False
23 | Whether to l2 normalize output embeddings.
24 | window_size : int, default: 5
25 | Maximum item distance within a sequence during training.
26 | n_epochs: int, default: 10
27 | Number of epochs for training.
28 | n_threads : int, default: 0
29 | Number of threads to use, `0` will use all cores.
30 | seed : int, default: 42
31 | Random seed.
32 | lower_upper_bound : tuple or None, default: None
33 | Lower and upper score bound for `rating` task.
34 |
35 | References
36 | ----------
37 | *Oren Barkan and Noam Koenigstein.* `Item2Vec: Neural Item Embedding for Collaborative Filtering
38 | `_.
39 | """
40 |
41 | def __init__(
42 | self,
43 | task,
44 | data_info=None,
45 | embed_size=16,
46 | norm_embed=False,
47 | window_size=5,
48 | n_epochs=10,
49 | n_threads=0,
50 | seed=42,
51 | lower_upper_bound=None,
52 | ):
53 | super().__init__(
54 | task,
55 | data_info,
56 | embed_size,
57 | norm_embed,
58 | window_size,
59 | n_epochs,
60 | n_threads,
61 | seed,
62 | lower_upper_bound,
63 | )
64 | assert task == "ranking", "Item2Vec is only suitable for ranking"
65 | self.all_args = locals()
66 |
67 | def get_data(self):
68 | return _ItemCorpus(self.user_consumed)
69 |
70 | def build_model(self):
71 | model = Word2Vec(
72 | vector_size=self.embed_size,
73 | window=self.window_size,
74 | sg=1,
75 | hs=0,
76 | negative=5,
77 | seed=self.seed,
78 | min_count=1,
79 | workers=self.workers,
80 | sorted_vocab=0,
81 | )
82 | model.build_vocab(self.data, update=False)
83 | return model
84 |
85 |
86 | class _ItemCorpus:
87 | def __init__(self, user_consumed):
88 | self.item_seqs = user_consumed.values()
89 | self.i = 0
90 |
91 | def __iter__(self):
92 | for items in tqdm(self.item_seqs, desc=f"Item2vec iter{self.i}"):
93 | yield list(map(str, items))
94 | self.i += 1
95 |
--------------------------------------------------------------------------------
/libreco/algorithms/item_cf_rs.py:
--------------------------------------------------------------------------------
1 | """Implementation of RsItemCF."""
2 | from ..bases import RsCfBase
3 |
4 |
5 | class RsItemCF(RsCfBase):
6 | """*Item Collaborative Filtering* algorithm implemented in Rust.
7 |
8 | Parameters
9 | ----------
10 | task : {'rating', 'ranking'}
11 | Recommendation task. See :ref:`Task`.
12 | data_info : :class:`~libreco.data.DataInfo` object
13 | Object that contains useful information for training and inference.
14 | k_sim : int, default: 20
15 | Number of similar items to use.
16 | num_threads : int, default: 1
17 | Number of threads to use.
18 | min_common : int, default: 1
19 | Number of minimum common items to consider when computing similarities.
20 | mode : {'forward', 'invert'}, default: 'invert'
21 | Whether to use forward index or invert index.
22 | seed : int, default: 42
23 | Random seed.
24 | lower_upper_bound : tuple or None, default: None
25 | Lower and upper score bound for `rating` task.
26 | """
27 |
28 | def __init__(
29 | self,
30 | task,
31 | data_info,
32 | k_sim=20,
33 | num_threads=1,
34 | min_common=1,
35 | mode="invert",
36 | seed=42,
37 | lower_upper_bound=None,
38 | ):
39 | super().__init__(
40 | task,
41 | data_info,
42 | k_sim,
43 | num_threads,
44 | min_common,
45 | mode,
46 | seed,
47 | lower_upper_bound,
48 | )
49 | self.all_args = locals()
50 |
--------------------------------------------------------------------------------
/libreco/algorithms/torch_modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .graphsage_module import GraphSageDGLModel, GraphSageModel
2 | from .lightgcn_module import LightGCNModel
3 | from .ngcf_module import NGCFModel
4 | from .pinsage_module import PinSageDGLModel, PinSageModel
5 |
6 | __all__ = [
7 | "GraphSageModel",
8 | "GraphSageDGLModel",
9 | "LightGCNModel",
10 | "NGCFModel",
11 | "PinSageModel",
12 | "PinSageDGLModel",
13 | ]
14 |
--------------------------------------------------------------------------------
/libreco/algorithms/user_cf_rs.py:
--------------------------------------------------------------------------------
1 | """Implementation of RsUserCF."""
2 | from ..bases import RsCfBase
3 |
4 |
5 | class RsUserCF(RsCfBase):
6 | """*User Collaborative Filtering* algorithm implemented in Rust.
7 |
8 | Parameters
9 | ----------
10 | task : {'rating', 'ranking'}
11 | Recommendation task. See :ref:`Task`.
12 | data_info : :class:`~libreco.data.DataInfo` object
13 | Object that contains useful information for training and inference.
14 | k_sim : int, default: 20
15 | Number of similar items to use.
16 | num_threads : int, default: 1
17 | Number of threads to use.
18 | min_common : int, default: 1
19 | Number of minimum common users to consider when computing similarities.
20 | mode : {'forward', 'invert'}, default: 'invert'
21 | Whether to use forward index or invert index.
22 | seed : int, default: 42
23 | Random seed.
24 | lower_upper_bound : tuple or None, default: None
25 | Lower and upper score bound for `rating` task.
26 | """
27 |
28 | def __init__(
29 | self,
30 | task,
31 | data_info,
32 | k_sim=20,
33 | num_threads=1,
34 | min_common=1,
35 | mode="invert",
36 | seed=42,
37 | lower_upper_bound=None,
38 | ):
39 | super().__init__(
40 | task,
41 | data_info,
42 | k_sim,
43 | num_threads,
44 | min_common,
45 | mode,
46 | seed,
47 | lower_upper_bound,
48 | )
49 | self.all_args = locals()
50 |
--------------------------------------------------------------------------------
/libreco/bases/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Base
2 | from .cf_base import CfBase
3 | from .cf_base_rs import RsCfBase
4 | from .dyn_embed_base import DynEmbedBase
5 | from .embed_base import EmbedBase
6 | from .gensim_base import GensimBase
7 | from .meta import ModelMeta
8 | from .sage_base import SageBase
9 | from .tf_base import TfBase
10 |
11 | __all__ = [
12 | "Base",
13 | "CfBase",
14 | "RsCfBase",
15 | "DynEmbedBase",
16 | "EmbedBase",
17 | "GensimBase",
18 | "ModelMeta",
19 | "SageBase",
20 | "TfBase",
21 | ]
22 |
--------------------------------------------------------------------------------
/libreco/bases/meta.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta
2 |
3 | from ..tfops import rebuild_tf_model
4 | from ..torchops import rebuild_torch_model
5 |
6 |
7 | class ModelMeta(ABCMeta):
8 | def __new__(mcs, cls_name, bases, cls_dict, **kwargs):
9 | backend = kwargs["backend"] if "backend" in kwargs else "none"
10 | if bases[0].__name__ == "TfBase" or backend == "tensorflow":
11 | cls_dict["rebuild_model"] = rebuild_tf_model
12 | elif backend == "torch":
13 | cls_dict["rebuild_model"] = rebuild_torch_model
14 | return super().__new__(mcs, cls_name, bases, cls_dict)
15 |
--------------------------------------------------------------------------------
/libreco/batch/__init__.py:
--------------------------------------------------------------------------------
1 | from .batch_data import adjust_batch_size, get_batch_loader
2 | from .tf_feed_dicts import get_tf_feeds
3 |
4 | __all__ = ["adjust_batch_size", "get_batch_loader", "get_tf_feeds"]
5 |
--------------------------------------------------------------------------------
/libreco/batch/enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class FeatType(Enum):
5 | SPARSE = "sparse"
6 | DENSE = "dense"
7 |
8 |
9 | class Backend(Enum):
10 | TF = "tensorflow"
11 | TORCH = "torch"
12 |
--------------------------------------------------------------------------------
/libreco/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_info import DataInfo, MultiSparseInfo
2 | from .dataset import DatasetFeat, DatasetPure
3 | from .processing import process_data, split_multi_value
4 | from .split import (
5 | random_split,
6 | split_by_num,
7 | split_by_num_chrono,
8 | split_by_ratio,
9 | split_by_ratio_chrono,
10 | )
11 | from .transformed import TransformedEvalSet, TransformedSet
12 |
13 | __all__ = [
14 | "DatasetPure",
15 | "DatasetFeat",
16 | "DataInfo",
17 | "MultiSparseInfo",
18 | "process_data",
19 | "split_multi_value",
20 | "split_by_num",
21 | "split_by_ratio",
22 | "split_by_num_chrono",
23 | "split_by_ratio_chrono",
24 | "random_split",
25 | "TransformedSet",
26 | "TransformedEvalSet",
27 | ]
28 |
--------------------------------------------------------------------------------
/libreco/data/consumed.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from collections import OrderedDict, defaultdict
3 |
4 | import numpy as np
5 |
6 |
7 | def interaction_consumed(user_indices, item_indices):
8 | """The underlying rust function will remove consecutive repeated elements."""
9 | if isinstance(user_indices, np.ndarray):
10 | user_indices = user_indices.tolist()
11 | if isinstance(item_indices, np.ndarray):
12 | item_indices = item_indices.tolist()
13 |
14 | try:
15 | from recfarm import build_consumed_unique
16 |
17 | return build_consumed_unique(user_indices, item_indices)
18 | except ModuleNotFoundError: # pragma: no cover
19 | return _interaction_consumed(user_indices, item_indices)
20 |
21 |
22 | def _interaction_consumed(user_indices, item_indices): # pragma: no cover
23 | user_consumed = defaultdict(list)
24 | item_consumed = defaultdict(list)
25 | for u, i in zip(user_indices, item_indices):
26 | user_consumed[u].append(i)
27 | item_consumed[i].append(u)
28 | return _remove_duplicates(user_consumed, item_consumed)
29 |
30 |
31 | def _remove_duplicates(user_consumed, item_consumed): # pragma: no cover
32 | # keys will preserve order in dict since Python3.7
33 | if sys.version_info[:2] >= (3, 7):
34 | dict_func = dict.fromkeys
35 | else: # pragma: no cover
36 | dict_func = OrderedDict.fromkeys
37 | user_dedup = {u: list(dict_func(items)) for u, items in user_consumed.items()}
38 | item_dedup = {i: list(dict_func(users)) for i, users in item_consumed.items()}
39 | return user_dedup, item_dedup
40 |
41 |
42 | def update_consumed(
43 | user_indices, item_indices, n_users, n_items, old_info, merge_behavior
44 | ):
45 | user_consumed, item_consumed = interaction_consumed(user_indices, item_indices)
46 | if merge_behavior:
47 | user_consumed = _merge_dedup(user_consumed, n_users, old_info.user_consumed)
48 | item_consumed = _merge_dedup(item_consumed, n_items, old_info.item_consumed)
49 | else:
50 | user_consumed = _fill_empty(user_consumed, n_users, old_info.user_consumed)
51 | item_consumed = _fill_empty(item_consumed, n_items, old_info.item_consumed)
52 | return user_consumed, item_consumed
53 |
54 |
55 | def _merge_dedup(new_consumed, num, old_consumed):
56 | result = dict()
57 | for i in range(num):
58 | assert i in new_consumed or i in old_consumed
59 | if i in new_consumed and i in old_consumed:
60 | result[i] = old_consumed[i] + new_consumed[i]
61 | else:
62 | result[i] = new_consumed[i] if i in new_consumed else old_consumed[i]
63 | return result
64 |
65 |
66 | # some users may not appear in new data
67 | def _fill_empty(consumed, num, old_consumed):
68 | return {i: consumed[i] if i in consumed else old_consumed[i] for i in range(num)}
69 |
70 |
71 | # def _remove_first_duplicates(consumed):
72 | # return pd.Series(consumed).drop_duplicates(keep="last").tolist()
73 |
--------------------------------------------------------------------------------
/libreco/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .evaluate import evaluate, print_metrics
2 |
3 | __all__ = ["evaluate", "print_metrics"]
4 |
--------------------------------------------------------------------------------
/libreco/evaluation/computation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from tqdm import tqdm
4 |
5 | from ..data import TransformedEvalSet
6 | from ..prediction.preprocess import convert_id
7 | from ..utils.validate import check_labels
8 |
9 |
10 | def build_eval_transformed_data(model, data, neg_sampling, seed):
11 | if isinstance(data, pd.DataFrame):
12 | assert "user" in data and "item" in data and "label" in data
13 | users = data["user"].tolist()
14 | items = data["item"].tolist()
15 | user_indices, item_indices = convert_id(model, users, items, inner_id=False)
16 | labels = data["label"].to_numpy(dtype=np.float32)
17 | data = TransformedEvalSet(user_indices, item_indices, labels)
18 | if neg_sampling and not data.has_sampled:
19 | num_neg = model.num_neg or 1 if hasattr(model, "num_neg") else 1
20 | data.build_negatives(model.n_items, num_neg, seed=seed)
21 | else:
22 | check_labels(model, data.labels, neg_sampling)
23 | return data
24 |
25 |
26 | def compute_preds(model, data, batch_size):
27 | y_pred = list()
28 | y_label = list()
29 | for i in tqdm(range(0, len(data), batch_size), desc="eval_pointwise"):
30 | user_indices, item_indices, labels = data[i : i + batch_size]
31 | preds = model.predict(user_indices, item_indices, inner_id=True)
32 | y_pred.extend(preds)
33 | y_label.extend(labels)
34 | return y_pred, y_label
35 |
36 |
37 | def compute_probs(model, data, batch_size):
38 | return compute_preds(model, data, batch_size)
39 |
40 |
41 | def compute_recommends(model, users, k, num_batch_users):
42 | y_recommends = dict()
43 | for i in tqdm(range(0, len(users), num_batch_users), desc="eval_listwise"):
44 | batch_users = users[i : i + num_batch_users]
45 | batch_recs = model.recommend_user(
46 | user=batch_users,
47 | n_rec=k,
48 | inner_id=True,
49 | filter_consumed=True,
50 | random_rec=False,
51 | )
52 | y_recommends.update(batch_recs)
53 | return y_recommends
54 |
--------------------------------------------------------------------------------
/libreco/feature/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/libreco/feature/__init__.py
--------------------------------------------------------------------------------
/libreco/feature/column_mapping.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict, defaultdict
2 |
3 | import numpy as np
4 |
5 |
6 | # format: {column_family_name: {column_name: index}}
7 | # if no such family, default format would be: {column_family_name: {[]: []}
8 | def col_name2index(user_col=None, item_col=None, sparse_col=None, dense_col=None):
9 | name_mapping = defaultdict(OrderedDict)
10 | if sparse_col:
11 | sparse_col_dict = {col: i for i, col in enumerate(sparse_col)}
12 | name_mapping["sparse_col"].update(sparse_col_dict)
13 | if dense_col:
14 | dense_col_dict = {col: i for i, col in enumerate(dense_col)}
15 | name_mapping["dense_col"].update(dense_col_dict)
16 |
17 | if user_col and sparse_col:
18 | user_sparse_col = _extract_common_col(sparse_col, user_col)
19 | for col in user_sparse_col:
20 | name_mapping["user_sparse_col"].update(
21 | {col: name_mapping["sparse_col"][col]}
22 | )
23 | if user_col and dense_col:
24 | user_dense_col = _extract_common_col(dense_col, user_col)
25 | for col in user_dense_col:
26 | name_mapping["user_dense_col"].update({col: name_mapping["dense_col"][col]})
27 |
28 | if item_col and sparse_col:
29 | item_sparse_col = _extract_common_col(sparse_col, item_col)
30 | for col in item_sparse_col:
31 | name_mapping["item_sparse_col"].update(
32 | {col: name_mapping["sparse_col"][col]}
33 | )
34 | if item_col and dense_col:
35 | item_dense_col = _extract_common_col(dense_col, item_col)
36 | for col in item_dense_col:
37 | name_mapping["item_dense_col"].update({col: name_mapping["dense_col"][col]})
38 |
39 | return dict(name_mapping)
40 |
41 |
42 | # `np.intersect1d` will return the sorted common column names,
43 | # but we also want to preserve the original order of common column in col1 and col2
44 | def _extract_common_col(col1, col2):
45 | common_col, indices_in_col1, _ = np.intersect1d(
46 | col1, col2, assume_unique=True, return_indices=True
47 | )
48 | return common_col[np.lexsort((common_col, indices_in_col1))]
49 |
--------------------------------------------------------------------------------
/libreco/feature/ssl.py:
--------------------------------------------------------------------------------
1 | """Feature Generation for Self-Supervised Learning."""
2 | import numpy as np
3 | from sklearn.metrics import mutual_info_score
4 |
5 |
6 | def get_ssl_features(model, batch_size):
7 | ssl_feats = dict()
8 | rng, n_items = model.data_info.np_rng, model.n_items
9 | replace = False if batch_size < n_items else True
10 | item_indices = rng.choice(n_items, size=batch_size, replace=replace)
11 | feat_indices = model.data_info.item_sparse_unique[item_indices]
12 | # add offset since default embedding has 0 index
13 | sparse_indices = np.hstack(
14 | [np.expand_dims(item_indices + 1, 1), feat_indices + n_items + 1]
15 | )
16 | feat_num = sparse_indices.shape[1]
17 | mid_point = feat_num // 2
18 | if model.ssl_pattern.startswith("cfm"):
19 | seed_col = rng.integers(feat_num)
20 | left_cols = model.sparse_feat_mutual_info[seed_col]
21 | right_cols = np.setdiff1d(range(feat_num), left_cols)
22 | elif model.ssl_pattern.endswith("complementary"):
23 | random_cols = rng.permutation(feat_num)
24 | left_cols, right_cols = np.split(random_cols, [mid_point])
25 | else:
26 | left_cols = rng.permutation(feat_num)[:mid_point]
27 | right_cols = rng.permutation(feat_num)[:mid_point]
28 |
29 | left_sparse_indices = sparse_indices.copy()
30 | left_sparse_indices[:, left_cols] = 0
31 | ssl_feats.update({model.ssl_left_sparse_indices: left_sparse_indices})
32 | right_sparse_indices = sparse_indices.copy()
33 | right_sparse_indices[:, right_cols] = 0
34 | ssl_feats.update({model.ssl_right_sparse_indices: right_sparse_indices})
35 |
36 | if model.item_dense:
37 | dense_values = model.data_info.item_dense_unique[item_indices]
38 | ssl_feats.update({model.ssl_left_dense_values: dense_values})
39 | ssl_feats.update({model.ssl_right_dense_values: dense_values})
40 | return ssl_feats
41 |
42 |
43 | def get_mutual_info(data, data_info):
44 | """Compute mutual information for each pair of item sparse features."""
45 | item_indices = np.expand_dims(data.item_indices, 1)
46 | feat_indices = data.sparse_indices[:, data_info.item_sparse_col.index]
47 | sparse_indices = np.hstack([item_indices, feat_indices])
48 | feat_num = sparse_indices.shape[1]
49 | pairwise_mutual_info = np.zeros((feat_num, feat_num))
50 | # assign self mutual info to impossible value
51 | np.fill_diagonal(pairwise_mutual_info, -1)
52 | for i in range(feat_num):
53 | for j in range(i + 1, feat_num):
54 | mi = mutual_info_score(sparse_indices[:, i], sparse_indices[:, j])
55 | pairwise_mutual_info[i][j] = pairwise_mutual_info[j][i] = mi
56 |
57 | n = feat_num // 2
58 | topn_mutual_info = np.argsort(pairwise_mutual_info, axis=1)[:, -n:]
59 | return {i: topn_mutual_info[i] for i in range(feat_num)}
60 |
--------------------------------------------------------------------------------
/libreco/feature/unique.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def construct_unique_feat(
5 | user_indices,
6 | item_indices,
7 | sparse_indices,
8 | dense_values,
9 | col_name_mapping,
10 | unique_feat,
11 | ):
12 | # use mergesort to preserve order
13 | sort_kind = "quicksort" if unique_feat else "mergesort"
14 | user_pos = np.argsort(user_indices, kind=sort_kind)
15 | item_pos = np.argsort(item_indices, kind=sort_kind)
16 |
17 | user_sparse_matrix, item_sparse_matrix = None, None
18 | user_dense_matrix, item_dense_matrix = None, None
19 | if "user_sparse_col" in col_name_mapping:
20 | user_sparse_col = list(col_name_mapping["user_sparse_col"].values())
21 | user_sparse_matrix = _compress_unique_values(
22 | sparse_indices, user_sparse_col, user_indices, user_pos
23 | )
24 | if "item_sparse_col" in col_name_mapping:
25 | item_sparse_col = list(col_name_mapping["item_sparse_col"].values())
26 | item_sparse_matrix = _compress_unique_values(
27 | sparse_indices, item_sparse_col, item_indices, item_pos
28 | )
29 | if "user_dense_col" in col_name_mapping:
30 | user_dense_col = list(col_name_mapping["user_dense_col"].values())
31 | user_dense_matrix = _compress_unique_values(
32 | dense_values, user_dense_col, user_indices, user_pos
33 | )
34 | if "item_dense_col" in col_name_mapping:
35 | item_dense_col = list(col_name_mapping["item_dense_col"].values())
36 | item_dense_matrix = _compress_unique_values(
37 | dense_values, item_dense_col, item_indices, item_pos
38 | )
39 | return (
40 | user_sparse_matrix,
41 | user_dense_matrix,
42 | item_sparse_matrix,
43 | item_dense_matrix,
44 | )
45 |
46 |
47 | # https://stackoverflow.com/questions/46390376/drop-duplicates-from-structured-numpy-array-python3-x
48 | def _compress_unique_values(orig_val, col, indices, pos):
49 | values = np.take(orig_val, col, axis=1)
50 | values = values.reshape(-1, 1) if orig_val.ndim == 1 else values
51 | indices = indices[pos]
52 | mask = np.empty(len(indices), dtype=bool)
53 | mask[:-1] = indices[:-1] != indices[1:]
54 | mask[-1] = True
55 | mask = pos[mask]
56 | unique_values = values[mask]
57 | assert len(np.unique(indices)) == len(unique_values)
58 | return unique_values
59 |
--------------------------------------------------------------------------------
/libreco/graph/__init__.py:
--------------------------------------------------------------------------------
1 | from .from_dgl import (
2 | build_i2i_homo_graph,
3 | build_subgraphs,
4 | build_u2i_hetero_graph,
5 | check_dgl,
6 | compute_i2i_edge_scores,
7 | compute_u2i_edge_scores,
8 | pairs_from_dgl_graph,
9 | )
10 | from .neighbor_walk import NeighborWalker, NeighborWalkerDGL
11 |
12 | __all__ = [
13 | "build_i2i_homo_graph",
14 | "build_subgraphs",
15 | "build_u2i_hetero_graph",
16 | "check_dgl",
17 | "compute_i2i_edge_scores",
18 | "compute_u2i_edge_scores",
19 | "pairs_from_dgl_graph",
20 | "NeighborWalker",
21 | "NeighborWalkerDGL",
22 | ]
23 |
--------------------------------------------------------------------------------
/libreco/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .attention import (
2 | compute_causal_mask,
3 | compute_seq_mask,
4 | din_attention,
5 | multi_head_attention,
6 | tf_attention,
7 | )
8 | from .convolutional import conv_nn, max_pool
9 | from .dense import dense_nn, shared_dense, tf_dense
10 | from .embedding import embedding_lookup, seq_embeds_pooling, sparse_embeds_pooling
11 | from .normalization import layer_normalization, normalize_embeds, rms_norm
12 | from .recurrent import tf_rnn
13 |
14 | __all__ = [
15 | "compute_causal_mask",
16 | "compute_seq_mask",
17 | "conv_nn",
18 | "dense_nn",
19 | "din_attention",
20 | "embedding_lookup",
21 | "layer_normalization",
22 | "max_pool",
23 | "multi_head_attention",
24 | "normalize_embeds",
25 | "rms_norm",
26 | "shared_dense",
27 | "seq_embeds_pooling",
28 | "sparse_embeds_pooling",
29 | "tf_attention",
30 | "tf_dense",
31 | "tf_rnn",
32 | ]
33 |
--------------------------------------------------------------------------------
/libreco/layers/activation.py:
--------------------------------------------------------------------------------
1 | from ..tfops import tf
2 |
3 |
4 | def gelu(x):
5 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, tf.float32)))
6 |
7 |
8 | def swish(x):
9 | return x * tf.sigmoid(x)
10 |
--------------------------------------------------------------------------------
/libreco/layers/convolutional.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from ..tfops import get_tf_version, tf
4 |
5 |
6 | def conv_nn(
7 | filters, kernel_size, strides, padding, activation, dilation_rate=1, version=None
8 | ):
9 | tf_version = get_tf_version(version)
10 | if tf_version >= "2.0.0":
11 | net = tf.keras.layers.Conv1D(
12 | filters=filters,
13 | kernel_size=kernel_size,
14 | strides=strides,
15 | padding=padding,
16 | activation=activation,
17 | dilation_rate=dilation_rate,
18 | )
19 | else:
20 | net = partial(
21 | tf.layers.conv1d,
22 | filters=filters,
23 | kernel_size=kernel_size,
24 | strides=strides,
25 | padding=padding,
26 | activation=activation,
27 | )
28 | return net
29 |
30 |
31 | def max_pool(pool_size, strides, padding, version=None):
32 | tf_version = get_tf_version(version)
33 | if tf_version >= "2.0.0":
34 | net = tf.keras.layers.MaxPool1D(
35 | pool_size=pool_size, strides=strides, padding=padding
36 | )
37 | else:
38 | net = partial(
39 | tf.layers.max_pooling1d,
40 | pool_size=pool_size,
41 | strides=strides,
42 | padding=padding,
43 | )
44 | return net
45 |
--------------------------------------------------------------------------------
/libreco/layers/embedding.py:
--------------------------------------------------------------------------------
1 | from ..tfops import tf
2 |
3 |
4 | def embedding_lookup(
5 | indices,
6 | var_name=None,
7 | var_shape=None,
8 | initializer=None,
9 | regularizer=None,
10 | reuse_layer=None,
11 | embed_var=None,
12 | scope_name="embedding",
13 | ):
14 | reuse = tf.AUTO_REUSE if reuse_layer else None
15 | with tf.variable_scope(scope_name, reuse=reuse):
16 | if embed_var is None:
17 | embed_var = tf.get_variable(
18 | name=var_name,
19 | shape=var_shape,
20 | initializer=initializer,
21 | regularizer=regularizer,
22 | )
23 | return tf.nn.embedding_lookup(embed_var, indices)
24 |
25 |
26 | def sparse_embeds_pooling(
27 | sparse_indices,
28 | var_name,
29 | var_shape,
30 | initializer,
31 | regularizer=None,
32 | reuse_layer=None,
33 | combiner="sqrtn",
34 | scope_name="sparse_embeds_pooling",
35 | ):
36 | reuse = tf.AUTO_REUSE if reuse_layer else None
37 | with tf.variable_scope(scope_name, reuse=reuse):
38 | embed_var = tf.get_variable(
39 | name=var_name,
40 | shape=var_shape,
41 | initializer=initializer,
42 | regularizer=regularizer,
43 | )
44 | # unknown user will return 0-vector in `safe_embedding_lookup_sparse`
45 | return tf.nn.safe_embedding_lookup_sparse(
46 | embed_var,
47 | sparse_indices,
48 | sparse_weights=None,
49 | combiner=combiner,
50 | default_id=None,
51 | )
52 |
53 |
54 | def seq_embeds_pooling(
55 | seq_indices,
56 | seq_lens,
57 | n_items,
58 | var_name,
59 | var_shape,
60 | initializer=None,
61 | regularizer=None,
62 | reuse_layer=None,
63 | scope_name="seq_embeds_pooling",
64 | ):
65 | reuse = tf.AUTO_REUSE if reuse_layer else None
66 | with tf.variable_scope(scope_name, reuse=reuse):
67 | embed_var = tf.get_variable(
68 | name=var_name,
69 | shape=var_shape,
70 | initializer=initializer,
71 | regularizer=regularizer,
72 | )
73 | # unknown items are padded to 0-vector
74 | embed_size = var_shape[1]
75 | zero_padding_op = tf.scatter_update(
76 | embed_var, n_items, tf.zeros([embed_size], dtype=tf.float32)
77 | )
78 | with tf.control_dependencies([zero_padding_op]):
79 | # B * seq * K
80 | multi_item_embed = tf.nn.embedding_lookup(embed_var, seq_indices)
81 |
82 | return tf.div_no_nan(
83 | tf.reduce_sum(multi_item_embed, axis=1),
84 | tf.expand_dims(tf.sqrt(tf.cast(seq_lens, tf.float32)), axis=1),
85 | )
86 |
--------------------------------------------------------------------------------
/libreco/layers/normalization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.linalg
3 |
4 | from ..tfops import tf
5 |
6 |
7 | def layer_normalization(inputs, reuse_layer=False, scope_name="layer_norm"):
8 | reuse = tf.AUTO_REUSE if reuse_layer else None
9 | with tf.variable_scope(scope_name, reuse=reuse):
10 | dim = inputs.get_shape().as_list()[-1]
11 | scale = tf.get_variable("scale", shape=[dim], initializer=tf.ones_initializer())
12 | bias = tf.get_variable("bias", shape=[dim], initializer=tf.zeros_initializer())
13 | mean = tf.reduce_mean(inputs, axis=-1, keepdims=True)
14 | variance = tf.reduce_mean(
15 | tf.squared_difference(inputs, mean), axis=-1, keepdims=True
16 | )
17 | outputs = (inputs - mean) * tf.rsqrt(variance + 1e-8)
18 | return outputs * scale + bias
19 |
20 |
21 | def rms_norm(inputs, reuse_layer=False, scope_name="rms_norm"):
22 | """Root mean square layer normalization."""
23 | reuse = tf.AUTO_REUSE if reuse_layer else None
24 | with tf.variable_scope(scope_name, reuse=reuse):
25 | dim = inputs.get_shape().as_list()[-1]
26 | scale = tf.get_variable("scale", shape=[dim], initializer=tf.ones_initializer())
27 | mean_square = tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True)
28 | outputs = inputs * tf.rsqrt(mean_square + 1e-8)
29 | return outputs * scale
30 |
31 |
32 | def normalize_embeds(*embeds, backend):
33 | normed_embeds = []
34 | for e in embeds:
35 | if backend == "tf":
36 | ne = tf.linalg.l2_normalize(e, axis=1)
37 | elif backend == "torch":
38 | norms = torch.linalg.norm(e, dim=1, keepdim=True)
39 | ne = e / norms
40 | else:
41 | norms = np.linalg.norm(e, axis=1, keepdims=True)
42 | ne = e / norms
43 | normed_embeds.append(ne)
44 | return normed_embeds[0] if len(embeds) == 1 else normed_embeds
45 |
--------------------------------------------------------------------------------
/libreco/layers/recurrent.py:
--------------------------------------------------------------------------------
1 | from ..tfops import get_tf_version, tf
2 |
3 |
4 | def tf_rnn(
5 | inputs,
6 | rnn_type,
7 | lengths,
8 | maxlen,
9 | hidden_units,
10 | dropout_rate,
11 | use_ln,
12 | is_training,
13 | version=None,
14 | ):
15 | tf_version = get_tf_version(version)
16 | if tf_version >= "2.0.0":
17 | # cell_type = (
18 | # tf.keras.layers.LSTMCell
19 | # if self.rnn_type.endswith("lstm")
20 | # else tf.keras.layers.GRUCell
21 | # )
22 | # cells = [cell_type(size) for size in self.hidden_units]
23 | # masks = tf.sequence_mask(self.user_interacted_len, self.max_seq_len)
24 | # tf2_rnn = tf.keras.layers.RNN(cells, return_state=True)
25 | # output, *state = tf2_rnn(seq_item_embed, mask=masks)
26 |
27 | rnn_layer = (
28 | tf.keras.layers.LSTM if rnn_type.endswith("lstm") else tf.keras.layers.GRU
29 | )
30 | output = inputs
31 | masks = tf.sequence_mask(lengths, maxlen)
32 | for units in hidden_units:
33 | output = rnn_layer(
34 | units,
35 | return_sequences=True,
36 | dropout=dropout_rate,
37 | recurrent_dropout=dropout_rate,
38 | activation=None if use_ln else "tanh",
39 | )(output, mask=masks, training=is_training)
40 |
41 | if use_ln:
42 | output = tf.keras.layers.LayerNormalization()(output)
43 | output = tf.keras.activations.get("tanh")(output)
44 |
45 | return output[:, -1, :]
46 |
47 | else:
48 | cell_type = (
49 | tf.nn.rnn_cell.LSTMCell
50 | if rnn_type.endswith("lstm")
51 | else tf.nn.rnn_cell.GRUCell
52 | )
53 | cells = [cell_type(size) for size in hidden_units]
54 | stacked_cells = tf.nn.rnn_cell.MultiRNNCell(cells)
55 | zero_state = stacked_cells.zero_state(tf.shape(inputs)[0], dtype=tf.float32)
56 | _, state = tf.nn.dynamic_rnn(
57 | cell=stacked_cells,
58 | inputs=inputs,
59 | sequence_length=lengths,
60 | initial_state=zero_state,
61 | time_major=False,
62 | )
63 | return state[-1][1] if rnn_type == "lstm" else state[-1]
64 |
--------------------------------------------------------------------------------
/libreco/prediction/__init__.py:
--------------------------------------------------------------------------------
1 | from .predict import (
2 | normalize_prediction,
3 | predict_data_with_feats,
4 | predict_from_embedding,
5 | predict_tf_feat,
6 | )
7 |
8 | __all__ = [
9 | "predict_data_with_feats",
10 | "predict_from_embedding",
11 | "predict_tf_feat",
12 | "normalize_prediction",
13 | ]
14 |
--------------------------------------------------------------------------------
/libreco/recommendation/__init__.py:
--------------------------------------------------------------------------------
1 | from .cold_start import cold_start_rec, popular_recommendations
2 | from .ranking import rank_recommendations
3 | from .recommend import (
4 | check_dynamic_rec_feats,
5 | construct_rec,
6 | recommend_from_embedding,
7 | recommend_tf_feat,
8 | )
9 |
10 | __all__ = [
11 | "check_dynamic_rec_feats",
12 | "cold_start_rec",
13 | "construct_rec",
14 | "popular_recommendations",
15 | "rank_recommendations",
16 | "recommend_from_embedding",
17 | "recommend_tf_feat",
18 | ]
19 |
--------------------------------------------------------------------------------
/libreco/recommendation/cold_start.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def popular_recommendations(data_info, inner_id, n_rec):
5 | popular_recs = data_info.np_rng.choice(data_info.popular_items, n_rec)
6 | if inner_id:
7 | return np.array([data_info.item2id[i] for i in popular_recs])
8 | else:
9 | return popular_recs
10 |
11 |
12 | def average_recommendations(data_info, default_recs, inner_id, n_rec):
13 | average_recs = data_info.np_rng.choice(default_recs, n_rec)
14 | if inner_id:
15 | return average_recs
16 | else:
17 | return np.array([data_info.id2item[i] for i in average_recs])
18 |
19 |
20 | def cold_start_rec(data_info, default_recs, cold_start, users, n_rec, inner_id):
21 | if cold_start not in ("average", "popular"):
22 | raise ValueError(f"Unknown cold start strategy: {cold_start}")
23 | result_recs = dict()
24 | for u in users:
25 | if cold_start == "average":
26 | result_recs[u] = average_recommendations(
27 | data_info, default_recs, inner_id, n_rec
28 | )
29 | elif cold_start == "popular":
30 | result_recs[u] = popular_recommendations(data_info, inner_id, n_rec)
31 | return result_recs
32 |
--------------------------------------------------------------------------------
/libreco/recommendation/ranking.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.random import default_rng
3 | from scipy.special import expit, softmax
4 |
5 | # Numpy doc states that it is recommended to use new random API
6 | # https://numpy.org/doc/stable/reference/random/index.html
7 | np_rng = default_rng()
8 |
9 |
10 | def rank_recommendations(
11 | task,
12 | user_ids,
13 | model_preds,
14 | n_rec,
15 | n_items,
16 | user_consumed,
17 | filter_consumed=True,
18 | random_rec=False,
19 | return_scores=False,
20 | ):
21 | if n_rec > n_items:
22 | raise ValueError(f"`n_rec` {n_rec} exceeds num of items {n_items}")
23 | if model_preds.ndim == 1:
24 | assert len(model_preds) % n_items == 0
25 | batch_size = int(len(model_preds) / n_items)
26 | all_preds = model_preds.reshape(batch_size, n_items)
27 | else:
28 | batch_size = len(model_preds)
29 | all_preds = model_preds
30 | all_ids = np.tile(np.arange(n_items), (batch_size, 1))
31 |
32 | batch_ids, batch_preds = [], []
33 | for i in range(batch_size):
34 | user = user_ids[i]
35 | ids = all_ids[i]
36 | preds = all_preds[i]
37 | consumed = user_consumed[user] if user in user_consumed else []
38 | if filter_consumed and consumed and n_rec + len(consumed) <= n_items:
39 | ids, preds = filter_items(ids, preds, consumed)
40 | if random_rec:
41 | ids, preds = random_select(ids, preds, n_rec)
42 | else:
43 | ids, preds = partition_select(ids, preds, n_rec)
44 | batch_ids.append(ids)
45 | batch_preds.append(preds)
46 |
47 | ids, preds = np.array(batch_ids), np.array(batch_preds)
48 | indices = np.argsort(preds, axis=1)[:, ::-1]
49 | ids = np.take_along_axis(ids, indices, axis=1)
50 | if return_scores:
51 | scores = np.take_along_axis(preds, indices, axis=1)
52 | if task == "ranking":
53 | scores = expit(scores)
54 | return ids, scores
55 | else:
56 | return ids
57 |
58 |
59 | def filter_items(ids, preds, items):
60 | mask = np.isin(ids, items, assume_unique=True, invert=True)
61 | return ids[mask], preds[mask]
62 |
63 |
64 | # add `**0.75` to lower probability of high score items
65 | def get_reco_probs(preds):
66 | p = np.power(softmax(preds), 0.75) + 1e-8 # avoid zero probs
67 | return p / p.sum()
68 |
69 |
70 | def random_select(ids, preds, n_rec):
71 | p = get_reco_probs(preds)
72 | mask = np_rng.choice(len(preds), n_rec, p=p, replace=False, shuffle=False)
73 | return ids[mask], preds[mask]
74 |
75 |
76 | def partition_select(ids, preds, n_rec):
77 | mask = np.argpartition(preds, -n_rec)[-n_rec:]
78 | return ids[mask], preds[mask]
79 |
--------------------------------------------------------------------------------
/libreco/sampling/__init__.py:
--------------------------------------------------------------------------------
1 | from .negatives import (
2 | neg_probs_from_frequency,
3 | negatives_from_out_batch,
4 | negatives_from_popular,
5 | negatives_from_random,
6 | negatives_from_unconsumed,
7 | pos_probs_from_frequency,
8 | )
9 | from .random_walks import (
10 | bipartite_neighbors,
11 | bipartite_neighbors_with_weights,
12 | pairs_from_random_walk,
13 | )
14 |
15 | __all__ = [
16 | "bipartite_neighbors",
17 | "bipartite_neighbors_with_weights",
18 | "negatives_from_out_batch",
19 | "negatives_from_popular",
20 | "negatives_from_random",
21 | "negatives_from_unconsumed",
22 | "neg_probs_from_frequency",
23 | "pairs_from_random_walk",
24 | "pos_probs_from_frequency",
25 | ]
26 |
--------------------------------------------------------------------------------
/libreco/tfops/__init__.py:
--------------------------------------------------------------------------------
1 | from .configs import (
2 | attention_config,
3 | dropout_config,
4 | lr_decay_config,
5 | reg_config,
6 | sess_config,
7 | )
8 | from .loss import choose_tf_loss
9 | from .rebuild import rebuild_tf_model
10 | from .variables import get_variable_from_graph, modify_variable_names, var_list_by_name
11 | from .version import TF_VERSION, get_tf_version, tf
12 |
13 | __all__ = [
14 | "attention_config",
15 | "dropout_config",
16 | "get_variable_from_graph",
17 | "lr_decay_config",
18 | "reg_config",
19 | "sess_config",
20 | "rebuild_tf_model",
21 | "choose_tf_loss",
22 | "modify_variable_names",
23 | "var_list_by_name",
24 | "tf",
25 | "TF_VERSION",
26 | "get_tf_version",
27 | ]
28 |
--------------------------------------------------------------------------------
/libreco/tfops/configs.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 |
3 | from .version import tf
4 |
5 |
6 | def attention_config(att_embed_size):
7 | if not att_embed_size:
8 | att_embed_size = (8, 8, 8)
9 | att_layer_num = 3
10 | elif isinstance(att_embed_size, int):
11 | att_embed_size = [att_embed_size]
12 | att_layer_num = 1
13 | elif isinstance(att_embed_size, (list, tuple)):
14 | att_layer_num = len(att_embed_size)
15 | else:
16 | raise ValueError("att_embed_size must be int or list")
17 | return att_embed_size, att_layer_num
18 |
19 |
20 | def reg_config(reg):
21 | if not reg:
22 | return None
23 | elif isinstance(reg, float) and reg > 0.0:
24 | return tf.keras.regularizers.l2(reg)
25 | else:
26 | raise ValueError("reg must be float and positive...")
27 |
28 |
29 | def dropout_config(dropout_rate):
30 | if not dropout_rate:
31 | return 0.0
32 | elif dropout_rate <= 0.0 or dropout_rate >= 1.0:
33 | raise ValueError("dropout_rate must be in (0.0, 1.0)")
34 | else:
35 | return dropout_rate
36 |
37 |
38 | def lr_decay_config(initial_lr, default_decay_steps, **kwargs):
39 | decay_steps = kwargs.get("decay_steps", default_decay_steps)
40 | decay_rate = kwargs.get("decay_rate", 0.96)
41 | global_steps = tf.Variable(0, trainable=False, name="global_steps")
42 | learning_rate = tf.train.exponential_decay(
43 | initial_lr, global_steps, decay_steps, decay_rate, staircase=True
44 | )
45 | return learning_rate, global_steps
46 |
47 |
48 | def sess_config(tf_sess_config=None):
49 | if not tf_sess_config:
50 | # Session config based on:
51 | # https://software.intel.com/content/www/us/en/develop/articles/tips-to-improve-performance-for-popular-deep-learning-frameworks-on-multi-core-cpus.html
52 | # https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/core/protobuf/config.proto#L452
53 | tf_sess_config = {
54 | "intra_op_parallelism_threads": 0,
55 | "inter_op_parallelism_threads": 0,
56 | "allow_soft_placement": True,
57 | "device_count": {"CPU": multiprocessing.cpu_count()},
58 | }
59 | # os.environ["OMP_NUM_THREADS"] = f"{self.cpu_num}"
60 |
61 | config = tf.ConfigProto(**tf_sess_config)
62 | return tf.Session(config=config)
63 |
--------------------------------------------------------------------------------
/libreco/tfops/version.py:
--------------------------------------------------------------------------------
1 | import tensorflow
2 |
3 | tf = tensorflow.compat.v1
4 | tf.disable_v2_behavior()
5 |
6 | TF_VERSION = tf.__version__
7 |
8 |
9 | def get_tf_version(version):
10 | if version is not None:
11 | assert isinstance(version, str)
12 | return version
13 | else:
14 | return TF_VERSION
15 |
--------------------------------------------------------------------------------
/libreco/torchops/__init__.py:
--------------------------------------------------------------------------------
1 | from .configs import device_config, hidden_units_config, set_torch_seed
2 | from .loss import (
3 | binary_cross_entropy_loss,
4 | bpr_loss,
5 | compute_pair_scores,
6 | focal_loss,
7 | max_margin_loss,
8 | pairwise_bce_loss,
9 | pairwise_focal_loss,
10 | )
11 | from .rebuild import rebuild_torch_model
12 |
13 | __all__ = [
14 | "binary_cross_entropy_loss",
15 | "bpr_loss",
16 | "compute_pair_scores",
17 | "device_config",
18 | "hidden_units_config",
19 | "focal_loss",
20 | "max_margin_loss",
21 | "pairwise_bce_loss",
22 | "pairwise_focal_loss",
23 | "rebuild_torch_model",
24 | "set_torch_seed",
25 | ]
26 |
--------------------------------------------------------------------------------
/libreco/torchops/configs.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def device_config(device):
5 | if device == "cuda" and torch.cuda.is_available():
6 | return torch.device("cuda")
7 | else:
8 | return torch.device("cpu")
9 |
10 |
11 | def hidden_units_config(hidden_units):
12 | if isinstance(hidden_units, int):
13 | return [hidden_units]
14 | elif not isinstance(hidden_units, (list, tuple)):
15 | raise ValueError(
16 | f"`hidden_units` must be one of (int, list of int, tuple of int), "
17 | f"got: {type(hidden_units)}, {hidden_units}"
18 | )
19 | for i in hidden_units:
20 | if not isinstance(i, int):
21 | raise ValueError(f"`hidden_units` contains not int value: {hidden_units}")
22 | return list(hidden_units)
23 |
24 |
25 | def set_torch_seed(seed):
26 | torch.manual_seed(seed)
27 | if torch.cuda.is_available():
28 | torch.cuda.manual_seed(seed)
29 | torch.cuda.manual_seed_all(seed)
30 | # torch.backends.cudnn.deterministic = True
31 | # torch.backends.cudnn.benchmark = False
32 | # torch.use_deterministic_algorithms(True)
33 |
--------------------------------------------------------------------------------
/libreco/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/libreco/training/__init__.py
--------------------------------------------------------------------------------
/libreco/training/dispatch.py:
--------------------------------------------------------------------------------
1 | from .tf_trainer import TensorFlowTrainer, WideDeepTrainer, YoutubeRetrievalTrainer
2 | from .torch_trainer import GraphTrainer, TorchTrainer
3 | from ..utils.constants import SageModels, TfTrainModels
4 |
5 |
6 | def get_trainer(model):
7 | train_params = {
8 | "model": model,
9 | "task": model.task,
10 | "loss_type": model.loss_type,
11 | "n_epochs": model.n_epochs,
12 | "lr": model.lr,
13 | "lr_decay": model.lr_decay,
14 | "epsilon": model.epsilon,
15 | "batch_size": model.batch_size,
16 | "sampler": model.sampler,
17 | "num_neg": model.__dict__.get("num_neg"),
18 | }
19 |
20 | if TfTrainModels.contains(model.model_name):
21 | if model.model_name == "YouTubeRetrieval":
22 | train_params["num_sampled_per_batch"] = model.num_sampled_per_batch
23 | tf_trainer_cls = YoutubeRetrievalTrainer
24 | elif model.model_name == "WideDeep":
25 | tf_trainer_cls = WideDeepTrainer
26 | else:
27 | tf_trainer_cls = TensorFlowTrainer
28 | return tf_trainer_cls(**train_params)
29 | else:
30 | train_params.update(
31 | {
32 | "amsgrad": model.amsgrad,
33 | "reg": model.reg,
34 | "margin": model.margin,
35 | "device": model.device,
36 | }
37 | )
38 | if SageModels.contains(model.model_name):
39 | torch_trainer_cls = GraphTrainer
40 | else:
41 | torch_trainer_cls = TorchTrainer
42 | return torch_trainer_cls(**train_params)
43 |
--------------------------------------------------------------------------------
/libreco/training/trainer.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | from ..batch import adjust_batch_size
4 | from ..utils.validate import is_listwise_training
5 |
6 |
7 | class BaseTrainer(abc.ABC):
8 | def __init__(
9 | self,
10 | model,
11 | task,
12 | loss_type,
13 | n_epochs,
14 | lr,
15 | lr_decay,
16 | epsilon,
17 | batch_size,
18 | sampler,
19 | num_neg,
20 | ):
21 | self.model = model
22 | self.task = task
23 | self.loss_type = loss_type
24 | self.n_epochs = n_epochs
25 | self.lr = lr
26 | self.lr_decay = lr_decay
27 | self.epsilon = epsilon
28 | self.batch_size = adjust_batch_size(model, batch_size)
29 | self.sampler = sampler
30 | self.num_neg = num_neg
31 |
32 | def _check_params(self):
33 | if not is_listwise_training(self.model):
34 | n_items = self.model.data_info.n_items
35 | assert 0 < self.num_neg < n_items, (
36 | f"`num_neg` should be positive and smaller than total items, "
37 | f"got {self.num_neg}, {n_items}"
38 | )
39 | if self.sampler not in ("random", "unconsumed", "popular"):
40 | raise ValueError(
41 | f"`sampler` must be one of (`random`, `unconsumed`, `popular`), "
42 | f"got {self.sampler}"
43 | )
44 |
45 | @abc.abstractmethod
46 | def run(self, *args, **kwargs):
47 | raise NotImplementedError
48 |
--------------------------------------------------------------------------------
/libreco/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/libreco/utils/__init__.py
--------------------------------------------------------------------------------
/libreco/utils/constants.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, unique
2 |
3 |
4 | class StrEnum(str, Enum):
5 | @classmethod
6 | def contains(cls, x):
7 | return x in cls.__members__.values() # cls._member_names_
8 |
9 |
10 | @unique
11 | class FeatModels(StrEnum):
12 | WIDEDEEP = "WideDeep"
13 | FM = "FM"
14 | DEEPFM = "DeepFM"
15 | YOUTUBERETRIEVAL = "YouTubeRetrieval"
16 | YOUTUBERANKING = "YouTubeRanking"
17 | AUTOINT = "AutoInt"
18 | DIN = "DIN"
19 | GRAPHSAGE = "GraphSage"
20 | GRAPHSAGEDGL = "GraphSageDGL"
21 | PINSAGE = "PinSage"
22 | PINSAGEDGL = "PinSageDGL"
23 | TWOTOWER = "TwoTower"
24 | TRANSFORMER = "Transformer"
25 | SIM = "SIM"
26 |
27 |
28 | @unique
29 | class SequenceModels(StrEnum):
30 | YOUTUBERETRIEVAL = "YouTubeRetrieval"
31 | YOUTUBERANKING = "YouTubeRanking"
32 | DIN = "DIN"
33 | RNN4REC = "RNN4Rec"
34 | CASER = "Caser"
35 | WAVENET = "WaveNet"
36 | TRANSFORMER = "Transformer"
37 | SIM = "SIM"
38 |
39 |
40 | @unique
41 | class TfTrainModels(StrEnum):
42 | SVD = "SVD"
43 | SVDPP = "SVDpp"
44 | NCF = "NCF"
45 | BPR = "BPR"
46 | WIDEDEEP = "WideDeep"
47 | FM = "FM"
48 | DEEPFM = "DeepFM"
49 | YOUTUBERETRIEVAL = "YouTubeRetrieval"
50 | YOUTUBERANKING = "YouTubeRanking"
51 | AUTOINT = "AutoInt"
52 | DIN = "DIN"
53 | RNN4REC = "RNN4Rec"
54 | CASER = "Caser"
55 | WAVENET = "WaveNet"
56 | TWOTOWER = "TwoTower"
57 | TRANSFORMER = "Transformer"
58 | SIM = "SIM"
59 |
60 |
61 | @unique
62 | class EmbeddingModels(StrEnum):
63 | SVD = "SVD"
64 | SVDPP = "SVDpp"
65 | ALS = "ALS"
66 | BPR = "BPR"
67 | YOUTUBERETRIEVAL = "YouTubeRetrieval"
68 | ITEM2VEC = "Item2Vec"
69 | RNN4REC = "RNN4Rec"
70 | CASER = "Caser"
71 | WAVENET = "WaveNet"
72 | DEEPWALK = "DeepWalk"
73 | NGCF = "NGCF"
74 | LIGHTGCN = "LightGCN"
75 | GRAPHSAGE = "GraphSage"
76 | GRAPHSAGEDGL = "GraphSageDGL"
77 | PINSAGE = "PinSage"
78 | PINSAGEDGL = "PinSageDGL"
79 | TWOTOWER = "TwoTower"
80 |
81 |
82 | @unique
83 | class SageModels(StrEnum):
84 | GRAPHSAGE = "GraphSage"
85 | GRAPHSAGEDGL = "GraphSageDGL"
86 | PINSAGE = "PinSage"
87 | PINSAGEDGL = "PinSageDGL"
88 |
89 |
90 | @unique
91 | class UserEmbedModels(StrEnum):
92 | """Models can only generate user embeddings dynamically."""
93 |
94 | YOUTUBERETRIEVAL = "YouTubeRetrieval"
95 | RNN4REC = "RNN4Rec"
96 | CASER = "Caser"
97 | WAVENET = "WaveNet"
98 |
--------------------------------------------------------------------------------
/libreco/utils/exception.py:
--------------------------------------------------------------------------------
1 | class NotSamplingError(Exception):
2 | """Exception related to sampling data
3 |
4 | If client wants to use batch_sampling and then evaluation on the dataset,
5 | but forgot to do whole data sampling beforehand, this exception will be
6 | raised. Because in this case, unsampled data can't be evaluated.
7 | """
8 |
9 | pass
10 |
--------------------------------------------------------------------------------
/libreco/utils/initializers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def truncated_normal(
5 | np_rng: np.random.Generator,
6 | shape,
7 | mean=0.0,
8 | scale=0.05,
9 | tolerance=5,
10 | ):
11 | # total_num = np.multiply(*shape)
12 | total_num = shape if len(shape) == 1 else np.multiply(*shape)
13 | array = np_rng.normal(mean, scale, total_num).astype(np.float32)
14 | upper_limit, lower_limit = mean + 2 * scale, mean - 2 * scale
15 | for _ in range(tolerance):
16 | index = np.logical_or((array > upper_limit), (array < lower_limit))
17 | num = len(np.where(index)[0])
18 | if num == 0:
19 | break
20 | array[index] = np_rng.normal(mean, scale, num)
21 | return array.reshape(*shape)
22 |
23 |
24 | def xavier_init(np_rng, fan_in, fan_out):
25 | std = np.sqrt(2.0 / (fan_in + fan_out))
26 | return truncated_normal(np_rng, mean=0.0, scale=std, shape=[fan_in, fan_out])
27 |
28 |
29 | def he_init(np_rng, fan_in, fan_out):
30 | std = 2.0 / np.sqrt(fan_in + fan_out)
31 | # std = np.sqrt(2.0 / fan_in)
32 | return truncated_normal(np_rng, mean=0.0, scale=std, shape=[fan_in, fan_out])
33 |
34 |
35 | def variance_scaling(np_rng, scale, fan_in=None, fan_out=None, mode="fan_in"):
36 | """
37 | xavier: mode = "fan_average", scale = 1.0
38 | he: mode = "fan_in", scale = 2.0
39 | he2: mode = "fan_average", scale = 2.0
40 | """
41 | if mode == "fan_in":
42 | std = np.sqrt(scale / fan_in)
43 | elif mode == "fan_out":
44 | std = np.sqrt(scale / fan_out)
45 | elif mode == "fan_average":
46 | std = np.sqrt(2.0 * scale / (fan_in + fan_out))
47 | else:
48 | raise ValueError("mode must be one of these: fan_in, fan_out, fan_average")
49 | return truncated_normal(np_rng, mean=0.0, scale=std, shape=[fan_in, fan_out])
50 |
--------------------------------------------------------------------------------
/libreco/utils/sparse.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 |
4 | from scipy.sparse import csr_matrix
5 |
6 |
7 | @dataclass
8 | class SparseMatrix:
9 | sparse_indices: List[int]
10 | sparse_indptr: List[int]
11 | sparse_data: List[float]
12 |
13 |
14 | def build_sparse(matrix: csr_matrix, transpose: bool = False):
15 | m = matrix.T.tocsr() if transpose else matrix
16 | return SparseMatrix(
17 | m.indices.tolist(),
18 | m.indptr.tolist(),
19 | m.data.tolist(),
20 | )
21 |
--------------------------------------------------------------------------------
/libserving/.dockerignore:
--------------------------------------------------------------------------------
1 | */*/bin
2 | */target
3 | */request*
4 | */*/utils/*22*
--------------------------------------------------------------------------------
/libserving/Dockerfile-py:
--------------------------------------------------------------------------------
1 | FROM python:3.7-slim
2 |
3 | WORKDIR /app
4 |
5 | RUN pip install --no-cache-dir numpy==1.19.5 -i https://pypi.tuna.tsinghua.edu.cn/simple
6 | RUN pip install --no-cache-dir sanic==22.6.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
7 | RUN pip install --no-cache-dir aiohttp==3.8.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
8 | RUN pip install --no-cache-dir pydantic==1.9.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
9 | RUN pip install --no-cache-dir ujson==5.4.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
10 | RUN pip install --no-cache-dir redis==4.3.4 -i https://pypi.tuna.tsinghua.edu.cn/simple
11 | RUN pip install --no-cache-dir faiss-cpu==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
12 |
13 | COPY sanic_serving /app/sanic_serving
14 |
15 | ENV PYTHONPATH=/app
16 |
17 | EXPOSE 8000
18 |
--------------------------------------------------------------------------------
/libserving/Dockerfile-rs:
--------------------------------------------------------------------------------
1 | FROM debian:bullseye-slim AS faiss-builder
2 |
3 | WORKDIR /cmake
4 | # install blas used in faiss
5 | RUN echo "deb http://mirrors.tuna.tsinghua.edu.cn/debian/ bullseye main contrib non-free" >/etc/apt/sources.list && \
6 | apt-get update && \
7 | apt-get install -y gcc g++ make wget git libblas-dev liblapack-dev && \
8 | apt-get clean && \
9 | rm -rf /var/lib/apt/lists/*
10 |
11 | # install cmake to build faiss
12 | RUN wget https://cmake.org/files/LatestRelease/cmake-3.25.0-linux-x86_64.tar.gz
13 | RUN tar -zxf cmake-3.25.0-linux-x86_64.tar.gz -C /cmake --strip-components 1
14 | RUN ln -s /cmake/bin/cmake /usr/bin/cmake
15 | RUN cmake --version
16 |
17 | WORKDIR /faiss
18 | # clone branch `c_api_head` in faiss repository
19 | RUN git clone -b c_api_head https://github.com/Enet4/faiss.git .
20 | # COPY ./faiss /faiss
21 | RUN cmake -B build . \
22 | -DFAISS_ENABLE_C_API=ON \
23 | -DBUILD_SHARED_LIBS=ON \
24 | -DCMAKE_BUILD_TYPE=Release \
25 | -DFAISS_ENABLE_GPU=OFF \
26 | -DFAISS_ENABLE_PYTHON=OFF \
27 | -DBUILD_TESTING=OFF
28 | RUN make -C build/c_api
29 |
30 | FROM rust:1.64-slim-bullseye AS rust-builder
31 |
32 | WORKDIR /serving_build
33 |
34 | RUN echo "deb http://mirrors.tuna.tsinghua.edu.cn/debian/ bullseye main contrib non-free" >/etc/apt/sources.list && \
35 | apt-get update && \
36 | apt-get install -y libblas-dev liblapack-dev && \
37 | apt-get clean && \
38 | rm -rf /var/lib/apt/lists/*
39 |
40 | # cache crate index
41 | COPY crate-index-config /usr/local/cargo/config
42 | RUN cargo init
43 | COPY actix_serving/Cargo.toml actix_serving/Cargo.lock /serving_build/
44 | RUN cargo fetch
45 |
46 | COPY actix_serving/src /serving_build/src
47 | COPY --from=faiss-builder /faiss/build/c_api/libfaiss_c.so /usr/lib
48 | COPY --from=faiss-builder /faiss/build/faiss/libfaiss.so /usr/lib
49 | ENV LD_LIBRARY_PATH=/usr/lib
50 | RUN cargo build --release
51 |
52 | FROM debian:bullseye-slim
53 |
54 | WORKDIR /app
55 |
56 | # need gcc & blas for faiss
57 | RUN echo "deb http://mirrors.tuna.tsinghua.edu.cn/debian/ bullseye main contrib non-free" >/etc/apt/sources.list && \
58 | apt-get update && \
59 | apt-get install -y gcc libblas-dev liblapack-dev && \
60 | apt-get clean && \
61 | rm -rf /var/lib/apt/lists/*
62 |
63 | COPY --from=faiss-builder /faiss/build/c_api/libfaiss_c.so /usr/lib
64 | COPY --from=faiss-builder /faiss/build/faiss/libfaiss.so /usr/lib
65 | ENV LD_LIBRARY_PATH=/usr/lib
66 |
67 | COPY --from=rust-builder /serving_build/target/release/actix_serving /app
68 |
69 | USER 1001
70 |
71 | EXPOSE 8080
72 |
73 | CMD ["/app/actix_serving"]
74 |
--------------------------------------------------------------------------------
/libserving/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.5.1"
2 |
--------------------------------------------------------------------------------
/libserving/actix_serving/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "actix_serving"
3 | version = "1.2.0"
4 | edition = "2021"
5 | description = "Online model serving for LibRecommender"
6 | license = "MIT"
7 |
8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
9 |
10 | [dependencies]
11 | actix-web = "4.3"
12 | clap = { version = "4.1.8", features = ["derive"] }
13 | deadpool-redis = "0.11.1"
14 | env_logger = "0.9.2"
15 | faiss = "0.11.0"
16 | fnv = "1.0.7"
17 | futures = "0.3.25"
18 | log = "0.4.15"
19 | num_cpus = "1.0"
20 | once_cell = "1.18"
21 | prost = "0.11"
22 | # openssl = { version = "0.10", features = ["vendored"] }
23 | redis = { version = "0.22.1", features = ["connection-manager", "tokio-comp"] }
24 | reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
25 | serde = { version = "1.0", features = ["derive"] }
26 | serde_json = "1.0"
27 | serde_with = "3.0"
28 | thiserror = "1.0"
29 | tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
30 | tonic = "0.9"
31 | walkdir = "2.3.2"
32 |
33 | [build-dependencies]
34 | tonic-build = "0.9"
35 |
36 | [dev-dependencies]
37 | assert_cmd = "2.0.8"
38 | pretty_assertions = "1.0"
39 |
40 | [profile.release]
41 | strip = true
42 | opt-level = 3
43 | lto = true
44 | codegen-units = 1
45 |
--------------------------------------------------------------------------------
/libserving/actix_serving/build.rs:
--------------------------------------------------------------------------------
1 | fn main() -> Result<(), Box> {
2 | tonic_build::compile_protos("proto/recommend.proto")?;
3 |
4 | tonic_build::configure()
5 | .build_server(false)
6 | .compile(
7 | &["proto/tensorflow_serving/apis/prediction_service.proto"],
8 | &["proto"],
9 | )?;
10 |
11 | Ok(())
12 | }
13 |
--------------------------------------------------------------------------------
/libserving/actix_serving/proto/recommend.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package recommend;
4 |
5 | service Recommend {
6 | rpc GetRecommendation(RecRequest) returns (RecResponse);
7 | }
8 |
9 | message RecRequest {
10 | string user = 1;
11 | int32 n_rec = 2;
12 | map user_feats = 3;
13 | repeated int32 seq = 4;
14 | }
15 |
16 | message RecResponse {
17 | repeated string items = 1;
18 | }
19 |
20 | message Feature {
21 | oneof value {
22 | string string_val = 1;
23 | int32 int_val = 2;
24 | float float_val = 3;
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/libserving/actix_serving/proto/tensorflow/core/framework/resource_handle.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorflow;
4 |
5 | import "tensorflow/core/framework/tensor_shape.proto";
6 | import "tensorflow/core/framework/types.proto";
7 |
8 | option cc_enable_arenas = true;
9 | option java_outer_classname = "ResourceHandle";
10 | option java_multiple_files = true;
11 | option java_package = "org.tensorflow.framework";
12 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/resource_handle_go_proto";
13 |
14 | // Protocol buffer representing a handle to a tensorflow resource. Handles are
15 | // not valid across executions, but can be serialized back and forth from within
16 | // a single run.
17 | message ResourceHandleProto {
18 | // Unique name for the device containing the resource.
19 | string device = 1;
20 |
21 | // Container in which this resource is placed.
22 | string container = 2;
23 |
24 | // Unique name of this resource.
25 | string name = 3;
26 |
27 | // Hash code for the type of the resource. Is only valid in the same device
28 | // and in the same execution.
29 | uint64 hash_code = 4;
30 |
31 | // For debug-only, the name of the type pointed to by this handle, if
32 | // available.
33 | string maybe_type_name = 5;
34 |
35 | // Protocol buffer representing a pair of (data type, tensor shape).
36 | message DtypeAndShape {
37 | DataType dtype = 1;
38 | TensorShapeProto shape = 2;
39 | }
40 |
41 | // Data types and shapes for the underlying resource.
42 | repeated DtypeAndShape dtypes_and_shapes = 6;
43 |
44 | reserved 7;
45 | }
--------------------------------------------------------------------------------
/libserving/actix_serving/proto/tensorflow/core/framework/tensor_shape.proto:
--------------------------------------------------------------------------------
1 | // Protocol buffer representing the shape of tensors.
2 |
3 | syntax = "proto3";
4 |
5 | option cc_enable_arenas = true;
6 | option java_outer_classname = "TensorShapeProtos";
7 | option java_multiple_files = true;
8 | option java_package = "org.tensorflow.framework";
9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto";
10 |
11 | package tensorflow;
12 |
13 | // Dimensions of a tensor.
14 | message TensorShapeProto {
15 | // One dimension of the tensor.
16 | message Dim {
17 | // Size of the tensor in that dimension.
18 | // This value must be >= -1, but values of -1 are reserved for "unknown"
19 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers
20 | // that work with TensorShapeProto may fail at runtime when deserializing
21 | // a TensorShapeProto containing a dim value of -1.
22 | int64 size = 1;
23 |
24 | // Optional name of the tensor dimension.
25 | string name = 2;
26 | };
27 |
28 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40}
29 | // for a 30 x 40 2D tensor. If an entry has size -1, this
30 | // corresponds to a dimension of unknown size. The names are
31 | // optional.
32 | //
33 | // The order of entries in "dim" matters: It indicates the layout of the
34 | // values in the tensor in-memory representation.
35 | //
36 | // The first entry in "dim" is the outermost dimension used to layout the
37 | // values, the last entry is the innermost dimension. This matches the
38 | // in-memory layout of RowMajor Eigen tensors.
39 | //
40 | // If "dim.size()" > 0, "unknown_rank" must be false.
41 | repeated Dim dim = 2;
42 |
43 | // If true, the number of dimensions in the shape is unknown.
44 | //
45 | // If true, "dim.size()" must be 0.
46 | bool unknown_rank = 3;
47 | };
--------------------------------------------------------------------------------
/libserving/actix_serving/proto/tensorflow_serving/apis/model.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorflow.serving;
4 | option cc_enable_arenas = true;
5 |
6 | import "google/protobuf/wrappers.proto";
7 |
8 | // Metadata for an inference request such as the model name and version.
9 | message ModelSpec {
10 | // Required servable name.
11 | string name = 1;
12 |
13 | // Optional choice of which version of the model to use.
14 | //
15 | // Recommended to be left unset in the common case. Should be specified only
16 | // when there is a strong version consistency requirement.
17 | //
18 | // When left unspecified, the system will serve the best available version.
19 | // This is typically the latest version, though during version transitions,
20 | // notably when serving on a fleet of instances, may be either the previous or
21 | // new version.
22 | oneof version_choice {
23 | // Use this specific version number.
24 | google.protobuf.Int64Value version = 2;
25 |
26 | // Use the version associated with the given label.
27 | string version_label = 4;
28 | }
29 |
30 | // A named signature to evaluate. If unspecified, the default signature will
31 | // be used.
32 | string signature_name = 3;
33 | }
--------------------------------------------------------------------------------
/libserving/actix_serving/proto/tensorflow_serving/apis/predict.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorflow.serving;
4 |
5 | import "tensorflow/core/framework/tensor.proto";
6 | import "tensorflow_serving/apis/model.proto";
7 |
8 | option cc_enable_arenas = true;
9 |
10 | // PredictRequest specifies which TensorFlow model to run, as well as
11 | // how inputs are mapped to tensors and how outputs are filtered before
12 | // returning to user.
13 | message PredictRequest {
14 | // Model Specification. If version is not specified, will use the latest
15 | // (numerical) version.
16 | ModelSpec model_spec = 1;
17 |
18 | // Input tensors.
19 | // Names of input tensor are alias names. The mapping from aliases to real
20 | // input tensor names is stored in the SavedModel export as a prediction
21 | // SignatureDef under the 'inputs' field.
22 | map inputs = 2;
23 |
24 | // Output filter.
25 | // Names specified are alias names. The mapping from aliases to real output
26 | // tensor names is stored in the SavedModel export as a prediction
27 | // SignatureDef under the 'outputs' field.
28 | // Only tensors specified here will be run/fetched and returned, with the
29 | // exception that when none is specified, all tensors specified in the
30 | // named signature will be run/fetched and returned.
31 | repeated string output_filter = 3;
32 |
33 | // Reserved field 4.
34 | reserved 4;
35 | }
36 |
37 | // Response for PredictRequest on successful run.
38 | message PredictResponse {
39 | // Effective Model Specification used to process PredictRequest.
40 | ModelSpec model_spec = 2;
41 |
42 | // Output tensors.
43 | map outputs = 1;
44 | }
--------------------------------------------------------------------------------
/libserving/actix_serving/proto/tensorflow_serving/apis/prediction_service.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorflow.serving;
4 |
5 | option cc_enable_arenas = true;
6 |
7 | import "tensorflow_serving/apis/predict.proto";
8 |
9 | // open source marker; do not remove
10 | // PredictionService provides access to machine-learned models loaded by
11 | // model_servers.
12 | service PredictionService {
13 | // Predict -- provides access to loaded TensorFlow model.
14 | rpc Predict(PredictRequest) returns (PredictResponse);
15 | }
16 |
--------------------------------------------------------------------------------
/libserving/actix_serving/rustfmt.toml:
--------------------------------------------------------------------------------
1 | array_width = 50
2 | chain_width = 50
3 | edition = "2021"
4 | group_imports = "StdExternalCrate"
5 | imports_granularity = "Module"
6 | max_width = 100
7 | reorder_imports = true
8 | reorder_modules = true
9 | use_field_init_shorthand = true
10 |
--------------------------------------------------------------------------------
/libserving/actix_serving/src/bin/realtime.rs:
--------------------------------------------------------------------------------
1 | use actix_web::{middleware::Logger, web, App, HttpServer};
2 | use once_cell::sync::Lazy;
3 |
4 | use actix_serving::common::get_env;
5 | use actix_serving::online_serving;
6 | use actix_serving::redis_ops::create_redis_pool;
7 | use actix_serving::tf_deploy::{init_tf_state, TfAppState};
8 |
9 | static TF_STATE: Lazy> = Lazy::new(|| web::Data::new(init_tf_state()));
10 |
11 | #[actix_web::main]
12 | async fn main() -> std::io::Result<()> {
13 | let (redis_host, port, workers, log_level) = get_env()?;
14 | std::env::set_var("RUST_LOG", log_level);
15 | env_logger::init();
16 |
17 | let redis_pool = web::Data::new(create_redis_pool(redis_host)?);
18 | HttpServer::new(move || {
19 | App::new()
20 | .wrap(Logger::default())
21 | .app_data(redis_pool.clone())
22 | .app_data(web::Data::clone(&TF_STATE))
23 | .service(online_serving)
24 | })
25 | .workers(workers)
26 | .bind(("0.0.0.0", port))?
27 | .run()
28 | .await
29 | }
30 |
--------------------------------------------------------------------------------
/libserving/actix_serving/src/bin/realtime_grpc_client.rs:
--------------------------------------------------------------------------------
1 | use std::collections::HashMap;
2 |
3 | use actix_serving::online_deploy_grpc::recommend_proto::recommend_client::RecommendClient;
4 | use actix_serving::online_deploy_grpc::recommend_proto::{
5 | feature::Value as FeatValue, Feature, RecRequest,
6 | };
7 |
8 | #[tokio::main]
9 | async fn main() -> Result<(), Box> {
10 | let mut client = RecommendClient::connect("http://[::1]:50051").await?;
11 |
12 | let user = String::from("1");
13 | let n_rec = 11;
14 | let feature_sparse = (
15 | String::from("sex"),
16 | Feature {
17 | value: Some(FeatValue::StringVal(String::from("F"))),
18 | },
19 | );
20 | let feature_dense = (
21 | String::from("age"),
22 | Feature {
23 | value: Some(FeatValue::IntVal(33)),
24 | },
25 | );
26 | let request = RecRequest {
27 | user,
28 | n_rec,
29 | user_feats: HashMap::from([feature_sparse, feature_dense]),
30 | seq: vec![1, 2, 3],
31 | };
32 |
33 | let response = client
34 | .get_recommendation(tonic::Request::new(request))
35 | .await?;
36 |
37 | println!("rec for user: {:?}", response.into_inner().items);
38 | Ok(())
39 | }
40 |
--------------------------------------------------------------------------------
/libserving/actix_serving/src/bin/realtime_grpc_server.rs:
--------------------------------------------------------------------------------
1 | use tonic::transport::Server;
2 |
3 | use actix_serving::online_deploy_grpc::recommend_proto::recommend_server::RecommendServer;
4 | use actix_serving::online_deploy_grpc::RecommendService;
5 | use actix_serving::redis_ops;
6 |
7 | #[tokio::main(worker_threads = 4)]
8 | async fn main() -> Result<(), Box> {
9 | std::env::set_var("RUST_LOG", "info");
10 | env_logger::init();
11 |
12 | let addr = "[::1]:50051".parse()?;
13 | let redis_pool = redis_ops::create_redis_pool(String::from("127.0.0.1"))
14 | .expect("Failed to connect to redis pool");
15 | let service = RecommendService { redis_pool };
16 |
17 | Server::builder()
18 | .add_service(RecommendServer::new(service))
19 | .serve(addr)
20 | .await?;
21 |
22 | Ok(())
23 | }
24 |
--------------------------------------------------------------------------------
/libserving/actix_serving/src/lib.rs:
--------------------------------------------------------------------------------
1 | pub mod embed_deploy;
2 | pub mod knn_deploy;
3 | pub mod online_deploy;
4 | #[allow(clippy::too_many_arguments)]
5 | pub mod online_deploy_grpc;
6 | pub mod tf_deploy;
7 | pub mod utils;
8 |
9 | pub use embed_deploy::embed_serving;
10 | pub use knn_deploy::knn_serving;
11 | pub use online_deploy::online_serving;
12 | pub use tf_deploy::tf_serving;
13 | pub use utils::common;
14 | pub use utils::constants;
15 | pub use utils::errors;
16 | pub use utils::faiss;
17 | pub use utils::features;
18 | pub use utils::redis_ops;
19 |
--------------------------------------------------------------------------------
/libserving/actix_serving/src/utils/common.rs:
--------------------------------------------------------------------------------
1 | use std::collections::HashMap;
2 |
3 | use serde::{Deserialize, Serialize};
4 | use serde_json::Value;
5 |
6 | use crate::errors::ServingError;
7 |
8 | #[derive(Serialize, Deserialize)]
9 | pub struct Payload {
10 | pub user: String,
11 | pub n_rec: usize,
12 | }
13 |
14 | #[derive(Serialize, Deserialize)]
15 | pub struct RealtimePayload {
16 | pub user: String,
17 | pub n_rec: usize,
18 | pub user_feats: Option>,
19 | pub seq: Option>,
20 | }
21 |
22 | #[derive(Debug, Serialize, Deserialize)]
23 | pub struct Recommendation {
24 | pub rec_list: Vec,
25 | }
26 |
27 | #[derive(Deserialize)]
28 | pub struct Prediction {
29 | pub outputs: Vec,
30 | }
31 |
32 | #[derive(Debug, Deserialize)]
33 | pub struct RankedItems {
34 | pub outputs: Vec,
35 | }
36 |
37 | pub fn get_env() -> Result<(String, u16, usize, String), ServingError> {
38 | let host = std::env::var("REDIS_HOST").unwrap_or_else(|_| String::from("127.0.0.1"));
39 | let port = std::env::var("PORT")
40 | .map_err(|e| ServingError::EnvError(e, "PORT"))?
41 | .parse::()?;
42 | let workers = std::env::var("WORKERS").map_or(Ok(4), |w| w.parse::())?;
43 | let log_level = std::env::var("RUST_LOG").unwrap_or_else(|_| String::from("info"));
44 | Ok((host, port, workers, log_level))
45 | }
46 |
--------------------------------------------------------------------------------
/libserving/actix_serving/src/utils/constants.rs:
--------------------------------------------------------------------------------
1 | use crate::redis_ops::RedisFeatKeys;
2 |
3 | pub const KNN_MODELS: [&'static str; 2] = ["UserCF", "ItemCF"];
4 |
5 | pub const EMBED_MODELS: [&'static str; 14] = [
6 | "SVD",
7 | "SVDpp",
8 | "ALS",
9 | "BPR",
10 | "YouTubeRetrieval",
11 | "Item2Vec",
12 | "RNN4Rec",
13 | "Caser",
14 | "WaveNet",
15 | "DeepWalk",
16 | "NGCF",
17 | "LightGCN",
18 | "PinSage",
19 | "PinSageDGL",
20 | ];
21 |
22 | pub const CROSS_FEAT_MODELS: [&'static str; 6] = [
23 | "WideDeep",
24 | "FM",
25 | "DeepFM",
26 | "YouTubeRanking",
27 | "AutoInt",
28 | "DIN",
29 | ];
30 |
31 | pub const SEQ_EMBED_MODELS: [&'static str; 3] = ["RNN4Rec", "Caser", "WaveNet"];
32 |
33 | pub const USER_ID_EMBED_MODELS: [&'static str; 2] = ["Caser", "WaveNet"];
34 |
35 | pub const SEPARATE_FEAT_MODELS: [&'static str; 1] = ["TwoTower"];
36 |
37 | pub const SPARSE_SEQ_MODELS: [&'static str; 1] = ["YouTubeRetrieval"];
38 |
39 | pub const CROSS_SEQ_MODELS: [&'static str; 2] = ["YouTubeRanking", "DIN"];
40 |
41 | pub const SPARSE_REDIS_KEYS: RedisFeatKeys = RedisFeatKeys {
42 | user_index: "user_sparse_col_index",
43 | item_index: "item_sparse_col_index",
44 | user_value: "user_sparse_values",
45 | item_value: "item_sparse_values",
46 | };
47 |
48 | pub const DENSE_REDIS_KEYS: RedisFeatKeys = RedisFeatKeys {
49 | user_index: "user_dense_col_index",
50 | item_index: "item_dense_col_index",
51 | user_value: "user_dense_values",
52 | item_value: "item_dense_values",
53 | };
54 |
--------------------------------------------------------------------------------
/libserving/actix_serving/src/utils/errors.rs:
--------------------------------------------------------------------------------
1 | use actix_web::http::StatusCode;
2 | use actix_web::HttpResponse;
3 |
4 | pub type ServingResult = std::result::Result;
5 |
6 | #[derive(thiserror::Error, Debug)]
7 | pub enum ServingError {
8 | #[error("error: failed to get environment variable `{1}`")]
9 | EnvError(#[source] std::env::VarError, &'static str),
10 | #[error("faiss error: {0}")]
11 | FaissError(#[source] faiss::error::Error),
12 | #[error(transparent)]
13 | IoError(#[from] std::io::Error),
14 | #[error(transparent)]
15 | JsonParseError(#[from] serde_json::Error),
16 | #[error("error: `{0}` doesn't exist in redis")]
17 | NotExist(&'static str),
18 | #[error("error: `{0}` not found")]
19 | NotFound(&'static str),
20 | #[error("error: {0}")]
21 | Other(&'static str),
22 | #[error(transparent)]
23 | ParseError(#[from] std::num::ParseIntError),
24 | #[error("error: redis error, {0}")]
25 | RedisError(#[from] redis::RedisError),
26 | #[error("error: failed to create redis pool, {0}")]
27 | RedisCreatePoolError(#[from] deadpool_redis::CreatePoolError),
28 | #[error("error: failed to get redis pool, {0}")]
29 | RedisGetPoolError(#[from] deadpool_redis::PoolError),
30 | #[error("error: failed to execute tokio blocking task, {0}")]
31 | TaskError(#[from] tokio::task::JoinError),
32 | #[error("error: failed to get prediction from tf serving, {0}")]
33 | TfServingError(#[from] reqwest::Error),
34 | #[error("error: request timeout")]
35 | Timeout,
36 | #[error("error: unknown model `{0}`")]
37 | UnknownModel(String),
38 | }
39 |
40 | impl actix_web::error::ResponseError for ServingError {
41 | fn status_code(&self) -> StatusCode {
42 | match *self {
43 | ServingError::NotExist(_) => StatusCode::BAD_REQUEST,
44 | ServingError::Timeout => StatusCode::REQUEST_TIMEOUT,
45 | ServingError::TfServingError(_) => StatusCode::GATEWAY_TIMEOUT,
46 | _ => StatusCode::INTERNAL_SERVER_ERROR,
47 | }
48 | }
49 |
50 | fn error_response(&self) -> HttpResponse {
51 | HttpResponse::build(self.status_code()).body(self.to_string())
52 | }
53 | }
54 |
55 | impl From for std::io::Error {
56 | fn from(e: ServingError) -> Self {
57 | std::io::Error::new(std::io::ErrorKind::Other, e.to_string())
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/libserving/actix_serving/src/utils/faiss.rs:
--------------------------------------------------------------------------------
1 | use std::path::{Component::RootDir, Path, PathBuf};
2 |
3 | use walkdir::WalkDir;
4 |
5 | use crate::errors::{ServingError, ServingResult};
6 |
7 | pub(crate) fn find_index_path(path: Option) -> ServingResult {
8 | let cur_dir = path.map_or(std::env::current_dir()?, PathBuf::from);
9 | // search in two level parent directory
10 | let dual_parent = cur_dir
11 | .parent()
12 | .and_then(|p| p.parent())
13 | .unwrap_or_else(|| Path::new(RootDir.as_os_str()));
14 | let walk_dirs = WalkDir::new(dual_parent)
15 | .into_iter()
16 | .filter_map(|d| d.ok());
17 | for entry in walk_dirs {
18 | let file_name = entry.file_name().to_string_lossy();
19 | if file_name.starts_with("faiss_index") && !entry.path().is_dir() {
20 | log::info!("Found faiss index in {}", entry.path().display());
21 | return Ok(entry.path().to_string_lossy().into_owned());
22 | }
23 | }
24 | Err(ServingError::NotFound("faiss index"))
25 | }
26 |
--------------------------------------------------------------------------------
/libserving/actix_serving/src/utils/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod common;
2 | #[allow(clippy::redundant_static_lifetimes)]
3 | pub mod constants;
4 | pub mod errors;
5 | pub mod faiss;
6 | pub mod features;
7 | pub mod redis_ops;
8 |
--------------------------------------------------------------------------------
/libserving/actix_serving/tests/common/mod.rs:
--------------------------------------------------------------------------------
1 | use std::process::Command;
2 | use std::time;
3 |
4 | use assert_cmd::prelude::*;
5 | use serde::Serialize;
6 |
7 | #[derive(Serialize)]
8 | pub struct InvalidParam {
9 | pub user: i32,
10 | pub n_rec: usize,
11 | }
12 |
13 | pub fn start_server(model_type: &str) {
14 | Command::cargo_bin("actix_serving")
15 | .unwrap()
16 | .env("REDIS_HOST", "localhost")
17 | .env("PORT", "8080")
18 | .env("MODEL_TYPE", model_type)
19 | .env("RUST_LOG", "debug")
20 | .env("WORKERS", "2")
21 | .spawn()
22 | .expect("Failed to start actix server");
23 | // std::env::set_var("RUST_LOG", "debug");
24 | // let cmd = Command::new("./target/debug/actix_serving")
25 | // .env_clear()
26 | // .env("WORKERS", "2")
27 | // .spawn()
28 | // .unwrap();
29 | std::thread::sleep(time::Duration::from_secs(1));
30 | }
31 |
32 | pub fn stop_server() {
33 | Command::new("pkill")
34 | .arg("actix_serving")
35 | .output()
36 | .expect("Failed to stop actix server");
37 | std::thread::sleep(time::Duration::from_millis(200));
38 | }
39 |
--------------------------------------------------------------------------------
/libserving/actix_serving/tests/embed.rs:
--------------------------------------------------------------------------------
1 | use actix_serving::common::{Payload, Recommendation};
2 | use pretty_assertions::assert_eq;
3 |
4 | mod common;
5 | use common::{start_server, stop_server, InvalidParam};
6 |
7 | // cargo test --package actix_serving --test embed -- --test-threads=1
8 | #[test]
9 | fn test_main_embed_serving() {
10 | start_server("embed");
11 | let req = Payload {
12 | user: String::from("10"),
13 | n_rec: 3,
14 | };
15 | let resp: Recommendation = reqwest::blocking::Client::new()
16 | .post("http://localhost:8080/embed/recommend")
17 | .json(&req)
18 | .send()
19 | .unwrap()
20 | .json()
21 | .unwrap();
22 | assert_eq!(resp.rec_list.len(), 3);
23 | stop_server();
24 | }
25 |
26 | #[test]
27 | fn test_bad_request() {
28 | start_server("embed");
29 | let invalid_req = InvalidParam { user: 10, n_rec: 3 };
30 | let resp = reqwest::blocking::Client::new()
31 | .post("http://localhost:8080/embed/recommend")
32 | .json(&invalid_req)
33 | .send()
34 | .unwrap();
35 | assert_eq!(resp.status(), reqwest::StatusCode::BAD_REQUEST);
36 | stop_server();
37 | }
38 |
39 | #[test]
40 | fn test_not_found() {
41 | start_server("embed");
42 | let req = Payload {
43 | user: String::from("10"),
44 | n_rec: 3,
45 | };
46 | let resp = reqwest::blocking::Client::new()
47 | .post("http://localhost:8080/nooo_embed/recommend")
48 | .json(&req)
49 | .send()
50 | .unwrap();
51 | assert_eq!(resp.status(), reqwest::StatusCode::NOT_FOUND);
52 | assert_eq!(
53 | resp.text().unwrap(),
54 | "`nooo_embed/recommend` is not available, make sure you've started the right service."
55 | );
56 | stop_server();
57 | }
58 |
59 | #[test]
60 | fn test_method_not_allowed() {
61 | start_server("embed");
62 | let resp = reqwest::blocking::get("http://localhost:8080/embed/recommend").unwrap();
63 | assert_eq!(resp.status(), reqwest::StatusCode::METHOD_NOT_ALLOWED);
64 | stop_server();
65 | }
66 |
--------------------------------------------------------------------------------
/libserving/actix_serving/tests/knn.rs:
--------------------------------------------------------------------------------
1 | use actix_serving::common::{Payload, Recommendation};
2 | use pretty_assertions::assert_eq;
3 |
4 | mod common;
5 | use common::{start_server, stop_server, InvalidParam};
6 |
7 | // cargo test --package actix_serving --test knn -- --test-threads=1
8 | #[test]
9 | fn test_main_knn_serving() {
10 | start_server("knn");
11 | let req = Payload {
12 | user: String::from("10"),
13 | n_rec: 3,
14 | };
15 | let resp: Recommendation = reqwest::blocking::Client::new()
16 | .post("http://localhost:8080/knn/recommend")
17 | .json(&req)
18 | .send()
19 | .unwrap()
20 | .json()
21 | .unwrap();
22 | assert_eq!(resp.rec_list.len(), 3);
23 | stop_server();
24 | }
25 |
26 | #[test]
27 | fn test_bad_request() {
28 | start_server("knn");
29 | let invalid_req = InvalidParam { user: 10, n_rec: 3 };
30 | let resp = reqwest::blocking::Client::new()
31 | .post("http://localhost:8080/knn/recommend")
32 | .json(&invalid_req)
33 | .send()
34 | .unwrap();
35 | assert_eq!(resp.status(), reqwest::StatusCode::BAD_REQUEST);
36 | stop_server();
37 | }
38 |
39 | #[test]
40 | fn test_not_found() {
41 | start_server("knn");
42 | let req = Payload {
43 | user: String::from("10"),
44 | n_rec: 3,
45 | };
46 | let resp = reqwest::blocking::Client::new()
47 | .post("http://localhost:8080/nooo_knn/recommend")
48 | .json(&req)
49 | .send()
50 | .unwrap();
51 | assert_eq!(resp.status(), reqwest::StatusCode::NOT_FOUND);
52 | assert_eq!(
53 | resp.text().unwrap(),
54 | "`nooo_knn/recommend` is not available, make sure you've started the right service."
55 | );
56 | stop_server();
57 | }
58 |
59 | #[test]
60 | fn test_method_not_allowed() {
61 | start_server("knn");
62 | let resp = reqwest::blocking::get("http://localhost:8080/knn/recommend").unwrap();
63 | assert_eq!(resp.status(), reqwest::StatusCode::METHOD_NOT_ALLOWED);
64 | stop_server();
65 | }
66 |
--------------------------------------------------------------------------------
/libserving/actix_serving/tests/tf.rs:
--------------------------------------------------------------------------------
1 | use actix_serving::common::{Payload, Recommendation};
2 | use pretty_assertions::assert_eq;
3 |
4 | mod common;
5 | use common::{start_server, stop_server, InvalidParam};
6 |
7 | // cargo test --package actix_serving --test tf -- --test-threads=1
8 | #[test]
9 | fn test_main_tf_serving() {
10 | start_server("tf");
11 | let req = Payload {
12 | user: String::from("10"),
13 | n_rec: 3,
14 | };
15 | let resp: Recommendation = reqwest::blocking::Client::new()
16 | .post("http://localhost:8080/tf/recommend")
17 | .json(&req)
18 | .send()
19 | .unwrap()
20 | .json()
21 | .unwrap();
22 | assert_eq!(resp.rec_list.len(), 3);
23 | stop_server();
24 | }
25 |
26 | #[test]
27 | fn test_bad_request() {
28 | start_server("tf");
29 | let invalid_req = InvalidParam { user: 10, n_rec: 3 };
30 | let resp = reqwest::blocking::Client::new()
31 | .post("http://localhost:8080/tf/recommend")
32 | .json(&invalid_req)
33 | .send()
34 | .unwrap();
35 | assert_eq!(resp.status(), reqwest::StatusCode::BAD_REQUEST);
36 | stop_server();
37 | }
38 |
39 | #[test]
40 | fn test_not_found() {
41 | start_server("tf");
42 | let req = Payload {
43 | user: String::from("10"),
44 | n_rec: 3,
45 | };
46 | let resp = reqwest::blocking::Client::new()
47 | .post("http://localhost:8080/nooo_tf/recommend")
48 | .json(&req)
49 | .send()
50 | .unwrap();
51 | assert_eq!(resp.status(), reqwest::StatusCode::NOT_FOUND);
52 | assert_eq!(
53 | resp.text().unwrap(),
54 | "`nooo_tf/recommend` is not available, make sure you've started the right service."
55 | );
56 | stop_server();
57 | }
58 |
59 | #[test]
60 | fn test_method_not_allowed() {
61 | start_server("tf");
62 | let resp = reqwest::blocking::get("http://localhost:8080/tf/recommend").unwrap();
63 | assert_eq!(resp.status(), reqwest::StatusCode::METHOD_NOT_ALLOWED);
64 | stop_server();
65 | }
66 |
--------------------------------------------------------------------------------
/libserving/crate-index-config:
--------------------------------------------------------------------------------
1 | [source]
2 |
3 | [source.crates-io]
4 | replace-with = "tuna"
5 |
6 | [source.tuna]
7 | registry = "https://mirrors.tuna.tsinghua.edu.cn/git/crates.io-index.git"
8 |
9 | [registries]
10 |
11 | [registries.bfsu]
12 | index = "https://mirrors.bfsu.edu.cn/git/crates.io-index.git"
13 |
14 | [registries.hit]
15 | index = "https://mirrors.hit.edu.cn/crates.io-index.git"
16 |
17 | [registries.nju]
18 | index = "https://mirror.nju.edu.cn/git/crates.io-index.git"
19 |
20 | [registries.rsproxy]
21 | index = "https://rsproxy.cn/crates.io-index"
22 |
23 | [registries.sjtu]
24 | index = "https://mirrors.sjtug.sjtu.edu.cn/git/crates.io-index"
25 |
26 | [registries.tuna]
27 | index = "https://mirrors.tuna.tsinghua.edu.cn/git/crates.io-index.git"
28 |
29 | [registries.ustc]
30 | index = "git://mirrors.ustc.edu.cn/crates.io-index"
31 |
32 | [net]
33 | git-fetch-with-cli = false
--------------------------------------------------------------------------------
/libserving/docker-compose-py.yml:
--------------------------------------------------------------------------------
1 | version: "1"
2 | services:
3 | libserving:
4 | image: docker.io/massquantity/sanic-serving:0.1.0
5 | ports:
6 | - '8000:8000'
7 | # command: sanic sanic_serving.embed_deploy:app --host=0.0.0.0 --port=8000 --dev --access-logs -v --workers 2
8 | command: sanic sanic_serving.tf_deploy:app --host=0.0.0.0 --port=8000 --no-access-logs --workers 8
9 | # command: python sanic_serving/knn_deploy.py
10 | environment:
11 | - REDIS_HOST=redis
12 | networks:
13 | - server
14 | volumes:
15 | - './embed_model:/app/faiss_index_path'
16 | restart: always
17 | depends_on:
18 | - redis
19 |
20 | redis:
21 | image: docker.io/redis:7.0.4-alpine
22 | ports:
23 | - '6379:6379'
24 | command: redis-server --save 60 1 --loglevel warning
25 | networks:
26 | - server
27 | volumes:
28 | - './redis_data:/data'
29 | restart: always
30 |
31 | volumes:
32 | embed_model: {}
33 | redis_data: {}
34 |
35 | networks:
36 | server: {}
37 |
--------------------------------------------------------------------------------
/libserving/docker-compose-rs.yml:
--------------------------------------------------------------------------------
1 | version: "1"
2 | services:
3 | libserving:
4 | image: docker.io/massquantity/actix-serving:0.1.0
5 | ports:
6 | - '8080:8080'
7 | command: /app/actix_serving
8 | environment:
9 | - PORT=8080
10 | - MODEL_TYPE=embed
11 | - REDIS_HOST=redis
12 | - WORKERS=8
13 | - RUST_LOG=info
14 | networks:
15 | - server
16 | volumes:
17 | - './embed_model:/app/faiss_index_path'
18 | restart: always
19 | depends_on:
20 | - redis
21 |
22 | redis:
23 | image: docker.io/redis:7.0.4-alpine
24 | ports:
25 | - '6379:6379'
26 | command: redis-server --save 60 1 --loglevel warning
27 | networks:
28 | - server
29 | volumes:
30 | - './redis_data:/data'
31 | restart: always
32 |
33 | volumes:
34 | embed_model: {}
35 | redis_data: {}
36 |
37 | networks:
38 | server: {}
39 |
--------------------------------------------------------------------------------
/libserving/docker-compose-tf-serving.yml:
--------------------------------------------------------------------------------
1 | version: "1"
2 | services:
3 | libserving:
4 | environment:
5 | - TF_SERVING_HOST=tensorflow-serving
6 | depends_on:
7 | - tensorflow-serving
8 |
9 | tensorflow-serving:
10 | image: docker.io/tensorflow/serving:2.8.2
11 | ports:
12 | - '8500:8500'
13 | - '8501:8501'
14 | environment:
15 | - MODEL_BASE_PATH=/usr/local/tf_model
16 | - MODEL_NAME=youtuberanking
17 | networks:
18 | - server
19 | volumes:
20 | - './tf_model:/usr/local/tf_model'
21 | restart: always
22 |
23 | volumes:
24 | tf_model: {}
25 |
26 | networks:
27 | server: {}
28 |
--------------------------------------------------------------------------------
/libserving/request.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import requests
5 |
6 |
7 | def parse_args():
8 | parser = argparse.ArgumentParser()
9 |
10 | parser.add_argument("--host", default="127.0.0.1")
11 | parser.add_argument("--port", default=8000)
12 | parser.add_argument("--user", type=str, help="user id")
13 | parser.add_argument("--n_rec", type=int, help="num of recommendations")
14 | parser.add_argument("--algo", type=str, help="type of serving algorithm")
15 | parser.add_argument("--user_feats", help="user features, type: dict")
16 | parser.add_argument("--seq", help="user behavior sequence, type: list")
17 | return parser.parse_args()
18 |
19 |
20 | def main():
21 | args = parse_args()
22 | url = f"http://{args.host}:{args.port}/{args.algo}/recommend"
23 | data = {"user": args.user, "n_rec": args.n_rec}
24 | if args.user_feats:
25 | data["user_feats"] = json.loads(args.user_feats)
26 | if args.seq:
27 | data["seq"] = json.loads(args.seq)
28 |
29 | response = requests.post(url, json=data, timeout=1)
30 | if response.status_code != 200:
31 | print(f"Failed to get recommendation: {url}")
32 | print(response.text)
33 | response.raise_for_status()
34 | try:
35 | print(response.json())
36 | except json.JSONDecodeError:
37 | print("Failed to decode response json")
38 |
39 |
40 | if __name__ == "__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/libserving/sanic_serving/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/libserving/sanic_serving/__init__.py
--------------------------------------------------------------------------------
/libserving/sanic_serving/benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import time
4 | from concurrent.futures import ThreadPoolExecutor, as_completed
5 |
6 | import aiohttp
7 | import requests
8 | import ujson
9 |
10 | REQUEST_LIMIT = 64
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser()
15 |
16 | parser.add_argument("--host", default="127.0.0.1")
17 | parser.add_argument("--port", default=8000)
18 | parser.add_argument("--user", type=str, help="user id")
19 | parser.add_argument("--n_rec", type=int, help="num of recommendations")
20 | parser.add_argument("--n_times", type=int, help="num of requests")
21 | parser.add_argument("--n_threads", type=int, default=1, help="num of threads")
22 | parser.add_argument("--algo", type=str, help="type of algorithm")
23 | return parser.parse_args()
24 |
25 |
26 | async def get_reco_async(
27 | session: aiohttp.ClientSession, url: str, data: dict, semaphore: asyncio.Semaphore
28 | ):
29 | async with semaphore, session.post(url, json=data) as resp:
30 | # if semaphore.locked():
31 | # await asyncio.sleep(1.0)
32 | resp.raise_for_status()
33 | reco = await resp.json(loads=ujson.loads)
34 | return reco
35 |
36 |
37 | async def main_async(args):
38 | url = f"http://{args.host}:{args.port}/{args.algo}/recommend"
39 | data = {"user": args.user, "n_rec": args.n_rec}
40 | semaphore = asyncio.Semaphore(REQUEST_LIMIT)
41 | async with aiohttp.ClientSession() as session:
42 | tasks = [
43 | get_reco_async(session, url, data, semaphore) for _ in range(args.n_times)
44 | ]
45 | # await asyncio.gather(*tasks, return_exceptions=True)
46 | for future in asyncio.as_completed(tasks):
47 | _ = await future
48 |
49 |
50 | def get_reco_sync(url: str, data: dict):
51 | resp = requests.post(url, json=data, timeout=1)
52 | resp.raise_for_status()
53 | return resp.json()
54 |
55 |
56 | def main_sync(args):
57 | url = f"http://{args.host}:{args.port}/{args.algo}/recommend"
58 | data = {"user": args.user, "n_rec": args.n_rec}
59 | with ThreadPoolExecutor(max_workers=args.n_threads) as executor:
60 | futures = [
61 | executor.submit(get_reco_sync, url, data) for _ in range(args.n_times)
62 | ]
63 | for future in as_completed(futures):
64 | _ = future.result()
65 |
66 |
67 | if __name__ == "__main__":
68 | args = parse_args()
69 |
70 | start = time.perf_counter()
71 | asyncio.run(main_async(args))
72 | duration = time.perf_counter() - start
73 | print(
74 | f"total time {duration}s for async requests, "
75 | f"{duration / args.n_times * 1000} ms/request"
76 | )
77 |
78 | start = time.perf_counter()
79 | main_sync(args)
80 | duration = time.perf_counter() - start
81 | print(
82 | f"total time {duration}s for sync requests, "
83 | f"{duration / args.n_times * 1000} ms/request"
84 | )
85 |
--------------------------------------------------------------------------------
/libserving/sanic_serving/common.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from typing import Callable, Dict, List, Optional, Type, Union
3 |
4 | from pydantic import BaseModel, Extra, ValidationError
5 | from sanic.exceptions import SanicException
6 | from sanic.log import logger
7 | from sanic.request import Request
8 |
9 |
10 | class Params(BaseModel, extra=Extra.forbid):
11 | user: Union[str, int]
12 | n_rec: int
13 | user_feats: Optional[Dict[str, Union[str, int, float]]] = None
14 | seq: Optional[List[Union[str, int]]] = None
15 |
16 |
17 | def validate(model: Type[object]):
18 | def decorator(func: Callable):
19 | @functools.wraps(func)
20 | async def decorated_function(request: Request, **kwargs):
21 | try:
22 | params = model(**request.json)
23 | kwargs["params"] = params
24 | except ValidationError as e:
25 | logger.error(f"Invalid request body: {request.json}")
26 | raise SanicException(
27 | f"Invalid payload: `{request.json}`, please check key name and value type."
28 | ) from e
29 |
30 | return await func(request, **kwargs)
31 |
32 | return decorated_function
33 |
34 | return decorator
35 |
--------------------------------------------------------------------------------
/libserving/serialization/__init__.py:
--------------------------------------------------------------------------------
1 | from .embed import save_embed, save_faiss_index
2 | from .knn import save_knn
3 | from .online import save_online
4 | from .redis import embed2redis, knn2redis, online2redis, tf2redis
5 | from .tfmodel import save_tf
6 |
7 | __all__ = [
8 | "save_knn",
9 | "save_embed",
10 | "save_faiss_index",
11 | "save_online",
12 | "save_tf",
13 | "knn2redis",
14 | "embed2redis",
15 | "tf2redis",
16 | "online2redis",
17 | ]
18 |
--------------------------------------------------------------------------------
/libserving/serialization/embed.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | from libreco.bases import EmbedBase
6 |
7 | from .common import (
8 | check_path_exists,
9 | save_id_mapping,
10 | save_model_name,
11 | save_to_json,
12 | save_user_consumed,
13 | )
14 |
15 |
16 | def save_embed(path: str, model: EmbedBase):
17 | """Save Embed model to disk.
18 |
19 | Parameters
20 | ----------
21 | path : str
22 | Model saving path.
23 | model : EmbedBase
24 | Model to save.
25 | """
26 | check_path_exists(path)
27 | save_model_name(path, model)
28 | save_id_mapping(path, model.data_info)
29 | save_user_consumed(path, model.data_info)
30 | save_vectors(path, model.user_embeds_np, model.n_users, "user_embed.json")
31 | save_vectors(path, model.item_embeds_np, model.n_items, "item_embed.json")
32 |
33 |
34 | def save_vectors(path: str, embeds: np.ndarray, num: int, name: str):
35 | embed_path = os.path.join(path, name)
36 | embed_dict = dict()
37 | for i in range(num):
38 | embed_dict[i] = embeds[i].tolist()
39 | save_to_json(embed_path, embed_dict)
40 |
41 |
42 | def save_faiss_index(path: str, model: EmbedBase, nlist: int = 80, nprobe: int = 10):
43 | import faiss
44 |
45 | check_path_exists(path)
46 | index_path = os.path.join(path, "faiss_index.bin")
47 | item_embeds = model.item_embeds_np[: model.n_items].astype(np.float32)
48 | d = item_embeds.shape[1]
49 | quantizer = faiss.IndexFlatIP(d)
50 | index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
51 | index.train(item_embeds)
52 | index.add(item_embeds)
53 | index.nprobe = nprobe
54 | faiss.write_index(index, index_path)
55 |
--------------------------------------------------------------------------------
/libserving/serialization/knn.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from scipy import sparse
4 |
5 | from libreco.bases import CfBase
6 |
7 | from .common import (
8 | check_path_exists,
9 | save_id_mapping,
10 | save_model_name,
11 | save_to_json,
12 | save_user_consumed,
13 | )
14 |
15 |
16 | def save_knn(path: str, model: CfBase, k: int):
17 | """Save KNN model to disk.
18 |
19 | Parameters
20 | ----------
21 | path : str
22 | Model saving path.
23 | model : CfBase
24 | Model to save.
25 | k : int
26 | Number of similar users/items to save.
27 | """
28 | check_path_exists(path)
29 | save_model_name(path, model)
30 | save_id_mapping(path, model.data_info)
31 | save_user_consumed(path, model.data_info)
32 | save_sim_matrix(path, model.sim_matrix, k)
33 |
34 |
35 | def save_sim_matrix(path: str, sim_matrix: sparse.csr_matrix, k: int):
36 | k_sims = dict()
37 | num = len(sim_matrix.indptr) - 1
38 | indices = sim_matrix.indices.tolist()
39 | indptr = sim_matrix.indptr.tolist()
40 | data = sim_matrix.data.tolist()
41 | for i in range(num):
42 | i_slice = slice(indptr[i], indptr[i + 1])
43 | sorted_sims = sorted(zip(indices[i_slice], data[i_slice]), key=lambda x: -x[1])
44 | k_sims[i] = sorted_sims[:k]
45 | sim_path = os.path.join(path, "sim.json")
46 | save_to_json(sim_path, k_sims)
47 |
--------------------------------------------------------------------------------
/python-package-conda.yml:
--------------------------------------------------------------------------------
1 | name: Python Package using Conda
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build-linux:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | max-parallel: 5
10 |
11 | steps:
12 | - uses: actions/checkout@v2
13 | - name: Set up Python 3.10
14 | uses: actions/setup-python@v2
15 | with:
16 | python-version: 3.10
17 | - name: Add conda to system path
18 | run: |
19 | # $CONDA is an environment variable pointing to the root of the miniconda directory
20 | echo $CONDA/bin >> $GITHUB_PATH
21 | - name: Install dependencies
22 | run: |
23 | conda env update --file environment.yml --name base
24 | - name: Lint with flake8
25 | run: |
26 | conda install flake8
27 | # stop the build if there are Python syntax errors or undefined names
28 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
30 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
31 | - name: Test with pytest
32 | run: |
33 | conda install pytest
34 | pytest
35 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | -r requirements.txt
2 |
3 | flake8
4 | ruff >= 0.3.0 ; python_version > "3.6"
5 | pytest
6 | coverage
7 | cython >= 0.29.0,<3
8 | smart_open < 7.0.0 ; python_version == "3.6"
9 | pyyaml ; python_version >= "3.8"
10 | pydantic ; python_version >= "3.8"
11 |
--------------------------------------------------------------------------------
/requirements-serving.txt:
--------------------------------------------------------------------------------
1 | sanic >= 22.3
2 | ujson
3 | requests
4 | pydantic >= 2.0
5 | aiohttp
6 | redis >= 4.2.0
7 | faiss-cpu == 1.7.2 ; python_version == "3.6"
8 | faiss-cpu >= 1.5.2 ; python_version > "3.6"
9 | protobuf < 4.24.0
10 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy >= 1.19.5
2 | scipy >= 1.2.1, < 1.13.0
3 | pandas >= 1.0.0
4 | scikit-learn >= 0.20.0
5 | tensorflow >= 1.15.0, < 2.16.0
6 | torch >= 1.10.0
7 | gensim >= 4.0.0
8 | tqdm
9 | nmslib ; python_version < "3.11"
10 | dgl < 2.0.0 -f https://data.dgl.ai/wheels/repo.html
11 | dataclasses ; python_version == "3.6"
12 | recfarm ; python_version >= "3.7"
13 |
--------------------------------------------------------------------------------
/rust/.gitignore:
--------------------------------------------------------------------------------
1 | /target
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | .pytest_cache/
6 | *.py[cod]
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | .venv/
14 | env/
15 | bin/
16 | build/
17 | develop-eggs/
18 | dist/
19 | eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | include/
26 | man/
27 | venv/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 | pip-selfcheck.json
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .cache
42 | nosetests.xml
43 | coverage.xml
44 |
45 | # Translations
46 | *.mo
47 |
48 | # Mr Developer
49 | .mr.developer.cfg
50 | .project
51 | .pydevproject
52 |
53 | # Rope
54 | .ropeproject
55 |
56 | # Django stuff:
57 | *.log
58 | *.pot
59 |
60 | .DS_Store
61 |
62 | # Sphinx documentation
63 | docs/_build/
64 |
65 | # PyCharm
66 | .idea/
67 |
68 | # VSCode
69 | .vscode/
70 |
71 | # Pyenv
72 | .python-version
73 |
--------------------------------------------------------------------------------
/rust/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "recfarm"
3 | version = "0.2.0"
4 | edition = "2021"
5 |
6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7 | [lib]
8 | name = "recfarm"
9 | crate-type = ["cdylib"]
10 |
11 | [dependencies]
12 | bincode = "1.3.3"
13 | dashmap = "5.5.3"
14 | flate2 = "1.0"
15 | fxhash = "0.2.1"
16 | pyo3 = { version = "0.23.3", features = ["abi3-py37"] }
17 | rand = { version = "0.8", features = ["default", "alloc"] }
18 | rayon = "1.8"
19 | serde = { version = "1.0", features = ["derive"] }
20 |
21 | [profile.release]
22 | strip = true
23 | opt-level = 3
24 | lto = true
25 | codegen-units = 4
26 |
--------------------------------------------------------------------------------
/rust/README.md:
--------------------------------------------------------------------------------
1 | # RecFarm
2 |
3 | Rust implementation for some non-deep learning recommender algorithms.
4 |
5 | ## Installation
6 |
7 | ```shell
8 | $ pip install recfarm
9 | ```
10 |
11 | Requires Python >= 3.7.
12 |
--------------------------------------------------------------------------------
/rust/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["maturin>=1.2,<2.0"]
3 | build-backend = "maturin"
4 |
5 | [project]
6 | name = "recfarm"
7 | description = "Rust implementation for some non-deep learning recommender algorithms."
8 | authors = [
9 | { name = "massquantity", email = "jinxin_madie@163.com" },
10 | ]
11 | readme = "README.md"
12 | requires-python = ">=3.7"
13 | classifiers = [
14 | "Programming Language :: Rust",
15 | "Programming Language :: Python :: Implementation :: CPython",
16 | ]
17 | dynamic = ["version"]
18 |
19 | [tool.maturin]
20 | features = ["pyo3/extension-module"]
21 |
--------------------------------------------------------------------------------
/rust/recfarm/__init__.py:
--------------------------------------------------------------------------------
1 | from recfarm import recfarm
2 | from recfarm.recfarm import (
3 | ItemCF,
4 | UserCF,
5 | Swing,
6 | __version__,
7 | build_consumed_unique,
8 | load_item_cf,
9 | load_user_cf,
10 | save_item_cf,
11 | save_user_cf,
12 | save_swing,
13 | load_swing,
14 | )
15 |
16 | __all__ = ["recfarm", "UserCF", "ItemCF", "Swing"]
17 |
--------------------------------------------------------------------------------
/rust/rustfmt.toml:
--------------------------------------------------------------------------------
1 | array_width = 50
2 | chain_width = 50
3 | edition = "2021"
4 | group_imports = "StdExternalCrate"
5 | imports_granularity = "Module"
6 | max_width = 100
7 | reorder_imports = true
8 | reorder_modules = true
9 | use_field_init_shorthand = true
10 |
--------------------------------------------------------------------------------
/rust/src/lib.rs:
--------------------------------------------------------------------------------
1 | #![allow(clippy::too_many_arguments)]
2 |
3 | use pyo3::prelude::*;
4 |
5 | mod graph;
6 | mod incremental;
7 | mod inference;
8 | mod item_cf;
9 | mod ordering;
10 | mod serialization;
11 | mod similarities;
12 | mod sparse;
13 | mod swing;
14 | mod user_cf;
15 | mod utils;
16 |
17 | const VERSION: &str = env!("CARGO_PKG_VERSION");
18 |
19 | /// RecFarm module
20 | #[pymodule]
21 | fn recfarm(m: &Bound<'_, PyModule>) -> PyResult<()> {
22 | m.add_class::()?;
23 | m.add_class::()?;
24 | m.add_class::()?;
25 | m.add_function(wrap_pyfunction!(user_cf::save, m)?)?;
26 | m.add_function(wrap_pyfunction!(user_cf::load, m)?)?;
27 | m.add_function(wrap_pyfunction!(item_cf::save, m)?)?;
28 | m.add_function(wrap_pyfunction!(item_cf::load, m)?)?;
29 | m.add_function(wrap_pyfunction!(swing::save, m)?)?;
30 | m.add_function(wrap_pyfunction!(swing::load, m)?)?;
31 | m.add_function(wrap_pyfunction!(utils::build_consumed, m)?)?;
32 | m.add("__version__", VERSION)?;
33 | Ok(())
34 | }
35 |
--------------------------------------------------------------------------------
/rust/src/ordering.rs:
--------------------------------------------------------------------------------
1 | use std::cmp::Ordering;
2 |
3 | /// 0: similarity, 1: label
4 | #[derive(Debug)]
5 | pub(crate) struct SimOrd(pub f32, pub f32);
6 |
7 | impl Ord for SimOrd {
8 | fn cmp(&self, other: &Self) -> Ordering {
9 | self.0
10 | .partial_cmp(&other.0)
11 | .unwrap_or(Ordering::Equal)
12 | }
13 | }
14 |
15 | impl PartialOrd for SimOrd {
16 | fn partial_cmp(&self, other: &Self) -> Option {
17 | Some(self.cmp(other))
18 | }
19 | }
20 |
21 | impl PartialEq for SimOrd {
22 | fn eq(&self, other: &Self) -> bool {
23 | self.0 == other.0
24 | }
25 | }
26 |
27 | impl Eq for SimOrd {}
28 |
29 | #[cfg(test)]
30 | mod tests {
31 | use std::collections::BinaryHeap;
32 |
33 | use super::*;
34 |
35 | #[test]
36 | fn test_sim_max_heap() {
37 | let mut heap = BinaryHeap::new();
38 | heap.push(SimOrd(1.1, 1.1));
39 | heap.push(SimOrd(0.0, 1.1));
40 | heap.push(SimOrd(-2.0, 0.0));
41 | heap.push(SimOrd(-0.2, 3.3));
42 | heap.push(SimOrd(8.8, 8.8));
43 | assert_eq!(heap.pop(), Some(SimOrd(8.8, 8.8)));
44 | assert_eq!(heap.pop(), Some(SimOrd(1.1, 1.1)));
45 | assert_eq!(heap.pop(), Some(SimOrd(0.0, 1.1)));
46 | assert_eq!(heap.pop(), Some(SimOrd(-0.2, 3.3)));
47 | assert_eq!(heap.pop(), Some(SimOrd(-2.0, 0.0)));
48 | assert_eq!(heap.pop(), None);
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/rust/src/serialization.rs:
--------------------------------------------------------------------------------
1 | use std::fs::{File, OpenOptions};
2 | use std::io::{Read, Write};
3 | use std::path::Path;
4 |
5 | use flate2::read::GzDecoder;
6 | use flate2::write::GzEncoder;
7 | use flate2::Compression;
8 | use pyo3::exceptions::PyIOError;
9 | use pyo3::PyResult;
10 | use serde::de::DeserializeOwned;
11 | use serde::Serialize;
12 |
13 | pub fn save_model(
14 | model: &T,
15 | path: &str,
16 | model_name: &str,
17 | class_name: &str,
18 | ) -> PyResult<()> {
19 | let file_name = format!("{model_name}.gz");
20 | let model_path = Path::new(path).join(file_name);
21 | let file = OpenOptions::new()
22 | .write(true)
23 | .create(true)
24 | .open(model_path.as_path())?;
25 | let mut encoder = GzEncoder::new(file, Compression::new(1));
26 | let model_bytes: Vec = match bincode::serialize(model) {
27 | Ok(bytes) => bytes,
28 | Err(e) => return Err(PyIOError::new_err(e.to_string())),
29 | };
30 | encoder.write_all(&model_bytes)?;
31 | encoder.finish()?;
32 | println!(
33 | "Save `{class_name}` model to `{}`",
34 | model_path.canonicalize()?.display()
35 | );
36 | Ok(())
37 | }
38 |
39 | pub fn load_model(
40 | path: &str,
41 | model_name: &str,
42 | class_name: &str,
43 | ) -> PyResult {
44 | let file_name = format!("{model_name}.gz");
45 | let model_path = Path::new(path).join(file_name);
46 | let file = File::open(model_path.as_path())?;
47 | let mut decoder = GzDecoder::new(file);
48 | let mut model_bytes: Vec = Vec::new();
49 | decoder.read_to_end(&mut model_bytes)?;
50 | let model: T = match bincode::deserialize(&model_bytes) {
51 | Ok(m) => m,
52 | Err(e) => return Err(PyIOError::new_err(e.to_string())),
53 | };
54 | println!(
55 | "Load `{class_name}` model from `{}`",
56 | model_path.canonicalize()?.display()
57 | );
58 | Ok(model)
59 | }
60 |
--------------------------------------------------------------------------------
/rust/src/utils.rs:
--------------------------------------------------------------------------------
1 | use fxhash::FxHashMap;
2 | use pyo3::prelude::*;
3 | use pyo3::types::{IntoPyDict, PyDict, PyList};
4 |
5 | /// (x1, x2, prod, count)
6 | pub(crate) type CumValues = (i32, i32, f32, usize);
7 |
8 | #[pyfunction]
9 | #[pyo3(name = "build_consumed_unique")]
10 | pub fn build_consumed<'py>(
11 | py: Python<'py>,
12 | user_indices: &Bound<'py, PyList>,
13 | item_indices: &Bound<'py, PyList>,
14 | ) -> PyResult<(Bound<'py, PyDict>, Bound<'py, PyDict>)> {
15 | let add_or_insert = |mapping: &mut FxHashMap>, k: i32, v: i32| {
16 | mapping
17 | .entry(k)
18 | .and_modify(|consumed| consumed.push(v))
19 | .or_insert_with(|| vec![v]);
20 | };
21 | let user_indices: Vec = user_indices.extract()?;
22 | let item_indices: Vec = item_indices.extract()?;
23 | let mut user_consumed: FxHashMap> = FxHashMap::default();
24 | let mut item_consumed: FxHashMap> = FxHashMap::default();
25 | for (&u, &i) in user_indices.iter().zip(item_indices.iter()) {
26 | add_or_insert(&mut user_consumed, u, i);
27 | add_or_insert(&mut item_consumed, i, u);
28 | }
29 | // remove consecutive repeated elements
30 | user_consumed.values_mut().for_each(|v| v.dedup());
31 | item_consumed.values_mut().for_each(|v| v.dedup());
32 | let user_consumed_py: Bound<'py, PyDict> = user_consumed.into_py_dict(py)?;
33 | let item_consumed_py: Bound<'py, PyDict> = item_consumed.into_py_dict(py)?;
34 | Ok((user_consumed_py, item_consumed_py))
35 | }
36 |
37 | #[cfg(test)]
38 | mod tests {
39 | use super::*;
40 |
41 | #[test]
42 | fn test_build_consumed() -> Result<(), Box> {
43 | let get_values = |mapping: &Bound<'_, PyDict>, k: i32| -> PyResult> {
44 | mapping.get_item(k)?.unwrap().extract()
45 | };
46 | pyo3::prepare_freethreaded_python();
47 | Ok(Python::with_gil(|py| -> PyResult<()> {
48 | let user_indices = PyList::new(py, vec![1, 1, 1, 2, 2, 1, 2, 3, 2, 3])?;
49 | let item_indices = PyList::new(py, vec![11, 11, 999, 0, 11, 11, 999, 11, 999, 0])?;
50 | let (user_consumed, item_consumed) = build_consumed(py, &user_indices, &item_indices)?;
51 | assert_eq!(get_values(&user_consumed, 1)?, vec![11, 999, 11]);
52 | assert_eq!(get_values(&user_consumed, 2)?, vec![0, 11, 999]);
53 | assert_eq!(get_values(&user_consumed, 3)?, vec![11, 0]);
54 | assert_eq!(get_values(&item_consumed, 11)?, vec![1, 2, 1, 3]);
55 | assert_eq!(get_values(&item_consumed, 999)?, vec![1, 2]);
56 | assert_eq!(get_values(&item_consumed, 0)?, vec![2, 3]);
57 | Ok(())
58 | })?)
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | ignore = E501,E203,F811,F401,W503,W504
4 | count = True
5 | show-source = True
6 | statistics = True
7 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/tests/__init__.py
--------------------------------------------------------------------------------
/tests/compatibility_test.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import pandas as pd
4 | import tensorflow as tf
5 |
6 | import libreco
7 | from libreco.algorithms import Caser, RNN4Rec
8 | from libreco.data import DatasetPure, split_by_ratio_chrono
9 | from libreco.tfops import TF_VERSION
10 |
11 | if __name__ == "__main__":
12 | print(f"tensorflow version: {TF_VERSION}")
13 | print(libreco)
14 | from libreco.algorithms._als import als_update
15 | from libreco.utils._similarities import forward_cosine, invert_cosine
16 |
17 | print("Cython functions: ", invert_cosine, forward_cosine, als_update)
18 |
19 | cur_path = Path(".").parent
20 | if Path.exists(cur_path / "sample_movielens_rating.dat"):
21 | data_path = cur_path / "sample_movielens_rating.dat"
22 | else:
23 | data_path = cur_path / "sample_data" / "sample_movielens_rating.dat"
24 |
25 | pd_data = pd.read_csv(data_path, sep="::", names=["user", "item", "label", "time"])
26 | train_data, eval_data = split_by_ratio_chrono(pd_data, test_size=0.2)
27 | train_data, data_info = DatasetPure.build_trainset(train_data)
28 | eval_data = DatasetPure.build_evalset(eval_data)
29 |
30 | rnn = RNN4Rec(
31 | "ranking",
32 | data_info,
33 | rnn_type="lstm",
34 | loss_type="cross_entropy",
35 | embed_size=16,
36 | n_epochs=1,
37 | lr=0.001,
38 | lr_decay=False,
39 | hidden_units=(16, 16),
40 | reg=None,
41 | batch_size=256,
42 | num_neg=1,
43 | dropout_rate=None,
44 | recent_num=10,
45 | tf_sess_config=None,
46 | )
47 | rnn.fit(
48 | train_data,
49 | neg_sampling=True,
50 | verbose=2,
51 | shuffle=True,
52 | eval_data=eval_data,
53 | metrics=[
54 | "loss",
55 | "balanced_accuracy",
56 | "roc_auc",
57 | "pr_auc",
58 | "precision",
59 | "recall",
60 | "map",
61 | "ndcg",
62 | ],
63 | num_workers=2,
64 | )
65 | print("prediction: ", rnn.predict(user=1, item=2))
66 | print("recommendation: ", rnn.recommend_user(user=1, n_rec=7))
67 |
68 | tf.compat.v1.reset_default_graph()
69 | caser = Caser(
70 | "ranking",
71 | data_info=data_info,
72 | loss_type="cross_entropy",
73 | embed_size=16,
74 | n_epochs=1,
75 | lr=1e-4,
76 | batch_size=2048,
77 | )
78 | caser.fit(
79 | train_data,
80 | neg_sampling=True,
81 | verbose=2,
82 | shuffle=True,
83 | eval_data=eval_data,
84 | metrics=[
85 | "loss",
86 | "balanced_accuracy",
87 | "roc_auc",
88 | "pr_auc",
89 | "precision",
90 | "recall",
91 | "map",
92 | "ndcg",
93 | ],
94 | num_workers=2,
95 | )
96 |
--------------------------------------------------------------------------------
/tests/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/tests/models/__init__.py
--------------------------------------------------------------------------------
/tests/models/test_base.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from libreco.bases import Base
4 |
5 |
6 | class NCF(Base):
7 | def __init__(self, task, data_info, lower_upper_bound=None):
8 | super().__init__(task, data_info, lower_upper_bound)
9 |
10 | def fit(self, train_data, **kwargs):
11 | raise NotImplementedError
12 |
13 | def predict(self, user, item, **kwargs):
14 | raise NotImplementedError
15 |
16 | def recommend_user(self, user, n_rec, **kwargs):
17 | raise NotImplementedError
18 |
19 | def save(self, path, model_name, **kwargs):
20 | raise NotImplementedError
21 |
22 | @classmethod
23 | def load(cls, path, model_name, data_info, **kwargs):
24 | raise NotImplementedError
25 |
26 |
27 | def test_base(prepare_pure_data):
28 | _, train_data, _, data_info = prepare_pure_data
29 | with pytest.raises(ValueError):
30 | _ = NCF(task="unknown", data_info=data_info)
31 | with pytest.raises(AssertionError):
32 | _ = NCF(task="rating", data_info=data_info, lower_upper_bound=1)
33 |
34 | model = NCF(task="rating", data_info=data_info, lower_upper_bound=[1, 5])
35 | with pytest.raises(NotImplementedError):
36 | model.fit(train_data)
37 | with pytest.raises(NotImplementedError):
38 | model.predict(1, 2)
39 | with pytest.raises(NotImplementedError):
40 | model.recommend_user(1, 7)
41 | with pytest.raises(NotImplementedError):
42 | model.save("path", "model_name")
43 | with pytest.raises(NotImplementedError):
44 | NCF.load("path", "model_name", data_info)
45 |
--------------------------------------------------------------------------------
/tests/models/test_deepwalk.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import tensorflow as tf
3 |
4 | from libreco.algorithms import DeepWalk
5 | from tests.utils_data import remove_path, set_ranking_labels
6 | from tests.utils_metrics import get_metrics
7 | from tests.utils_pred import ptest_preds
8 | from tests.utils_reco import ptest_recommends
9 | from tests.utils_save_load import save_load_model
10 |
11 |
12 | @pytest.mark.parametrize("task", ["rating", "ranking"])
13 | @pytest.mark.parametrize(
14 | "norm_embed, n_walks, walk_length, window_size",
15 | [(True, 10, 10, 5), (False, 1, 1, 1)],
16 | )
17 | @pytest.mark.parametrize("neg_sampling", [True, False, None])
18 | def test_deepwalk(
19 | pure_data_small, task, norm_embed, n_walks, walk_length, window_size, neg_sampling
20 | ):
21 | tf.compat.v1.reset_default_graph()
22 | pd_data, train_data, eval_data, data_info = pure_data_small
23 | if neg_sampling is False:
24 | set_ranking_labels(train_data)
25 | set_ranking_labels(eval_data)
26 |
27 | if task == "rating":
28 | with pytest.raises(AssertionError):
29 | _ = DeepWalk(task, data_info)
30 | elif neg_sampling is None:
31 | with pytest.raises(AssertionError):
32 | DeepWalk(task, data_info).fit(train_data, neg_sampling)
33 | else:
34 | model = DeepWalk(
35 | task=task,
36 | data_info=data_info,
37 | embed_size=16,
38 | n_epochs=2,
39 | norm_embed=norm_embed,
40 | window_size=window_size,
41 | n_walks=n_walks,
42 | walk_length=walk_length,
43 | )
44 | model.fit(
45 | train_data,
46 | neg_sampling,
47 | verbose=2,
48 | shuffle=True,
49 | eval_data=eval_data,
50 | metrics=get_metrics(task),
51 | )
52 | ptest_preds(model, task, pd_data, with_feats=False)
53 | ptest_recommends(model, data_info, pd_data, with_feats=False)
54 |
55 | # test save and load model
56 | loaded_model, loaded_data_info = save_load_model(DeepWalk, model, data_info)
57 | with pytest.raises(RuntimeError):
58 | loaded_model.fit(train_data, neg_sampling)
59 | ptest_preds(loaded_model, task, pd_data, with_feats=False)
60 | ptest_recommends(loaded_model, loaded_data_info, pd_data, with_feats=False)
61 | model.save("not_existed_path", "deepwalk2")
62 | remove_path("not_existed_path")
63 |
--------------------------------------------------------------------------------
/tests/models/test_item2vec.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import tensorflow as tf
3 |
4 | from libreco.algorithms import Item2Vec
5 | from tests.utils_data import remove_path, set_ranking_labels
6 | from tests.utils_metrics import get_metrics
7 | from tests.utils_pred import ptest_preds
8 | from tests.utils_reco import ptest_recommends
9 | from tests.utils_save_load import save_load_model
10 |
11 |
12 | @pytest.mark.parametrize("task", ["rating", "ranking"])
13 | @pytest.mark.parametrize("norm_embed, window_size", [(True, 5), (False, None)])
14 | @pytest.mark.parametrize("neg_sampling", [True, False, None])
15 | def test_item2vec(pure_data_small, task, norm_embed, window_size, neg_sampling):
16 | tf.compat.v1.reset_default_graph()
17 | pd_data, train_data, eval_data, data_info = pure_data_small
18 | if neg_sampling is False:
19 | set_ranking_labels(train_data)
20 | set_ranking_labels(eval_data)
21 |
22 | if task == "rating":
23 | with pytest.raises(AssertionError):
24 | _ = Item2Vec(task, data_info)
25 | elif neg_sampling is None:
26 | with pytest.raises(AssertionError):
27 | Item2Vec(task, data_info).fit(train_data, neg_sampling)
28 | else:
29 | model = Item2Vec(
30 | task=task,
31 | data_info=data_info,
32 | embed_size=16,
33 | n_epochs=2,
34 | norm_embed=norm_embed,
35 | window_size=window_size,
36 | )
37 | model.fit(
38 | train_data,
39 | neg_sampling,
40 | verbose=2,
41 | shuffle=True,
42 | eval_data=eval_data,
43 | metrics=get_metrics(task),
44 | )
45 | ptest_preds(model, task, pd_data, with_feats=False)
46 | ptest_recommends(model, data_info, pd_data, with_feats=False)
47 |
48 | # test save and load model
49 | loaded_model, loaded_data_info = save_load_model(Item2Vec, model, data_info)
50 | with pytest.raises(RuntimeError):
51 | loaded_model.fit(train_data, neg_sampling)
52 | ptest_preds(loaded_model, task, pd_data, with_feats=False)
53 | ptest_recommends(loaded_model, loaded_data_info, pd_data, with_feats=False)
54 | model.save("not_existed_path", "item2vec2")
55 | remove_path("not_existed_path")
56 |
--------------------------------------------------------------------------------
/tests/models/utils_tf.py:
--------------------------------------------------------------------------------
1 | from libreco.tfops import tf
2 |
3 |
4 | def ptest_tf_variables(model):
5 | var_names = [v.name for v in tf.trainable_variables()]
6 | if hasattr(model, "user_variables"):
7 | for v in model.user_variables:
8 | assert f"{v}:0" in var_names
9 | if hasattr(model, "item_variables"):
10 | for v in model.item_variables:
11 | assert f"{v}:0" in var_names
12 | if hasattr(model, "sparse_variables"):
13 | for v in model.sparse_variables:
14 | assert f"{v}:0" in var_names
15 | if hasattr(model, "dense_variables"):
16 | for v in model.dense_variables:
17 | assert f"{v}:0" in var_names
18 |
--------------------------------------------------------------------------------
/tests/retrain/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/tests/retrain/__init__.py
--------------------------------------------------------------------------------
/tests/serving/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/tests/serving/__init__.py
--------------------------------------------------------------------------------
/tests/serving/mock_tf_server.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sanic import Sanic
3 | from sanic.log import logger
4 | from sanic.request import Request
5 | from sanic.response import HTTPResponse, json
6 |
7 | app = Sanic("mock-tf-server")
8 |
9 |
10 | @app.post("/v1/models/")
11 | async def tf_serving(request: Request, model_name: str) -> HTTPResponse:
12 | logger.info(f"Mock predictions for {model_name.replace(':predict', '')}")
13 | rng = np.random.default_rng(42)
14 | payload = request.json["inputs"]
15 | if "k" in payload:
16 | return json({"outputs": rng.integers(0, 20, size=payload["k"]).tolist()})
17 | n_items = len(request.json["inputs"]["item_indices"])
18 | return json({"outputs": rng.normal(size=n_items).tolist()})
19 |
20 |
21 | if __name__ == "__main__":
22 | app.run(
23 | host="0.0.0.0", port=8501, debug=False, access_log=False, single_process=True
24 | )
25 |
--------------------------------------------------------------------------------
/tests/serving/setup_coverage.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | PYTHON_SITE_PATH=$(python -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])')
4 | FILE_NAME="sitecustomize.py"
5 | FULL_PATH="${PYTHON_SITE_PATH}/${FILE_NAME}"
6 |
7 | echo "coverage path: ${FULL_PATH}"
8 | cp tests/serving/subprocess_coverage_setup.py "$FULL_PATH"
9 |
--------------------------------------------------------------------------------
/tests/serving/subprocess_coverage_setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import coverage
4 |
5 | # essential for subprocess parallel coverage like sanic serving
6 | # https://coverage.readthedocs.io/en/latest/subprocess.html
7 | os.environ["COVERAGE_PROCESS_START"] = ".coveragerc"
8 | coverage.process_startup()
9 |
--------------------------------------------------------------------------------
/tests/serving/test_embed_serving.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from pathlib import Path
3 |
4 | from libserving.serialization import embed2redis, save_embed, save_faiss_index
5 | from tests.utils_data import SAVE_PATH, remove_path
6 |
7 |
8 | def test_embed_serving(embed_model, session, close_server):
9 | save_embed(SAVE_PATH, embed_model)
10 | embed2redis(SAVE_PATH)
11 | faiss_path = str(Path(__file__).parents[2] / "libserving" / "embed_model")
12 | save_faiss_index(faiss_path, embed_model, 40, 10)
13 |
14 | subprocess.run(
15 | "sanic libserving.sanic_serving.embed_deploy:app --no-access-logs --single-process &",
16 | shell=True,
17 | check=True,
18 | )
19 | # time.sleep(2) # wait for the server to start
20 |
21 | response = session.post(
22 | "http://localhost:8000/embed/recommend",
23 | json={"user": 1, "n_rec": 1},
24 | timeout=0.5,
25 | )
26 | assert len(next(iter(response.json().values()))) == 1
27 | response = session.post(
28 | "http://localhost:8000/embed/recommend",
29 | json={"user": 33, "n_rec": 3},
30 | timeout=0.5,
31 | )
32 | assert len(next(iter(response.json().values()))) == 3
33 | remove_path(faiss_path)
34 |
--------------------------------------------------------------------------------
/tests/serving/test_faiss_index.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | from libreco.algorithms import BPR
6 | from libserving.serialization import save_faiss_index
7 | from tests.utils_data import SAVE_PATH
8 |
9 |
10 | def test_faiss_index(embed_model):
11 | import faiss
12 |
13 | save_faiss_index(SAVE_PATH, embed_model, 80, 10)
14 | index = faiss.read_index(os.path.join(SAVE_PATH, "faiss_index.bin"))
15 | _, ids = index.search(embed_model.user_embeds_np[0].reshape(1, -1), 10)
16 | assert ids.shape == (1, 10)
17 | assert index.ntotal == embed_model.n_items
18 | assert index.d == embed_model.embed_size + 1 # embed + bias
19 |
20 |
21 | @pytest.fixture
22 | def embed_model(prepare_pure_data):
23 | _, train_data, _, data_info = prepare_pure_data
24 | model = BPR(
25 | data_info=data_info,
26 | n_epochs=2,
27 | lr=1e-4,
28 | batch_size=2048,
29 | use_tf=False,
30 | optimizer="adam",
31 | )
32 | model.fit(train_data, neg_sampling=True, verbose=2)
33 | return model
34 |
--------------------------------------------------------------------------------
/tests/serving/test_knn_serving.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | import pytest
4 |
5 | from libreco.bases import CfBase
6 | from libserving.serialization import knn2redis, save_knn
7 | from tests.utils_data import SAVE_PATH
8 |
9 |
10 | @pytest.mark.parametrize("knn_model", ["UserCF", "ItemCF"], indirect=True)
11 | def test_knn_serving(knn_model, session, close_server):
12 | assert isinstance(knn_model, CfBase)
13 | save_knn(SAVE_PATH, knn_model, k=10)
14 | knn2redis(SAVE_PATH)
15 |
16 | subprocess.run(
17 | "sanic libserving.sanic_serving.knn_deploy:app --no-access-logs --single-process &",
18 | shell=True,
19 | check=True,
20 | )
21 | # time.sleep(2) # wait for the server to start
22 |
23 | response = session.post(
24 | "http://localhost:8000/knn/recommend", json={"user": 1, "n_rec": 1}, timeout=0.5
25 | )
26 | assert len(next(iter(response.json().values()))) == 1
27 | response = session.post(
28 | "http://localhost:8000/knn/recommend",
29 | json={"user": 33, "n_rec": 3},
30 | timeout=0.5,
31 | )
32 | assert len(next(iter(response.json().values()))) == 3
33 |
--------------------------------------------------------------------------------
/tests/serving/test_online_serving.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | import pytest
4 |
5 | from libserving.serialization import online2redis, save_online
6 | from tests.utils_data import SAVE_PATH
7 |
8 |
9 | @pytest.mark.parametrize(
10 | "online_model",
11 | ["pure", "user_feat", "separate", "multi_sparse", "item_feat", "all"],
12 | indirect=True,
13 | )
14 | def test_online_serving(online_model, session, close_server):
15 | save_online(SAVE_PATH, online_model, version=1)
16 | online2redis(SAVE_PATH)
17 |
18 | subprocess.run(
19 | "sanic libserving.sanic_serving.online_deploy:app --no-access-logs --single-process &",
20 | shell=True,
21 | check=True,
22 | )
23 | subprocess.run("python tests/serving/mock_tf_server.py &", shell=True, check=True)
24 | # time.sleep(2) # wait for the server to start
25 |
26 | response = session.post(
27 | "http://localhost:8000/online/recommend",
28 | json={"user": 1, "n_rec": 1},
29 | timeout=0.5,
30 | )
31 | assert len(next(iter(response.json().values()))) == 1
32 |
33 | response = session.post(
34 | "http://localhost:8000/online/recommend",
35 | json={"user": "uuu", "n_rec": 3},
36 | timeout=0.5,
37 | )
38 | assert len(next(iter(response.json().values()))) == 3
39 |
40 | response = session.post(
41 | "http://localhost:8000/online/recommend",
42 | json={"user": 2, "n_rec": 3, "user_feats": {"sex": "male"}},
43 | timeout=0.5,
44 | )
45 | assert len(next(iter(response.json().values()))) == 3
46 |
47 | response = session.post(
48 | "http://localhost:8000/online/recommend",
49 | json={"user": 2, "n_rec": 3, "seq": [1, 2, 3, 10, 11, 11, 22, 1, 0, -1, 12, 1]},
50 | timeout=0.5,
51 | )
52 | assert len(next(iter(response.json().values()))) == 3
53 |
54 | response = session.post(
55 | "http://localhost:8000/online/recommend",
56 | json={
57 | "user": "uuu",
58 | "n_rec": 30000,
59 | "user_feats": {"sex": "bb", "age": 1000, "occupation": "ooo", "ggg": "eee"},
60 | "seq": [1, 2, 3, "??"],
61 | },
62 | timeout=0.5,
63 | )
64 | # noinspection PyUnresolvedReferences
65 | assert len(next(iter(response.json().values()))) == online_model.n_items
66 |
67 | response = session.post(
68 | "http://localhost:8000/online/recommend",
69 | json={
70 | "user": "uuu",
71 | "n_rec": 300,
72 | "item_feats": {"sex": "bb", "age": 1000, "occupation": "ooo", "ggg": "eee"},
73 | "item_seq": [1, 2, 3, "??"],
74 | },
75 | timeout=0.5,
76 | )
77 | assert "Invalid payload" in response.text
78 |
--------------------------------------------------------------------------------
/tests/serving/test_tf_serving.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | import pytest
4 |
5 | from libreco.bases import TfBase
6 | from libserving.serialization import save_tf, tf2redis
7 | from tests.utils_data import SAVE_PATH
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "tf_model", ["pure", "feat-all", "feat-user", "feat-item"], indirect=True
12 | )
13 | def test_tf_serving(tf_model, session, close_server):
14 | assert isinstance(tf_model, TfBase)
15 | save_tf(SAVE_PATH, tf_model, version=1)
16 | tf2redis(SAVE_PATH)
17 |
18 | subprocess.run(
19 | "sanic libserving.sanic_serving.tf_deploy:app --no-access-logs --single-process &",
20 | shell=True,
21 | check=True,
22 | )
23 | subprocess.run("python tests/serving/mock_tf_server.py &", shell=True, check=True)
24 | # time.sleep(2) # wait for the server to start
25 |
26 | response = session.post(
27 | "http://localhost:8000/tf/recommend", json={"user": 1, "n_rec": 1}, timeout=0.5
28 | )
29 | assert len(next(iter(response.json().values()))) == 1
30 | response = session.post(
31 | "http://localhost:8000/tf/recommend", json={"user": 33, "n_rec": 3}, timeout=0.5
32 | )
33 | assert len(next(iter(response.json().values()))) == 3
34 |
--------------------------------------------------------------------------------
/tests/test_consumed.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import pytest
4 |
5 | from libreco.data.consumed import _fill_empty, _merge_dedup, interaction_consumed
6 |
7 |
8 | @pytest.mark.skipif(
9 | sys.version_info[:2] < (3, 7),
10 | reason="Rust implementation only supports Python >= 3.7.",
11 | )
12 | def test_remove_consecutive_duplicates():
13 | user_indices = [1, 1, 1, 2, 2, 1, 2, 3, 2, 3]
14 | item_indices = [11, 11, 999, 0, 11, 11, 999, 11, 999, 0]
15 | user_consumed, item_consumed = interaction_consumed(user_indices, item_indices)
16 | assert isinstance(user_consumed, dict)
17 | assert isinstance(item_consumed, dict)
18 | assert isinstance(user_consumed[1], list)
19 | assert isinstance(item_consumed[11], list)
20 | assert user_consumed[1] == [11, 999, 11]
21 | assert user_consumed[2] == [0, 11, 999]
22 | assert user_consumed[3] == [11, 0]
23 | assert item_consumed[11] == [1, 2, 1, 3]
24 | assert item_consumed[999] == [1, 2]
25 | assert item_consumed[0] == [2, 3]
26 |
27 |
28 | @pytest.mark.skipif(
29 | sys.version_info[:2] >= (3, 7),
30 | reason="Specific python 3.6 implementation",
31 | )
32 | def test_remove_duplicates():
33 | user_indices = [1, 1, 1, 2, 2, 1, 2, 3, 2, 3]
34 | item_indices = [11, 11, 999, 0, 11, 11, 999, 11, 999, 0]
35 | user_consumed, item_consumed = interaction_consumed(user_indices, item_indices)
36 | assert isinstance(user_consumed, dict)
37 | assert isinstance(item_consumed, dict)
38 | assert isinstance(user_consumed[1], list)
39 | assert isinstance(item_consumed[11], list)
40 | assert user_consumed[1] == [11, 999]
41 | assert user_consumed[2] == [0, 11, 999]
42 | assert user_consumed[3] == [11, 0]
43 | assert item_consumed[11] == [1, 2, 3]
44 | assert item_consumed[999] == [1, 2]
45 | assert item_consumed[0] == [2, 3]
46 |
47 |
48 | def test_merge_remove_duplicates():
49 | num = 3
50 | old_consumed = {0: [1, 2, 3], 1: [4, 5]}
51 | new_consumed = {0: [2, 1], 2: [7, 8]}
52 | consumed = _merge_dedup(new_consumed, num, old_consumed)
53 | assert consumed[0] == [1, 2, 3, 2, 1]
54 | assert consumed[1] == [4, 5]
55 | assert consumed[2] == [7, 8]
56 |
57 |
58 | def test_no_merge():
59 | num = 4
60 | old_consumed = {0: [1, 2, 3], 1: [4, 5], 2: [0], 3: [99]}
61 | new_consumed = {0: [2, 1], 2: [7, 8]}
62 | consumed = _fill_empty(new_consumed, num, old_consumed)
63 | assert consumed[0] == [2, 1]
64 | assert consumed[1] == [4, 5]
65 | assert consumed[2] == [7, 8]
66 | assert consumed[3] == [99]
67 |
--------------------------------------------------------------------------------
/tests/test_dgl.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import sys
3 |
4 | import pytest
5 |
6 | from libreco.graph import check_dgl
7 |
8 |
9 | def test_dgl(prepare_feat_data, monkeypatch):
10 | *_, data_info = prepare_feat_data
11 |
12 | with monkeypatch.context() as m:
13 | m.setitem(sys.modules, "dgl", None)
14 | with pytest.raises(ModuleNotFoundError):
15 | from libreco.algorithms import PinSageDGL
16 |
17 | _ = PinSageDGL("ranking", data_info)
18 |
19 | @check_dgl
20 | class ClsWithDGL:
21 | def __new__(cls, *args, **kwargs):
22 | if cls.dgl_error is not None:
23 | raise cls.dgl_error
24 | cls._dgl = importlib.import_module("dgl")
25 | return super().__new__(cls)
26 |
27 | model = ClsWithDGL()
28 | assert model.dgl_error is None
29 | assert model._dgl.__name__ == "dgl"
30 |
--------------------------------------------------------------------------------
/tests/test_initializers.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from libreco.utils.initializers import (
7 | he_init,
8 | truncated_normal,
9 | variance_scaling,
10 | xavier_init,
11 | )
12 |
13 |
14 | def test_initializers():
15 | np_rng = np.random.default_rng(42)
16 | mean, std, fan_in, fan_out, scale = 0.1, 0.01, 4, 2, 2.5
17 | variables = truncated_normal(np_rng, [3, 2], mean=0.1, scale=0.01)
18 | assert variables.shape == (3, 2)
19 | variables_in_range(variables, mean, std)
20 |
21 | variables = xavier_init(np_rng, fan_in, fan_out)
22 | std = np.sqrt(2.0 / (fan_in + fan_out))
23 | variables_in_range(variables, mean, std)
24 |
25 | variables = he_init(np_rng, fan_in, fan_out)
26 | std = 2.0 / np.sqrt(fan_in + fan_out)
27 | variables_in_range(variables, mean, std)
28 |
29 | variables = variance_scaling(np_rng, scale, fan_in, fan_out, mode="fan_in")
30 | std = np.sqrt(scale / fan_in)
31 | variables_in_range(variables, mean, std)
32 |
33 | variables = variance_scaling(np_rng, scale, fan_in, fan_out, mode="fan_out")
34 | std = np.sqrt(scale / fan_out)
35 | variables_in_range(variables, mean, std)
36 |
37 | variables = variance_scaling(np_rng, scale, fan_in, fan_out, mode="fan_average")
38 | std = np.sqrt(2.0 * scale / (fan_in + fan_out))
39 | variables_in_range(variables, mean, std)
40 |
41 | with pytest.raises(ValueError):
42 | _ = variance_scaling(np_rng, scale, fan_in, fan_out, mode="unknown")
43 |
44 |
45 | def variables_in_range(variables, mean, std):
46 | for v in itertools.chain.from_iterable(variables):
47 | assert (mean - 3 * std) < v < (mean + 3 * std)
48 |
--------------------------------------------------------------------------------
/tests/test_misc.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import pytest
4 |
5 | from libreco.utils.misc import colorize, time_block, time_func
6 |
7 |
8 | @time_func
9 | def long_work():
10 | time.sleep(0.1)
11 | print(colorize("done!", color="red", bold=True, highlight=True))
12 |
13 |
14 | def test_misc():
15 | long_work()
16 | with time_block("long work2", verbose=0):
17 | time.sleep(0.1)
18 | with pytest.raises(RuntimeError):
19 | with time_block("long work2", verbose=0):
20 | raise RuntimeError
21 |
--------------------------------------------------------------------------------
/tests/test_multi_sparse_processing.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 | import pytest
5 |
6 | from libreco.data import split_multi_value
7 |
8 |
9 | def test_multi_sparse_processing():
10 | data_path = os.path.join(
11 | os.path.dirname(os.path.realpath(__file__)),
12 | "sample_data",
13 | "sample_movielens_genre.csv",
14 | )
15 | data = pd.read_csv(data_path, sep=",", header=0)
16 |
17 | with pytest.raises(AssertionError):
18 | # max_len must be list or tuple
19 | split_multi_value(data, multi_value_col=["genre"], sep="|", max_len=3)
20 |
21 | sep = "," # wrong separator
22 | data, *_ = split_multi_value(
23 | data,
24 | multi_value_col=["genre"],
25 | sep=sep,
26 | max_len=[3],
27 | pad_val="missing",
28 | user_col=["sex", "age", "occupation"],
29 | item_col=["genre"],
30 | )
31 | assert all(data["genre_2"].str.contains("missing"))
32 | assert all(data["genre_3"].str.contains("missing"))
33 |
34 | sep = "|"
35 | data = pd.read_csv(data_path, sep=",", header=0)
36 | data, multi_sparse_col, multi_user_col, multi_item_col = split_multi_value(
37 | data,
38 | multi_value_col=["genre"],
39 | sep=sep,
40 | max_len=[3],
41 | pad_val="missing",
42 | user_col=["sex", "age", "occupation"],
43 | item_col=["genre"],
44 | )
45 | assert multi_sparse_col == [["genre_1", "genre_2", "genre_3"]]
46 | assert multi_user_col == []
47 | assert multi_item_col == ["genre_1", "genre_2", "genre_3"]
48 | all_columns = data.columns.tolist()
49 | assert "genre" not in all_columns
50 | assert all_columns == [
51 | "user",
52 | "item",
53 | "label",
54 | "time",
55 | "sex",
56 | "age",
57 | "occupation",
58 | "genre_1",
59 | "genre_2",
60 | "genre_3",
61 | ]
62 |
--------------------------------------------------------------------------------
/tests/utils_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from pathlib import Path
4 |
5 | import numpy as np
6 |
7 | from libreco.data import TransformedEvalSet, TransformedSet
8 |
9 | SAVE_PATH = os.path.join(str(Path(os.path.realpath(__file__)).parent), "save_path")
10 |
11 |
12 | def remove_path(path):
13 | if os.path.exists(path) and os.path.isdir(path):
14 | shutil.rmtree(path)
15 |
16 |
17 | def set_ranking_labels(data):
18 | if isinstance(data, TransformedSet):
19 | original_labels = data._labels.copy()
20 | data._labels[original_labels >= 4] = 1
21 | data._labels[original_labels < 4] = 0
22 | elif isinstance(data, TransformedEvalSet):
23 | original_labels = np.copy(data.labels)
24 | data.labels[original_labels >= 4] = 1
25 | data.labels[original_labels < 4] = 0
26 |
--------------------------------------------------------------------------------
/tests/utils_metrics.py:
--------------------------------------------------------------------------------
1 | def get_metrics(task):
2 | if task == "rating":
3 | return ["rmse", "mae", "r2"]
4 | else:
5 | return [
6 | "loss",
7 | "balanced_accuracy",
8 | "roc_auc",
9 | "roc_gauc",
10 | "pr_auc",
11 | "precision",
12 | "recall",
13 | "map",
14 | "ndcg",
15 | "coverage",
16 | ]
17 |
--------------------------------------------------------------------------------
/tests/utils_multi_sparse_models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from libreco.bases import TfBase
4 |
5 |
6 | def fit_multi_sparse(cls, train_data, eval_data, data_info, lr=None):
7 | if issubclass(cls, TfBase):
8 | tf.compat.v1.reset_default_graph()
9 |
10 | model = cls(
11 | task="ranking",
12 | data_info=data_info,
13 | loss_type="cross_entropy",
14 | embed_size=4,
15 | n_epochs=1,
16 | lr=1e-4 if not lr else lr,
17 | batch_size=100,
18 | )
19 | model.fit(
20 | train_data,
21 | neg_sampling=True,
22 | verbose=2,
23 | shuffle=True,
24 | eval_data=eval_data,
25 | metrics=["roc_auc", "precision", "map", "ndcg"],
26 | eval_user_num=40,
27 | )
28 | return model
29 |
--------------------------------------------------------------------------------
/tests/utils_pred.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from libreco.prediction import predict_data_with_feats
4 |
5 |
6 | def ptest_preds(model, task, pd_data, with_feats):
7 | user = pd_data.user.iloc[0]
8 | item = pd_data.item.iloc[0]
9 | pred = model.predict(user=user, item=item)
10 | # prediction in range
11 | if task == "rating":
12 | assert 1 <= pred <= 5
13 | else:
14 | assert 0 <= pred <= 1
15 |
16 | popular_pred = model.predict(
17 | user="cold user2", item="cold item2", cold_start="popular"
18 | )
19 | assert np.allclose(popular_pred, model.default_pred)
20 |
21 | cold_pred1 = model.predict(user="cold user1", item="cold item2")
22 | cold_pred2 = model.predict(user="cold user2", item="cold item2")
23 | assert cold_pred1 == cold_pred2
24 |
25 | if with_feats:
26 | assert len(predict_data_with_feats(model, pd_data[:5])) == 5
27 | model.predict(user=user, item=item, feats={"sex": "male", "genre_1": "crime"})
28 |
--------------------------------------------------------------------------------
/tests/utils_save_load.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from libreco.bases import TfBase
4 | from libreco.data import DataInfo
5 | from tests.utils_data import SAVE_PATH
6 |
7 |
8 | def save_load_model(cls, model, data_info):
9 | model_name = cls.__name__.lower() + "_model"
10 | data_info.save(path=SAVE_PATH, model_name=model_name)
11 | model.save(SAVE_PATH, model_name, manual=True, inference_only=True)
12 |
13 | if issubclass(cls, TfBase) or hasattr(model, "sess"):
14 | tf.compat.v1.reset_default_graph()
15 | loaded_data_info = DataInfo.load(path=SAVE_PATH, model_name=model_name)
16 | loaded_model = cls.load(SAVE_PATH, model_name, loaded_data_info, manual=True)
17 | return loaded_model, loaded_data_info
18 |
--------------------------------------------------------------------------------