├── .coveragerc ├── .github └── workflows │ ├── ci.yml │ ├── ci_serving.yml │ ├── ci_win_mac.yml │ ├── compatibility-tensorflow1.yml │ ├── compatibility-tensorflow2.yml │ ├── rust.yml │ ├── test-install.yml │ └── wheels.yml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── _config.yml ├── codecov.yml ├── distributed ├── spark │ ├── pom.xml │ ├── spark-recommender.iml │ └── src │ │ └── main │ │ └── scala │ │ ├── TestScala.scala │ │ └── com │ │ └── libreco │ │ ├── data │ │ └── DataSplitter.scala │ │ ├── evaluate │ │ ├── EvalClassifier.scala │ │ ├── EvalRecommender.scala │ │ └── EvalRegressor.scala │ │ ├── example │ │ ├── AlsExample.scala │ │ ├── ClassifierExample.scala │ │ └── RegressorExample.scala │ │ ├── feature │ │ ├── FeatureEngineering.scala │ │ └── MultiHotEncoder.scala │ │ ├── model │ │ ├── Classifier.scala │ │ ├── Recommender.scala │ │ └── Regressor.scala │ │ └── utils │ │ ├── Context.scala │ │ ├── FilterNAs.scala │ │ └── ItemNameConverter.scala └── youtube_distributed.py ├── docker ├── Dockerfile └── README.md ├── docs ├── Makefile ├── make.bat ├── md_doc │ ├── autoint_feature.jpg │ ├── implementation_details.md │ ├── python_serving_guide.md │ ├── rust_serving_guide.md │ └── user_guide.md ├── requirements.txt └── source │ ├── _static │ ├── autoint_feature.jpg │ └── css │ │ └── custom.css │ ├── api │ ├── algorithms │ │ ├── als.rst │ │ ├── autoint.rst │ │ ├── bases.rst │ │ ├── bpr.rst │ │ ├── caser.rst │ │ ├── deepfm.rst │ │ ├── deepwalk.rst │ │ ├── din.rst │ │ ├── fm.rst │ │ ├── graphsage.rst │ │ ├── graphsage_dgl.rst │ │ ├── index.rst │ │ ├── item2vec.rst │ │ ├── item_cf.rst │ │ ├── item_cf_rs.rst │ │ ├── lightgcn.rst │ │ ├── ncf.rst │ │ ├── ngcf.rst │ │ ├── pinsage.rst │ │ ├── pinsage_dgl.rst │ │ ├── rnn4rec.rst │ │ ├── sim.rst │ │ ├── svd.rst │ │ ├── svdpp.rst │ │ ├── swing.rst │ │ ├── transformer.rst │ │ ├── two_tower.rst │ │ ├── user_cf.rst │ │ ├── user_cf_rs.rst │ │ ├── wavenet.rst │ │ ├── wide_deep.rst │ │ ├── youtube_ranking.rst │ │ └── youtube_retrieval.rst │ ├── data │ │ ├── data_info.rst │ │ ├── dataset.rst │ │ ├── index.rst │ │ ├── split.rst │ │ └── transformed.rst │ ├── evaluation.rst │ └── serialization.rst │ ├── conf.py │ ├── index.rst │ ├── installation.rst │ ├── internal │ ├── data_info.rst │ ├── implementation_details.rst │ └── index.rst │ ├── serving_guide │ ├── online.rst │ ├── python.rst │ └── rust.rst │ ├── tutorial.rst │ └── user_guide │ ├── data_processing.rst │ ├── embedding.rst │ ├── evaluation_save_load.rst │ ├── feature_engineering.rst │ ├── index.rst │ ├── model_retrain.rst │ ├── model_train.rst │ └── recommendation.rst ├── examples ├── changing_feature_example.py ├── feat_example.py ├── feat_ranking_example.py ├── feat_rating_example.py ├── knn_embedding_example.py ├── model_retrain_example.py ├── multi_sparse_example.py ├── multi_sparse_processing_example.py ├── pure_example.py ├── pure_ranking_example.py ├── pure_rating_example.py ├── sample_data │ ├── sample_movielens_genre.csv │ ├── sample_movielens_merged.csv │ └── sample_movielens_rating.dat ├── save_load_example.py ├── seq_example.py ├── split_data_example.py └── tutorial.ipynb ├── libreco ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── _als.pyx │ ├── _bpr.pyx │ ├── als.py │ ├── autoint.py │ ├── bpr.py │ ├── caser.py │ ├── deepfm.py │ ├── deepwalk.py │ ├── din.py │ ├── fm.py │ ├── graphsage.py │ ├── graphsage_dgl.py │ ├── item2vec.py │ ├── item_cf.py │ ├── item_cf_rs.py │ ├── lightgcn.py │ ├── ncf.py │ ├── ngcf.py │ ├── pinsage.py │ ├── pinsage_dgl.py │ ├── rnn4rec.py │ ├── sim.py │ ├── svd.py │ ├── svdpp.py │ ├── swing.py │ ├── torch_modules │ │ ├── __init__.py │ │ ├── graphsage_module.py │ │ ├── lightgcn_module.py │ │ ├── ngcf_module.py │ │ └── pinsage_module.py │ ├── transformer.py │ ├── two_tower.py │ ├── user_cf.py │ ├── user_cf_rs.py │ ├── wave_net.py │ ├── wide_deep.py │ ├── youtube_ranking.py │ └── youtube_retrieval.py ├── bases │ ├── __init__.py │ ├── base.py │ ├── cf_base.py │ ├── cf_base_rs.py │ ├── dyn_embed_base.py │ ├── embed_base.py │ ├── gensim_base.py │ ├── meta.py │ ├── sage_base.py │ └── tf_base.py ├── batch │ ├── __init__.py │ ├── batch_data.py │ ├── batch_unit.py │ ├── collators.py │ ├── enums.py │ ├── sequence.py │ └── tf_feed_dicts.py ├── data │ ├── __init__.py │ ├── consumed.py │ ├── data_info.py │ ├── dataset.py │ ├── processing.py │ ├── split.py │ └── transformed.py ├── evaluation │ ├── __init__.py │ ├── computation.py │ ├── evaluate.py │ └── metrics.py ├── feature │ ├── __init__.py │ ├── column_mapping.py │ ├── multi_sparse.py │ ├── sparse.py │ ├── ssl.py │ ├── unique.py │ └── update.py ├── graph │ ├── __init__.py │ ├── from_dgl.py │ ├── inference.py │ ├── message.py │ └── neighbor_walk.py ├── layers │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── convolutional.py │ ├── dense.py │ ├── embedding.py │ ├── normalization.py │ ├── recurrent.py │ └── transformer.py ├── prediction │ ├── __init__.py │ ├── predict.py │ └── preprocess.py ├── recommendation │ ├── __init__.py │ ├── cold_start.py │ ├── preprocess.py │ ├── ranking.py │ └── recommend.py ├── sampling │ ├── __init__.py │ ├── negatives.py │ └── random_walks.py ├── tfops │ ├── __init__.py │ ├── configs.py │ ├── features.py │ ├── loss.py │ ├── rebuild.py │ ├── variables.py │ └── version.py ├── torchops │ ├── __init__.py │ ├── configs.py │ ├── loss.py │ └── rebuild.py ├── training │ ├── __init__.py │ ├── dispatch.py │ ├── tf_trainer.py │ ├── torch_trainer.py │ └── trainer.py └── utils │ ├── __init__.py │ ├── _similarities.pyx │ ├── constants.py │ ├── exception.py │ ├── initializers.py │ ├── misc.py │ ├── save_load.py │ ├── similarities.py │ ├── sparse.py │ └── validate.py ├── libserving ├── .dockerignore ├── Dockerfile-py ├── Dockerfile-rs ├── __init__.py ├── actix_serving │ ├── Cargo.lock │ ├── Cargo.toml │ ├── build.rs │ ├── proto │ │ ├── recommend.proto │ │ ├── tensorflow │ │ │ └── core │ │ │ │ └── framework │ │ │ │ ├── resource_handle.proto │ │ │ │ ├── tensor.proto │ │ │ │ ├── tensor_shape.proto │ │ │ │ └── types.proto │ │ └── tensorflow_serving │ │ │ └── apis │ │ │ ├── model.proto │ │ │ ├── predict.proto │ │ │ └── prediction_service.proto │ ├── rustfmt.toml │ ├── src │ │ ├── bin │ │ │ ├── benchmark.rs │ │ │ ├── realtime.rs │ │ │ ├── realtime_grpc_client.rs │ │ │ └── realtime_grpc_server.rs │ │ ├── embed_deploy.rs │ │ ├── knn_deploy.rs │ │ ├── lib.rs │ │ ├── main.rs │ │ ├── online_deploy.rs │ │ ├── online_deploy_grpc.rs │ │ ├── tf_deploy.rs │ │ └── utils │ │ │ ├── common.rs │ │ │ ├── constants.rs │ │ │ ├── errors.rs │ │ │ ├── faiss.rs │ │ │ ├── features.rs │ │ │ ├── mod.rs │ │ │ └── redis_ops.rs │ └── tests │ │ ├── common │ │ └── mod.rs │ │ ├── embed.rs │ │ ├── knn.rs │ │ └── tf.rs ├── crate-index-config ├── docker-compose-py.yml ├── docker-compose-rs.yml ├── docker-compose-tf-serving.yml ├── request.py ├── sanic_serving │ ├── __init__.py │ ├── benchmark.py │ ├── common.py │ ├── embed_deploy.py │ ├── knn_deploy.py │ ├── online_deploy.py │ └── tf_deploy.py └── serialization │ ├── __init__.py │ ├── common.py │ ├── embed.py │ ├── knn.py │ ├── online.py │ ├── redis.py │ └── tfmodel.py ├── pyproject.toml ├── python-package-conda.yml ├── requirements-dev.txt ├── requirements-serving.txt ├── requirements.txt ├── rust ├── .gitignore ├── Cargo.toml ├── README.md ├── pyproject.toml ├── recfarm │ └── __init__.py ├── rustfmt.toml └── src │ ├── graph.rs │ ├── incremental.rs │ ├── inference.rs │ ├── item_cf.rs │ ├── lib.rs │ ├── ordering.rs │ ├── serialization.rs │ ├── similarities.rs │ ├── sparse.rs │ ├── swing.rs │ ├── user_cf.rs │ └── utils.rs ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── compatibility_test.py ├── conftest.py ├── models ├── __init__.py ├── test_als.py ├── test_autoint.py ├── test_base.py ├── test_bpr.py ├── test_caser.py ├── test_deepfm.py ├── test_deepwalk.py ├── test_din.py ├── test_fm.py ├── test_graphsage.py ├── test_graphsage_dgl.py ├── test_item2vec.py ├── test_item_cf.py ├── test_item_cf_rs.py ├── test_lightgcn.py ├── test_ncf.py ├── test_ngcf.py ├── test_pinsage.py ├── test_pinsage_dgl.py ├── test_rnn4rec.py ├── test_sim.py ├── test_svd.py ├── test_svdpp.py ├── test_swing.py ├── test_transformer.py ├── test_two_tower.py ├── test_user_cf.py ├── test_user_cf_rs.py ├── test_wave_net.py ├── test_wide_deep.py ├── test_youtube_ranking.py ├── test_youtube_retrieval.py └── utils_tf.py ├── retrain ├── __init__.py ├── test_als_retrain.py ├── test_gensim_model_retrain.py ├── test_rs_cf_retrain.py ├── test_rs_swing_retrain.py ├── test_tfmodel_retrain_feat.py ├── test_tfmodel_retrain_pure.py ├── test_thmodel_retrain_feat.py ├── test_thmodel_retrain_feat_dgl.py ├── test_thmodel_retrain_pure.py └── test_two_tower_retrain.py ├── sample_data ├── sample_movielens_genre.csv ├── sample_movielens_merged.csv └── sample_movielens_rating.dat ├── serving ├── __init__.py ├── conftest.py ├── mock_tf_server.py ├── setup_coverage.sh ├── subprocess_coverage_setup.py ├── test_embed_serving.py ├── test_faiss_index.py ├── test_knn_serving.py ├── test_online_serving.py ├── test_serialization.py └── test_tf_serving.py ├── test_collators.py ├── test_consumed.py ├── test_data.py ├── test_dgl.py ├── test_feature.py ├── test_initializers.py ├── test_knn_embed.py ├── test_misc.py ├── test_multi_sparse_processing.py ├── test_multiprocessing_seeds.py ├── test_rank_reco.py ├── test_similarities.py ├── test_split_data.py ├── test_tf_layers.py ├── utils_data.py ├── utils_metrics.py ├── utils_multi_sparse_models.py ├── utils_pred.py ├── utils_reco.py └── utils_save_load.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | parallel = True 3 | 4 | concurrency = multiprocessing 5 | 6 | source = 7 | libreco/ 8 | libserving/serialization/ 9 | libserving/sanic_serving/ 10 | 11 | omit = 12 | libreco/utils/exception.py 13 | libreco/utils/sampling.py 14 | libserving/sanic_serving/benchmark.py 15 | 16 | [report] 17 | exclude_lines = 18 | pragma: no cover 19 | raise AssertionError 20 | raise NameError 21 | raise NotImplementedError 22 | raise OSError.* 23 | raise ValueError 24 | raise SanicException.* 25 | except .*redis.* 26 | except \(ImportError, ModuleNotFoundError\): 27 | except ValidationError. 28 | if __name__ == .__main__.: 29 | @(abc\.)?abstractmethod 30 | 31 | precision = 2 32 | 33 | show_missing = True 34 | 35 | skip_empty = True 36 | 37 | [html] 38 | directory = html-coverage-report 39 | 40 | title = LibRecommender Coverage Report 41 | 42 | [xml] 43 | output = coverage.xml 44 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | # Manual run 11 | workflow_dispatch: 12 | 13 | jobs: 14 | build: 15 | name: testing 16 | runs-on: ${{ matrix.os }} 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | os: [ubuntu-22.04] 21 | python-version: [3.7, 3.8, 3.9, '3.10', '3.11'] 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | cache: 'pip' 30 | 31 | - name: Display Python version 32 | run: python -c "import sys; print(sys.version)" 33 | 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install -U pip wheel setuptools 37 | python -m pip install -r requirements-dev.txt 38 | python -m pip install -e . 39 | 40 | - name: Lint with flake8 41 | run: | 42 | flake8 libreco/ libserving/ tests/ examples/ 43 | 44 | - name: Lint with ruff 45 | run: | 46 | ruff check libreco/ libserving/ tests/ examples/ 47 | if: matrix.python-version != '3.6' 48 | 49 | - name: Test 50 | run: | 51 | python -m pip install pytest 52 | python -m pytest tests/ --ignore="tests/serving" 53 | if: matrix.python-version != '3.10' 54 | 55 | - name: Test with coverage 56 | run: | 57 | bash tests/serving/setup_coverage.sh 58 | coverage --version && coverage erase 59 | coverage run -m pytest tests/ --ignore="tests/serving" 60 | coverage combine && coverage report 61 | coverage xml 62 | if: matrix.python-version == '3.10' 63 | 64 | - name: Upload coverage to Codecov 65 | uses: codecov/codecov-action@v4 66 | with: 67 | file: ./coverage.xml 68 | flags: CI 69 | name: python${{ matrix.python-version }}-test 70 | token: ${{ secrets.CODECOV_TOKEN }} 71 | fail_ci_if_error: false 72 | verbose: true 73 | if: matrix.python-version == '3.10' 74 | 75 | - name: Upload coverage to Codacy 76 | uses: codacy/codacy-coverage-reporter-action@v1 77 | with: 78 | project-token: ${{ secrets.CODACY_PROJECT_TOKEN }} 79 | coverage-reports: ./coverage.xml 80 | if: matrix.python-version == '3.10' 81 | -------------------------------------------------------------------------------- /.github/workflows/ci_serving.yml: -------------------------------------------------------------------------------- 1 | name: CI-Serving 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | # Manual run 11 | workflow_dispatch: 12 | 13 | jobs: 14 | testing: 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | os: [ubuntu-22.04] 20 | python-version: [3.9, '3.10', '3.11'] 21 | 22 | steps: 23 | - uses: actions/checkout@v4 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | cache: 'pip' 29 | 30 | - name: Display Python version 31 | run: python -c "import sys; print(sys.version)" 32 | 33 | - name: Install dependencies 34 | run: | 35 | python -m pip install -U pip wheel setuptools 36 | python -m pip install numpy>=1.19.5 37 | python -m pip install "scipy>=1.2.1,<1.13.0" 38 | python -m pip install pandas>=1.0.0 39 | python -m pip install scikit-learn>=0.20.0 40 | python -m pip install "tensorflow>=1.15.0,<2.16.0" 41 | python -m pip install torch>=1.10.0 42 | python -m pip install gensim>=4.0.0 43 | python -m pip install tqdm 44 | python -m pip install recfarm 45 | python -m pip install -r requirements-serving.txt 46 | python -m pip install -e . 47 | 48 | - name: Set up Redis 49 | uses: shogo82148/actions-setup-redis@v1 50 | with: 51 | redis-version: '7.x' 52 | 53 | - name: Test Redis 54 | run: redis-cli ping 55 | 56 | - name: Test 57 | run: | 58 | python -m pip install pytest 59 | python -m pytest tests/serving 60 | if: matrix.python-version != '3.10' 61 | 62 | - name: Test with coverage 63 | run: | 64 | python -m pip install pytest coverage 65 | bash tests/serving/setup_coverage.sh 66 | coverage --version && coverage erase 67 | coverage run -m pytest tests/serving 68 | coverage combine && coverage report 69 | coverage xml 70 | if: matrix.python-version == '3.10' 71 | 72 | - name: Upload coverage to Codecov 73 | uses: codecov/codecov-action@v4 74 | with: 75 | file: ./coverage.xml 76 | flags: CI 77 | name: python${{ matrix.python-version }}-serving 78 | token: ${{ secrets.CODECOV_TOKEN }} 79 | fail_ci_if_error: false 80 | verbose: true 81 | if: matrix.python-version == '3.10' 82 | -------------------------------------------------------------------------------- /.github/workflows/ci_win_mac.yml: -------------------------------------------------------------------------------- 1 | name: CI-Windows-macOS 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | # Manual run 8 | workflow_dispatch: 9 | 10 | jobs: 11 | testing: 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: [macos-latest, windows-latest] 17 | python-version: [3.8, '3.10'] 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | cache: 'pip' 26 | 27 | - name: Display Python version 28 | run: python -c "import sys; print(sys.version)" 29 | 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install -U pip wheel setuptools 33 | python -m pip install numpy>=1.19.5 34 | python -m pip install "scipy>=1.2.1,<1.13.0" 35 | python -m pip install pandas>=1.0.0 36 | python -m pip install scikit-learn>=0.20.0 37 | python -m pip install "tensorflow>=1.15.0,<2.16.0" 38 | python -m pip install torch>=1.10.0 39 | python -m pip install "smart_open<7.0.0" 40 | python -m pip install gensim>=4.0.0 41 | python -m pip install tqdm 42 | python -m pip install -e . 43 | 44 | - name: Install DGL on Windows 45 | run: python -m pip install 'dgl<=1.1.0' -f https://data.dgl.ai/wheels/repo.html 46 | if: matrix.os == 'windows-latest' 47 | 48 | - name: Install DGL on macOS 49 | run: | 50 | python -m pip install 'dgl<2.0.0' -f https://data.dgl.ai/wheels/repo.html 51 | if: matrix.os == 'macos-latest' 52 | 53 | - name: Install dataclasses 54 | run: | 55 | python -m pip install dataclasses 56 | if: matrix.python-version == '3.6' 57 | 58 | - name: Install recfarm 59 | run: | 60 | python -m pip install recfarm 61 | if: matrix.python-version != '3.6' 62 | 63 | - name: Test with pytest 64 | run: | 65 | python -m pip install pytest 66 | python -m pytest tests/ --ignore="tests/serving" 67 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | permissions: 7 | contents: read 8 | 9 | jobs: 10 | linux: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v5 15 | with: 16 | python-version: '3.10' 17 | - name: Build wheels 18 | uses: PyO3/maturin-action@v1 19 | with: 20 | target: x86_64 21 | working-directory: ./rust 22 | args: --release --out dist 23 | sccache: 'true' 24 | manylinux: auto 25 | rust-toolchain: stable 26 | - name: Upload wheels 27 | uses: actions/upload-artifact@v4 28 | with: 29 | name: wheels-linux 30 | path: ./rust/dist 31 | 32 | windows: 33 | runs-on: windows-latest 34 | steps: 35 | - uses: actions/checkout@v4 36 | - uses: actions/setup-python@v5 37 | with: 38 | python-version: '3.10' 39 | - name: Build wheels 40 | uses: PyO3/maturin-action@v1 41 | with: 42 | target: x86_64 43 | args: --release --out dist --manifest-path rust/Cargo.toml 44 | sccache: 'true' 45 | rust-toolchain: stable 46 | - name: Upload wheels 47 | uses: actions/upload-artifact@v4 48 | with: 49 | name: wheels-windows 50 | path: dist 51 | 52 | macos: 53 | runs-on: macos-latest 54 | steps: 55 | - uses: actions/checkout@v4 56 | - uses: actions/setup-python@v5 57 | with: 58 | python-version: '3.10' 59 | - name: Build wheels 60 | uses: PyO3/maturin-action@v1 61 | with: 62 | target: x86_64 63 | args: --release --out dist --manifest-path rust/Cargo.toml 64 | sccache: 'true' 65 | rust-toolchain: stable 66 | - name: Upload wheels 67 | uses: actions/upload-artifact@v4 68 | with: 69 | name: wheels-macos 70 | path: dist 71 | 72 | sdist: 73 | runs-on: ubuntu-latest 74 | steps: 75 | - uses: actions/checkout@v4 76 | - name: Build sdist 77 | uses: PyO3/maturin-action@v1 78 | with: 79 | command: sdist 80 | args: --out dist --manifest-path rust/Cargo.toml 81 | - name: Upload sdist 82 | uses: actions/upload-artifact@v4 83 | with: 84 | name: sdist 85 | path: dist 86 | -------------------------------------------------------------------------------- /.github/workflows/test-install.yml: -------------------------------------------------------------------------------- 1 | name: test install library 2 | 3 | on: 4 | # Manual run 5 | workflow_dispatch: 6 | 7 | jobs: 8 | build: 9 | name: test install 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: [ubuntu-20.04, windows-latest, macos-12] 15 | python-version: [3.6, 3.7, 3.8, 3.9, '3.10', '3.11'] 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Display Python version 25 | run: python -c "import sys; print(sys.version)" 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install $(sed "/nmslib*/d;/dgl*/d;/dataclasses*/d" requirements.txt | tr -d ' ') 31 | python -m pip install Cython==0.29.37 32 | python -m pip install -e . 33 | 34 | - name: Install dataclasses 35 | run: | 36 | python -m pip install dataclasses 37 | if: matrix.python-version == '3.6' 38 | 39 | - name: Test install 40 | run: python -c "import libreco; print('libreco --', libreco.__version__)" 41 | 42 | - name: Test install dgl 43 | run: | 44 | python -m pip install 'dgl<2.0.0' -f https://data.dgl.ai/wheels/repo.html 45 | python -c "import dgl; print('dgl --', dgl.__version__)" 46 | 47 | - name: Test running on Linux 48 | run: | 49 | cp tests/sample_data/sample_movielens_rating.dat /home/runner/work/ 50 | cp tests/compatibility_test.py /home/runner/work/ && cd /home/runner/work/ 51 | python -m compatibility_test 52 | if: matrix.os == 'ubuntu-20.04' 53 | 54 | - name: Test running on Windows 55 | run: | 56 | copy D:\a\LibRecommender\LibRecommender\tests\sample_data\sample_movielens_rating.dat D:\a\ 57 | copy D:\a\LibRecommender\LibRecommender\tests\compatibility_test.py D:\a\ && Set-Location -Path "D:\a\" 58 | python -m compatibility_test 59 | if: matrix.os == 'windows-latest' 60 | 61 | - name: Test running on macOS 62 | run: | 63 | cp tests/sample_data/sample_movielens_rating.dat /Users/runner/work/ 64 | cp tests/compatibility_test.py /Users/runner/work/ && cd /Users/runner/work/ 65 | python -m compatibility_test 66 | if: matrix.os == 'macos-latest' 67 | 68 | - name: Test install nmslib 69 | run: | 70 | python -m pip install nmslib 71 | python -c "import nmslib; print('nmslib --', nmslib.__version__)" 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # .idea folder 10 | .idea/ 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | version: 2 6 | 7 | build: 8 | os: ubuntu-22.04 9 | tools: 10 | python: "3.10" 11 | 12 | sphinx: 13 | configuration: docs/source/conf.py 14 | 15 | python: 16 | install: 17 | - requirements: docs/requirements.txt 18 | - method: pip 19 | path: . 20 | # system_packages: true 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright 2023 massquantity 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the “Software”), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE.. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements.txt requirements-serving.txt 4 | recursive-include libreco *.c *.cpp *.pyx 5 | recursive-include libserving *.py *.sh 6 | recursive-exclude examples * 7 | recursive-exclude distributed * 8 | recursive-exclude tests * 9 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-hacker -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | notify: 3 | after_n_builds: 1 4 | 5 | coverage: 6 | status: 7 | project: 8 | default: 9 | target: auto 10 | threshold: 1% 11 | patch: 12 | default: 13 | target: '50' 14 | informational: true 15 | 16 | ignore: 17 | - "docs/*" 18 | - "docker/*" 19 | -------------------------------------------------------------------------------- /distributed/spark/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | com.massquantity 8 | spark-recommender 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 13 | org.apache.maven.plugins 14 | maven-compiler-plugin 15 | 16 | 1.8 17 | 1.8 18 | 19 | 20 | 21 | 22 | 23 | 24 | UTF-8 25 | 1.8 26 | 2.3.3 27 | 2.11 28 | 1.8 29 | 1.8 30 | 31 | 32 | 33 | 34 | junit 35 | junit 36 | 4.13.1 37 | test 38 | 39 | 40 | 41 | org.apache.spark 42 | spark-core_${scala.version} 43 | ${spark.version} 44 | 45 | 46 | 47 | org.apache.spark 48 | spark-sql_${scala.version} 49 | ${spark.version} 50 | 51 | 52 | 53 | org.apache.spark 54 | spark-mllib_${scala.version} 55 | ${spark.version} 56 | 57 | 58 | 59 | 60 | 61 | alimaven 62 | aliyun maven 63 | http://maven.aliyun.com/nexus/content/groups/public/ 64 | 65 | true 66 | 67 | 68 | false 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /distributed/spark/src/main/scala/TestScala.scala: -------------------------------------------------------------------------------- 1 | import org.apache.log4j.{Level, Logger} 2 | import org.apache.spark.ml.{Pipeline, Transformer} 3 | import org.apache.spark.ml.param.ParamMap 4 | import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCols} 5 | import org.apache.spark.ml.util.DefaultParamsWritable 6 | import org.apache.spark.sql.functions.{array_contains, col, split} 7 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType} 8 | import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} 9 | import org.apache.spark.{SparkConf, SparkContext} 10 | import org.apache.spark.ml.util.Identifiable 11 | 12 | 13 | object TestScala { 14 | def main(args: Array[String]): Unit = { 15 | printf("test scala...") 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /distributed/spark/src/main/scala/com/libreco/example/AlsExample.scala: -------------------------------------------------------------------------------- 1 | package com.libreco.example 2 | 3 | import org.apache.spark.sql.functions.{round, sum, lit} 4 | import com.libreco.utils.{Context, ItemNameConverter} 5 | import com.libreco.data.DataSplitter 6 | import com.libreco.model.Recommender 7 | 8 | import scala.util.Random 9 | 10 | 11 | object AlsExample extends Context { 12 | import spark.implicits._ 13 | 14 | def main(args: Array[String]): Unit = { 15 | 16 | val dataPath = this.getClass.getResource("/ml-1m/ratings.dat").toString 17 | val splitter = new DataSplitter() 18 | var data = spark.read.textFile(dataPath) 19 | .map(splitter.parseRating("::")) 20 | .toDF("user", "item", "rating", "timestamp") 21 | .withColumn("label", lit(1)) 22 | data.show(4, truncate = false) 23 | // data.columns.foreach(x => println(s"$x -> ${data.filter(data(x).isNull).count}")) 24 | // data.groupBy("rating").count().orderBy($"count".desc) 25 | // .withColumn("percent", round($"count" / sum("count").over(), 4)).show() 26 | 27 | data = data.sample(withReplacement = false, 0.1) // sample 0.1 28 | 29 | val model = new Recommender() 30 | time(model.train(data), "Training") 31 | val transformedData = model.transform(data) 32 | transformedData.show(4, truncate = false) 33 | 34 | val movieMap = ItemNameConverter.getId2ItemName() 35 | val rec = model.recommendForUsers(data, 10, movieMap) 36 | rec.show(20, truncate = false) 37 | 38 | // val model = new Recommender() 39 | // time(model.train(data, evaluate = true, num = 10), "Evaluating") 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /distributed/spark/src/main/scala/com/libreco/example/RegressorExample.scala: -------------------------------------------------------------------------------- 1 | package com.libreco.example 2 | 3 | import org.apache.spark.ml.linalg.SparseVector 4 | import com.libreco.utils.{Context, ItemNameConverter} 5 | import com.libreco.model.Regressor 6 | 7 | 8 | object RegressorExample extends Context { 9 | import spark.implicits._ 10 | 11 | def main(args: Array[String]): Unit = { 12 | val movieNameConverter = ItemNameConverter.getItemName() 13 | val userPath = this.getClass.getResource("/ml-1m/users.dat").toString 14 | val moviePath = this.getClass.getResource("/ml-1m/movies.dat").toString 15 | val ratingPath = this.getClass.getResource("/ml-1m/ratings.dat").toString 16 | 17 | val users = spark.read.textFile(userPath) 18 | .selectExpr("split(value, '::') as col") 19 | .selectExpr( 20 | "cast(col[0] as int) as user", 21 | "cast(col[1] as string) as sex", 22 | "cast(col[2] as int) as age", 23 | "cast(col[3] as int) as occupation") 24 | val items = spark.read.textFile(moviePath) 25 | .selectExpr("split(value, '::') as col") 26 | .selectExpr( 27 | "cast(col[0] as int) as item", 28 | "cast(col[1] as string) as movie", 29 | "cast(col[2] as string) as genre") 30 | .withColumn("movieName", movieNameConverter($"movie")) 31 | .drop($"movie") 32 | .withColumnRenamed("movieName", "movie") 33 | .select("item", "movie", "genre") 34 | var ratings = spark.read.textFile(ratingPath) 35 | .selectExpr("split(value, '::') as col") 36 | .selectExpr( 37 | "cast(col[0] as int) as user", 38 | "cast(col[1] as int) as item", 39 | "cast(col[2] as int) as rating", 40 | "cast(col[3] as long) as timestamp") 41 | 42 | ratings = ratings.sample(withReplacement = false, 0.1) 43 | val temp = ratings.join(users, Seq("user"), "left") 44 | val data = temp.join(items, Seq("item"), "left") 45 | data.show(4, truncate = false) 46 | 47 | // val model = new Regressor(Some("glr")) 48 | // time(model.train(data, evaluate = false), "Training") 49 | // val transformedData = model.transform(data) 50 | // transformedData.show(4, truncate = false) 51 | 52 | val model = new Regressor(Some("gbdt")) 53 | time(model.train(data, evaluate = true, debug = true), "Evaluating") 54 | } 55 | } 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /distributed/spark/src/main/scala/com/libreco/model/Recommender.scala: -------------------------------------------------------------------------------- 1 | package com.libreco.model 2 | 3 | import org.apache.spark.sql.DataFrame 4 | import org.apache.spark.ml.recommendation.{ALS, ALSModel} 5 | import org.apache.spark.sql.functions.{coalesce, typedLit} 6 | import com.libreco.utils.Context 7 | import com.libreco.evaluate.EvalRecommender 8 | 9 | import scala.collection.Map 10 | 11 | class Recommender extends Serializable with Context{ 12 | import spark.implicits._ 13 | var model: ALSModel = _ 14 | 15 | def train(df: DataFrame, evaluate: Boolean = false, num: Int = 10): Unit = { 16 | df.cache() 17 | if (evaluate) { 18 | val evalModel = new EvalRecommender(num, "ndcg") 19 | evalModel.eval(df) 20 | } 21 | else { 22 | val als = new ALS() 23 | .setMaxIter(20) 24 | .setRegParam(0.01) 25 | .setUserCol("user") 26 | .setItemCol("item") 27 | .setRank(50) 28 | .setImplicitPrefs(true) 29 | .setRatingCol("label") 30 | model = als.fit(df) 31 | model.setColdStartStrategy("drop") 32 | } 33 | df.unpersist() 34 | } 35 | 36 | def transform(df: DataFrame): DataFrame = { 37 | model.transform(df) 38 | } 39 | 40 | def recommendForUsers(df: DataFrame, 41 | num: Int, 42 | ItemNameMap: Map[Int, String] = Map.empty): DataFrame = { 43 | 44 | val rec = model.recommendForUserSubset(df, num) 45 | .selectExpr("user", "explode(recommendations) as predAndProb") 46 | .select("user", "predAndProb.*").toDF("user", "item", "prob") // pred means specific item 47 | 48 | val nameMapCol = typedLit(ItemNameMap) 49 | rec.withColumn("name", coalesce(nameMapCol($"item"))) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /distributed/spark/src/main/scala/com/libreco/model/Regressor.scala: -------------------------------------------------------------------------------- 1 | package com.libreco.model 2 | 3 | import org.apache.spark.sql.DataFrame 4 | import com.libreco.feature.FeatureEngineering 5 | import com.libreco.evaluate.EvalRegressor 6 | import org.apache.spark.ml.regression.{GBTRegressor, GeneralizedLinearRegression} 7 | import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage} 8 | 9 | import scala.util.Random 10 | 11 | class Regressor(algo: Option[String] = Some("gbdt")) extends Serializable { 12 | var pipelineModel: PipelineModel = _ 13 | 14 | def train(df: DataFrame, evaluate: Boolean = false, debug: Boolean = false): Unit = { 15 | val prePipelineStages: Array[PipelineStage] = FeatureEngineering.preProcessPipeline(df) 16 | if (debug) { 17 | val pipeline = new Pipeline().setStages(prePipelineStages) 18 | pipelineModel = pipeline.fit(df) 19 | val transformed = pipelineModel.transform(df) 20 | transformed.show(4, truncate = false) 21 | } 22 | if (evaluate) { 23 | val evalModel = new EvalRegressor(algo, prePipelineStages) 24 | evalModel.eval(df) 25 | } 26 | else { 27 | algo match { 28 | case Some("gbdt") => 29 | val model = new GBTRegressor() 30 | .setFeaturesCol("featureVector") 31 | .setLabelCol("rating") 32 | .setPredictionCol("pred") 33 | .setFeatureSubsetStrategy("auto") 34 | .setMaxDepth(3) 35 | .setMaxIter(20) 36 | .setStepSize(0.01) 37 | .setSubsamplingRate(0.8) 38 | .setSeed(Random.nextLong()) 39 | val pipelineStages = prePipelineStages ++ Array(model) 40 | val pipeline = new Pipeline().setStages(pipelineStages) 41 | pipelineModel = pipeline.fit(df) 42 | 43 | case Some("glr") => 44 | val model = new GeneralizedLinearRegression() 45 | .setFeaturesCol("featureVector") 46 | .setLabelCol("rating") 47 | .setPredictionCol("pred") 48 | .setFamily("gaussian") 49 | .setLink("identity") 50 | .setRegParam(0.0) 51 | val pipelineStages = prePipelineStages ++ Array(model) 52 | val pipeline = new Pipeline().setStages(pipelineStages) 53 | pipelineModel = pipeline.fit(df) 54 | 55 | case _ => 56 | println("Model muse either be GBDTRegressor or GeneralizedLinearRegression") 57 | System.exit(1) 58 | } 59 | } 60 | } 61 | 62 | def transform(dataset: DataFrame): DataFrame = { 63 | pipelineModel.transform(dataset) 64 | } 65 | } 66 | 67 | 68 | -------------------------------------------------------------------------------- /distributed/spark/src/main/scala/com/libreco/utils/Context.scala: -------------------------------------------------------------------------------- 1 | package com.libreco.utils 2 | 3 | import org.apache.log4j.{Level, Logger} 4 | import org.apache.spark.SparkConf 5 | import org.apache.spark.sql.SparkSession 6 | 7 | trait Context { 8 | Logger.getLogger("org").setLevel(Level.ERROR) 9 | Logger.getLogger("com").setLevel(Level.ERROR) 10 | 11 | lazy val sparkConf: SparkConf = new SparkConf() 12 | .setAppName("Spark Recommender") 13 | .setMaster("local[*]") 14 | // .set("spark.core.max", "4") 15 | 16 | lazy val spark: SparkSession = SparkSession 17 | .builder() 18 | .config(sparkConf) 19 | .getOrCreate() 20 | 21 | def time[T](block: => T, info: String): T = { 22 | val t0 = System.nanoTime() 23 | val result = block 24 | val t1 = System.nanoTime() 25 | println(f"$info time: ${(t1 - t0) / 1e9d}%.2fs") 26 | result 27 | } 28 | } -------------------------------------------------------------------------------- /distributed/spark/src/main/scala/com/libreco/utils/FilterNAs.scala: -------------------------------------------------------------------------------- 1 | package com.libreco.utils 2 | 3 | import org.apache.spark.sql.{DataFrame, Dataset, Column} 4 | import org.apache.spark.sql.functions.col 5 | 6 | 7 | object FilterNAs { 8 | def filter(data: DataFrame): DataFrame = { 9 | println(s"find and filter NAs for each column...") 10 | data.columns.foreach(col => println(s"$col -> ${data.filter(data(col).isNull).count}")) 11 | val allCols: Array[Column] = data.columns.map(col) // col is a function to get Column 12 | val nullFilter: Column = allCols.map(_.isNotNull).reduce(_ && _) 13 | data.select(allCols: _*).filter(nullFilter) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /distributed/spark/src/main/scala/com/libreco/utils/ItemNameConverter.scala: -------------------------------------------------------------------------------- 1 | package com.libreco.utils 2 | 3 | import org.apache.spark.sql.expressions.UserDefinedFunction 4 | import org.apache.spark.sql.{DataFrame, Dataset} 5 | import org.apache.spark.sql.functions.udf 6 | 7 | import scala.collection.Map 8 | import scala.util.matching.Regex 9 | 10 | 11 | object ItemNameConverter extends Context{ 12 | import spark.implicits._ 13 | 14 | def getId2ItemName(): Map[Int, String] = { 15 | val itemDataPath = this.getClass.getResource("/ml-1m/movies.dat").toString 16 | val itemData: Dataset[String] = spark.read.textFile(itemDataPath) 17 | 18 | itemData.flatMap { line: String => 19 | val Array(id, movieName, _*): Array[String] = line.split("::") 20 | if (id.isEmpty) { 21 | None 22 | } else { 23 | val pattern = new Regex("(.+)(\\(\\d+\\))") 24 | val name = for (m <- pattern.findFirstMatchIn(movieName)) yield m.group(1) 25 | Some(id.toInt, name.mkString) 26 | } 27 | }.collect().toMap 28 | } 29 | 30 | def getItemName(): UserDefinedFunction = { 31 | val itemDataPath = this.getClass.getResource("/ml-1m/movies.dat").toString 32 | val itemData: Dataset[String] = spark.read.textFile(itemDataPath) 33 | 34 | val itemMap = itemData.map { line: String => 35 | val Array(_, movieName, _*): Array[String] = line.split("::") 36 | val pattern = new Regex("(.+)(\\(\\d+\\))") 37 | val name = for (m <- pattern.findFirstMatchIn(movieName)) yield m.group(1) 38 | (movieName, name.mkString) 39 | }.collect().toMap 40 | udf((origName: String) => itemMap(origName)) 41 | } 42 | 43 | def main(args: Array[String]): Unit = { 44 | val mm = getId2ItemName() 45 | for (i <- 1 to 5) { 46 | println(s"Id2ItemName: $i -> ${mm(i)}") 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim 2 | 3 | WORKDIR /root 4 | 5 | RUN apt-get update && \ 6 | apt-get install --no-install-recommends -y gcc g++ && \ 7 | apt-get clean && \ 8 | rm -rf /var/lib/apt/lists/* 9 | 10 | ADD ../requirements.txt /root 11 | RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 12 | RUN pip install --no-cache-dir LibRecommender 13 | RUN pip install --no-cache-dir jupyterlab==3.5.0 -i https://pypi.tuna.tsinghua.edu.cn/simple 14 | 15 | RUN jupyter server --generate-config --allow-root 16 | # password generated based on https://jupyter-notebook.readthedocs.io/en/stable/config.html 17 | RUN echo "c.ServerApp.password = 'argon2:\$argon2id\$v=19\$m=10240,t=10,p=8\$1xV3ym3i6fh/Y9WrkfOfag\$pbATSK3YAwGw1GqdzGqhCw'" >> /root/.jupyter/jupyter_server_config.py 18 | RUN echo "c.ServerApp.ip = '0.0.0.0'" >> /root/.jupyter/jupyter_server_config.py 19 | RUN echo "c.ServerApp.port = 8888" >> /root/.jupyter/jupyter_server_config.py 20 | 21 | ADD ../examples /root/examples 22 | 23 | EXPOSE 8888 24 | 25 | CMD ["jupyter", "lab", "--allow-root", "--notebook-dir=/root", "--no-browser"] 26 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # LibRecommender in Docker 2 | 3 | Users can run [JupyterLab](https://jupyterlab.readthedocs.io/en/stable/) in a docker container and use the library without installing the package. 4 | 5 | 1. Pull image from Docker Hub 6 | ```shell 7 | $ docker pull massquantity/librecommender:latest 8 | ``` 9 | 2. Start a docker container by running following command: 10 | ```shell 11 | $ docker run --rm -p 8889:8888 massquantity/librecommender:latest 12 | ``` 13 | This command exposes the container port 8888 to 8889 on your machine. Feel free to change 8889 to any port you want, 14 | but make sure it is available. 15 | 16 | Or if you want to use your own data on your machine, try following command: 17 | ```shell 18 | $ docker run --rm -p 8889:8888 -v $(pwd):/root/data:ro massquantity/librecommender:latest 19 | ``` 20 | The `-v` flag mounts the current directory to `/root/data` in the container, and the `ro` option means readonly. 21 | You can change `(pwd)` to the directory you want. For more information see [Use bind mounts](https://docs.docker.com/storage/bind-mounts/) 22 | 23 | 3. Open the JupyterLab in a browser with `http://localhost:8889` 24 | 25 | 4. Enter `LibRecommender` as the password. 26 | 27 | 5. The `examples` folder in the repository has been included in the container, so one can use the magic command in the notebook to run some example scripts: 28 | ```shell 29 | cd examples 30 | %run pure_ranking_example.py 31 | ``` 32 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/md_doc/autoint_feature.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/docs/md_doc/autoint_feature.jpg -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==6.1.3 2 | sphinx_copybutton==0.5.2 3 | sphinx-inline-tabs==2022.1.2b11 4 | furo==2022.12.7 5 | numpy==1.23.4 6 | cython==0.29.30 7 | scipy==1.8.1 8 | pandas==1.4.3 9 | scikit-learn==1.1.1 10 | tensorflow-cpu==2.10.1 11 | torch==1.11.0 --index-url https://download.pytorch.org/whl/cpu 12 | gensim>=4.0.0 13 | tqdm==4.64.0 14 | ujson==5.4.0 15 | redis==4.3.4 16 | -------------------------------------------------------------------------------- /docs/source/_static/autoint_feature.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/docs/source/_static/autoint_feature.jpg -------------------------------------------------------------------------------- /docs/source/api/algorithms/als.rst: -------------------------------------------------------------------------------- 1 | ALS 2 | --- 3 | 4 | .. autoclass:: libreco.algorithms.ALS 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/autoint.rst: -------------------------------------------------------------------------------- 1 | AutoInt 2 | ------- 3 | 4 | .. autoclass:: libreco.algorithms.AutoInt 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/bases.rst: -------------------------------------------------------------------------------- 1 | Base Classes 2 | ------------ 3 | 4 | .. autoclass:: libreco.bases.Base 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | 9 | .. autoclass:: libreco.bases.EmbedBase 10 | :members: 11 | :inherited-members: 12 | :show-inheritance: 13 | 14 | .. autoclass:: libreco.bases.TfBase 15 | :members: 16 | :inherited-members: 17 | :show-inheritance: 18 | 19 | .. autoclass:: libreco.bases.CfBase 20 | :members: 21 | :inherited-members: 22 | :show-inheritance: 23 | 24 | .. autoclass:: libreco.bases.RsCfBase 25 | :members: 26 | :inherited-members: 27 | :show-inheritance: 28 | 29 | .. autoclass:: libreco.bases.GensimBase 30 | :members: 31 | :inherited-members: 32 | :show-inheritance: 33 | 34 | .. autoclass:: libreco.bases.SageBase 35 | :members: 36 | :inherited-members: 37 | :show-inheritance: 38 | 39 | .. autoclass:: libreco.bases.DynEmbedBase 40 | :members: 41 | :inherited-members: 42 | :show-inheritance: 43 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/bpr.rst: -------------------------------------------------------------------------------- 1 | BPR 2 | --- 3 | 4 | .. autoclass:: libreco.algorithms.BPR 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/caser.rst: -------------------------------------------------------------------------------- 1 | Caser 2 | ----- 3 | 4 | .. autoclass:: libreco.algorithms.Caser 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/deepfm.rst: -------------------------------------------------------------------------------- 1 | DeepFM 2 | ------ 3 | 4 | .. autoclass:: libreco.algorithms.DeepFM 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/deepwalk.rst: -------------------------------------------------------------------------------- 1 | DeepWalk 2 | -------- 3 | 4 | .. autoclass:: libreco.algorithms.DeepWalk 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/din.rst: -------------------------------------------------------------------------------- 1 | DIN 2 | --- 3 | 4 | .. autoclass:: libreco.algorithms.DIN 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/fm.rst: -------------------------------------------------------------------------------- 1 | FM 2 | -- 3 | 4 | .. autoclass:: libreco.algorithms.FM 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/graphsage.rst: -------------------------------------------------------------------------------- 1 | GraphSage 2 | --------- 3 | 4 | .. autoclass:: libreco.algorithms.GraphSage 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/graphsage_dgl.rst: -------------------------------------------------------------------------------- 1 | GraphSageDGL 2 | ------------ 3 | 4 | .. autoclass:: libreco.algorithms.GraphSageDGL 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | :exclude-members: transform_blocks 9 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/index.rst: -------------------------------------------------------------------------------- 1 | Algorithms 2 | ========== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | bases 8 | user_cf 9 | user_cf_rs 10 | item_cf 11 | item_cf_rs 12 | svd 13 | svdpp 14 | als 15 | ncf 16 | bpr 17 | wide_deep 18 | fm 19 | deepfm 20 | youtube_retrieval 21 | youtube_ranking 22 | autoint 23 | din 24 | item2vec 25 | rnn4rec 26 | caser 27 | wavenet 28 | deepwalk 29 | ngcf 30 | lightgcn 31 | graphsage 32 | graphsage_dgl 33 | pinsage 34 | pinsage_dgl 35 | two_tower 36 | transformer 37 | sim 38 | swing 39 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/item2vec.rst: -------------------------------------------------------------------------------- 1 | Item2Vec 2 | -------- 3 | 4 | .. autoclass:: libreco.algorithms.Item2Vec 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/item_cf.rst: -------------------------------------------------------------------------------- 1 | ItemCF 2 | ------ 3 | 4 | .. autoclass:: libreco.algorithms.ItemCF 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/item_cf_rs.rst: -------------------------------------------------------------------------------- 1 | RsItemCF 2 | -------- 3 | 4 | .. autoclass:: libreco.algorithms.RsItemCF 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/lightgcn.rst: -------------------------------------------------------------------------------- 1 | LightGCN 2 | -------- 3 | 4 | .. autoclass:: libreco.algorithms.LightGCN 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/ncf.rst: -------------------------------------------------------------------------------- 1 | NCF 2 | --- 3 | 4 | .. autoclass:: libreco.algorithms.NCF 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/ngcf.rst: -------------------------------------------------------------------------------- 1 | NGCF 2 | ---- 3 | 4 | .. autoclass:: libreco.algorithms.NGCF 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/pinsage.rst: -------------------------------------------------------------------------------- 1 | PinSage 2 | ------- 3 | 4 | .. autoclass:: libreco.algorithms.PinSage 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/pinsage_dgl.rst: -------------------------------------------------------------------------------- 1 | PinSageDGL 2 | ---------- 3 | 4 | .. autoclass:: libreco.algorithms.PinSageDGL 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | :exclude-members: transform_blocks 9 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/rnn4rec.rst: -------------------------------------------------------------------------------- 1 | RNN4Rec 2 | ------- 3 | 4 | .. autoclass:: libreco.algorithms.RNN4Rec 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/sim.rst: -------------------------------------------------------------------------------- 1 | SIM 2 | --- 3 | 4 | .. autoclass:: libreco.algorithms.SIM 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/svd.rst: -------------------------------------------------------------------------------- 1 | SVD 2 | --- 3 | 4 | .. autoclass:: libreco.algorithms.SVD 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/svdpp.rst: -------------------------------------------------------------------------------- 1 | SVD++ 2 | ----- 3 | 4 | .. autoclass:: libreco.algorithms.SVDpp 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/swing.rst: -------------------------------------------------------------------------------- 1 | Swing 2 | ----- 3 | 4 | .. autoclass:: libreco.algorithms.Swing 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/transformer.rst: -------------------------------------------------------------------------------- 1 | Transformer 2 | ----------- 3 | 4 | .. autoclass:: libreco.algorithms.Transformer 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/two_tower.rst: -------------------------------------------------------------------------------- 1 | TwoTower 2 | -------- 3 | 4 | .. autoclass:: libreco.algorithms.TwoTower 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/user_cf.rst: -------------------------------------------------------------------------------- 1 | UserCF 2 | ------ 3 | 4 | .. autoclass:: libreco.algorithms.UserCF 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/user_cf_rs.rst: -------------------------------------------------------------------------------- 1 | RsUserCF 2 | -------- 3 | 4 | .. autoclass:: libreco.algorithms.RsUserCF 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/wavenet.rst: -------------------------------------------------------------------------------- 1 | WaveNet 2 | ------- 3 | 4 | .. autoclass:: libreco.algorithms.WaveNet 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/wide_deep.rst: -------------------------------------------------------------------------------- 1 | Wide & Deep 2 | ----------- 3 | 4 | .. autoclass:: libreco.algorithms.WideDeep 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/youtube_ranking.rst: -------------------------------------------------------------------------------- 1 | YouTubeRanking 2 | -------------- 3 | 4 | .. autoclass:: libreco.algorithms.YouTubeRanking 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/algorithms/youtube_retrieval.rst: -------------------------------------------------------------------------------- 1 | YouTubeRetrieval 2 | ---------------- 3 | 4 | .. autoclass:: libreco.algorithms.YouTubeRetrieval 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/data/data_info.rst: -------------------------------------------------------------------------------- 1 | DataInfo 2 | ======== 3 | 4 | .. autoclass:: libreco.data.DataInfo 5 | :members: 6 | :special-members: __repr__ 7 | 8 | .. autoclass:: libreco.data.MultiSparseInfo 9 | -------------------------------------------------------------------------------- /docs/source/api/data/dataset.rst: -------------------------------------------------------------------------------- 1 | Dataset 2 | ======= 3 | 4 | .. automodule:: libreco.data.dataset 5 | :members: 6 | :inherited-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/api/data/index.rst: -------------------------------------------------------------------------------- 1 | Data 2 | ==== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | dataset 8 | data_info 9 | split 10 | transformed -------------------------------------------------------------------------------- /docs/source/api/data/split.rst: -------------------------------------------------------------------------------- 1 | Split 2 | ===== 3 | 4 | .. autofunction:: libreco.data.random_split 5 | 6 | .. autofunction:: libreco.data.split_by_ratio 7 | 8 | .. autofunction:: libreco.data.split_by_num 9 | 10 | .. autofunction:: libreco.data.split_by_ratio_chrono 11 | 12 | .. autofunction:: libreco.data.split_by_num_chrono 13 | 14 | .. autofunction:: libreco.data.split_multi_value 15 | -------------------------------------------------------------------------------- /docs/source/api/data/transformed.rst: -------------------------------------------------------------------------------- 1 | TransformedSet 2 | ============== 3 | 4 | .. autoclass:: libreco.data.TransformedSet 5 | :members: 6 | :special-members: __len__ 7 | 8 | .. autoclass:: libreco.data.TransformedEvalSet 9 | :members: 10 | :special-members: __len__ 11 | -------------------------------------------------------------------------------- /docs/source/api/evaluation.rst: -------------------------------------------------------------------------------- 1 | Evaluation 2 | ========== 3 | 4 | .. automodule:: libreco.evaluation 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/api/serialization.rst: -------------------------------------------------------------------------------- 1 | Serialization 2 | ============= 3 | 4 | .. autofunction:: libserving.serialization.save_knn 5 | 6 | .. autofunction:: libserving.serialization.save_embed 7 | 8 | .. autofunction:: libserving.serialization.save_tf 9 | 10 | .. autofunction:: libserving.serialization.save_online 11 | 12 | .. autofunction:: libserving.serialization.knn2redis 13 | 14 | .. autofunction:: libserving.serialization.embed2redis 15 | 16 | .. autofunction:: libserving.serialization.tf2redis 17 | 18 | .. autofunction:: libserving.serialization.online2redis 19 | -------------------------------------------------------------------------------- /docs/source/internal/data_info.rst: -------------------------------------------------------------------------------- 1 | Data Info 2 | ========= 3 | 4 | The :class:`~libreco.data.DataInfo` object stores almost all the useful information from the original data. 5 | We admit there may be too much information in this object, but for the ease of use of the library, 6 | we've decided not to split it. 7 | So almost every model has a ``data_info`` attribute that is used to make recommendations. 8 | Additionally, when saving and loading a model, the corresponding *DataInfo* should also be saved and loaded. 9 | 10 | When using a ``feat`` model, the :class:`~libreco.data.DataInfo` object stores the unique features of 11 | all users/items in the training data. However, if a user/item has different categories or values 12 | in the training data (which may be unlikely if the data is clean :)), only the last one will be stored. 13 | For example, if in one sample a user's age is 20, and in another sample this user's age becomes 25, 14 | then only 25 will be kept. So here we basically assume the data is always sorted by time, 15 | and you should do so if it doesn't. 16 | 17 | Therefore, when you call ``model.predict(user=..., item=...)`` or ``model.recommend_user(user=...)`` 18 | for a feat model, the model will use the stored feature information in DataInfo. 19 | 20 | The :class:`~libreco.data.DataInfo` object also stores users' consumed items, which can be useful in sequence models 21 | and ``unconsumed`` sampler. 22 | 23 | Changing User/Item Features 24 | --------------------------- 25 | It is also possible to change the unique user/item feature values stored in *DataInfo*, 26 | then the new features would be used in prediction and recommendation. 27 | 28 | .. code-block:: python3 29 | 30 | >>> data_info.assign_user_features(user_data=data) 31 | >>> data_info.assign_item_features(item_data=data) 32 | 33 | The passed ``data`` argument is a ``pandas.DataFrame`` that contains the user/item information. 34 | Be careful with this assign operation if you are not sure if the features in ``data`` are useful. 35 | 36 | .. SeeAlso:: 37 | 38 | `changing_feature_example.py `_ 39 | -------------------------------------------------------------------------------- /docs/source/internal/index.rst: -------------------------------------------------------------------------------- 1 | Internal 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | implementation_details 8 | data_info 9 | -------------------------------------------------------------------------------- /docs/source/user_guide/embedding.rst: -------------------------------------------------------------------------------- 1 | Embedding 2 | ========= 3 | 4 | According to the `algorithm list `_, 5 | there are some algorithms that can generate user and item embeddings after training. 6 | So LibRecommender provides public APIs to get them: 7 | 8 | .. code-block:: python3 9 | 10 | >>> model = RNN4Rec(task="ranking", ...) 11 | >>> model.fit(train_data, ...) 12 | >>> model.get_user_embedding(user=1) # get user embedding for user 1 13 | >>> model.get_item_embedding(item=2) # get item embedding for item 2 14 | 15 | One can also search for similar users/items based on embeddings. By default, 16 | we use `nmslib `_ to do approximate similarity 17 | searching since it's generally fast, but some people may find it difficult 18 | to build and install the library, especially on Windows platform or Python >= 3.10. 19 | So one can fall back to numpy similarity calculation if nmslib is not available. 20 | 21 | .. code-block:: python3 22 | 23 | >>> model = RNN4Rec(task="ranking", ...) 24 | >>> model.fit(train_data, ...) 25 | >>> model.init_knn(approximate=True, sim_type="cosine") 26 | >>> model.search_knn_users(user=1, k=3) 27 | >>> model.search_knn_items(item=2, k=3) 28 | 29 | Before searching, one should call :func:`~libreco.bases.EmbedBase.init_knn` to initialize the index. 30 | Set ``approximate=True`` if you can use nmslib, otherwise set ``approximate=False``. 31 | The ``sim_type`` parameter should either be ``cosine`` or ``inner-product``. 32 | 33 | 34 | Dynamic Embedding Generation 35 | ---------------------------- 36 | It is also common to generate user embeddings based on features or behavior sequences. 37 | Once the user embedding has been generated, you can use it to perform similarity search with all the item embeddings. 38 | 39 | This can be useful in the cold-start scenario, so LibRecommender provides API for dynamic user embeddings: 40 | 41 | .. code-block:: python3 42 | 43 | >>> model = RNN4Rec(task="ranking", norm_embed=True, ...) 44 | >>> model.fit(train_data, ...) 45 | >>> user_embed = model.dyn_user_embedding(user=1, seq=[0, 10]) 46 | 47 | >>> model2 = YouTubeRetrieval(task="ranking", norm_embed=False, ...) 48 | >>> model2.fit(train_data, ...) 49 | >>> user_embed = model2.dyn_user_embedding(user="cold user", user_feats={"sex": "F"}, seq=[0, 10]) 50 | 51 | .. SeeAlso:: 52 | 53 | `knn_embedding_example.py `_ 54 | 55 | -------------------------------------------------------------------------------- /docs/source/user_guide/evaluation_save_load.rst: -------------------------------------------------------------------------------- 1 | Evaluation & Save/Load 2 | ====================== 3 | 4 | Evaluate During Training 5 | ------------------------ 6 | 7 | The standard procedure in LibRecommender is evaluating during training. 8 | However, for some complex models doing full evaluation on eval data can be very 9 | time-consuming, so you can specify some evaluation parameters to speed this up. 10 | 11 | The default value of ``eval_batch_size`` is 8192, and you can use a higher value if 12 | you have enough machine or GPU memory. On the contrary, if you encounter memory error during 13 | evaluation, try reducing ``eval_batch_size``. 14 | 15 | The ``eval_user_num`` parameter controls how many users to use in evaluation. 16 | By default, it is ``None``, which uses all the users in eval data. 17 | You can use a smaller value if the evaluation is slow, and this will sample ``eval_user_num`` 18 | users randomly from eval data. 19 | 20 | .. code-block:: python3 21 | 22 | model.fit( 23 | train_data, 24 | verbose=2, 25 | shuffle=True, 26 | eval_data=eval_data, 27 | metrics=metrics, 28 | k=10, # parameter of metrics, e.g. recall at k, ndcg at k 29 | eval_batch_size=8192, 30 | eval_user_num=100, 31 | ) 32 | 33 | 34 | Evaluate After Training 35 | ----------------------- 36 | 37 | After the training, one can use the :func:`~libreco.evaluation.evaluate` function to 38 | evaluate on test data directly. 39 | 40 | Note that if your evaluation data(typically in :class:`pandas.DataFrame` format) **is implicit and only contains positive label**, 41 | then negative sampling is needed by passing ``neg_sampling=True``: 42 | 43 | .. literalinclude:: ../../../examples/save_load_example.py 44 | :caption: From file `examples/save_load_example.py `_ 45 | :name: save_load_example.py 46 | :lines: 85-94 47 | 48 | Save/Load Model 49 | --------------- 50 | 51 | In general, we may want to save/load a model for two reasons: 52 | 53 | 1. Save the model, then load it to make some predictions and recommendations. This is called inference. 54 | 2. Save the model, then load it to retrain the model when we get some new data. 55 | 56 | The ``save/load`` API mainly deal with the first one, and the retraining problem is quite 57 | different, which will be covered in the :doc:`model_retrain`. 58 | When making predictions and recommendations, it may be unnecessary to save all the model 59 | variables. So one can pass ``inference_only=True`` to only save the essential model part. 60 | 61 | After loading the model, one can also evaluate the model directly, 62 | see `save_load_example.py `_ for typical usages. 63 | -------------------------------------------------------------------------------- /docs/source/user_guide/index.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | User Guide 4 | ========== 5 | 6 | The purpose of this guide is to illustrate some main features that LibRecommender provides. 7 | Example usages are all listed in `examples `_ folder. 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | 12 | data_processing 13 | feature_engineering 14 | model_train 15 | evaluation_save_load 16 | recommendation 17 | embedding 18 | model_retrain 19 | -------------------------------------------------------------------------------- /examples/feat_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from libreco.algorithms import YouTubeRanking 4 | from libreco.data import DatasetFeat, split_by_ratio_chrono 5 | 6 | if __name__ == "__main__": 7 | data = pd.read_csv("sample_data/sample_movielens_merged.csv", sep=",", header=0) 8 | 9 | # split into train and test data based on time 10 | train_data, test_data = split_by_ratio_chrono(data, test_size=0.2) 11 | 12 | # specify complete columns information 13 | sparse_col = ["sex", "occupation", "genre1", "genre2", "genre3"] 14 | dense_col = ["age"] 15 | user_col = ["sex", "age", "occupation"] 16 | item_col = ["genre1", "genre2", "genre3"] 17 | 18 | train_data, data_info = DatasetFeat.build_trainset( 19 | train_data, user_col, item_col, sparse_col, dense_col 20 | ) 21 | test_data = DatasetFeat.build_testset(test_data) 22 | print(data_info) # n_users: 5953, n_items: 3209, data density: 0.4213 % 23 | 24 | ytb_ranking = YouTubeRanking( 25 | task="ranking", 26 | data_info=data_info, 27 | embed_size=16, 28 | n_epochs=3, 29 | lr=1e-4, 30 | batch_size=512, 31 | use_bn=True, 32 | hidden_units=(128, 64, 32), 33 | ) 34 | ytb_ranking.fit( 35 | train_data, 36 | neg_sampling=True, # sample negative items train and eval data 37 | verbose=2, 38 | shuffle=True, 39 | eval_data=test_data, 40 | metrics=["loss", "roc_auc", "precision", "recall", "map", "ndcg"], 41 | ) 42 | 43 | # predict preference of user 2211 to item 110 44 | print("prediction: ", ytb_ranking.predict(user=2211, item=110)) 45 | # recommend 7 items for user 2211 46 | print("recommendation: ", ytb_ranking.recommend_user(user=2211, n_rec=7)) 47 | 48 | # cold-start prediction 49 | print( 50 | "cold prediction: ", 51 | ytb_ranking.predict(user="ccc", item="not item", cold_start="average"), 52 | ) 53 | # cold-start recommendation 54 | print( 55 | "cold recommendation: ", 56 | ytb_ranking.recommend_user(user="are we good?", n_rec=7, cold_start="popular"), 57 | ) 58 | -------------------------------------------------------------------------------- /examples/knn_embedding_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from libreco.algorithms import RNN4Rec 4 | from libreco.data import DatasetPure, split_by_ratio_chrono 5 | from libreco.utils.misc import colorize 6 | 7 | try: 8 | import nmslib # noqa: F401 9 | 10 | approximate = True 11 | print_str = "using `nmslib` for similarity search" 12 | print(f"{colorize(print_str, 'cyan')}") 13 | except (ImportError, ModuleNotFoundError): 14 | approximate = False 15 | print_str = "failed to import `nmslib`, using `numpy` for similarity search" 16 | print(f"{colorize(print_str, 'cyan')}") 17 | 18 | 19 | if __name__ == "__main__": 20 | data = pd.read_csv( 21 | "sample_data/sample_movielens_rating.dat", 22 | sep="::", 23 | names=["user", "item", "label", "time"], 24 | ) 25 | 26 | train_data, eval_data = split_by_ratio_chrono(data, test_size=0.2) 27 | train_data, data_info = DatasetPure.build_trainset(train_data) 28 | eval_data = DatasetPure.build_evalset(eval_data) 29 | 30 | rnn = RNN4Rec( 31 | "ranking", 32 | data_info, 33 | rnn_type="lstm", 34 | loss_type="cross_entropy", 35 | embed_size=16, 36 | n_epochs=2, 37 | lr=0.001, 38 | lr_decay=False, 39 | hidden_units=16, 40 | reg=None, 41 | batch_size=2048, 42 | num_neg=1, 43 | dropout_rate=None, 44 | recent_num=10, 45 | tf_sess_config=None, 46 | ) 47 | rnn.fit(train_data, neg_sampling=True, verbose=2) 48 | 49 | # `sim_type` should either be `cosine` or `inner-product` 50 | rnn.init_knn(approximate=approximate, sim_type="cosine") 51 | print("embedding for user 1: ", rnn.get_user_embedding(user=1)) 52 | print("embedding for item 2: ", rnn.get_item_embedding(item=2)) 53 | print() 54 | 55 | print(" 3 most similar users for user 1: ", rnn.search_knn_users(user=1, k=3)) 56 | print(" 3 most similar items for item 2: ", rnn.search_knn_items(item=2, k=3)) 57 | print() 58 | 59 | user_embed = rnn.dyn_user_embedding(user=1, seq=[0, 10, 100]) 60 | print("generate embedding for user 1: ", user_embed) 61 | -------------------------------------------------------------------------------- /examples/multi_sparse_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from libreco.algorithms import DeepFM 4 | from libreco.data import DatasetFeat, split_by_ratio_chrono 5 | 6 | if __name__ == "__main__": 7 | data = pd.read_csv("sample_data/sample_movielens_merged.csv", sep=",", header=0) 8 | train_data, eval_data = split_by_ratio_chrono(data, test_size=0.2) 9 | 10 | # specify complete columns information 11 | sparse_col = ["sex", "occupation"] 12 | multi_sparse_col = [["genre1", "genre2", "genre3"]] # should be list of list 13 | dense_col = ["age"] 14 | user_col = ["sex", "age", "occupation"] 15 | item_col = ["genre1", "genre2", "genre3"] 16 | 17 | train_data, data_info = DatasetFeat.build_trainset( 18 | train_data=train_data, 19 | user_col=user_col, 20 | item_col=item_col, 21 | sparse_col=sparse_col, 22 | dense_col=dense_col, 23 | multi_sparse_col=multi_sparse_col, 24 | pad_val=["missing"], # specify padding value 25 | ) 26 | eval_data = DatasetFeat.build_testset(eval_data) 27 | print(data_info) 28 | 29 | deepfm = DeepFM( 30 | "ranking", 31 | data_info, 32 | embed_size=16, 33 | n_epochs=2, 34 | lr=1e-4, 35 | lr_decay=False, 36 | reg=None, 37 | batch_size=2048, 38 | num_neg=1, 39 | use_bn=False, 40 | dropout_rate=None, 41 | hidden_units=(128, 64, 32), 42 | tf_sess_config=None, 43 | multi_sparse_combiner="sqrtn", # specify multi_sparse combiner 44 | ) 45 | 46 | deepfm.fit( 47 | train_data, 48 | neg_sampling=True, 49 | verbose=2, 50 | shuffle=True, 51 | eval_data=eval_data, 52 | metrics=[ 53 | "loss", 54 | "balanced_accuracy", 55 | "roc_auc", 56 | "pr_auc", 57 | "precision", 58 | "recall", 59 | "map", 60 | "ndcg", 61 | ], 62 | ) 63 | 64 | print("prediction: ", deepfm.predict(user=1, item=2333)) 65 | print("recommendation: ", deepfm.recommend_user(user=1, n_rec=7)) 66 | -------------------------------------------------------------------------------- /examples/pure_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from libreco.algorithms import LightGCN # pure data, algorithm LightGCN 4 | from libreco.data import DatasetPure, random_split 5 | from libreco.evaluation import evaluate 6 | 7 | if __name__ == "__main__": 8 | data = pd.read_csv( 9 | "sample_data/sample_movielens_rating.dat", 10 | sep="::", 11 | names=["user", "item", "label", "time"], 12 | ) 13 | 14 | # split whole data into three folds for training, evaluating and testing 15 | train_data, eval_data, test_data = random_split(data, multi_ratios=[0.8, 0.1, 0.1]) 16 | 17 | train_data, data_info = DatasetPure.build_trainset(train_data) 18 | eval_data = DatasetPure.build_evalset(eval_data) 19 | test_data = DatasetPure.build_testset(test_data) 20 | print(data_info) # n_users: 5894, n_items: 3253, data sparsity: 0.4172 % 21 | 22 | lightgcn = LightGCN( 23 | task="ranking", 24 | data_info=data_info, 25 | loss_type="bpr", 26 | embed_size=16, 27 | n_epochs=3, 28 | lr=1e-3, 29 | batch_size=2048, 30 | num_neg=1, 31 | device="cuda", 32 | ) 33 | # monitor metrics on eval_data during training 34 | lightgcn.fit( 35 | train_data, 36 | neg_sampling=True, # sample negative items for train and eval data 37 | verbose=2, 38 | eval_data=eval_data, 39 | metrics=["loss", "roc_auc", "precision", "recall", "ndcg"], 40 | ) 41 | 42 | # do final evaluation on test data 43 | print( 44 | "evaluate_result: ", 45 | evaluate( 46 | model=lightgcn, 47 | data=test_data, 48 | neg_sampling=True, # sample negative items for test data 49 | metrics=["loss", "roc_auc", "precision", "recall", "ndcg"], 50 | ), 51 | ) 52 | # predict preference of user 2211 to item 110 53 | print("prediction: ", lightgcn.predict(user=2211, item=110)) 54 | # recommend 7 items for user 2211 55 | print("recommendation: ", lightgcn.recommend_user(user=2211, n_rec=7)) 56 | 57 | # cold-start prediction 58 | print( 59 | "cold prediction: ", 60 | lightgcn.predict(user="ccc", item="not item", cold_start="average"), 61 | ) 62 | # cold-start recommendation 63 | print( 64 | "cold recommendation: ", 65 | lightgcn.recommend_user(user="are we good?", n_rec=7, cold_start="popular"), 66 | ) 67 | -------------------------------------------------------------------------------- /examples/seq_example.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pandas as pd 4 | 5 | from libreco.algorithms import DIN 6 | from libreco.data import DatasetFeat 7 | 8 | if __name__ == "__main__": 9 | start_time = time.perf_counter() 10 | data = pd.read_csv("sample_data/sample_movielens_merged.csv", sep=",", header=0) 11 | 12 | # specify complete columns information 13 | sparse_col = ["sex", "occupation", "genre1", "genre2", "genre3"] 14 | dense_col = ["age"] 15 | user_col = ["sex", "age", "occupation"] 16 | item_col = ["genre1", "genre2", "genre3"] 17 | 18 | train_data, data_info = DatasetFeat.build_trainset( 19 | data, user_col, item_col, sparse_col, dense_col 20 | ) 21 | 22 | din = DIN( 23 | "ranking", 24 | data_info, 25 | loss_type="focal", 26 | embed_size=16, 27 | n_epochs=1, 28 | lr=3e-3, 29 | lr_decay=False, 30 | reg=None, 31 | batch_size=64, 32 | sampler="popular", 33 | num_neg=1, 34 | use_bn=True, 35 | hidden_units=(110, 32), 36 | recent_num=10, 37 | tf_sess_config=None, 38 | use_tf_attention=True, 39 | ) 40 | din.fit(train_data, neg_sampling=True, verbose=2, shuffle=True, eval_data=None) 41 | 42 | print( 43 | "feat recommendation: ", 44 | din.recommend_user(user=4617, n_rec=7, user_feats={"sex": "F", "age": 3}), 45 | ) 46 | print( 47 | "seq recommendation1: ", 48 | din.recommend_user( 49 | user=4617, 50 | n_rec=7, 51 | seq=[4, 0, 1, 222, "cold item", 222, 12, 1213, 1197, 1193], 52 | ), 53 | ) 54 | print( 55 | "seq recommendation2: ", 56 | din.recommend_user( 57 | user=4617, 58 | n_rec=7, 59 | seq=["cold item", 1270, 2161, 110, 3827, 12, 34, 1273, 1589], 60 | ), 61 | ) 62 | print( 63 | "feat & seq recommendation1: ", 64 | din.recommend_user( 65 | user=1, n_rec=7, user_feats={"sex": "F", "age": 3}, seq=[4, 0, 1, 222] 66 | ), 67 | ) 68 | print( 69 | "feat & seq recommendation2: ", 70 | din.recommend_user( 71 | user=1, n_rec=7, user_feats={"sex": "M", "age": 33}, seq=[4, 0, 337, 1497] 72 | ), 73 | ) 74 | -------------------------------------------------------------------------------- /examples/split_data_example.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pandas as pd 4 | 5 | from libreco.data import ( 6 | random_split, 7 | split_by_num, 8 | split_by_num_chrono, 9 | split_by_ratio, 10 | split_by_ratio_chrono, 11 | ) 12 | 13 | if __name__ == "__main__": 14 | start_time = time.perf_counter() 15 | data = pd.read_csv("sample_data/sample_movielens_merged.csv", sep=",", header=0) 16 | 17 | train_data, eval_data, test_data = random_split(data, multi_ratios=[0.8, 0.1, 0.1]) 18 | 19 | train_data2, eval_data2 = split_by_ratio(data, test_size=0.2) 20 | print(train_data2.shape, eval_data2.shape) 21 | 22 | train_data3, eval_data3 = split_by_num(data, test_size=1) 23 | print(train_data3.shape, eval_data3.shape) 24 | 25 | train_data4, eval_data4 = split_by_ratio_chrono(data, test_size=0.2) 26 | print(train_data4.shape, eval_data4.shape) 27 | 28 | train_data5, eval_data5 = split_by_num_chrono(data, test_size=1) 29 | print(train_data5.shape, eval_data5.shape) 30 | -------------------------------------------------------------------------------- /libreco/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.5.1" 2 | -------------------------------------------------------------------------------- /libreco/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | from .als import ALS 2 | from .autoint import AutoInt 3 | from .bpr import BPR 4 | from .caser import Caser 5 | from .deepfm import DeepFM 6 | from .deepwalk import DeepWalk 7 | from .din import DIN 8 | from .fm import FM 9 | from .graphsage import GraphSage 10 | from .graphsage_dgl import GraphSageDGL 11 | from .item2vec import Item2Vec 12 | from .item_cf import ItemCF 13 | from .item_cf_rs import RsItemCF 14 | from .lightgcn import LightGCN 15 | from .ncf import NCF 16 | from .ngcf import NGCF 17 | from .pinsage import PinSage 18 | from .pinsage_dgl import PinSageDGL 19 | from .rnn4rec import RNN4Rec 20 | from .sim import SIM 21 | from .svd import SVD 22 | from .svdpp import SVDpp 23 | from .swing import Swing 24 | from .transformer import Transformer 25 | from .two_tower import TwoTower 26 | from .user_cf import UserCF 27 | from .user_cf_rs import RsUserCF 28 | from .wave_net import WaveNet 29 | from .wide_deep import WideDeep 30 | from .youtube_ranking import YouTubeRanking 31 | from .youtube_retrieval import YouTubeRetrieval 32 | 33 | __all__ = [ 34 | "UserCF", 35 | "RsUserCF", 36 | "ItemCF", 37 | "RsItemCF", 38 | "SVD", 39 | "SVDpp", 40 | "ALS", 41 | "BPR", 42 | "NCF", 43 | "YouTubeRetrieval", 44 | "YouTubeRanking", 45 | "FM", 46 | "WideDeep", 47 | "DeepFM", 48 | "AutoInt", 49 | "DIN", 50 | "RNN4Rec", 51 | "Caser", 52 | "WaveNet", 53 | "Item2Vec", 54 | "DeepWalk", 55 | "NGCF", 56 | "LightGCN", 57 | "PinSage", 58 | "PinSageDGL", 59 | "GraphSage", 60 | "GraphSageDGL", 61 | "TwoTower", 62 | "Transformer", 63 | "SIM", 64 | "Swing", 65 | ] 66 | -------------------------------------------------------------------------------- /libreco/algorithms/item2vec.py: -------------------------------------------------------------------------------- 1 | """Implementation of Item2Vec.""" 2 | from gensim.models import Word2Vec 3 | from tqdm import tqdm 4 | 5 | from ..bases import GensimBase 6 | 7 | 8 | class Item2Vec(GensimBase): 9 | """*Item2Vec* algorithm. 10 | 11 | .. WARNING:: 12 | Item2Vec can only use in ``ranking`` task. 13 | 14 | Parameters 15 | ---------- 16 | task : {'ranking'} 17 | Recommendation task. See :ref:`Task`. 18 | data_info : :class:`~libreco.data.DataInfo` object 19 | Object that contains useful information for training and inference. 20 | embed_size: int, default: 16 21 | Vector size of embeddings. 22 | norm_embed : bool, default: False 23 | Whether to l2 normalize output embeddings. 24 | window_size : int, default: 5 25 | Maximum item distance within a sequence during training. 26 | n_epochs: int, default: 10 27 | Number of epochs for training. 28 | n_threads : int, default: 0 29 | Number of threads to use, `0` will use all cores. 30 | seed : int, default: 42 31 | Random seed. 32 | lower_upper_bound : tuple or None, default: None 33 | Lower and upper score bound for `rating` task. 34 | 35 | References 36 | ---------- 37 | *Oren Barkan and Noam Koenigstein.* `Item2Vec: Neural Item Embedding for Collaborative Filtering 38 | `_. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | task, 44 | data_info=None, 45 | embed_size=16, 46 | norm_embed=False, 47 | window_size=5, 48 | n_epochs=10, 49 | n_threads=0, 50 | seed=42, 51 | lower_upper_bound=None, 52 | ): 53 | super().__init__( 54 | task, 55 | data_info, 56 | embed_size, 57 | norm_embed, 58 | window_size, 59 | n_epochs, 60 | n_threads, 61 | seed, 62 | lower_upper_bound, 63 | ) 64 | assert task == "ranking", "Item2Vec is only suitable for ranking" 65 | self.all_args = locals() 66 | 67 | def get_data(self): 68 | return _ItemCorpus(self.user_consumed) 69 | 70 | def build_model(self): 71 | model = Word2Vec( 72 | vector_size=self.embed_size, 73 | window=self.window_size, 74 | sg=1, 75 | hs=0, 76 | negative=5, 77 | seed=self.seed, 78 | min_count=1, 79 | workers=self.workers, 80 | sorted_vocab=0, 81 | ) 82 | model.build_vocab(self.data, update=False) 83 | return model 84 | 85 | 86 | class _ItemCorpus: 87 | def __init__(self, user_consumed): 88 | self.item_seqs = user_consumed.values() 89 | self.i = 0 90 | 91 | def __iter__(self): 92 | for items in tqdm(self.item_seqs, desc=f"Item2vec iter{self.i}"): 93 | yield list(map(str, items)) 94 | self.i += 1 95 | -------------------------------------------------------------------------------- /libreco/algorithms/item_cf_rs.py: -------------------------------------------------------------------------------- 1 | """Implementation of RsItemCF.""" 2 | from ..bases import RsCfBase 3 | 4 | 5 | class RsItemCF(RsCfBase): 6 | """*Item Collaborative Filtering* algorithm implemented in Rust. 7 | 8 | Parameters 9 | ---------- 10 | task : {'rating', 'ranking'} 11 | Recommendation task. See :ref:`Task`. 12 | data_info : :class:`~libreco.data.DataInfo` object 13 | Object that contains useful information for training and inference. 14 | k_sim : int, default: 20 15 | Number of similar items to use. 16 | num_threads : int, default: 1 17 | Number of threads to use. 18 | min_common : int, default: 1 19 | Number of minimum common items to consider when computing similarities. 20 | mode : {'forward', 'invert'}, default: 'invert' 21 | Whether to use forward index or invert index. 22 | seed : int, default: 42 23 | Random seed. 24 | lower_upper_bound : tuple or None, default: None 25 | Lower and upper score bound for `rating` task. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | task, 31 | data_info, 32 | k_sim=20, 33 | num_threads=1, 34 | min_common=1, 35 | mode="invert", 36 | seed=42, 37 | lower_upper_bound=None, 38 | ): 39 | super().__init__( 40 | task, 41 | data_info, 42 | k_sim, 43 | num_threads, 44 | min_common, 45 | mode, 46 | seed, 47 | lower_upper_bound, 48 | ) 49 | self.all_args = locals() 50 | -------------------------------------------------------------------------------- /libreco/algorithms/torch_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .graphsage_module import GraphSageDGLModel, GraphSageModel 2 | from .lightgcn_module import LightGCNModel 3 | from .ngcf_module import NGCFModel 4 | from .pinsage_module import PinSageDGLModel, PinSageModel 5 | 6 | __all__ = [ 7 | "GraphSageModel", 8 | "GraphSageDGLModel", 9 | "LightGCNModel", 10 | "NGCFModel", 11 | "PinSageModel", 12 | "PinSageDGLModel", 13 | ] 14 | -------------------------------------------------------------------------------- /libreco/algorithms/user_cf_rs.py: -------------------------------------------------------------------------------- 1 | """Implementation of RsUserCF.""" 2 | from ..bases import RsCfBase 3 | 4 | 5 | class RsUserCF(RsCfBase): 6 | """*User Collaborative Filtering* algorithm implemented in Rust. 7 | 8 | Parameters 9 | ---------- 10 | task : {'rating', 'ranking'} 11 | Recommendation task. See :ref:`Task`. 12 | data_info : :class:`~libreco.data.DataInfo` object 13 | Object that contains useful information for training and inference. 14 | k_sim : int, default: 20 15 | Number of similar items to use. 16 | num_threads : int, default: 1 17 | Number of threads to use. 18 | min_common : int, default: 1 19 | Number of minimum common users to consider when computing similarities. 20 | mode : {'forward', 'invert'}, default: 'invert' 21 | Whether to use forward index or invert index. 22 | seed : int, default: 42 23 | Random seed. 24 | lower_upper_bound : tuple or None, default: None 25 | Lower and upper score bound for `rating` task. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | task, 31 | data_info, 32 | k_sim=20, 33 | num_threads=1, 34 | min_common=1, 35 | mode="invert", 36 | seed=42, 37 | lower_upper_bound=None, 38 | ): 39 | super().__init__( 40 | task, 41 | data_info, 42 | k_sim, 43 | num_threads, 44 | min_common, 45 | mode, 46 | seed, 47 | lower_upper_bound, 48 | ) 49 | self.all_args = locals() 50 | -------------------------------------------------------------------------------- /libreco/bases/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Base 2 | from .cf_base import CfBase 3 | from .cf_base_rs import RsCfBase 4 | from .dyn_embed_base import DynEmbedBase 5 | from .embed_base import EmbedBase 6 | from .gensim_base import GensimBase 7 | from .meta import ModelMeta 8 | from .sage_base import SageBase 9 | from .tf_base import TfBase 10 | 11 | __all__ = [ 12 | "Base", 13 | "CfBase", 14 | "RsCfBase", 15 | "DynEmbedBase", 16 | "EmbedBase", 17 | "GensimBase", 18 | "ModelMeta", 19 | "SageBase", 20 | "TfBase", 21 | ] 22 | -------------------------------------------------------------------------------- /libreco/bases/meta.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | 3 | from ..tfops import rebuild_tf_model 4 | from ..torchops import rebuild_torch_model 5 | 6 | 7 | class ModelMeta(ABCMeta): 8 | def __new__(mcs, cls_name, bases, cls_dict, **kwargs): 9 | backend = kwargs["backend"] if "backend" in kwargs else "none" 10 | if bases[0].__name__ == "TfBase" or backend == "tensorflow": 11 | cls_dict["rebuild_model"] = rebuild_tf_model 12 | elif backend == "torch": 13 | cls_dict["rebuild_model"] = rebuild_torch_model 14 | return super().__new__(mcs, cls_name, bases, cls_dict) 15 | -------------------------------------------------------------------------------- /libreco/batch/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch_data import adjust_batch_size, get_batch_loader 2 | from .tf_feed_dicts import get_tf_feeds 3 | 4 | __all__ = ["adjust_batch_size", "get_batch_loader", "get_tf_feeds"] 5 | -------------------------------------------------------------------------------- /libreco/batch/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class FeatType(Enum): 5 | SPARSE = "sparse" 6 | DENSE = "dense" 7 | 8 | 9 | class Backend(Enum): 10 | TF = "tensorflow" 11 | TORCH = "torch" 12 | -------------------------------------------------------------------------------- /libreco/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_info import DataInfo, MultiSparseInfo 2 | from .dataset import DatasetFeat, DatasetPure 3 | from .processing import process_data, split_multi_value 4 | from .split import ( 5 | random_split, 6 | split_by_num, 7 | split_by_num_chrono, 8 | split_by_ratio, 9 | split_by_ratio_chrono, 10 | ) 11 | from .transformed import TransformedEvalSet, TransformedSet 12 | 13 | __all__ = [ 14 | "DatasetPure", 15 | "DatasetFeat", 16 | "DataInfo", 17 | "MultiSparseInfo", 18 | "process_data", 19 | "split_multi_value", 20 | "split_by_num", 21 | "split_by_ratio", 22 | "split_by_num_chrono", 23 | "split_by_ratio_chrono", 24 | "random_split", 25 | "TransformedSet", 26 | "TransformedEvalSet", 27 | ] 28 | -------------------------------------------------------------------------------- /libreco/data/consumed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict, defaultdict 3 | 4 | import numpy as np 5 | 6 | 7 | def interaction_consumed(user_indices, item_indices): 8 | """The underlying rust function will remove consecutive repeated elements.""" 9 | if isinstance(user_indices, np.ndarray): 10 | user_indices = user_indices.tolist() 11 | if isinstance(item_indices, np.ndarray): 12 | item_indices = item_indices.tolist() 13 | 14 | try: 15 | from recfarm import build_consumed_unique 16 | 17 | return build_consumed_unique(user_indices, item_indices) 18 | except ModuleNotFoundError: # pragma: no cover 19 | return _interaction_consumed(user_indices, item_indices) 20 | 21 | 22 | def _interaction_consumed(user_indices, item_indices): # pragma: no cover 23 | user_consumed = defaultdict(list) 24 | item_consumed = defaultdict(list) 25 | for u, i in zip(user_indices, item_indices): 26 | user_consumed[u].append(i) 27 | item_consumed[i].append(u) 28 | return _remove_duplicates(user_consumed, item_consumed) 29 | 30 | 31 | def _remove_duplicates(user_consumed, item_consumed): # pragma: no cover 32 | # keys will preserve order in dict since Python3.7 33 | if sys.version_info[:2] >= (3, 7): 34 | dict_func = dict.fromkeys 35 | else: # pragma: no cover 36 | dict_func = OrderedDict.fromkeys 37 | user_dedup = {u: list(dict_func(items)) for u, items in user_consumed.items()} 38 | item_dedup = {i: list(dict_func(users)) for i, users in item_consumed.items()} 39 | return user_dedup, item_dedup 40 | 41 | 42 | def update_consumed( 43 | user_indices, item_indices, n_users, n_items, old_info, merge_behavior 44 | ): 45 | user_consumed, item_consumed = interaction_consumed(user_indices, item_indices) 46 | if merge_behavior: 47 | user_consumed = _merge_dedup(user_consumed, n_users, old_info.user_consumed) 48 | item_consumed = _merge_dedup(item_consumed, n_items, old_info.item_consumed) 49 | else: 50 | user_consumed = _fill_empty(user_consumed, n_users, old_info.user_consumed) 51 | item_consumed = _fill_empty(item_consumed, n_items, old_info.item_consumed) 52 | return user_consumed, item_consumed 53 | 54 | 55 | def _merge_dedup(new_consumed, num, old_consumed): 56 | result = dict() 57 | for i in range(num): 58 | assert i in new_consumed or i in old_consumed 59 | if i in new_consumed and i in old_consumed: 60 | result[i] = old_consumed[i] + new_consumed[i] 61 | else: 62 | result[i] = new_consumed[i] if i in new_consumed else old_consumed[i] 63 | return result 64 | 65 | 66 | # some users may not appear in new data 67 | def _fill_empty(consumed, num, old_consumed): 68 | return {i: consumed[i] if i in consumed else old_consumed[i] for i in range(num)} 69 | 70 | 71 | # def _remove_first_duplicates(consumed): 72 | # return pd.Series(consumed).drop_duplicates(keep="last").tolist() 73 | -------------------------------------------------------------------------------- /libreco/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluate import evaluate, print_metrics 2 | 3 | __all__ = ["evaluate", "print_metrics"] 4 | -------------------------------------------------------------------------------- /libreco/evaluation/computation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from tqdm import tqdm 4 | 5 | from ..data import TransformedEvalSet 6 | from ..prediction.preprocess import convert_id 7 | from ..utils.validate import check_labels 8 | 9 | 10 | def build_eval_transformed_data(model, data, neg_sampling, seed): 11 | if isinstance(data, pd.DataFrame): 12 | assert "user" in data and "item" in data and "label" in data 13 | users = data["user"].tolist() 14 | items = data["item"].tolist() 15 | user_indices, item_indices = convert_id(model, users, items, inner_id=False) 16 | labels = data["label"].to_numpy(dtype=np.float32) 17 | data = TransformedEvalSet(user_indices, item_indices, labels) 18 | if neg_sampling and not data.has_sampled: 19 | num_neg = model.num_neg or 1 if hasattr(model, "num_neg") else 1 20 | data.build_negatives(model.n_items, num_neg, seed=seed) 21 | else: 22 | check_labels(model, data.labels, neg_sampling) 23 | return data 24 | 25 | 26 | def compute_preds(model, data, batch_size): 27 | y_pred = list() 28 | y_label = list() 29 | for i in tqdm(range(0, len(data), batch_size), desc="eval_pointwise"): 30 | user_indices, item_indices, labels = data[i : i + batch_size] 31 | preds = model.predict(user_indices, item_indices, inner_id=True) 32 | y_pred.extend(preds) 33 | y_label.extend(labels) 34 | return y_pred, y_label 35 | 36 | 37 | def compute_probs(model, data, batch_size): 38 | return compute_preds(model, data, batch_size) 39 | 40 | 41 | def compute_recommends(model, users, k, num_batch_users): 42 | y_recommends = dict() 43 | for i in tqdm(range(0, len(users), num_batch_users), desc="eval_listwise"): 44 | batch_users = users[i : i + num_batch_users] 45 | batch_recs = model.recommend_user( 46 | user=batch_users, 47 | n_rec=k, 48 | inner_id=True, 49 | filter_consumed=True, 50 | random_rec=False, 51 | ) 52 | y_recommends.update(batch_recs) 53 | return y_recommends 54 | -------------------------------------------------------------------------------- /libreco/feature/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/libreco/feature/__init__.py -------------------------------------------------------------------------------- /libreco/feature/column_mapping.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, defaultdict 2 | 3 | import numpy as np 4 | 5 | 6 | # format: {column_family_name: {column_name: index}} 7 | # if no such family, default format would be: {column_family_name: {[]: []} 8 | def col_name2index(user_col=None, item_col=None, sparse_col=None, dense_col=None): 9 | name_mapping = defaultdict(OrderedDict) 10 | if sparse_col: 11 | sparse_col_dict = {col: i for i, col in enumerate(sparse_col)} 12 | name_mapping["sparse_col"].update(sparse_col_dict) 13 | if dense_col: 14 | dense_col_dict = {col: i for i, col in enumerate(dense_col)} 15 | name_mapping["dense_col"].update(dense_col_dict) 16 | 17 | if user_col and sparse_col: 18 | user_sparse_col = _extract_common_col(sparse_col, user_col) 19 | for col in user_sparse_col: 20 | name_mapping["user_sparse_col"].update( 21 | {col: name_mapping["sparse_col"][col]} 22 | ) 23 | if user_col and dense_col: 24 | user_dense_col = _extract_common_col(dense_col, user_col) 25 | for col in user_dense_col: 26 | name_mapping["user_dense_col"].update({col: name_mapping["dense_col"][col]}) 27 | 28 | if item_col and sparse_col: 29 | item_sparse_col = _extract_common_col(sparse_col, item_col) 30 | for col in item_sparse_col: 31 | name_mapping["item_sparse_col"].update( 32 | {col: name_mapping["sparse_col"][col]} 33 | ) 34 | if item_col and dense_col: 35 | item_dense_col = _extract_common_col(dense_col, item_col) 36 | for col in item_dense_col: 37 | name_mapping["item_dense_col"].update({col: name_mapping["dense_col"][col]}) 38 | 39 | return dict(name_mapping) 40 | 41 | 42 | # `np.intersect1d` will return the sorted common column names, 43 | # but we also want to preserve the original order of common column in col1 and col2 44 | def _extract_common_col(col1, col2): 45 | common_col, indices_in_col1, _ = np.intersect1d( 46 | col1, col2, assume_unique=True, return_indices=True 47 | ) 48 | return common_col[np.lexsort((common_col, indices_in_col1))] 49 | -------------------------------------------------------------------------------- /libreco/feature/ssl.py: -------------------------------------------------------------------------------- 1 | """Feature Generation for Self-Supervised Learning.""" 2 | import numpy as np 3 | from sklearn.metrics import mutual_info_score 4 | 5 | 6 | def get_ssl_features(model, batch_size): 7 | ssl_feats = dict() 8 | rng, n_items = model.data_info.np_rng, model.n_items 9 | replace = False if batch_size < n_items else True 10 | item_indices = rng.choice(n_items, size=batch_size, replace=replace) 11 | feat_indices = model.data_info.item_sparse_unique[item_indices] 12 | # add offset since default embedding has 0 index 13 | sparse_indices = np.hstack( 14 | [np.expand_dims(item_indices + 1, 1), feat_indices + n_items + 1] 15 | ) 16 | feat_num = sparse_indices.shape[1] 17 | mid_point = feat_num // 2 18 | if model.ssl_pattern.startswith("cfm"): 19 | seed_col = rng.integers(feat_num) 20 | left_cols = model.sparse_feat_mutual_info[seed_col] 21 | right_cols = np.setdiff1d(range(feat_num), left_cols) 22 | elif model.ssl_pattern.endswith("complementary"): 23 | random_cols = rng.permutation(feat_num) 24 | left_cols, right_cols = np.split(random_cols, [mid_point]) 25 | else: 26 | left_cols = rng.permutation(feat_num)[:mid_point] 27 | right_cols = rng.permutation(feat_num)[:mid_point] 28 | 29 | left_sparse_indices = sparse_indices.copy() 30 | left_sparse_indices[:, left_cols] = 0 31 | ssl_feats.update({model.ssl_left_sparse_indices: left_sparse_indices}) 32 | right_sparse_indices = sparse_indices.copy() 33 | right_sparse_indices[:, right_cols] = 0 34 | ssl_feats.update({model.ssl_right_sparse_indices: right_sparse_indices}) 35 | 36 | if model.item_dense: 37 | dense_values = model.data_info.item_dense_unique[item_indices] 38 | ssl_feats.update({model.ssl_left_dense_values: dense_values}) 39 | ssl_feats.update({model.ssl_right_dense_values: dense_values}) 40 | return ssl_feats 41 | 42 | 43 | def get_mutual_info(data, data_info): 44 | """Compute mutual information for each pair of item sparse features.""" 45 | item_indices = np.expand_dims(data.item_indices, 1) 46 | feat_indices = data.sparse_indices[:, data_info.item_sparse_col.index] 47 | sparse_indices = np.hstack([item_indices, feat_indices]) 48 | feat_num = sparse_indices.shape[1] 49 | pairwise_mutual_info = np.zeros((feat_num, feat_num)) 50 | # assign self mutual info to impossible value 51 | np.fill_diagonal(pairwise_mutual_info, -1) 52 | for i in range(feat_num): 53 | for j in range(i + 1, feat_num): 54 | mi = mutual_info_score(sparse_indices[:, i], sparse_indices[:, j]) 55 | pairwise_mutual_info[i][j] = pairwise_mutual_info[j][i] = mi 56 | 57 | n = feat_num // 2 58 | topn_mutual_info = np.argsort(pairwise_mutual_info, axis=1)[:, -n:] 59 | return {i: topn_mutual_info[i] for i in range(feat_num)} 60 | -------------------------------------------------------------------------------- /libreco/feature/unique.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def construct_unique_feat( 5 | user_indices, 6 | item_indices, 7 | sparse_indices, 8 | dense_values, 9 | col_name_mapping, 10 | unique_feat, 11 | ): 12 | # use mergesort to preserve order 13 | sort_kind = "quicksort" if unique_feat else "mergesort" 14 | user_pos = np.argsort(user_indices, kind=sort_kind) 15 | item_pos = np.argsort(item_indices, kind=sort_kind) 16 | 17 | user_sparse_matrix, item_sparse_matrix = None, None 18 | user_dense_matrix, item_dense_matrix = None, None 19 | if "user_sparse_col" in col_name_mapping: 20 | user_sparse_col = list(col_name_mapping["user_sparse_col"].values()) 21 | user_sparse_matrix = _compress_unique_values( 22 | sparse_indices, user_sparse_col, user_indices, user_pos 23 | ) 24 | if "item_sparse_col" in col_name_mapping: 25 | item_sparse_col = list(col_name_mapping["item_sparse_col"].values()) 26 | item_sparse_matrix = _compress_unique_values( 27 | sparse_indices, item_sparse_col, item_indices, item_pos 28 | ) 29 | if "user_dense_col" in col_name_mapping: 30 | user_dense_col = list(col_name_mapping["user_dense_col"].values()) 31 | user_dense_matrix = _compress_unique_values( 32 | dense_values, user_dense_col, user_indices, user_pos 33 | ) 34 | if "item_dense_col" in col_name_mapping: 35 | item_dense_col = list(col_name_mapping["item_dense_col"].values()) 36 | item_dense_matrix = _compress_unique_values( 37 | dense_values, item_dense_col, item_indices, item_pos 38 | ) 39 | return ( 40 | user_sparse_matrix, 41 | user_dense_matrix, 42 | item_sparse_matrix, 43 | item_dense_matrix, 44 | ) 45 | 46 | 47 | # https://stackoverflow.com/questions/46390376/drop-duplicates-from-structured-numpy-array-python3-x 48 | def _compress_unique_values(orig_val, col, indices, pos): 49 | values = np.take(orig_val, col, axis=1) 50 | values = values.reshape(-1, 1) if orig_val.ndim == 1 else values 51 | indices = indices[pos] 52 | mask = np.empty(len(indices), dtype=bool) 53 | mask[:-1] = indices[:-1] != indices[1:] 54 | mask[-1] = True 55 | mask = pos[mask] 56 | unique_values = values[mask] 57 | assert len(np.unique(indices)) == len(unique_values) 58 | return unique_values 59 | -------------------------------------------------------------------------------- /libreco/graph/__init__.py: -------------------------------------------------------------------------------- 1 | from .from_dgl import ( 2 | build_i2i_homo_graph, 3 | build_subgraphs, 4 | build_u2i_hetero_graph, 5 | check_dgl, 6 | compute_i2i_edge_scores, 7 | compute_u2i_edge_scores, 8 | pairs_from_dgl_graph, 9 | ) 10 | from .neighbor_walk import NeighborWalker, NeighborWalkerDGL 11 | 12 | __all__ = [ 13 | "build_i2i_homo_graph", 14 | "build_subgraphs", 15 | "build_u2i_hetero_graph", 16 | "check_dgl", 17 | "compute_i2i_edge_scores", 18 | "compute_u2i_edge_scores", 19 | "pairs_from_dgl_graph", 20 | "NeighborWalker", 21 | "NeighborWalkerDGL", 22 | ] 23 | -------------------------------------------------------------------------------- /libreco/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import ( 2 | compute_causal_mask, 3 | compute_seq_mask, 4 | din_attention, 5 | multi_head_attention, 6 | tf_attention, 7 | ) 8 | from .convolutional import conv_nn, max_pool 9 | from .dense import dense_nn, shared_dense, tf_dense 10 | from .embedding import embedding_lookup, seq_embeds_pooling, sparse_embeds_pooling 11 | from .normalization import layer_normalization, normalize_embeds, rms_norm 12 | from .recurrent import tf_rnn 13 | 14 | __all__ = [ 15 | "compute_causal_mask", 16 | "compute_seq_mask", 17 | "conv_nn", 18 | "dense_nn", 19 | "din_attention", 20 | "embedding_lookup", 21 | "layer_normalization", 22 | "max_pool", 23 | "multi_head_attention", 24 | "normalize_embeds", 25 | "rms_norm", 26 | "shared_dense", 27 | "seq_embeds_pooling", 28 | "sparse_embeds_pooling", 29 | "tf_attention", 30 | "tf_dense", 31 | "tf_rnn", 32 | ] 33 | -------------------------------------------------------------------------------- /libreco/layers/activation.py: -------------------------------------------------------------------------------- 1 | from ..tfops import tf 2 | 3 | 4 | def gelu(x): 5 | return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, tf.float32))) 6 | 7 | 8 | def swish(x): 9 | return x * tf.sigmoid(x) 10 | -------------------------------------------------------------------------------- /libreco/layers/convolutional.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from ..tfops import get_tf_version, tf 4 | 5 | 6 | def conv_nn( 7 | filters, kernel_size, strides, padding, activation, dilation_rate=1, version=None 8 | ): 9 | tf_version = get_tf_version(version) 10 | if tf_version >= "2.0.0": 11 | net = tf.keras.layers.Conv1D( 12 | filters=filters, 13 | kernel_size=kernel_size, 14 | strides=strides, 15 | padding=padding, 16 | activation=activation, 17 | dilation_rate=dilation_rate, 18 | ) 19 | else: 20 | net = partial( 21 | tf.layers.conv1d, 22 | filters=filters, 23 | kernel_size=kernel_size, 24 | strides=strides, 25 | padding=padding, 26 | activation=activation, 27 | ) 28 | return net 29 | 30 | 31 | def max_pool(pool_size, strides, padding, version=None): 32 | tf_version = get_tf_version(version) 33 | if tf_version >= "2.0.0": 34 | net = tf.keras.layers.MaxPool1D( 35 | pool_size=pool_size, strides=strides, padding=padding 36 | ) 37 | else: 38 | net = partial( 39 | tf.layers.max_pooling1d, 40 | pool_size=pool_size, 41 | strides=strides, 42 | padding=padding, 43 | ) 44 | return net 45 | -------------------------------------------------------------------------------- /libreco/layers/embedding.py: -------------------------------------------------------------------------------- 1 | from ..tfops import tf 2 | 3 | 4 | def embedding_lookup( 5 | indices, 6 | var_name=None, 7 | var_shape=None, 8 | initializer=None, 9 | regularizer=None, 10 | reuse_layer=None, 11 | embed_var=None, 12 | scope_name="embedding", 13 | ): 14 | reuse = tf.AUTO_REUSE if reuse_layer else None 15 | with tf.variable_scope(scope_name, reuse=reuse): 16 | if embed_var is None: 17 | embed_var = tf.get_variable( 18 | name=var_name, 19 | shape=var_shape, 20 | initializer=initializer, 21 | regularizer=regularizer, 22 | ) 23 | return tf.nn.embedding_lookup(embed_var, indices) 24 | 25 | 26 | def sparse_embeds_pooling( 27 | sparse_indices, 28 | var_name, 29 | var_shape, 30 | initializer, 31 | regularizer=None, 32 | reuse_layer=None, 33 | combiner="sqrtn", 34 | scope_name="sparse_embeds_pooling", 35 | ): 36 | reuse = tf.AUTO_REUSE if reuse_layer else None 37 | with tf.variable_scope(scope_name, reuse=reuse): 38 | embed_var = tf.get_variable( 39 | name=var_name, 40 | shape=var_shape, 41 | initializer=initializer, 42 | regularizer=regularizer, 43 | ) 44 | # unknown user will return 0-vector in `safe_embedding_lookup_sparse` 45 | return tf.nn.safe_embedding_lookup_sparse( 46 | embed_var, 47 | sparse_indices, 48 | sparse_weights=None, 49 | combiner=combiner, 50 | default_id=None, 51 | ) 52 | 53 | 54 | def seq_embeds_pooling( 55 | seq_indices, 56 | seq_lens, 57 | n_items, 58 | var_name, 59 | var_shape, 60 | initializer=None, 61 | regularizer=None, 62 | reuse_layer=None, 63 | scope_name="seq_embeds_pooling", 64 | ): 65 | reuse = tf.AUTO_REUSE if reuse_layer else None 66 | with tf.variable_scope(scope_name, reuse=reuse): 67 | embed_var = tf.get_variable( 68 | name=var_name, 69 | shape=var_shape, 70 | initializer=initializer, 71 | regularizer=regularizer, 72 | ) 73 | # unknown items are padded to 0-vector 74 | embed_size = var_shape[1] 75 | zero_padding_op = tf.scatter_update( 76 | embed_var, n_items, tf.zeros([embed_size], dtype=tf.float32) 77 | ) 78 | with tf.control_dependencies([zero_padding_op]): 79 | # B * seq * K 80 | multi_item_embed = tf.nn.embedding_lookup(embed_var, seq_indices) 81 | 82 | return tf.div_no_nan( 83 | tf.reduce_sum(multi_item_embed, axis=1), 84 | tf.expand_dims(tf.sqrt(tf.cast(seq_lens, tf.float32)), axis=1), 85 | ) 86 | -------------------------------------------------------------------------------- /libreco/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.linalg 3 | 4 | from ..tfops import tf 5 | 6 | 7 | def layer_normalization(inputs, reuse_layer=False, scope_name="layer_norm"): 8 | reuse = tf.AUTO_REUSE if reuse_layer else None 9 | with tf.variable_scope(scope_name, reuse=reuse): 10 | dim = inputs.get_shape().as_list()[-1] 11 | scale = tf.get_variable("scale", shape=[dim], initializer=tf.ones_initializer()) 12 | bias = tf.get_variable("bias", shape=[dim], initializer=tf.zeros_initializer()) 13 | mean = tf.reduce_mean(inputs, axis=-1, keepdims=True) 14 | variance = tf.reduce_mean( 15 | tf.squared_difference(inputs, mean), axis=-1, keepdims=True 16 | ) 17 | outputs = (inputs - mean) * tf.rsqrt(variance + 1e-8) 18 | return outputs * scale + bias 19 | 20 | 21 | def rms_norm(inputs, reuse_layer=False, scope_name="rms_norm"): 22 | """Root mean square layer normalization.""" 23 | reuse = tf.AUTO_REUSE if reuse_layer else None 24 | with tf.variable_scope(scope_name, reuse=reuse): 25 | dim = inputs.get_shape().as_list()[-1] 26 | scale = tf.get_variable("scale", shape=[dim], initializer=tf.ones_initializer()) 27 | mean_square = tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) 28 | outputs = inputs * tf.rsqrt(mean_square + 1e-8) 29 | return outputs * scale 30 | 31 | 32 | def normalize_embeds(*embeds, backend): 33 | normed_embeds = [] 34 | for e in embeds: 35 | if backend == "tf": 36 | ne = tf.linalg.l2_normalize(e, axis=1) 37 | elif backend == "torch": 38 | norms = torch.linalg.norm(e, dim=1, keepdim=True) 39 | ne = e / norms 40 | else: 41 | norms = np.linalg.norm(e, axis=1, keepdims=True) 42 | ne = e / norms 43 | normed_embeds.append(ne) 44 | return normed_embeds[0] if len(embeds) == 1 else normed_embeds 45 | -------------------------------------------------------------------------------- /libreco/layers/recurrent.py: -------------------------------------------------------------------------------- 1 | from ..tfops import get_tf_version, tf 2 | 3 | 4 | def tf_rnn( 5 | inputs, 6 | rnn_type, 7 | lengths, 8 | maxlen, 9 | hidden_units, 10 | dropout_rate, 11 | use_ln, 12 | is_training, 13 | version=None, 14 | ): 15 | tf_version = get_tf_version(version) 16 | if tf_version >= "2.0.0": 17 | # cell_type = ( 18 | # tf.keras.layers.LSTMCell 19 | # if self.rnn_type.endswith("lstm") 20 | # else tf.keras.layers.GRUCell 21 | # ) 22 | # cells = [cell_type(size) for size in self.hidden_units] 23 | # masks = tf.sequence_mask(self.user_interacted_len, self.max_seq_len) 24 | # tf2_rnn = tf.keras.layers.RNN(cells, return_state=True) 25 | # output, *state = tf2_rnn(seq_item_embed, mask=masks) 26 | 27 | rnn_layer = ( 28 | tf.keras.layers.LSTM if rnn_type.endswith("lstm") else tf.keras.layers.GRU 29 | ) 30 | output = inputs 31 | masks = tf.sequence_mask(lengths, maxlen) 32 | for units in hidden_units: 33 | output = rnn_layer( 34 | units, 35 | return_sequences=True, 36 | dropout=dropout_rate, 37 | recurrent_dropout=dropout_rate, 38 | activation=None if use_ln else "tanh", 39 | )(output, mask=masks, training=is_training) 40 | 41 | if use_ln: 42 | output = tf.keras.layers.LayerNormalization()(output) 43 | output = tf.keras.activations.get("tanh")(output) 44 | 45 | return output[:, -1, :] 46 | 47 | else: 48 | cell_type = ( 49 | tf.nn.rnn_cell.LSTMCell 50 | if rnn_type.endswith("lstm") 51 | else tf.nn.rnn_cell.GRUCell 52 | ) 53 | cells = [cell_type(size) for size in hidden_units] 54 | stacked_cells = tf.nn.rnn_cell.MultiRNNCell(cells) 55 | zero_state = stacked_cells.zero_state(tf.shape(inputs)[0], dtype=tf.float32) 56 | _, state = tf.nn.dynamic_rnn( 57 | cell=stacked_cells, 58 | inputs=inputs, 59 | sequence_length=lengths, 60 | initial_state=zero_state, 61 | time_major=False, 62 | ) 63 | return state[-1][1] if rnn_type == "lstm" else state[-1] 64 | -------------------------------------------------------------------------------- /libreco/prediction/__init__.py: -------------------------------------------------------------------------------- 1 | from .predict import ( 2 | normalize_prediction, 3 | predict_data_with_feats, 4 | predict_from_embedding, 5 | predict_tf_feat, 6 | ) 7 | 8 | __all__ = [ 9 | "predict_data_with_feats", 10 | "predict_from_embedding", 11 | "predict_tf_feat", 12 | "normalize_prediction", 13 | ] 14 | -------------------------------------------------------------------------------- /libreco/recommendation/__init__.py: -------------------------------------------------------------------------------- 1 | from .cold_start import cold_start_rec, popular_recommendations 2 | from .ranking import rank_recommendations 3 | from .recommend import ( 4 | check_dynamic_rec_feats, 5 | construct_rec, 6 | recommend_from_embedding, 7 | recommend_tf_feat, 8 | ) 9 | 10 | __all__ = [ 11 | "check_dynamic_rec_feats", 12 | "cold_start_rec", 13 | "construct_rec", 14 | "popular_recommendations", 15 | "rank_recommendations", 16 | "recommend_from_embedding", 17 | "recommend_tf_feat", 18 | ] 19 | -------------------------------------------------------------------------------- /libreco/recommendation/cold_start.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def popular_recommendations(data_info, inner_id, n_rec): 5 | popular_recs = data_info.np_rng.choice(data_info.popular_items, n_rec) 6 | if inner_id: 7 | return np.array([data_info.item2id[i] for i in popular_recs]) 8 | else: 9 | return popular_recs 10 | 11 | 12 | def average_recommendations(data_info, default_recs, inner_id, n_rec): 13 | average_recs = data_info.np_rng.choice(default_recs, n_rec) 14 | if inner_id: 15 | return average_recs 16 | else: 17 | return np.array([data_info.id2item[i] for i in average_recs]) 18 | 19 | 20 | def cold_start_rec(data_info, default_recs, cold_start, users, n_rec, inner_id): 21 | if cold_start not in ("average", "popular"): 22 | raise ValueError(f"Unknown cold start strategy: {cold_start}") 23 | result_recs = dict() 24 | for u in users: 25 | if cold_start == "average": 26 | result_recs[u] = average_recommendations( 27 | data_info, default_recs, inner_id, n_rec 28 | ) 29 | elif cold_start == "popular": 30 | result_recs[u] = popular_recommendations(data_info, inner_id, n_rec) 31 | return result_recs 32 | -------------------------------------------------------------------------------- /libreco/recommendation/ranking.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.random import default_rng 3 | from scipy.special import expit, softmax 4 | 5 | # Numpy doc states that it is recommended to use new random API 6 | # https://numpy.org/doc/stable/reference/random/index.html 7 | np_rng = default_rng() 8 | 9 | 10 | def rank_recommendations( 11 | task, 12 | user_ids, 13 | model_preds, 14 | n_rec, 15 | n_items, 16 | user_consumed, 17 | filter_consumed=True, 18 | random_rec=False, 19 | return_scores=False, 20 | ): 21 | if n_rec > n_items: 22 | raise ValueError(f"`n_rec` {n_rec} exceeds num of items {n_items}") 23 | if model_preds.ndim == 1: 24 | assert len(model_preds) % n_items == 0 25 | batch_size = int(len(model_preds) / n_items) 26 | all_preds = model_preds.reshape(batch_size, n_items) 27 | else: 28 | batch_size = len(model_preds) 29 | all_preds = model_preds 30 | all_ids = np.tile(np.arange(n_items), (batch_size, 1)) 31 | 32 | batch_ids, batch_preds = [], [] 33 | for i in range(batch_size): 34 | user = user_ids[i] 35 | ids = all_ids[i] 36 | preds = all_preds[i] 37 | consumed = user_consumed[user] if user in user_consumed else [] 38 | if filter_consumed and consumed and n_rec + len(consumed) <= n_items: 39 | ids, preds = filter_items(ids, preds, consumed) 40 | if random_rec: 41 | ids, preds = random_select(ids, preds, n_rec) 42 | else: 43 | ids, preds = partition_select(ids, preds, n_rec) 44 | batch_ids.append(ids) 45 | batch_preds.append(preds) 46 | 47 | ids, preds = np.array(batch_ids), np.array(batch_preds) 48 | indices = np.argsort(preds, axis=1)[:, ::-1] 49 | ids = np.take_along_axis(ids, indices, axis=1) 50 | if return_scores: 51 | scores = np.take_along_axis(preds, indices, axis=1) 52 | if task == "ranking": 53 | scores = expit(scores) 54 | return ids, scores 55 | else: 56 | return ids 57 | 58 | 59 | def filter_items(ids, preds, items): 60 | mask = np.isin(ids, items, assume_unique=True, invert=True) 61 | return ids[mask], preds[mask] 62 | 63 | 64 | # add `**0.75` to lower probability of high score items 65 | def get_reco_probs(preds): 66 | p = np.power(softmax(preds), 0.75) + 1e-8 # avoid zero probs 67 | return p / p.sum() 68 | 69 | 70 | def random_select(ids, preds, n_rec): 71 | p = get_reco_probs(preds) 72 | mask = np_rng.choice(len(preds), n_rec, p=p, replace=False, shuffle=False) 73 | return ids[mask], preds[mask] 74 | 75 | 76 | def partition_select(ids, preds, n_rec): 77 | mask = np.argpartition(preds, -n_rec)[-n_rec:] 78 | return ids[mask], preds[mask] 79 | -------------------------------------------------------------------------------- /libreco/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | from .negatives import ( 2 | neg_probs_from_frequency, 3 | negatives_from_out_batch, 4 | negatives_from_popular, 5 | negatives_from_random, 6 | negatives_from_unconsumed, 7 | pos_probs_from_frequency, 8 | ) 9 | from .random_walks import ( 10 | bipartite_neighbors, 11 | bipartite_neighbors_with_weights, 12 | pairs_from_random_walk, 13 | ) 14 | 15 | __all__ = [ 16 | "bipartite_neighbors", 17 | "bipartite_neighbors_with_weights", 18 | "negatives_from_out_batch", 19 | "negatives_from_popular", 20 | "negatives_from_random", 21 | "negatives_from_unconsumed", 22 | "neg_probs_from_frequency", 23 | "pairs_from_random_walk", 24 | "pos_probs_from_frequency", 25 | ] 26 | -------------------------------------------------------------------------------- /libreco/tfops/__init__.py: -------------------------------------------------------------------------------- 1 | from .configs import ( 2 | attention_config, 3 | dropout_config, 4 | lr_decay_config, 5 | reg_config, 6 | sess_config, 7 | ) 8 | from .loss import choose_tf_loss 9 | from .rebuild import rebuild_tf_model 10 | from .variables import get_variable_from_graph, modify_variable_names, var_list_by_name 11 | from .version import TF_VERSION, get_tf_version, tf 12 | 13 | __all__ = [ 14 | "attention_config", 15 | "dropout_config", 16 | "get_variable_from_graph", 17 | "lr_decay_config", 18 | "reg_config", 19 | "sess_config", 20 | "rebuild_tf_model", 21 | "choose_tf_loss", 22 | "modify_variable_names", 23 | "var_list_by_name", 24 | "tf", 25 | "TF_VERSION", 26 | "get_tf_version", 27 | ] 28 | -------------------------------------------------------------------------------- /libreco/tfops/configs.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | from .version import tf 4 | 5 | 6 | def attention_config(att_embed_size): 7 | if not att_embed_size: 8 | att_embed_size = (8, 8, 8) 9 | att_layer_num = 3 10 | elif isinstance(att_embed_size, int): 11 | att_embed_size = [att_embed_size] 12 | att_layer_num = 1 13 | elif isinstance(att_embed_size, (list, tuple)): 14 | att_layer_num = len(att_embed_size) 15 | else: 16 | raise ValueError("att_embed_size must be int or list") 17 | return att_embed_size, att_layer_num 18 | 19 | 20 | def reg_config(reg): 21 | if not reg: 22 | return None 23 | elif isinstance(reg, float) and reg > 0.0: 24 | return tf.keras.regularizers.l2(reg) 25 | else: 26 | raise ValueError("reg must be float and positive...") 27 | 28 | 29 | def dropout_config(dropout_rate): 30 | if not dropout_rate: 31 | return 0.0 32 | elif dropout_rate <= 0.0 or dropout_rate >= 1.0: 33 | raise ValueError("dropout_rate must be in (0.0, 1.0)") 34 | else: 35 | return dropout_rate 36 | 37 | 38 | def lr_decay_config(initial_lr, default_decay_steps, **kwargs): 39 | decay_steps = kwargs.get("decay_steps", default_decay_steps) 40 | decay_rate = kwargs.get("decay_rate", 0.96) 41 | global_steps = tf.Variable(0, trainable=False, name="global_steps") 42 | learning_rate = tf.train.exponential_decay( 43 | initial_lr, global_steps, decay_steps, decay_rate, staircase=True 44 | ) 45 | return learning_rate, global_steps 46 | 47 | 48 | def sess_config(tf_sess_config=None): 49 | if not tf_sess_config: 50 | # Session config based on: 51 | # https://software.intel.com/content/www/us/en/develop/articles/tips-to-improve-performance-for-popular-deep-learning-frameworks-on-multi-core-cpus.html 52 | # https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/core/protobuf/config.proto#L452 53 | tf_sess_config = { 54 | "intra_op_parallelism_threads": 0, 55 | "inter_op_parallelism_threads": 0, 56 | "allow_soft_placement": True, 57 | "device_count": {"CPU": multiprocessing.cpu_count()}, 58 | } 59 | # os.environ["OMP_NUM_THREADS"] = f"{self.cpu_num}" 60 | 61 | config = tf.ConfigProto(**tf_sess_config) 62 | return tf.Session(config=config) 63 | -------------------------------------------------------------------------------- /libreco/tfops/version.py: -------------------------------------------------------------------------------- 1 | import tensorflow 2 | 3 | tf = tensorflow.compat.v1 4 | tf.disable_v2_behavior() 5 | 6 | TF_VERSION = tf.__version__ 7 | 8 | 9 | def get_tf_version(version): 10 | if version is not None: 11 | assert isinstance(version, str) 12 | return version 13 | else: 14 | return TF_VERSION 15 | -------------------------------------------------------------------------------- /libreco/torchops/__init__.py: -------------------------------------------------------------------------------- 1 | from .configs import device_config, hidden_units_config, set_torch_seed 2 | from .loss import ( 3 | binary_cross_entropy_loss, 4 | bpr_loss, 5 | compute_pair_scores, 6 | focal_loss, 7 | max_margin_loss, 8 | pairwise_bce_loss, 9 | pairwise_focal_loss, 10 | ) 11 | from .rebuild import rebuild_torch_model 12 | 13 | __all__ = [ 14 | "binary_cross_entropy_loss", 15 | "bpr_loss", 16 | "compute_pair_scores", 17 | "device_config", 18 | "hidden_units_config", 19 | "focal_loss", 20 | "max_margin_loss", 21 | "pairwise_bce_loss", 22 | "pairwise_focal_loss", 23 | "rebuild_torch_model", 24 | "set_torch_seed", 25 | ] 26 | -------------------------------------------------------------------------------- /libreco/torchops/configs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def device_config(device): 5 | if device == "cuda" and torch.cuda.is_available(): 6 | return torch.device("cuda") 7 | else: 8 | return torch.device("cpu") 9 | 10 | 11 | def hidden_units_config(hidden_units): 12 | if isinstance(hidden_units, int): 13 | return [hidden_units] 14 | elif not isinstance(hidden_units, (list, tuple)): 15 | raise ValueError( 16 | f"`hidden_units` must be one of (int, list of int, tuple of int), " 17 | f"got: {type(hidden_units)}, {hidden_units}" 18 | ) 19 | for i in hidden_units: 20 | if not isinstance(i, int): 21 | raise ValueError(f"`hidden_units` contains not int value: {hidden_units}") 22 | return list(hidden_units) 23 | 24 | 25 | def set_torch_seed(seed): 26 | torch.manual_seed(seed) 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | # torch.backends.cudnn.deterministic = True 31 | # torch.backends.cudnn.benchmark = False 32 | # torch.use_deterministic_algorithms(True) 33 | -------------------------------------------------------------------------------- /libreco/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/libreco/training/__init__.py -------------------------------------------------------------------------------- /libreco/training/dispatch.py: -------------------------------------------------------------------------------- 1 | from .tf_trainer import TensorFlowTrainer, WideDeepTrainer, YoutubeRetrievalTrainer 2 | from .torch_trainer import GraphTrainer, TorchTrainer 3 | from ..utils.constants import SageModels, TfTrainModels 4 | 5 | 6 | def get_trainer(model): 7 | train_params = { 8 | "model": model, 9 | "task": model.task, 10 | "loss_type": model.loss_type, 11 | "n_epochs": model.n_epochs, 12 | "lr": model.lr, 13 | "lr_decay": model.lr_decay, 14 | "epsilon": model.epsilon, 15 | "batch_size": model.batch_size, 16 | "sampler": model.sampler, 17 | "num_neg": model.__dict__.get("num_neg"), 18 | } 19 | 20 | if TfTrainModels.contains(model.model_name): 21 | if model.model_name == "YouTubeRetrieval": 22 | train_params["num_sampled_per_batch"] = model.num_sampled_per_batch 23 | tf_trainer_cls = YoutubeRetrievalTrainer 24 | elif model.model_name == "WideDeep": 25 | tf_trainer_cls = WideDeepTrainer 26 | else: 27 | tf_trainer_cls = TensorFlowTrainer 28 | return tf_trainer_cls(**train_params) 29 | else: 30 | train_params.update( 31 | { 32 | "amsgrad": model.amsgrad, 33 | "reg": model.reg, 34 | "margin": model.margin, 35 | "device": model.device, 36 | } 37 | ) 38 | if SageModels.contains(model.model_name): 39 | torch_trainer_cls = GraphTrainer 40 | else: 41 | torch_trainer_cls = TorchTrainer 42 | return torch_trainer_cls(**train_params) 43 | -------------------------------------------------------------------------------- /libreco/training/trainer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from ..batch import adjust_batch_size 4 | from ..utils.validate import is_listwise_training 5 | 6 | 7 | class BaseTrainer(abc.ABC): 8 | def __init__( 9 | self, 10 | model, 11 | task, 12 | loss_type, 13 | n_epochs, 14 | lr, 15 | lr_decay, 16 | epsilon, 17 | batch_size, 18 | sampler, 19 | num_neg, 20 | ): 21 | self.model = model 22 | self.task = task 23 | self.loss_type = loss_type 24 | self.n_epochs = n_epochs 25 | self.lr = lr 26 | self.lr_decay = lr_decay 27 | self.epsilon = epsilon 28 | self.batch_size = adjust_batch_size(model, batch_size) 29 | self.sampler = sampler 30 | self.num_neg = num_neg 31 | 32 | def _check_params(self): 33 | if not is_listwise_training(self.model): 34 | n_items = self.model.data_info.n_items 35 | assert 0 < self.num_neg < n_items, ( 36 | f"`num_neg` should be positive and smaller than total items, " 37 | f"got {self.num_neg}, {n_items}" 38 | ) 39 | if self.sampler not in ("random", "unconsumed", "popular"): 40 | raise ValueError( 41 | f"`sampler` must be one of (`random`, `unconsumed`, `popular`), " 42 | f"got {self.sampler}" 43 | ) 44 | 45 | @abc.abstractmethod 46 | def run(self, *args, **kwargs): 47 | raise NotImplementedError 48 | -------------------------------------------------------------------------------- /libreco/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/libreco/utils/__init__.py -------------------------------------------------------------------------------- /libreco/utils/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, unique 2 | 3 | 4 | class StrEnum(str, Enum): 5 | @classmethod 6 | def contains(cls, x): 7 | return x in cls.__members__.values() # cls._member_names_ 8 | 9 | 10 | @unique 11 | class FeatModels(StrEnum): 12 | WIDEDEEP = "WideDeep" 13 | FM = "FM" 14 | DEEPFM = "DeepFM" 15 | YOUTUBERETRIEVAL = "YouTubeRetrieval" 16 | YOUTUBERANKING = "YouTubeRanking" 17 | AUTOINT = "AutoInt" 18 | DIN = "DIN" 19 | GRAPHSAGE = "GraphSage" 20 | GRAPHSAGEDGL = "GraphSageDGL" 21 | PINSAGE = "PinSage" 22 | PINSAGEDGL = "PinSageDGL" 23 | TWOTOWER = "TwoTower" 24 | TRANSFORMER = "Transformer" 25 | SIM = "SIM" 26 | 27 | 28 | @unique 29 | class SequenceModels(StrEnum): 30 | YOUTUBERETRIEVAL = "YouTubeRetrieval" 31 | YOUTUBERANKING = "YouTubeRanking" 32 | DIN = "DIN" 33 | RNN4REC = "RNN4Rec" 34 | CASER = "Caser" 35 | WAVENET = "WaveNet" 36 | TRANSFORMER = "Transformer" 37 | SIM = "SIM" 38 | 39 | 40 | @unique 41 | class TfTrainModels(StrEnum): 42 | SVD = "SVD" 43 | SVDPP = "SVDpp" 44 | NCF = "NCF" 45 | BPR = "BPR" 46 | WIDEDEEP = "WideDeep" 47 | FM = "FM" 48 | DEEPFM = "DeepFM" 49 | YOUTUBERETRIEVAL = "YouTubeRetrieval" 50 | YOUTUBERANKING = "YouTubeRanking" 51 | AUTOINT = "AutoInt" 52 | DIN = "DIN" 53 | RNN4REC = "RNN4Rec" 54 | CASER = "Caser" 55 | WAVENET = "WaveNet" 56 | TWOTOWER = "TwoTower" 57 | TRANSFORMER = "Transformer" 58 | SIM = "SIM" 59 | 60 | 61 | @unique 62 | class EmbeddingModels(StrEnum): 63 | SVD = "SVD" 64 | SVDPP = "SVDpp" 65 | ALS = "ALS" 66 | BPR = "BPR" 67 | YOUTUBERETRIEVAL = "YouTubeRetrieval" 68 | ITEM2VEC = "Item2Vec" 69 | RNN4REC = "RNN4Rec" 70 | CASER = "Caser" 71 | WAVENET = "WaveNet" 72 | DEEPWALK = "DeepWalk" 73 | NGCF = "NGCF" 74 | LIGHTGCN = "LightGCN" 75 | GRAPHSAGE = "GraphSage" 76 | GRAPHSAGEDGL = "GraphSageDGL" 77 | PINSAGE = "PinSage" 78 | PINSAGEDGL = "PinSageDGL" 79 | TWOTOWER = "TwoTower" 80 | 81 | 82 | @unique 83 | class SageModels(StrEnum): 84 | GRAPHSAGE = "GraphSage" 85 | GRAPHSAGEDGL = "GraphSageDGL" 86 | PINSAGE = "PinSage" 87 | PINSAGEDGL = "PinSageDGL" 88 | 89 | 90 | @unique 91 | class UserEmbedModels(StrEnum): 92 | """Models can only generate user embeddings dynamically.""" 93 | 94 | YOUTUBERETRIEVAL = "YouTubeRetrieval" 95 | RNN4REC = "RNN4Rec" 96 | CASER = "Caser" 97 | WAVENET = "WaveNet" 98 | -------------------------------------------------------------------------------- /libreco/utils/exception.py: -------------------------------------------------------------------------------- 1 | class NotSamplingError(Exception): 2 | """Exception related to sampling data 3 | 4 | If client wants to use batch_sampling and then evaluation on the dataset, 5 | but forgot to do whole data sampling beforehand, this exception will be 6 | raised. Because in this case, unsampled data can't be evaluated. 7 | """ 8 | 9 | pass 10 | -------------------------------------------------------------------------------- /libreco/utils/initializers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def truncated_normal( 5 | np_rng: np.random.Generator, 6 | shape, 7 | mean=0.0, 8 | scale=0.05, 9 | tolerance=5, 10 | ): 11 | # total_num = np.multiply(*shape) 12 | total_num = shape if len(shape) == 1 else np.multiply(*shape) 13 | array = np_rng.normal(mean, scale, total_num).astype(np.float32) 14 | upper_limit, lower_limit = mean + 2 * scale, mean - 2 * scale 15 | for _ in range(tolerance): 16 | index = np.logical_or((array > upper_limit), (array < lower_limit)) 17 | num = len(np.where(index)[0]) 18 | if num == 0: 19 | break 20 | array[index] = np_rng.normal(mean, scale, num) 21 | return array.reshape(*shape) 22 | 23 | 24 | def xavier_init(np_rng, fan_in, fan_out): 25 | std = np.sqrt(2.0 / (fan_in + fan_out)) 26 | return truncated_normal(np_rng, mean=0.0, scale=std, shape=[fan_in, fan_out]) 27 | 28 | 29 | def he_init(np_rng, fan_in, fan_out): 30 | std = 2.0 / np.sqrt(fan_in + fan_out) 31 | # std = np.sqrt(2.0 / fan_in) 32 | return truncated_normal(np_rng, mean=0.0, scale=std, shape=[fan_in, fan_out]) 33 | 34 | 35 | def variance_scaling(np_rng, scale, fan_in=None, fan_out=None, mode="fan_in"): 36 | """ 37 | xavier: mode = "fan_average", scale = 1.0 38 | he: mode = "fan_in", scale = 2.0 39 | he2: mode = "fan_average", scale = 2.0 40 | """ 41 | if mode == "fan_in": 42 | std = np.sqrt(scale / fan_in) 43 | elif mode == "fan_out": 44 | std = np.sqrt(scale / fan_out) 45 | elif mode == "fan_average": 46 | std = np.sqrt(2.0 * scale / (fan_in + fan_out)) 47 | else: 48 | raise ValueError("mode must be one of these: fan_in, fan_out, fan_average") 49 | return truncated_normal(np_rng, mean=0.0, scale=std, shape=[fan_in, fan_out]) 50 | -------------------------------------------------------------------------------- /libreco/utils/sparse.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | from scipy.sparse import csr_matrix 5 | 6 | 7 | @dataclass 8 | class SparseMatrix: 9 | sparse_indices: List[int] 10 | sparse_indptr: List[int] 11 | sparse_data: List[float] 12 | 13 | 14 | def build_sparse(matrix: csr_matrix, transpose: bool = False): 15 | m = matrix.T.tocsr() if transpose else matrix 16 | return SparseMatrix( 17 | m.indices.tolist(), 18 | m.indptr.tolist(), 19 | m.data.tolist(), 20 | ) 21 | -------------------------------------------------------------------------------- /libserving/.dockerignore: -------------------------------------------------------------------------------- 1 | */*/bin 2 | */target 3 | */request* 4 | */*/utils/*22* -------------------------------------------------------------------------------- /libserving/Dockerfile-py: -------------------------------------------------------------------------------- 1 | FROM python:3.7-slim 2 | 3 | WORKDIR /app 4 | 5 | RUN pip install --no-cache-dir numpy==1.19.5 -i https://pypi.tuna.tsinghua.edu.cn/simple 6 | RUN pip install --no-cache-dir sanic==22.6.2 -i https://pypi.tuna.tsinghua.edu.cn/simple 7 | RUN pip install --no-cache-dir aiohttp==3.8.1 -i https://pypi.tuna.tsinghua.edu.cn/simple 8 | RUN pip install --no-cache-dir pydantic==1.9.1 -i https://pypi.tuna.tsinghua.edu.cn/simple 9 | RUN pip install --no-cache-dir ujson==5.4.0 -i https://pypi.tuna.tsinghua.edu.cn/simple 10 | RUN pip install --no-cache-dir redis==4.3.4 -i https://pypi.tuna.tsinghua.edu.cn/simple 11 | RUN pip install --no-cache-dir faiss-cpu==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/simple 12 | 13 | COPY sanic_serving /app/sanic_serving 14 | 15 | ENV PYTHONPATH=/app 16 | 17 | EXPOSE 8000 18 | -------------------------------------------------------------------------------- /libserving/Dockerfile-rs: -------------------------------------------------------------------------------- 1 | FROM debian:bullseye-slim AS faiss-builder 2 | 3 | WORKDIR /cmake 4 | # install blas used in faiss 5 | RUN echo "deb http://mirrors.tuna.tsinghua.edu.cn/debian/ bullseye main contrib non-free" >/etc/apt/sources.list && \ 6 | apt-get update && \ 7 | apt-get install -y gcc g++ make wget git libblas-dev liblapack-dev && \ 8 | apt-get clean && \ 9 | rm -rf /var/lib/apt/lists/* 10 | 11 | # install cmake to build faiss 12 | RUN wget https://cmake.org/files/LatestRelease/cmake-3.25.0-linux-x86_64.tar.gz 13 | RUN tar -zxf cmake-3.25.0-linux-x86_64.tar.gz -C /cmake --strip-components 1 14 | RUN ln -s /cmake/bin/cmake /usr/bin/cmake 15 | RUN cmake --version 16 | 17 | WORKDIR /faiss 18 | # clone branch `c_api_head` in faiss repository 19 | RUN git clone -b c_api_head https://github.com/Enet4/faiss.git . 20 | # COPY ./faiss /faiss 21 | RUN cmake -B build . \ 22 | -DFAISS_ENABLE_C_API=ON \ 23 | -DBUILD_SHARED_LIBS=ON \ 24 | -DCMAKE_BUILD_TYPE=Release \ 25 | -DFAISS_ENABLE_GPU=OFF \ 26 | -DFAISS_ENABLE_PYTHON=OFF \ 27 | -DBUILD_TESTING=OFF 28 | RUN make -C build/c_api 29 | 30 | FROM rust:1.64-slim-bullseye AS rust-builder 31 | 32 | WORKDIR /serving_build 33 | 34 | RUN echo "deb http://mirrors.tuna.tsinghua.edu.cn/debian/ bullseye main contrib non-free" >/etc/apt/sources.list && \ 35 | apt-get update && \ 36 | apt-get install -y libblas-dev liblapack-dev && \ 37 | apt-get clean && \ 38 | rm -rf /var/lib/apt/lists/* 39 | 40 | # cache crate index 41 | COPY crate-index-config /usr/local/cargo/config 42 | RUN cargo init 43 | COPY actix_serving/Cargo.toml actix_serving/Cargo.lock /serving_build/ 44 | RUN cargo fetch 45 | 46 | COPY actix_serving/src /serving_build/src 47 | COPY --from=faiss-builder /faiss/build/c_api/libfaiss_c.so /usr/lib 48 | COPY --from=faiss-builder /faiss/build/faiss/libfaiss.so /usr/lib 49 | ENV LD_LIBRARY_PATH=/usr/lib 50 | RUN cargo build --release 51 | 52 | FROM debian:bullseye-slim 53 | 54 | WORKDIR /app 55 | 56 | # need gcc & blas for faiss 57 | RUN echo "deb http://mirrors.tuna.tsinghua.edu.cn/debian/ bullseye main contrib non-free" >/etc/apt/sources.list && \ 58 | apt-get update && \ 59 | apt-get install -y gcc libblas-dev liblapack-dev && \ 60 | apt-get clean && \ 61 | rm -rf /var/lib/apt/lists/* 62 | 63 | COPY --from=faiss-builder /faiss/build/c_api/libfaiss_c.so /usr/lib 64 | COPY --from=faiss-builder /faiss/build/faiss/libfaiss.so /usr/lib 65 | ENV LD_LIBRARY_PATH=/usr/lib 66 | 67 | COPY --from=rust-builder /serving_build/target/release/actix_serving /app 68 | 69 | USER 1001 70 | 71 | EXPOSE 8080 72 | 73 | CMD ["/app/actix_serving"] 74 | -------------------------------------------------------------------------------- /libserving/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.5.1" 2 | -------------------------------------------------------------------------------- /libserving/actix_serving/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "actix_serving" 3 | version = "1.2.0" 4 | edition = "2021" 5 | description = "Online model serving for LibRecommender" 6 | license = "MIT" 7 | 8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 9 | 10 | [dependencies] 11 | actix-web = "4.3" 12 | clap = { version = "4.1.8", features = ["derive"] } 13 | deadpool-redis = "0.11.1" 14 | env_logger = "0.9.2" 15 | faiss = "0.11.0" 16 | fnv = "1.0.7" 17 | futures = "0.3.25" 18 | log = "0.4.15" 19 | num_cpus = "1.0" 20 | once_cell = "1.18" 21 | prost = "0.11" 22 | # openssl = { version = "0.10", features = ["vendored"] } 23 | redis = { version = "0.22.1", features = ["connection-manager", "tokio-comp"] } 24 | reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] } 25 | serde = { version = "1.0", features = ["derive"] } 26 | serde_json = "1.0" 27 | serde_with = "3.0" 28 | thiserror = "1.0" 29 | tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } 30 | tonic = "0.9" 31 | walkdir = "2.3.2" 32 | 33 | [build-dependencies] 34 | tonic-build = "0.9" 35 | 36 | [dev-dependencies] 37 | assert_cmd = "2.0.8" 38 | pretty_assertions = "1.0" 39 | 40 | [profile.release] 41 | strip = true 42 | opt-level = 3 43 | lto = true 44 | codegen-units = 1 45 | -------------------------------------------------------------------------------- /libserving/actix_serving/build.rs: -------------------------------------------------------------------------------- 1 | fn main() -> Result<(), Box> { 2 | tonic_build::compile_protos("proto/recommend.proto")?; 3 | 4 | tonic_build::configure() 5 | .build_server(false) 6 | .compile( 7 | &["proto/tensorflow_serving/apis/prediction_service.proto"], 8 | &["proto"], 9 | )?; 10 | 11 | Ok(()) 12 | } 13 | -------------------------------------------------------------------------------- /libserving/actix_serving/proto/recommend.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package recommend; 4 | 5 | service Recommend { 6 | rpc GetRecommendation(RecRequest) returns (RecResponse); 7 | } 8 | 9 | message RecRequest { 10 | string user = 1; 11 | int32 n_rec = 2; 12 | map user_feats = 3; 13 | repeated int32 seq = 4; 14 | } 15 | 16 | message RecResponse { 17 | repeated string items = 1; 18 | } 19 | 20 | message Feature { 21 | oneof value { 22 | string string_val = 1; 23 | int32 int_val = 2; 24 | float float_val = 3; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /libserving/actix_serving/proto/tensorflow/core/framework/resource_handle.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | 5 | import "tensorflow/core/framework/tensor_shape.proto"; 6 | import "tensorflow/core/framework/types.proto"; 7 | 8 | option cc_enable_arenas = true; 9 | option java_outer_classname = "ResourceHandle"; 10 | option java_multiple_files = true; 11 | option java_package = "org.tensorflow.framework"; 12 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/resource_handle_go_proto"; 13 | 14 | // Protocol buffer representing a handle to a tensorflow resource. Handles are 15 | // not valid across executions, but can be serialized back and forth from within 16 | // a single run. 17 | message ResourceHandleProto { 18 | // Unique name for the device containing the resource. 19 | string device = 1; 20 | 21 | // Container in which this resource is placed. 22 | string container = 2; 23 | 24 | // Unique name of this resource. 25 | string name = 3; 26 | 27 | // Hash code for the type of the resource. Is only valid in the same device 28 | // and in the same execution. 29 | uint64 hash_code = 4; 30 | 31 | // For debug-only, the name of the type pointed to by this handle, if 32 | // available. 33 | string maybe_type_name = 5; 34 | 35 | // Protocol buffer representing a pair of (data type, tensor shape). 36 | message DtypeAndShape { 37 | DataType dtype = 1; 38 | TensorShapeProto shape = 2; 39 | } 40 | 41 | // Data types and shapes for the underlying resource. 42 | repeated DtypeAndShape dtypes_and_shapes = 6; 43 | 44 | reserved 7; 45 | } -------------------------------------------------------------------------------- /libserving/actix_serving/proto/tensorflow/core/framework/tensor_shape.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffer representing the shape of tensors. 2 | 3 | syntax = "proto3"; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "TensorShapeProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.tensorflow.framework"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto"; 10 | 11 | package tensorflow; 12 | 13 | // Dimensions of a tensor. 14 | message TensorShapeProto { 15 | // One dimension of the tensor. 16 | message Dim { 17 | // Size of the tensor in that dimension. 18 | // This value must be >= -1, but values of -1 are reserved for "unknown" 19 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers 20 | // that work with TensorShapeProto may fail at runtime when deserializing 21 | // a TensorShapeProto containing a dim value of -1. 22 | int64 size = 1; 23 | 24 | // Optional name of the tensor dimension. 25 | string name = 2; 26 | }; 27 | 28 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40} 29 | // for a 30 x 40 2D tensor. If an entry has size -1, this 30 | // corresponds to a dimension of unknown size. The names are 31 | // optional. 32 | // 33 | // The order of entries in "dim" matters: It indicates the layout of the 34 | // values in the tensor in-memory representation. 35 | // 36 | // The first entry in "dim" is the outermost dimension used to layout the 37 | // values, the last entry is the innermost dimension. This matches the 38 | // in-memory layout of RowMajor Eigen tensors. 39 | // 40 | // If "dim.size()" > 0, "unknown_rank" must be false. 41 | repeated Dim dim = 2; 42 | 43 | // If true, the number of dimensions in the shape is unknown. 44 | // 45 | // If true, "dim.size()" must be 0. 46 | bool unknown_rank = 3; 47 | }; -------------------------------------------------------------------------------- /libserving/actix_serving/proto/tensorflow_serving/apis/model.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow.serving; 4 | option cc_enable_arenas = true; 5 | 6 | import "google/protobuf/wrappers.proto"; 7 | 8 | // Metadata for an inference request such as the model name and version. 9 | message ModelSpec { 10 | // Required servable name. 11 | string name = 1; 12 | 13 | // Optional choice of which version of the model to use. 14 | // 15 | // Recommended to be left unset in the common case. Should be specified only 16 | // when there is a strong version consistency requirement. 17 | // 18 | // When left unspecified, the system will serve the best available version. 19 | // This is typically the latest version, though during version transitions, 20 | // notably when serving on a fleet of instances, may be either the previous or 21 | // new version. 22 | oneof version_choice { 23 | // Use this specific version number. 24 | google.protobuf.Int64Value version = 2; 25 | 26 | // Use the version associated with the given label. 27 | string version_label = 4; 28 | } 29 | 30 | // A named signature to evaluate. If unspecified, the default signature will 31 | // be used. 32 | string signature_name = 3; 33 | } -------------------------------------------------------------------------------- /libserving/actix_serving/proto/tensorflow_serving/apis/predict.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow.serving; 4 | 5 | import "tensorflow/core/framework/tensor.proto"; 6 | import "tensorflow_serving/apis/model.proto"; 7 | 8 | option cc_enable_arenas = true; 9 | 10 | // PredictRequest specifies which TensorFlow model to run, as well as 11 | // how inputs are mapped to tensors and how outputs are filtered before 12 | // returning to user. 13 | message PredictRequest { 14 | // Model Specification. If version is not specified, will use the latest 15 | // (numerical) version. 16 | ModelSpec model_spec = 1; 17 | 18 | // Input tensors. 19 | // Names of input tensor are alias names. The mapping from aliases to real 20 | // input tensor names is stored in the SavedModel export as a prediction 21 | // SignatureDef under the 'inputs' field. 22 | map inputs = 2; 23 | 24 | // Output filter. 25 | // Names specified are alias names. The mapping from aliases to real output 26 | // tensor names is stored in the SavedModel export as a prediction 27 | // SignatureDef under the 'outputs' field. 28 | // Only tensors specified here will be run/fetched and returned, with the 29 | // exception that when none is specified, all tensors specified in the 30 | // named signature will be run/fetched and returned. 31 | repeated string output_filter = 3; 32 | 33 | // Reserved field 4. 34 | reserved 4; 35 | } 36 | 37 | // Response for PredictRequest on successful run. 38 | message PredictResponse { 39 | // Effective Model Specification used to process PredictRequest. 40 | ModelSpec model_spec = 2; 41 | 42 | // Output tensors. 43 | map outputs = 1; 44 | } -------------------------------------------------------------------------------- /libserving/actix_serving/proto/tensorflow_serving/apis/prediction_service.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow.serving; 4 | 5 | option cc_enable_arenas = true; 6 | 7 | import "tensorflow_serving/apis/predict.proto"; 8 | 9 | // open source marker; do not remove 10 | // PredictionService provides access to machine-learned models loaded by 11 | // model_servers. 12 | service PredictionService { 13 | // Predict -- provides access to loaded TensorFlow model. 14 | rpc Predict(PredictRequest) returns (PredictResponse); 15 | } 16 | -------------------------------------------------------------------------------- /libserving/actix_serving/rustfmt.toml: -------------------------------------------------------------------------------- 1 | array_width = 50 2 | chain_width = 50 3 | edition = "2021" 4 | group_imports = "StdExternalCrate" 5 | imports_granularity = "Module" 6 | max_width = 100 7 | reorder_imports = true 8 | reorder_modules = true 9 | use_field_init_shorthand = true 10 | -------------------------------------------------------------------------------- /libserving/actix_serving/src/bin/realtime.rs: -------------------------------------------------------------------------------- 1 | use actix_web::{middleware::Logger, web, App, HttpServer}; 2 | use once_cell::sync::Lazy; 3 | 4 | use actix_serving::common::get_env; 5 | use actix_serving::online_serving; 6 | use actix_serving::redis_ops::create_redis_pool; 7 | use actix_serving::tf_deploy::{init_tf_state, TfAppState}; 8 | 9 | static TF_STATE: Lazy> = Lazy::new(|| web::Data::new(init_tf_state())); 10 | 11 | #[actix_web::main] 12 | async fn main() -> std::io::Result<()> { 13 | let (redis_host, port, workers, log_level) = get_env()?; 14 | std::env::set_var("RUST_LOG", log_level); 15 | env_logger::init(); 16 | 17 | let redis_pool = web::Data::new(create_redis_pool(redis_host)?); 18 | HttpServer::new(move || { 19 | App::new() 20 | .wrap(Logger::default()) 21 | .app_data(redis_pool.clone()) 22 | .app_data(web::Data::clone(&TF_STATE)) 23 | .service(online_serving) 24 | }) 25 | .workers(workers) 26 | .bind(("0.0.0.0", port))? 27 | .run() 28 | .await 29 | } 30 | -------------------------------------------------------------------------------- /libserving/actix_serving/src/bin/realtime_grpc_client.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use actix_serving::online_deploy_grpc::recommend_proto::recommend_client::RecommendClient; 4 | use actix_serving::online_deploy_grpc::recommend_proto::{ 5 | feature::Value as FeatValue, Feature, RecRequest, 6 | }; 7 | 8 | #[tokio::main] 9 | async fn main() -> Result<(), Box> { 10 | let mut client = RecommendClient::connect("http://[::1]:50051").await?; 11 | 12 | let user = String::from("1"); 13 | let n_rec = 11; 14 | let feature_sparse = ( 15 | String::from("sex"), 16 | Feature { 17 | value: Some(FeatValue::StringVal(String::from("F"))), 18 | }, 19 | ); 20 | let feature_dense = ( 21 | String::from("age"), 22 | Feature { 23 | value: Some(FeatValue::IntVal(33)), 24 | }, 25 | ); 26 | let request = RecRequest { 27 | user, 28 | n_rec, 29 | user_feats: HashMap::from([feature_sparse, feature_dense]), 30 | seq: vec![1, 2, 3], 31 | }; 32 | 33 | let response = client 34 | .get_recommendation(tonic::Request::new(request)) 35 | .await?; 36 | 37 | println!("rec for user: {:?}", response.into_inner().items); 38 | Ok(()) 39 | } 40 | -------------------------------------------------------------------------------- /libserving/actix_serving/src/bin/realtime_grpc_server.rs: -------------------------------------------------------------------------------- 1 | use tonic::transport::Server; 2 | 3 | use actix_serving::online_deploy_grpc::recommend_proto::recommend_server::RecommendServer; 4 | use actix_serving::online_deploy_grpc::RecommendService; 5 | use actix_serving::redis_ops; 6 | 7 | #[tokio::main(worker_threads = 4)] 8 | async fn main() -> Result<(), Box> { 9 | std::env::set_var("RUST_LOG", "info"); 10 | env_logger::init(); 11 | 12 | let addr = "[::1]:50051".parse()?; 13 | let redis_pool = redis_ops::create_redis_pool(String::from("127.0.0.1")) 14 | .expect("Failed to connect to redis pool"); 15 | let service = RecommendService { redis_pool }; 16 | 17 | Server::builder() 18 | .add_service(RecommendServer::new(service)) 19 | .serve(addr) 20 | .await?; 21 | 22 | Ok(()) 23 | } 24 | -------------------------------------------------------------------------------- /libserving/actix_serving/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod embed_deploy; 2 | pub mod knn_deploy; 3 | pub mod online_deploy; 4 | #[allow(clippy::too_many_arguments)] 5 | pub mod online_deploy_grpc; 6 | pub mod tf_deploy; 7 | pub mod utils; 8 | 9 | pub use embed_deploy::embed_serving; 10 | pub use knn_deploy::knn_serving; 11 | pub use online_deploy::online_serving; 12 | pub use tf_deploy::tf_serving; 13 | pub use utils::common; 14 | pub use utils::constants; 15 | pub use utils::errors; 16 | pub use utils::faiss; 17 | pub use utils::features; 18 | pub use utils::redis_ops; 19 | -------------------------------------------------------------------------------- /libserving/actix_serving/src/utils/common.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | use serde_json::Value; 5 | 6 | use crate::errors::ServingError; 7 | 8 | #[derive(Serialize, Deserialize)] 9 | pub struct Payload { 10 | pub user: String, 11 | pub n_rec: usize, 12 | } 13 | 14 | #[derive(Serialize, Deserialize)] 15 | pub struct RealtimePayload { 16 | pub user: String, 17 | pub n_rec: usize, 18 | pub user_feats: Option>, 19 | pub seq: Option>, 20 | } 21 | 22 | #[derive(Debug, Serialize, Deserialize)] 23 | pub struct Recommendation { 24 | pub rec_list: Vec, 25 | } 26 | 27 | #[derive(Deserialize)] 28 | pub struct Prediction { 29 | pub outputs: Vec, 30 | } 31 | 32 | #[derive(Debug, Deserialize)] 33 | pub struct RankedItems { 34 | pub outputs: Vec, 35 | } 36 | 37 | pub fn get_env() -> Result<(String, u16, usize, String), ServingError> { 38 | let host = std::env::var("REDIS_HOST").unwrap_or_else(|_| String::from("127.0.0.1")); 39 | let port = std::env::var("PORT") 40 | .map_err(|e| ServingError::EnvError(e, "PORT"))? 41 | .parse::()?; 42 | let workers = std::env::var("WORKERS").map_or(Ok(4), |w| w.parse::())?; 43 | let log_level = std::env::var("RUST_LOG").unwrap_or_else(|_| String::from("info")); 44 | Ok((host, port, workers, log_level)) 45 | } 46 | -------------------------------------------------------------------------------- /libserving/actix_serving/src/utils/constants.rs: -------------------------------------------------------------------------------- 1 | use crate::redis_ops::RedisFeatKeys; 2 | 3 | pub const KNN_MODELS: [&'static str; 2] = ["UserCF", "ItemCF"]; 4 | 5 | pub const EMBED_MODELS: [&'static str; 14] = [ 6 | "SVD", 7 | "SVDpp", 8 | "ALS", 9 | "BPR", 10 | "YouTubeRetrieval", 11 | "Item2Vec", 12 | "RNN4Rec", 13 | "Caser", 14 | "WaveNet", 15 | "DeepWalk", 16 | "NGCF", 17 | "LightGCN", 18 | "PinSage", 19 | "PinSageDGL", 20 | ]; 21 | 22 | pub const CROSS_FEAT_MODELS: [&'static str; 6] = [ 23 | "WideDeep", 24 | "FM", 25 | "DeepFM", 26 | "YouTubeRanking", 27 | "AutoInt", 28 | "DIN", 29 | ]; 30 | 31 | pub const SEQ_EMBED_MODELS: [&'static str; 3] = ["RNN4Rec", "Caser", "WaveNet"]; 32 | 33 | pub const USER_ID_EMBED_MODELS: [&'static str; 2] = ["Caser", "WaveNet"]; 34 | 35 | pub const SEPARATE_FEAT_MODELS: [&'static str; 1] = ["TwoTower"]; 36 | 37 | pub const SPARSE_SEQ_MODELS: [&'static str; 1] = ["YouTubeRetrieval"]; 38 | 39 | pub const CROSS_SEQ_MODELS: [&'static str; 2] = ["YouTubeRanking", "DIN"]; 40 | 41 | pub const SPARSE_REDIS_KEYS: RedisFeatKeys = RedisFeatKeys { 42 | user_index: "user_sparse_col_index", 43 | item_index: "item_sparse_col_index", 44 | user_value: "user_sparse_values", 45 | item_value: "item_sparse_values", 46 | }; 47 | 48 | pub const DENSE_REDIS_KEYS: RedisFeatKeys = RedisFeatKeys { 49 | user_index: "user_dense_col_index", 50 | item_index: "item_dense_col_index", 51 | user_value: "user_dense_values", 52 | item_value: "item_dense_values", 53 | }; 54 | -------------------------------------------------------------------------------- /libserving/actix_serving/src/utils/errors.rs: -------------------------------------------------------------------------------- 1 | use actix_web::http::StatusCode; 2 | use actix_web::HttpResponse; 3 | 4 | pub type ServingResult = std::result::Result; 5 | 6 | #[derive(thiserror::Error, Debug)] 7 | pub enum ServingError { 8 | #[error("error: failed to get environment variable `{1}`")] 9 | EnvError(#[source] std::env::VarError, &'static str), 10 | #[error("faiss error: {0}")] 11 | FaissError(#[source] faiss::error::Error), 12 | #[error(transparent)] 13 | IoError(#[from] std::io::Error), 14 | #[error(transparent)] 15 | JsonParseError(#[from] serde_json::Error), 16 | #[error("error: `{0}` doesn't exist in redis")] 17 | NotExist(&'static str), 18 | #[error("error: `{0}` not found")] 19 | NotFound(&'static str), 20 | #[error("error: {0}")] 21 | Other(&'static str), 22 | #[error(transparent)] 23 | ParseError(#[from] std::num::ParseIntError), 24 | #[error("error: redis error, {0}")] 25 | RedisError(#[from] redis::RedisError), 26 | #[error("error: failed to create redis pool, {0}")] 27 | RedisCreatePoolError(#[from] deadpool_redis::CreatePoolError), 28 | #[error("error: failed to get redis pool, {0}")] 29 | RedisGetPoolError(#[from] deadpool_redis::PoolError), 30 | #[error("error: failed to execute tokio blocking task, {0}")] 31 | TaskError(#[from] tokio::task::JoinError), 32 | #[error("error: failed to get prediction from tf serving, {0}")] 33 | TfServingError(#[from] reqwest::Error), 34 | #[error("error: request timeout")] 35 | Timeout, 36 | #[error("error: unknown model `{0}`")] 37 | UnknownModel(String), 38 | } 39 | 40 | impl actix_web::error::ResponseError for ServingError { 41 | fn status_code(&self) -> StatusCode { 42 | match *self { 43 | ServingError::NotExist(_) => StatusCode::BAD_REQUEST, 44 | ServingError::Timeout => StatusCode::REQUEST_TIMEOUT, 45 | ServingError::TfServingError(_) => StatusCode::GATEWAY_TIMEOUT, 46 | _ => StatusCode::INTERNAL_SERVER_ERROR, 47 | } 48 | } 49 | 50 | fn error_response(&self) -> HttpResponse { 51 | HttpResponse::build(self.status_code()).body(self.to_string()) 52 | } 53 | } 54 | 55 | impl From for std::io::Error { 56 | fn from(e: ServingError) -> Self { 57 | std::io::Error::new(std::io::ErrorKind::Other, e.to_string()) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /libserving/actix_serving/src/utils/faiss.rs: -------------------------------------------------------------------------------- 1 | use std::path::{Component::RootDir, Path, PathBuf}; 2 | 3 | use walkdir::WalkDir; 4 | 5 | use crate::errors::{ServingError, ServingResult}; 6 | 7 | pub(crate) fn find_index_path(path: Option) -> ServingResult { 8 | let cur_dir = path.map_or(std::env::current_dir()?, PathBuf::from); 9 | // search in two level parent directory 10 | let dual_parent = cur_dir 11 | .parent() 12 | .and_then(|p| p.parent()) 13 | .unwrap_or_else(|| Path::new(RootDir.as_os_str())); 14 | let walk_dirs = WalkDir::new(dual_parent) 15 | .into_iter() 16 | .filter_map(|d| d.ok()); 17 | for entry in walk_dirs { 18 | let file_name = entry.file_name().to_string_lossy(); 19 | if file_name.starts_with("faiss_index") && !entry.path().is_dir() { 20 | log::info!("Found faiss index in {}", entry.path().display()); 21 | return Ok(entry.path().to_string_lossy().into_owned()); 22 | } 23 | } 24 | Err(ServingError::NotFound("faiss index")) 25 | } 26 | -------------------------------------------------------------------------------- /libserving/actix_serving/src/utils/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod common; 2 | #[allow(clippy::redundant_static_lifetimes)] 3 | pub mod constants; 4 | pub mod errors; 5 | pub mod faiss; 6 | pub mod features; 7 | pub mod redis_ops; 8 | -------------------------------------------------------------------------------- /libserving/actix_serving/tests/common/mod.rs: -------------------------------------------------------------------------------- 1 | use std::process::Command; 2 | use std::time; 3 | 4 | use assert_cmd::prelude::*; 5 | use serde::Serialize; 6 | 7 | #[derive(Serialize)] 8 | pub struct InvalidParam { 9 | pub user: i32, 10 | pub n_rec: usize, 11 | } 12 | 13 | pub fn start_server(model_type: &str) { 14 | Command::cargo_bin("actix_serving") 15 | .unwrap() 16 | .env("REDIS_HOST", "localhost") 17 | .env("PORT", "8080") 18 | .env("MODEL_TYPE", model_type) 19 | .env("RUST_LOG", "debug") 20 | .env("WORKERS", "2") 21 | .spawn() 22 | .expect("Failed to start actix server"); 23 | // std::env::set_var("RUST_LOG", "debug"); 24 | // let cmd = Command::new("./target/debug/actix_serving") 25 | // .env_clear() 26 | // .env("WORKERS", "2") 27 | // .spawn() 28 | // .unwrap(); 29 | std::thread::sleep(time::Duration::from_secs(1)); 30 | } 31 | 32 | pub fn stop_server() { 33 | Command::new("pkill") 34 | .arg("actix_serving") 35 | .output() 36 | .expect("Failed to stop actix server"); 37 | std::thread::sleep(time::Duration::from_millis(200)); 38 | } 39 | -------------------------------------------------------------------------------- /libserving/actix_serving/tests/embed.rs: -------------------------------------------------------------------------------- 1 | use actix_serving::common::{Payload, Recommendation}; 2 | use pretty_assertions::assert_eq; 3 | 4 | mod common; 5 | use common::{start_server, stop_server, InvalidParam}; 6 | 7 | // cargo test --package actix_serving --test embed -- --test-threads=1 8 | #[test] 9 | fn test_main_embed_serving() { 10 | start_server("embed"); 11 | let req = Payload { 12 | user: String::from("10"), 13 | n_rec: 3, 14 | }; 15 | let resp: Recommendation = reqwest::blocking::Client::new() 16 | .post("http://localhost:8080/embed/recommend") 17 | .json(&req) 18 | .send() 19 | .unwrap() 20 | .json() 21 | .unwrap(); 22 | assert_eq!(resp.rec_list.len(), 3); 23 | stop_server(); 24 | } 25 | 26 | #[test] 27 | fn test_bad_request() { 28 | start_server("embed"); 29 | let invalid_req = InvalidParam { user: 10, n_rec: 3 }; 30 | let resp = reqwest::blocking::Client::new() 31 | .post("http://localhost:8080/embed/recommend") 32 | .json(&invalid_req) 33 | .send() 34 | .unwrap(); 35 | assert_eq!(resp.status(), reqwest::StatusCode::BAD_REQUEST); 36 | stop_server(); 37 | } 38 | 39 | #[test] 40 | fn test_not_found() { 41 | start_server("embed"); 42 | let req = Payload { 43 | user: String::from("10"), 44 | n_rec: 3, 45 | }; 46 | let resp = reqwest::blocking::Client::new() 47 | .post("http://localhost:8080/nooo_embed/recommend") 48 | .json(&req) 49 | .send() 50 | .unwrap(); 51 | assert_eq!(resp.status(), reqwest::StatusCode::NOT_FOUND); 52 | assert_eq!( 53 | resp.text().unwrap(), 54 | "`nooo_embed/recommend` is not available, make sure you've started the right service." 55 | ); 56 | stop_server(); 57 | } 58 | 59 | #[test] 60 | fn test_method_not_allowed() { 61 | start_server("embed"); 62 | let resp = reqwest::blocking::get("http://localhost:8080/embed/recommend").unwrap(); 63 | assert_eq!(resp.status(), reqwest::StatusCode::METHOD_NOT_ALLOWED); 64 | stop_server(); 65 | } 66 | -------------------------------------------------------------------------------- /libserving/actix_serving/tests/knn.rs: -------------------------------------------------------------------------------- 1 | use actix_serving::common::{Payload, Recommendation}; 2 | use pretty_assertions::assert_eq; 3 | 4 | mod common; 5 | use common::{start_server, stop_server, InvalidParam}; 6 | 7 | // cargo test --package actix_serving --test knn -- --test-threads=1 8 | #[test] 9 | fn test_main_knn_serving() { 10 | start_server("knn"); 11 | let req = Payload { 12 | user: String::from("10"), 13 | n_rec: 3, 14 | }; 15 | let resp: Recommendation = reqwest::blocking::Client::new() 16 | .post("http://localhost:8080/knn/recommend") 17 | .json(&req) 18 | .send() 19 | .unwrap() 20 | .json() 21 | .unwrap(); 22 | assert_eq!(resp.rec_list.len(), 3); 23 | stop_server(); 24 | } 25 | 26 | #[test] 27 | fn test_bad_request() { 28 | start_server("knn"); 29 | let invalid_req = InvalidParam { user: 10, n_rec: 3 }; 30 | let resp = reqwest::blocking::Client::new() 31 | .post("http://localhost:8080/knn/recommend") 32 | .json(&invalid_req) 33 | .send() 34 | .unwrap(); 35 | assert_eq!(resp.status(), reqwest::StatusCode::BAD_REQUEST); 36 | stop_server(); 37 | } 38 | 39 | #[test] 40 | fn test_not_found() { 41 | start_server("knn"); 42 | let req = Payload { 43 | user: String::from("10"), 44 | n_rec: 3, 45 | }; 46 | let resp = reqwest::blocking::Client::new() 47 | .post("http://localhost:8080/nooo_knn/recommend") 48 | .json(&req) 49 | .send() 50 | .unwrap(); 51 | assert_eq!(resp.status(), reqwest::StatusCode::NOT_FOUND); 52 | assert_eq!( 53 | resp.text().unwrap(), 54 | "`nooo_knn/recommend` is not available, make sure you've started the right service." 55 | ); 56 | stop_server(); 57 | } 58 | 59 | #[test] 60 | fn test_method_not_allowed() { 61 | start_server("knn"); 62 | let resp = reqwest::blocking::get("http://localhost:8080/knn/recommend").unwrap(); 63 | assert_eq!(resp.status(), reqwest::StatusCode::METHOD_NOT_ALLOWED); 64 | stop_server(); 65 | } 66 | -------------------------------------------------------------------------------- /libserving/actix_serving/tests/tf.rs: -------------------------------------------------------------------------------- 1 | use actix_serving::common::{Payload, Recommendation}; 2 | use pretty_assertions::assert_eq; 3 | 4 | mod common; 5 | use common::{start_server, stop_server, InvalidParam}; 6 | 7 | // cargo test --package actix_serving --test tf -- --test-threads=1 8 | #[test] 9 | fn test_main_tf_serving() { 10 | start_server("tf"); 11 | let req = Payload { 12 | user: String::from("10"), 13 | n_rec: 3, 14 | }; 15 | let resp: Recommendation = reqwest::blocking::Client::new() 16 | .post("http://localhost:8080/tf/recommend") 17 | .json(&req) 18 | .send() 19 | .unwrap() 20 | .json() 21 | .unwrap(); 22 | assert_eq!(resp.rec_list.len(), 3); 23 | stop_server(); 24 | } 25 | 26 | #[test] 27 | fn test_bad_request() { 28 | start_server("tf"); 29 | let invalid_req = InvalidParam { user: 10, n_rec: 3 }; 30 | let resp = reqwest::blocking::Client::new() 31 | .post("http://localhost:8080/tf/recommend") 32 | .json(&invalid_req) 33 | .send() 34 | .unwrap(); 35 | assert_eq!(resp.status(), reqwest::StatusCode::BAD_REQUEST); 36 | stop_server(); 37 | } 38 | 39 | #[test] 40 | fn test_not_found() { 41 | start_server("tf"); 42 | let req = Payload { 43 | user: String::from("10"), 44 | n_rec: 3, 45 | }; 46 | let resp = reqwest::blocking::Client::new() 47 | .post("http://localhost:8080/nooo_tf/recommend") 48 | .json(&req) 49 | .send() 50 | .unwrap(); 51 | assert_eq!(resp.status(), reqwest::StatusCode::NOT_FOUND); 52 | assert_eq!( 53 | resp.text().unwrap(), 54 | "`nooo_tf/recommend` is not available, make sure you've started the right service." 55 | ); 56 | stop_server(); 57 | } 58 | 59 | #[test] 60 | fn test_method_not_allowed() { 61 | start_server("tf"); 62 | let resp = reqwest::blocking::get("http://localhost:8080/tf/recommend").unwrap(); 63 | assert_eq!(resp.status(), reqwest::StatusCode::METHOD_NOT_ALLOWED); 64 | stop_server(); 65 | } 66 | -------------------------------------------------------------------------------- /libserving/crate-index-config: -------------------------------------------------------------------------------- 1 | [source] 2 | 3 | [source.crates-io] 4 | replace-with = "tuna" 5 | 6 | [source.tuna] 7 | registry = "https://mirrors.tuna.tsinghua.edu.cn/git/crates.io-index.git" 8 | 9 | [registries] 10 | 11 | [registries.bfsu] 12 | index = "https://mirrors.bfsu.edu.cn/git/crates.io-index.git" 13 | 14 | [registries.hit] 15 | index = "https://mirrors.hit.edu.cn/crates.io-index.git" 16 | 17 | [registries.nju] 18 | index = "https://mirror.nju.edu.cn/git/crates.io-index.git" 19 | 20 | [registries.rsproxy] 21 | index = "https://rsproxy.cn/crates.io-index" 22 | 23 | [registries.sjtu] 24 | index = "https://mirrors.sjtug.sjtu.edu.cn/git/crates.io-index" 25 | 26 | [registries.tuna] 27 | index = "https://mirrors.tuna.tsinghua.edu.cn/git/crates.io-index.git" 28 | 29 | [registries.ustc] 30 | index = "git://mirrors.ustc.edu.cn/crates.io-index" 31 | 32 | [net] 33 | git-fetch-with-cli = false -------------------------------------------------------------------------------- /libserving/docker-compose-py.yml: -------------------------------------------------------------------------------- 1 | version: "1" 2 | services: 3 | libserving: 4 | image: docker.io/massquantity/sanic-serving:0.1.0 5 | ports: 6 | - '8000:8000' 7 | # command: sanic sanic_serving.embed_deploy:app --host=0.0.0.0 --port=8000 --dev --access-logs -v --workers 2 8 | command: sanic sanic_serving.tf_deploy:app --host=0.0.0.0 --port=8000 --no-access-logs --workers 8 9 | # command: python sanic_serving/knn_deploy.py 10 | environment: 11 | - REDIS_HOST=redis 12 | networks: 13 | - server 14 | volumes: 15 | - './embed_model:/app/faiss_index_path' 16 | restart: always 17 | depends_on: 18 | - redis 19 | 20 | redis: 21 | image: docker.io/redis:7.0.4-alpine 22 | ports: 23 | - '6379:6379' 24 | command: redis-server --save 60 1 --loglevel warning 25 | networks: 26 | - server 27 | volumes: 28 | - './redis_data:/data' 29 | restart: always 30 | 31 | volumes: 32 | embed_model: {} 33 | redis_data: {} 34 | 35 | networks: 36 | server: {} 37 | -------------------------------------------------------------------------------- /libserving/docker-compose-rs.yml: -------------------------------------------------------------------------------- 1 | version: "1" 2 | services: 3 | libserving: 4 | image: docker.io/massquantity/actix-serving:0.1.0 5 | ports: 6 | - '8080:8080' 7 | command: /app/actix_serving 8 | environment: 9 | - PORT=8080 10 | - MODEL_TYPE=embed 11 | - REDIS_HOST=redis 12 | - WORKERS=8 13 | - RUST_LOG=info 14 | networks: 15 | - server 16 | volumes: 17 | - './embed_model:/app/faiss_index_path' 18 | restart: always 19 | depends_on: 20 | - redis 21 | 22 | redis: 23 | image: docker.io/redis:7.0.4-alpine 24 | ports: 25 | - '6379:6379' 26 | command: redis-server --save 60 1 --loglevel warning 27 | networks: 28 | - server 29 | volumes: 30 | - './redis_data:/data' 31 | restart: always 32 | 33 | volumes: 34 | embed_model: {} 35 | redis_data: {} 36 | 37 | networks: 38 | server: {} 39 | -------------------------------------------------------------------------------- /libserving/docker-compose-tf-serving.yml: -------------------------------------------------------------------------------- 1 | version: "1" 2 | services: 3 | libserving: 4 | environment: 5 | - TF_SERVING_HOST=tensorflow-serving 6 | depends_on: 7 | - tensorflow-serving 8 | 9 | tensorflow-serving: 10 | image: docker.io/tensorflow/serving:2.8.2 11 | ports: 12 | - '8500:8500' 13 | - '8501:8501' 14 | environment: 15 | - MODEL_BASE_PATH=/usr/local/tf_model 16 | - MODEL_NAME=youtuberanking 17 | networks: 18 | - server 19 | volumes: 20 | - './tf_model:/usr/local/tf_model' 21 | restart: always 22 | 23 | volumes: 24 | tf_model: {} 25 | 26 | networks: 27 | server: {} 28 | -------------------------------------------------------------------------------- /libserving/request.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument("--host", default="127.0.0.1") 11 | parser.add_argument("--port", default=8000) 12 | parser.add_argument("--user", type=str, help="user id") 13 | parser.add_argument("--n_rec", type=int, help="num of recommendations") 14 | parser.add_argument("--algo", type=str, help="type of serving algorithm") 15 | parser.add_argument("--user_feats", help="user features, type: dict") 16 | parser.add_argument("--seq", help="user behavior sequence, type: list") 17 | return parser.parse_args() 18 | 19 | 20 | def main(): 21 | args = parse_args() 22 | url = f"http://{args.host}:{args.port}/{args.algo}/recommend" 23 | data = {"user": args.user, "n_rec": args.n_rec} 24 | if args.user_feats: 25 | data["user_feats"] = json.loads(args.user_feats) 26 | if args.seq: 27 | data["seq"] = json.loads(args.seq) 28 | 29 | response = requests.post(url, json=data, timeout=1) 30 | if response.status_code != 200: 31 | print(f"Failed to get recommendation: {url}") 32 | print(response.text) 33 | response.raise_for_status() 34 | try: 35 | print(response.json()) 36 | except json.JSONDecodeError: 37 | print("Failed to decode response json") 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /libserving/sanic_serving/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/libserving/sanic_serving/__init__.py -------------------------------------------------------------------------------- /libserving/sanic_serving/benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import time 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | 6 | import aiohttp 7 | import requests 8 | import ujson 9 | 10 | REQUEST_LIMIT = 64 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument("--host", default="127.0.0.1") 17 | parser.add_argument("--port", default=8000) 18 | parser.add_argument("--user", type=str, help="user id") 19 | parser.add_argument("--n_rec", type=int, help="num of recommendations") 20 | parser.add_argument("--n_times", type=int, help="num of requests") 21 | parser.add_argument("--n_threads", type=int, default=1, help="num of threads") 22 | parser.add_argument("--algo", type=str, help="type of algorithm") 23 | return parser.parse_args() 24 | 25 | 26 | async def get_reco_async( 27 | session: aiohttp.ClientSession, url: str, data: dict, semaphore: asyncio.Semaphore 28 | ): 29 | async with semaphore, session.post(url, json=data) as resp: 30 | # if semaphore.locked(): 31 | # await asyncio.sleep(1.0) 32 | resp.raise_for_status() 33 | reco = await resp.json(loads=ujson.loads) 34 | return reco 35 | 36 | 37 | async def main_async(args): 38 | url = f"http://{args.host}:{args.port}/{args.algo}/recommend" 39 | data = {"user": args.user, "n_rec": args.n_rec} 40 | semaphore = asyncio.Semaphore(REQUEST_LIMIT) 41 | async with aiohttp.ClientSession() as session: 42 | tasks = [ 43 | get_reco_async(session, url, data, semaphore) for _ in range(args.n_times) 44 | ] 45 | # await asyncio.gather(*tasks, return_exceptions=True) 46 | for future in asyncio.as_completed(tasks): 47 | _ = await future 48 | 49 | 50 | def get_reco_sync(url: str, data: dict): 51 | resp = requests.post(url, json=data, timeout=1) 52 | resp.raise_for_status() 53 | return resp.json() 54 | 55 | 56 | def main_sync(args): 57 | url = f"http://{args.host}:{args.port}/{args.algo}/recommend" 58 | data = {"user": args.user, "n_rec": args.n_rec} 59 | with ThreadPoolExecutor(max_workers=args.n_threads) as executor: 60 | futures = [ 61 | executor.submit(get_reco_sync, url, data) for _ in range(args.n_times) 62 | ] 63 | for future in as_completed(futures): 64 | _ = future.result() 65 | 66 | 67 | if __name__ == "__main__": 68 | args = parse_args() 69 | 70 | start = time.perf_counter() 71 | asyncio.run(main_async(args)) 72 | duration = time.perf_counter() - start 73 | print( 74 | f"total time {duration}s for async requests, " 75 | f"{duration / args.n_times * 1000} ms/request" 76 | ) 77 | 78 | start = time.perf_counter() 79 | main_sync(args) 80 | duration = time.perf_counter() - start 81 | print( 82 | f"total time {duration}s for sync requests, " 83 | f"{duration / args.n_times * 1000} ms/request" 84 | ) 85 | -------------------------------------------------------------------------------- /libserving/sanic_serving/common.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Callable, Dict, List, Optional, Type, Union 3 | 4 | from pydantic import BaseModel, Extra, ValidationError 5 | from sanic.exceptions import SanicException 6 | from sanic.log import logger 7 | from sanic.request import Request 8 | 9 | 10 | class Params(BaseModel, extra=Extra.forbid): 11 | user: Union[str, int] 12 | n_rec: int 13 | user_feats: Optional[Dict[str, Union[str, int, float]]] = None 14 | seq: Optional[List[Union[str, int]]] = None 15 | 16 | 17 | def validate(model: Type[object]): 18 | def decorator(func: Callable): 19 | @functools.wraps(func) 20 | async def decorated_function(request: Request, **kwargs): 21 | try: 22 | params = model(**request.json) 23 | kwargs["params"] = params 24 | except ValidationError as e: 25 | logger.error(f"Invalid request body: {request.json}") 26 | raise SanicException( 27 | f"Invalid payload: `{request.json}`, please check key name and value type." 28 | ) from e 29 | 30 | return await func(request, **kwargs) 31 | 32 | return decorated_function 33 | 34 | return decorator 35 | -------------------------------------------------------------------------------- /libserving/serialization/__init__.py: -------------------------------------------------------------------------------- 1 | from .embed import save_embed, save_faiss_index 2 | from .knn import save_knn 3 | from .online import save_online 4 | from .redis import embed2redis, knn2redis, online2redis, tf2redis 5 | from .tfmodel import save_tf 6 | 7 | __all__ = [ 8 | "save_knn", 9 | "save_embed", 10 | "save_faiss_index", 11 | "save_online", 12 | "save_tf", 13 | "knn2redis", 14 | "embed2redis", 15 | "tf2redis", 16 | "online2redis", 17 | ] 18 | -------------------------------------------------------------------------------- /libserving/serialization/embed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from libreco.bases import EmbedBase 6 | 7 | from .common import ( 8 | check_path_exists, 9 | save_id_mapping, 10 | save_model_name, 11 | save_to_json, 12 | save_user_consumed, 13 | ) 14 | 15 | 16 | def save_embed(path: str, model: EmbedBase): 17 | """Save Embed model to disk. 18 | 19 | Parameters 20 | ---------- 21 | path : str 22 | Model saving path. 23 | model : EmbedBase 24 | Model to save. 25 | """ 26 | check_path_exists(path) 27 | save_model_name(path, model) 28 | save_id_mapping(path, model.data_info) 29 | save_user_consumed(path, model.data_info) 30 | save_vectors(path, model.user_embeds_np, model.n_users, "user_embed.json") 31 | save_vectors(path, model.item_embeds_np, model.n_items, "item_embed.json") 32 | 33 | 34 | def save_vectors(path: str, embeds: np.ndarray, num: int, name: str): 35 | embed_path = os.path.join(path, name) 36 | embed_dict = dict() 37 | for i in range(num): 38 | embed_dict[i] = embeds[i].tolist() 39 | save_to_json(embed_path, embed_dict) 40 | 41 | 42 | def save_faiss_index(path: str, model: EmbedBase, nlist: int = 80, nprobe: int = 10): 43 | import faiss 44 | 45 | check_path_exists(path) 46 | index_path = os.path.join(path, "faiss_index.bin") 47 | item_embeds = model.item_embeds_np[: model.n_items].astype(np.float32) 48 | d = item_embeds.shape[1] 49 | quantizer = faiss.IndexFlatIP(d) 50 | index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT) 51 | index.train(item_embeds) 52 | index.add(item_embeds) 53 | index.nprobe = nprobe 54 | faiss.write_index(index, index_path) 55 | -------------------------------------------------------------------------------- /libserving/serialization/knn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from scipy import sparse 4 | 5 | from libreco.bases import CfBase 6 | 7 | from .common import ( 8 | check_path_exists, 9 | save_id_mapping, 10 | save_model_name, 11 | save_to_json, 12 | save_user_consumed, 13 | ) 14 | 15 | 16 | def save_knn(path: str, model: CfBase, k: int): 17 | """Save KNN model to disk. 18 | 19 | Parameters 20 | ---------- 21 | path : str 22 | Model saving path. 23 | model : CfBase 24 | Model to save. 25 | k : int 26 | Number of similar users/items to save. 27 | """ 28 | check_path_exists(path) 29 | save_model_name(path, model) 30 | save_id_mapping(path, model.data_info) 31 | save_user_consumed(path, model.data_info) 32 | save_sim_matrix(path, model.sim_matrix, k) 33 | 34 | 35 | def save_sim_matrix(path: str, sim_matrix: sparse.csr_matrix, k: int): 36 | k_sims = dict() 37 | num = len(sim_matrix.indptr) - 1 38 | indices = sim_matrix.indices.tolist() 39 | indptr = sim_matrix.indptr.tolist() 40 | data = sim_matrix.data.tolist() 41 | for i in range(num): 42 | i_slice = slice(indptr[i], indptr[i + 1]) 43 | sorted_sims = sorted(zip(indices[i_slice], data[i_slice]), key=lambda x: -x[1]) 44 | k_sims[i] = sorted_sims[:k] 45 | sim_path = os.path.join(path, "sim.json") 46 | save_to_json(sim_path, k_sims) 47 | -------------------------------------------------------------------------------- /python-package-conda.yml: -------------------------------------------------------------------------------- 1 | name: Python Package using Conda 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build-linux: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | max-parallel: 5 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 3.10 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.10 17 | - name: Add conda to system path 18 | run: | 19 | # $CONDA is an environment variable pointing to the root of the miniconda directory 20 | echo $CONDA/bin >> $GITHUB_PATH 21 | - name: Install dependencies 22 | run: | 23 | conda env update --file environment.yml --name base 24 | - name: Lint with flake8 25 | run: | 26 | conda install flake8 27 | # stop the build if there are Python syntax errors or undefined names 28 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 30 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 31 | - name: Test with pytest 32 | run: | 33 | conda install pytest 34 | pytest 35 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | 3 | flake8 4 | ruff >= 0.3.0 ; python_version > "3.6" 5 | pytest 6 | coverage 7 | cython >= 0.29.0,<3 8 | smart_open < 7.0.0 ; python_version == "3.6" 9 | pyyaml ; python_version >= "3.8" 10 | pydantic ; python_version >= "3.8" 11 | -------------------------------------------------------------------------------- /requirements-serving.txt: -------------------------------------------------------------------------------- 1 | sanic >= 22.3 2 | ujson 3 | requests 4 | pydantic >= 2.0 5 | aiohttp 6 | redis >= 4.2.0 7 | faiss-cpu == 1.7.2 ; python_version == "3.6" 8 | faiss-cpu >= 1.5.2 ; python_version > "3.6" 9 | protobuf < 4.24.0 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy >= 1.19.5 2 | scipy >= 1.2.1, < 1.13.0 3 | pandas >= 1.0.0 4 | scikit-learn >= 0.20.0 5 | tensorflow >= 1.15.0, < 2.16.0 6 | torch >= 1.10.0 7 | gensim >= 4.0.0 8 | tqdm 9 | nmslib ; python_version < "3.11" 10 | dgl < 2.0.0 -f https://data.dgl.ai/wheels/repo.html 11 | dataclasses ; python_version == "3.6" 12 | recfarm ; python_version >= "3.7" 13 | -------------------------------------------------------------------------------- /rust/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | .pytest_cache/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | .venv/ 14 | env/ 15 | bin/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | include/ 26 | man/ 27 | venv/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | pip-selfcheck.json 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | 45 | # Translations 46 | *.mo 47 | 48 | # Mr Developer 49 | .mr.developer.cfg 50 | .project 51 | .pydevproject 52 | 53 | # Rope 54 | .ropeproject 55 | 56 | # Django stuff: 57 | *.log 58 | *.pot 59 | 60 | .DS_Store 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyCharm 66 | .idea/ 67 | 68 | # VSCode 69 | .vscode/ 70 | 71 | # Pyenv 72 | .python-version 73 | -------------------------------------------------------------------------------- /rust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "recfarm" 3 | version = "0.2.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | [lib] 8 | name = "recfarm" 9 | crate-type = ["cdylib"] 10 | 11 | [dependencies] 12 | bincode = "1.3.3" 13 | dashmap = "5.5.3" 14 | flate2 = "1.0" 15 | fxhash = "0.2.1" 16 | pyo3 = { version = "0.23.3", features = ["abi3-py37"] } 17 | rand = { version = "0.8", features = ["default", "alloc"] } 18 | rayon = "1.8" 19 | serde = { version = "1.0", features = ["derive"] } 20 | 21 | [profile.release] 22 | strip = true 23 | opt-level = 3 24 | lto = true 25 | codegen-units = 4 26 | -------------------------------------------------------------------------------- /rust/README.md: -------------------------------------------------------------------------------- 1 | # RecFarm 2 | 3 | Rust implementation for some non-deep learning recommender algorithms. 4 | 5 | ## Installation 6 | 7 | ```shell 8 | $ pip install recfarm 9 | ``` 10 | 11 | Requires Python >= 3.7. 12 | -------------------------------------------------------------------------------- /rust/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.2,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "recfarm" 7 | description = "Rust implementation for some non-deep learning recommender algorithms." 8 | authors = [ 9 | { name = "massquantity", email = "jinxin_madie@163.com" }, 10 | ] 11 | readme = "README.md" 12 | requires-python = ">=3.7" 13 | classifiers = [ 14 | "Programming Language :: Rust", 15 | "Programming Language :: Python :: Implementation :: CPython", 16 | ] 17 | dynamic = ["version"] 18 | 19 | [tool.maturin] 20 | features = ["pyo3/extension-module"] 21 | -------------------------------------------------------------------------------- /rust/recfarm/__init__.py: -------------------------------------------------------------------------------- 1 | from recfarm import recfarm 2 | from recfarm.recfarm import ( 3 | ItemCF, 4 | UserCF, 5 | Swing, 6 | __version__, 7 | build_consumed_unique, 8 | load_item_cf, 9 | load_user_cf, 10 | save_item_cf, 11 | save_user_cf, 12 | save_swing, 13 | load_swing, 14 | ) 15 | 16 | __all__ = ["recfarm", "UserCF", "ItemCF", "Swing"] 17 | -------------------------------------------------------------------------------- /rust/rustfmt.toml: -------------------------------------------------------------------------------- 1 | array_width = 50 2 | chain_width = 50 3 | edition = "2021" 4 | group_imports = "StdExternalCrate" 5 | imports_granularity = "Module" 6 | max_width = 100 7 | reorder_imports = true 8 | reorder_modules = true 9 | use_field_init_shorthand = true 10 | -------------------------------------------------------------------------------- /rust/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::too_many_arguments)] 2 | 3 | use pyo3::prelude::*; 4 | 5 | mod graph; 6 | mod incremental; 7 | mod inference; 8 | mod item_cf; 9 | mod ordering; 10 | mod serialization; 11 | mod similarities; 12 | mod sparse; 13 | mod swing; 14 | mod user_cf; 15 | mod utils; 16 | 17 | const VERSION: &str = env!("CARGO_PKG_VERSION"); 18 | 19 | /// RecFarm module 20 | #[pymodule] 21 | fn recfarm(m: &Bound<'_, PyModule>) -> PyResult<()> { 22 | m.add_class::()?; 23 | m.add_class::()?; 24 | m.add_class::()?; 25 | m.add_function(wrap_pyfunction!(user_cf::save, m)?)?; 26 | m.add_function(wrap_pyfunction!(user_cf::load, m)?)?; 27 | m.add_function(wrap_pyfunction!(item_cf::save, m)?)?; 28 | m.add_function(wrap_pyfunction!(item_cf::load, m)?)?; 29 | m.add_function(wrap_pyfunction!(swing::save, m)?)?; 30 | m.add_function(wrap_pyfunction!(swing::load, m)?)?; 31 | m.add_function(wrap_pyfunction!(utils::build_consumed, m)?)?; 32 | m.add("__version__", VERSION)?; 33 | Ok(()) 34 | } 35 | -------------------------------------------------------------------------------- /rust/src/ordering.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | 3 | /// 0: similarity, 1: label 4 | #[derive(Debug)] 5 | pub(crate) struct SimOrd(pub f32, pub f32); 6 | 7 | impl Ord for SimOrd { 8 | fn cmp(&self, other: &Self) -> Ordering { 9 | self.0 10 | .partial_cmp(&other.0) 11 | .unwrap_or(Ordering::Equal) 12 | } 13 | } 14 | 15 | impl PartialOrd for SimOrd { 16 | fn partial_cmp(&self, other: &Self) -> Option { 17 | Some(self.cmp(other)) 18 | } 19 | } 20 | 21 | impl PartialEq for SimOrd { 22 | fn eq(&self, other: &Self) -> bool { 23 | self.0 == other.0 24 | } 25 | } 26 | 27 | impl Eq for SimOrd {} 28 | 29 | #[cfg(test)] 30 | mod tests { 31 | use std::collections::BinaryHeap; 32 | 33 | use super::*; 34 | 35 | #[test] 36 | fn test_sim_max_heap() { 37 | let mut heap = BinaryHeap::new(); 38 | heap.push(SimOrd(1.1, 1.1)); 39 | heap.push(SimOrd(0.0, 1.1)); 40 | heap.push(SimOrd(-2.0, 0.0)); 41 | heap.push(SimOrd(-0.2, 3.3)); 42 | heap.push(SimOrd(8.8, 8.8)); 43 | assert_eq!(heap.pop(), Some(SimOrd(8.8, 8.8))); 44 | assert_eq!(heap.pop(), Some(SimOrd(1.1, 1.1))); 45 | assert_eq!(heap.pop(), Some(SimOrd(0.0, 1.1))); 46 | assert_eq!(heap.pop(), Some(SimOrd(-0.2, 3.3))); 47 | assert_eq!(heap.pop(), Some(SimOrd(-2.0, 0.0))); 48 | assert_eq!(heap.pop(), None); 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /rust/src/serialization.rs: -------------------------------------------------------------------------------- 1 | use std::fs::{File, OpenOptions}; 2 | use std::io::{Read, Write}; 3 | use std::path::Path; 4 | 5 | use flate2::read::GzDecoder; 6 | use flate2::write::GzEncoder; 7 | use flate2::Compression; 8 | use pyo3::exceptions::PyIOError; 9 | use pyo3::PyResult; 10 | use serde::de::DeserializeOwned; 11 | use serde::Serialize; 12 | 13 | pub fn save_model( 14 | model: &T, 15 | path: &str, 16 | model_name: &str, 17 | class_name: &str, 18 | ) -> PyResult<()> { 19 | let file_name = format!("{model_name}.gz"); 20 | let model_path = Path::new(path).join(file_name); 21 | let file = OpenOptions::new() 22 | .write(true) 23 | .create(true) 24 | .open(model_path.as_path())?; 25 | let mut encoder = GzEncoder::new(file, Compression::new(1)); 26 | let model_bytes: Vec = match bincode::serialize(model) { 27 | Ok(bytes) => bytes, 28 | Err(e) => return Err(PyIOError::new_err(e.to_string())), 29 | }; 30 | encoder.write_all(&model_bytes)?; 31 | encoder.finish()?; 32 | println!( 33 | "Save `{class_name}` model to `{}`", 34 | model_path.canonicalize()?.display() 35 | ); 36 | Ok(()) 37 | } 38 | 39 | pub fn load_model( 40 | path: &str, 41 | model_name: &str, 42 | class_name: &str, 43 | ) -> PyResult { 44 | let file_name = format!("{model_name}.gz"); 45 | let model_path = Path::new(path).join(file_name); 46 | let file = File::open(model_path.as_path())?; 47 | let mut decoder = GzDecoder::new(file); 48 | let mut model_bytes: Vec = Vec::new(); 49 | decoder.read_to_end(&mut model_bytes)?; 50 | let model: T = match bincode::deserialize(&model_bytes) { 51 | Ok(m) => m, 52 | Err(e) => return Err(PyIOError::new_err(e.to_string())), 53 | }; 54 | println!( 55 | "Load `{class_name}` model from `{}`", 56 | model_path.canonicalize()?.display() 57 | ); 58 | Ok(model) 59 | } 60 | -------------------------------------------------------------------------------- /rust/src/utils.rs: -------------------------------------------------------------------------------- 1 | use fxhash::FxHashMap; 2 | use pyo3::prelude::*; 3 | use pyo3::types::{IntoPyDict, PyDict, PyList}; 4 | 5 | /// (x1, x2, prod, count) 6 | pub(crate) type CumValues = (i32, i32, f32, usize); 7 | 8 | #[pyfunction] 9 | #[pyo3(name = "build_consumed_unique")] 10 | pub fn build_consumed<'py>( 11 | py: Python<'py>, 12 | user_indices: &Bound<'py, PyList>, 13 | item_indices: &Bound<'py, PyList>, 14 | ) -> PyResult<(Bound<'py, PyDict>, Bound<'py, PyDict>)> { 15 | let add_or_insert = |mapping: &mut FxHashMap>, k: i32, v: i32| { 16 | mapping 17 | .entry(k) 18 | .and_modify(|consumed| consumed.push(v)) 19 | .or_insert_with(|| vec![v]); 20 | }; 21 | let user_indices: Vec = user_indices.extract()?; 22 | let item_indices: Vec = item_indices.extract()?; 23 | let mut user_consumed: FxHashMap> = FxHashMap::default(); 24 | let mut item_consumed: FxHashMap> = FxHashMap::default(); 25 | for (&u, &i) in user_indices.iter().zip(item_indices.iter()) { 26 | add_or_insert(&mut user_consumed, u, i); 27 | add_or_insert(&mut item_consumed, i, u); 28 | } 29 | // remove consecutive repeated elements 30 | user_consumed.values_mut().for_each(|v| v.dedup()); 31 | item_consumed.values_mut().for_each(|v| v.dedup()); 32 | let user_consumed_py: Bound<'py, PyDict> = user_consumed.into_py_dict(py)?; 33 | let item_consumed_py: Bound<'py, PyDict> = item_consumed.into_py_dict(py)?; 34 | Ok((user_consumed_py, item_consumed_py)) 35 | } 36 | 37 | #[cfg(test)] 38 | mod tests { 39 | use super::*; 40 | 41 | #[test] 42 | fn test_build_consumed() -> Result<(), Box> { 43 | let get_values = |mapping: &Bound<'_, PyDict>, k: i32| -> PyResult> { 44 | mapping.get_item(k)?.unwrap().extract() 45 | }; 46 | pyo3::prepare_freethreaded_python(); 47 | Ok(Python::with_gil(|py| -> PyResult<()> { 48 | let user_indices = PyList::new(py, vec![1, 1, 1, 2, 2, 1, 2, 3, 2, 3])?; 49 | let item_indices = PyList::new(py, vec![11, 11, 999, 0, 11, 11, 999, 11, 999, 0])?; 50 | let (user_consumed, item_consumed) = build_consumed(py, &user_indices, &item_indices)?; 51 | assert_eq!(get_values(&user_consumed, 1)?, vec![11, 999, 11]); 52 | assert_eq!(get_values(&user_consumed, 2)?, vec![0, 11, 999]); 53 | assert_eq!(get_values(&user_consumed, 3)?, vec![11, 0]); 54 | assert_eq!(get_values(&item_consumed, 11)?, vec![1, 2, 1, 3]); 55 | assert_eq!(get_values(&item_consumed, 999)?, vec![1, 2]); 56 | assert_eq!(get_values(&item_consumed, 0)?, vec![2, 3]); 57 | Ok(()) 58 | })?) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | ignore = E501,E203,F811,F401,W503,W504 4 | count = True 5 | show-source = True 6 | statistics = True 7 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/tests/__init__.py -------------------------------------------------------------------------------- /tests/compatibility_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pandas as pd 4 | import tensorflow as tf 5 | 6 | import libreco 7 | from libreco.algorithms import Caser, RNN4Rec 8 | from libreco.data import DatasetPure, split_by_ratio_chrono 9 | from libreco.tfops import TF_VERSION 10 | 11 | if __name__ == "__main__": 12 | print(f"tensorflow version: {TF_VERSION}") 13 | print(libreco) 14 | from libreco.algorithms._als import als_update 15 | from libreco.utils._similarities import forward_cosine, invert_cosine 16 | 17 | print("Cython functions: ", invert_cosine, forward_cosine, als_update) 18 | 19 | cur_path = Path(".").parent 20 | if Path.exists(cur_path / "sample_movielens_rating.dat"): 21 | data_path = cur_path / "sample_movielens_rating.dat" 22 | else: 23 | data_path = cur_path / "sample_data" / "sample_movielens_rating.dat" 24 | 25 | pd_data = pd.read_csv(data_path, sep="::", names=["user", "item", "label", "time"]) 26 | train_data, eval_data = split_by_ratio_chrono(pd_data, test_size=0.2) 27 | train_data, data_info = DatasetPure.build_trainset(train_data) 28 | eval_data = DatasetPure.build_evalset(eval_data) 29 | 30 | rnn = RNN4Rec( 31 | "ranking", 32 | data_info, 33 | rnn_type="lstm", 34 | loss_type="cross_entropy", 35 | embed_size=16, 36 | n_epochs=1, 37 | lr=0.001, 38 | lr_decay=False, 39 | hidden_units=(16, 16), 40 | reg=None, 41 | batch_size=256, 42 | num_neg=1, 43 | dropout_rate=None, 44 | recent_num=10, 45 | tf_sess_config=None, 46 | ) 47 | rnn.fit( 48 | train_data, 49 | neg_sampling=True, 50 | verbose=2, 51 | shuffle=True, 52 | eval_data=eval_data, 53 | metrics=[ 54 | "loss", 55 | "balanced_accuracy", 56 | "roc_auc", 57 | "pr_auc", 58 | "precision", 59 | "recall", 60 | "map", 61 | "ndcg", 62 | ], 63 | num_workers=2, 64 | ) 65 | print("prediction: ", rnn.predict(user=1, item=2)) 66 | print("recommendation: ", rnn.recommend_user(user=1, n_rec=7)) 67 | 68 | tf.compat.v1.reset_default_graph() 69 | caser = Caser( 70 | "ranking", 71 | data_info=data_info, 72 | loss_type="cross_entropy", 73 | embed_size=16, 74 | n_epochs=1, 75 | lr=1e-4, 76 | batch_size=2048, 77 | ) 78 | caser.fit( 79 | train_data, 80 | neg_sampling=True, 81 | verbose=2, 82 | shuffle=True, 83 | eval_data=eval_data, 84 | metrics=[ 85 | "loss", 86 | "balanced_accuracy", 87 | "roc_auc", 88 | "pr_auc", 89 | "precision", 90 | "recall", 91 | "map", 92 | "ndcg", 93 | ], 94 | num_workers=2, 95 | ) 96 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from libreco.bases import Base 4 | 5 | 6 | class NCF(Base): 7 | def __init__(self, task, data_info, lower_upper_bound=None): 8 | super().__init__(task, data_info, lower_upper_bound) 9 | 10 | def fit(self, train_data, **kwargs): 11 | raise NotImplementedError 12 | 13 | def predict(self, user, item, **kwargs): 14 | raise NotImplementedError 15 | 16 | def recommend_user(self, user, n_rec, **kwargs): 17 | raise NotImplementedError 18 | 19 | def save(self, path, model_name, **kwargs): 20 | raise NotImplementedError 21 | 22 | @classmethod 23 | def load(cls, path, model_name, data_info, **kwargs): 24 | raise NotImplementedError 25 | 26 | 27 | def test_base(prepare_pure_data): 28 | _, train_data, _, data_info = prepare_pure_data 29 | with pytest.raises(ValueError): 30 | _ = NCF(task="unknown", data_info=data_info) 31 | with pytest.raises(AssertionError): 32 | _ = NCF(task="rating", data_info=data_info, lower_upper_bound=1) 33 | 34 | model = NCF(task="rating", data_info=data_info, lower_upper_bound=[1, 5]) 35 | with pytest.raises(NotImplementedError): 36 | model.fit(train_data) 37 | with pytest.raises(NotImplementedError): 38 | model.predict(1, 2) 39 | with pytest.raises(NotImplementedError): 40 | model.recommend_user(1, 7) 41 | with pytest.raises(NotImplementedError): 42 | model.save("path", "model_name") 43 | with pytest.raises(NotImplementedError): 44 | NCF.load("path", "model_name", data_info) 45 | -------------------------------------------------------------------------------- /tests/models/test_deepwalk.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | 4 | from libreco.algorithms import DeepWalk 5 | from tests.utils_data import remove_path, set_ranking_labels 6 | from tests.utils_metrics import get_metrics 7 | from tests.utils_pred import ptest_preds 8 | from tests.utils_reco import ptest_recommends 9 | from tests.utils_save_load import save_load_model 10 | 11 | 12 | @pytest.mark.parametrize("task", ["rating", "ranking"]) 13 | @pytest.mark.parametrize( 14 | "norm_embed, n_walks, walk_length, window_size", 15 | [(True, 10, 10, 5), (False, 1, 1, 1)], 16 | ) 17 | @pytest.mark.parametrize("neg_sampling", [True, False, None]) 18 | def test_deepwalk( 19 | pure_data_small, task, norm_embed, n_walks, walk_length, window_size, neg_sampling 20 | ): 21 | tf.compat.v1.reset_default_graph() 22 | pd_data, train_data, eval_data, data_info = pure_data_small 23 | if neg_sampling is False: 24 | set_ranking_labels(train_data) 25 | set_ranking_labels(eval_data) 26 | 27 | if task == "rating": 28 | with pytest.raises(AssertionError): 29 | _ = DeepWalk(task, data_info) 30 | elif neg_sampling is None: 31 | with pytest.raises(AssertionError): 32 | DeepWalk(task, data_info).fit(train_data, neg_sampling) 33 | else: 34 | model = DeepWalk( 35 | task=task, 36 | data_info=data_info, 37 | embed_size=16, 38 | n_epochs=2, 39 | norm_embed=norm_embed, 40 | window_size=window_size, 41 | n_walks=n_walks, 42 | walk_length=walk_length, 43 | ) 44 | model.fit( 45 | train_data, 46 | neg_sampling, 47 | verbose=2, 48 | shuffle=True, 49 | eval_data=eval_data, 50 | metrics=get_metrics(task), 51 | ) 52 | ptest_preds(model, task, pd_data, with_feats=False) 53 | ptest_recommends(model, data_info, pd_data, with_feats=False) 54 | 55 | # test save and load model 56 | loaded_model, loaded_data_info = save_load_model(DeepWalk, model, data_info) 57 | with pytest.raises(RuntimeError): 58 | loaded_model.fit(train_data, neg_sampling) 59 | ptest_preds(loaded_model, task, pd_data, with_feats=False) 60 | ptest_recommends(loaded_model, loaded_data_info, pd_data, with_feats=False) 61 | model.save("not_existed_path", "deepwalk2") 62 | remove_path("not_existed_path") 63 | -------------------------------------------------------------------------------- /tests/models/test_item2vec.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | 4 | from libreco.algorithms import Item2Vec 5 | from tests.utils_data import remove_path, set_ranking_labels 6 | from tests.utils_metrics import get_metrics 7 | from tests.utils_pred import ptest_preds 8 | from tests.utils_reco import ptest_recommends 9 | from tests.utils_save_load import save_load_model 10 | 11 | 12 | @pytest.mark.parametrize("task", ["rating", "ranking"]) 13 | @pytest.mark.parametrize("norm_embed, window_size", [(True, 5), (False, None)]) 14 | @pytest.mark.parametrize("neg_sampling", [True, False, None]) 15 | def test_item2vec(pure_data_small, task, norm_embed, window_size, neg_sampling): 16 | tf.compat.v1.reset_default_graph() 17 | pd_data, train_data, eval_data, data_info = pure_data_small 18 | if neg_sampling is False: 19 | set_ranking_labels(train_data) 20 | set_ranking_labels(eval_data) 21 | 22 | if task == "rating": 23 | with pytest.raises(AssertionError): 24 | _ = Item2Vec(task, data_info) 25 | elif neg_sampling is None: 26 | with pytest.raises(AssertionError): 27 | Item2Vec(task, data_info).fit(train_data, neg_sampling) 28 | else: 29 | model = Item2Vec( 30 | task=task, 31 | data_info=data_info, 32 | embed_size=16, 33 | n_epochs=2, 34 | norm_embed=norm_embed, 35 | window_size=window_size, 36 | ) 37 | model.fit( 38 | train_data, 39 | neg_sampling, 40 | verbose=2, 41 | shuffle=True, 42 | eval_data=eval_data, 43 | metrics=get_metrics(task), 44 | ) 45 | ptest_preds(model, task, pd_data, with_feats=False) 46 | ptest_recommends(model, data_info, pd_data, with_feats=False) 47 | 48 | # test save and load model 49 | loaded_model, loaded_data_info = save_load_model(Item2Vec, model, data_info) 50 | with pytest.raises(RuntimeError): 51 | loaded_model.fit(train_data, neg_sampling) 52 | ptest_preds(loaded_model, task, pd_data, with_feats=False) 53 | ptest_recommends(loaded_model, loaded_data_info, pd_data, with_feats=False) 54 | model.save("not_existed_path", "item2vec2") 55 | remove_path("not_existed_path") 56 | -------------------------------------------------------------------------------- /tests/models/utils_tf.py: -------------------------------------------------------------------------------- 1 | from libreco.tfops import tf 2 | 3 | 4 | def ptest_tf_variables(model): 5 | var_names = [v.name for v in tf.trainable_variables()] 6 | if hasattr(model, "user_variables"): 7 | for v in model.user_variables: 8 | assert f"{v}:0" in var_names 9 | if hasattr(model, "item_variables"): 10 | for v in model.item_variables: 11 | assert f"{v}:0" in var_names 12 | if hasattr(model, "sparse_variables"): 13 | for v in model.sparse_variables: 14 | assert f"{v}:0" in var_names 15 | if hasattr(model, "dense_variables"): 16 | for v in model.dense_variables: 17 | assert f"{v}:0" in var_names 18 | -------------------------------------------------------------------------------- /tests/retrain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/tests/retrain/__init__.py -------------------------------------------------------------------------------- /tests/serving/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/LibRecommender/9eadf8a3d6901a898630da113b92de48bd2897fb/tests/serving/__init__.py -------------------------------------------------------------------------------- /tests/serving/mock_tf_server.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sanic import Sanic 3 | from sanic.log import logger 4 | from sanic.request import Request 5 | from sanic.response import HTTPResponse, json 6 | 7 | app = Sanic("mock-tf-server") 8 | 9 | 10 | @app.post("/v1/models/") 11 | async def tf_serving(request: Request, model_name: str) -> HTTPResponse: 12 | logger.info(f"Mock predictions for {model_name.replace(':predict', '')}") 13 | rng = np.random.default_rng(42) 14 | payload = request.json["inputs"] 15 | if "k" in payload: 16 | return json({"outputs": rng.integers(0, 20, size=payload["k"]).tolist()}) 17 | n_items = len(request.json["inputs"]["item_indices"]) 18 | return json({"outputs": rng.normal(size=n_items).tolist()}) 19 | 20 | 21 | if __name__ == "__main__": 22 | app.run( 23 | host="0.0.0.0", port=8501, debug=False, access_log=False, single_process=True 24 | ) 25 | -------------------------------------------------------------------------------- /tests/serving/setup_coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | PYTHON_SITE_PATH=$(python -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])') 4 | FILE_NAME="sitecustomize.py" 5 | FULL_PATH="${PYTHON_SITE_PATH}/${FILE_NAME}" 6 | 7 | echo "coverage path: ${FULL_PATH}" 8 | cp tests/serving/subprocess_coverage_setup.py "$FULL_PATH" 9 | -------------------------------------------------------------------------------- /tests/serving/subprocess_coverage_setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import coverage 4 | 5 | # essential for subprocess parallel coverage like sanic serving 6 | # https://coverage.readthedocs.io/en/latest/subprocess.html 7 | os.environ["COVERAGE_PROCESS_START"] = ".coveragerc" 8 | coverage.process_startup() 9 | -------------------------------------------------------------------------------- /tests/serving/test_embed_serving.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | 4 | from libserving.serialization import embed2redis, save_embed, save_faiss_index 5 | from tests.utils_data import SAVE_PATH, remove_path 6 | 7 | 8 | def test_embed_serving(embed_model, session, close_server): 9 | save_embed(SAVE_PATH, embed_model) 10 | embed2redis(SAVE_PATH) 11 | faiss_path = str(Path(__file__).parents[2] / "libserving" / "embed_model") 12 | save_faiss_index(faiss_path, embed_model, 40, 10) 13 | 14 | subprocess.run( 15 | "sanic libserving.sanic_serving.embed_deploy:app --no-access-logs --single-process &", 16 | shell=True, 17 | check=True, 18 | ) 19 | # time.sleep(2) # wait for the server to start 20 | 21 | response = session.post( 22 | "http://localhost:8000/embed/recommend", 23 | json={"user": 1, "n_rec": 1}, 24 | timeout=0.5, 25 | ) 26 | assert len(next(iter(response.json().values()))) == 1 27 | response = session.post( 28 | "http://localhost:8000/embed/recommend", 29 | json={"user": 33, "n_rec": 3}, 30 | timeout=0.5, 31 | ) 32 | assert len(next(iter(response.json().values()))) == 3 33 | remove_path(faiss_path) 34 | -------------------------------------------------------------------------------- /tests/serving/test_faiss_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from libreco.algorithms import BPR 6 | from libserving.serialization import save_faiss_index 7 | from tests.utils_data import SAVE_PATH 8 | 9 | 10 | def test_faiss_index(embed_model): 11 | import faiss 12 | 13 | save_faiss_index(SAVE_PATH, embed_model, 80, 10) 14 | index = faiss.read_index(os.path.join(SAVE_PATH, "faiss_index.bin")) 15 | _, ids = index.search(embed_model.user_embeds_np[0].reshape(1, -1), 10) 16 | assert ids.shape == (1, 10) 17 | assert index.ntotal == embed_model.n_items 18 | assert index.d == embed_model.embed_size + 1 # embed + bias 19 | 20 | 21 | @pytest.fixture 22 | def embed_model(prepare_pure_data): 23 | _, train_data, _, data_info = prepare_pure_data 24 | model = BPR( 25 | data_info=data_info, 26 | n_epochs=2, 27 | lr=1e-4, 28 | batch_size=2048, 29 | use_tf=False, 30 | optimizer="adam", 31 | ) 32 | model.fit(train_data, neg_sampling=True, verbose=2) 33 | return model 34 | -------------------------------------------------------------------------------- /tests/serving/test_knn_serving.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import pytest 4 | 5 | from libreco.bases import CfBase 6 | from libserving.serialization import knn2redis, save_knn 7 | from tests.utils_data import SAVE_PATH 8 | 9 | 10 | @pytest.mark.parametrize("knn_model", ["UserCF", "ItemCF"], indirect=True) 11 | def test_knn_serving(knn_model, session, close_server): 12 | assert isinstance(knn_model, CfBase) 13 | save_knn(SAVE_PATH, knn_model, k=10) 14 | knn2redis(SAVE_PATH) 15 | 16 | subprocess.run( 17 | "sanic libserving.sanic_serving.knn_deploy:app --no-access-logs --single-process &", 18 | shell=True, 19 | check=True, 20 | ) 21 | # time.sleep(2) # wait for the server to start 22 | 23 | response = session.post( 24 | "http://localhost:8000/knn/recommend", json={"user": 1, "n_rec": 1}, timeout=0.5 25 | ) 26 | assert len(next(iter(response.json().values()))) == 1 27 | response = session.post( 28 | "http://localhost:8000/knn/recommend", 29 | json={"user": 33, "n_rec": 3}, 30 | timeout=0.5, 31 | ) 32 | assert len(next(iter(response.json().values()))) == 3 33 | -------------------------------------------------------------------------------- /tests/serving/test_online_serving.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import pytest 4 | 5 | from libserving.serialization import online2redis, save_online 6 | from tests.utils_data import SAVE_PATH 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "online_model", 11 | ["pure", "user_feat", "separate", "multi_sparse", "item_feat", "all"], 12 | indirect=True, 13 | ) 14 | def test_online_serving(online_model, session, close_server): 15 | save_online(SAVE_PATH, online_model, version=1) 16 | online2redis(SAVE_PATH) 17 | 18 | subprocess.run( 19 | "sanic libserving.sanic_serving.online_deploy:app --no-access-logs --single-process &", 20 | shell=True, 21 | check=True, 22 | ) 23 | subprocess.run("python tests/serving/mock_tf_server.py &", shell=True, check=True) 24 | # time.sleep(2) # wait for the server to start 25 | 26 | response = session.post( 27 | "http://localhost:8000/online/recommend", 28 | json={"user": 1, "n_rec": 1}, 29 | timeout=0.5, 30 | ) 31 | assert len(next(iter(response.json().values()))) == 1 32 | 33 | response = session.post( 34 | "http://localhost:8000/online/recommend", 35 | json={"user": "uuu", "n_rec": 3}, 36 | timeout=0.5, 37 | ) 38 | assert len(next(iter(response.json().values()))) == 3 39 | 40 | response = session.post( 41 | "http://localhost:8000/online/recommend", 42 | json={"user": 2, "n_rec": 3, "user_feats": {"sex": "male"}}, 43 | timeout=0.5, 44 | ) 45 | assert len(next(iter(response.json().values()))) == 3 46 | 47 | response = session.post( 48 | "http://localhost:8000/online/recommend", 49 | json={"user": 2, "n_rec": 3, "seq": [1, 2, 3, 10, 11, 11, 22, 1, 0, -1, 12, 1]}, 50 | timeout=0.5, 51 | ) 52 | assert len(next(iter(response.json().values()))) == 3 53 | 54 | response = session.post( 55 | "http://localhost:8000/online/recommend", 56 | json={ 57 | "user": "uuu", 58 | "n_rec": 30000, 59 | "user_feats": {"sex": "bb", "age": 1000, "occupation": "ooo", "ggg": "eee"}, 60 | "seq": [1, 2, 3, "??"], 61 | }, 62 | timeout=0.5, 63 | ) 64 | # noinspection PyUnresolvedReferences 65 | assert len(next(iter(response.json().values()))) == online_model.n_items 66 | 67 | response = session.post( 68 | "http://localhost:8000/online/recommend", 69 | json={ 70 | "user": "uuu", 71 | "n_rec": 300, 72 | "item_feats": {"sex": "bb", "age": 1000, "occupation": "ooo", "ggg": "eee"}, 73 | "item_seq": [1, 2, 3, "??"], 74 | }, 75 | timeout=0.5, 76 | ) 77 | assert "Invalid payload" in response.text 78 | -------------------------------------------------------------------------------- /tests/serving/test_tf_serving.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import pytest 4 | 5 | from libreco.bases import TfBase 6 | from libserving.serialization import save_tf, tf2redis 7 | from tests.utils_data import SAVE_PATH 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "tf_model", ["pure", "feat-all", "feat-user", "feat-item"], indirect=True 12 | ) 13 | def test_tf_serving(tf_model, session, close_server): 14 | assert isinstance(tf_model, TfBase) 15 | save_tf(SAVE_PATH, tf_model, version=1) 16 | tf2redis(SAVE_PATH) 17 | 18 | subprocess.run( 19 | "sanic libserving.sanic_serving.tf_deploy:app --no-access-logs --single-process &", 20 | shell=True, 21 | check=True, 22 | ) 23 | subprocess.run("python tests/serving/mock_tf_server.py &", shell=True, check=True) 24 | # time.sleep(2) # wait for the server to start 25 | 26 | response = session.post( 27 | "http://localhost:8000/tf/recommend", json={"user": 1, "n_rec": 1}, timeout=0.5 28 | ) 29 | assert len(next(iter(response.json().values()))) == 1 30 | response = session.post( 31 | "http://localhost:8000/tf/recommend", json={"user": 33, "n_rec": 3}, timeout=0.5 32 | ) 33 | assert len(next(iter(response.json().values()))) == 3 34 | -------------------------------------------------------------------------------- /tests/test_consumed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | 5 | from libreco.data.consumed import _fill_empty, _merge_dedup, interaction_consumed 6 | 7 | 8 | @pytest.mark.skipif( 9 | sys.version_info[:2] < (3, 7), 10 | reason="Rust implementation only supports Python >= 3.7.", 11 | ) 12 | def test_remove_consecutive_duplicates(): 13 | user_indices = [1, 1, 1, 2, 2, 1, 2, 3, 2, 3] 14 | item_indices = [11, 11, 999, 0, 11, 11, 999, 11, 999, 0] 15 | user_consumed, item_consumed = interaction_consumed(user_indices, item_indices) 16 | assert isinstance(user_consumed, dict) 17 | assert isinstance(item_consumed, dict) 18 | assert isinstance(user_consumed[1], list) 19 | assert isinstance(item_consumed[11], list) 20 | assert user_consumed[1] == [11, 999, 11] 21 | assert user_consumed[2] == [0, 11, 999] 22 | assert user_consumed[3] == [11, 0] 23 | assert item_consumed[11] == [1, 2, 1, 3] 24 | assert item_consumed[999] == [1, 2] 25 | assert item_consumed[0] == [2, 3] 26 | 27 | 28 | @pytest.mark.skipif( 29 | sys.version_info[:2] >= (3, 7), 30 | reason="Specific python 3.6 implementation", 31 | ) 32 | def test_remove_duplicates(): 33 | user_indices = [1, 1, 1, 2, 2, 1, 2, 3, 2, 3] 34 | item_indices = [11, 11, 999, 0, 11, 11, 999, 11, 999, 0] 35 | user_consumed, item_consumed = interaction_consumed(user_indices, item_indices) 36 | assert isinstance(user_consumed, dict) 37 | assert isinstance(item_consumed, dict) 38 | assert isinstance(user_consumed[1], list) 39 | assert isinstance(item_consumed[11], list) 40 | assert user_consumed[1] == [11, 999] 41 | assert user_consumed[2] == [0, 11, 999] 42 | assert user_consumed[3] == [11, 0] 43 | assert item_consumed[11] == [1, 2, 3] 44 | assert item_consumed[999] == [1, 2] 45 | assert item_consumed[0] == [2, 3] 46 | 47 | 48 | def test_merge_remove_duplicates(): 49 | num = 3 50 | old_consumed = {0: [1, 2, 3], 1: [4, 5]} 51 | new_consumed = {0: [2, 1], 2: [7, 8]} 52 | consumed = _merge_dedup(new_consumed, num, old_consumed) 53 | assert consumed[0] == [1, 2, 3, 2, 1] 54 | assert consumed[1] == [4, 5] 55 | assert consumed[2] == [7, 8] 56 | 57 | 58 | def test_no_merge(): 59 | num = 4 60 | old_consumed = {0: [1, 2, 3], 1: [4, 5], 2: [0], 3: [99]} 61 | new_consumed = {0: [2, 1], 2: [7, 8]} 62 | consumed = _fill_empty(new_consumed, num, old_consumed) 63 | assert consumed[0] == [2, 1] 64 | assert consumed[1] == [4, 5] 65 | assert consumed[2] == [7, 8] 66 | assert consumed[3] == [99] 67 | -------------------------------------------------------------------------------- /tests/test_dgl.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | import pytest 5 | 6 | from libreco.graph import check_dgl 7 | 8 | 9 | def test_dgl(prepare_feat_data, monkeypatch): 10 | *_, data_info = prepare_feat_data 11 | 12 | with monkeypatch.context() as m: 13 | m.setitem(sys.modules, "dgl", None) 14 | with pytest.raises(ModuleNotFoundError): 15 | from libreco.algorithms import PinSageDGL 16 | 17 | _ = PinSageDGL("ranking", data_info) 18 | 19 | @check_dgl 20 | class ClsWithDGL: 21 | def __new__(cls, *args, **kwargs): 22 | if cls.dgl_error is not None: 23 | raise cls.dgl_error 24 | cls._dgl = importlib.import_module("dgl") 25 | return super().__new__(cls) 26 | 27 | model = ClsWithDGL() 28 | assert model.dgl_error is None 29 | assert model._dgl.__name__ == "dgl" 30 | -------------------------------------------------------------------------------- /tests/test_initializers.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from libreco.utils.initializers import ( 7 | he_init, 8 | truncated_normal, 9 | variance_scaling, 10 | xavier_init, 11 | ) 12 | 13 | 14 | def test_initializers(): 15 | np_rng = np.random.default_rng(42) 16 | mean, std, fan_in, fan_out, scale = 0.1, 0.01, 4, 2, 2.5 17 | variables = truncated_normal(np_rng, [3, 2], mean=0.1, scale=0.01) 18 | assert variables.shape == (3, 2) 19 | variables_in_range(variables, mean, std) 20 | 21 | variables = xavier_init(np_rng, fan_in, fan_out) 22 | std = np.sqrt(2.0 / (fan_in + fan_out)) 23 | variables_in_range(variables, mean, std) 24 | 25 | variables = he_init(np_rng, fan_in, fan_out) 26 | std = 2.0 / np.sqrt(fan_in + fan_out) 27 | variables_in_range(variables, mean, std) 28 | 29 | variables = variance_scaling(np_rng, scale, fan_in, fan_out, mode="fan_in") 30 | std = np.sqrt(scale / fan_in) 31 | variables_in_range(variables, mean, std) 32 | 33 | variables = variance_scaling(np_rng, scale, fan_in, fan_out, mode="fan_out") 34 | std = np.sqrt(scale / fan_out) 35 | variables_in_range(variables, mean, std) 36 | 37 | variables = variance_scaling(np_rng, scale, fan_in, fan_out, mode="fan_average") 38 | std = np.sqrt(2.0 * scale / (fan_in + fan_out)) 39 | variables_in_range(variables, mean, std) 40 | 41 | with pytest.raises(ValueError): 42 | _ = variance_scaling(np_rng, scale, fan_in, fan_out, mode="unknown") 43 | 44 | 45 | def variables_in_range(variables, mean, std): 46 | for v in itertools.chain.from_iterable(variables): 47 | assert (mean - 3 * std) < v < (mean + 3 * std) 48 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pytest 4 | 5 | from libreco.utils.misc import colorize, time_block, time_func 6 | 7 | 8 | @time_func 9 | def long_work(): 10 | time.sleep(0.1) 11 | print(colorize("done!", color="red", bold=True, highlight=True)) 12 | 13 | 14 | def test_misc(): 15 | long_work() 16 | with time_block("long work2", verbose=0): 17 | time.sleep(0.1) 18 | with pytest.raises(RuntimeError): 19 | with time_block("long work2", verbose=0): 20 | raise RuntimeError 21 | -------------------------------------------------------------------------------- /tests/test_multi_sparse_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import pytest 5 | 6 | from libreco.data import split_multi_value 7 | 8 | 9 | def test_multi_sparse_processing(): 10 | data_path = os.path.join( 11 | os.path.dirname(os.path.realpath(__file__)), 12 | "sample_data", 13 | "sample_movielens_genre.csv", 14 | ) 15 | data = pd.read_csv(data_path, sep=",", header=0) 16 | 17 | with pytest.raises(AssertionError): 18 | # max_len must be list or tuple 19 | split_multi_value(data, multi_value_col=["genre"], sep="|", max_len=3) 20 | 21 | sep = "," # wrong separator 22 | data, *_ = split_multi_value( 23 | data, 24 | multi_value_col=["genre"], 25 | sep=sep, 26 | max_len=[3], 27 | pad_val="missing", 28 | user_col=["sex", "age", "occupation"], 29 | item_col=["genre"], 30 | ) 31 | assert all(data["genre_2"].str.contains("missing")) 32 | assert all(data["genre_3"].str.contains("missing")) 33 | 34 | sep = "|" 35 | data = pd.read_csv(data_path, sep=",", header=0) 36 | data, multi_sparse_col, multi_user_col, multi_item_col = split_multi_value( 37 | data, 38 | multi_value_col=["genre"], 39 | sep=sep, 40 | max_len=[3], 41 | pad_val="missing", 42 | user_col=["sex", "age", "occupation"], 43 | item_col=["genre"], 44 | ) 45 | assert multi_sparse_col == [["genre_1", "genre_2", "genre_3"]] 46 | assert multi_user_col == [] 47 | assert multi_item_col == ["genre_1", "genre_2", "genre_3"] 48 | all_columns = data.columns.tolist() 49 | assert "genre" not in all_columns 50 | assert all_columns == [ 51 | "user", 52 | "item", 53 | "label", 54 | "time", 55 | "sex", 56 | "age", 57 | "occupation", 58 | "genre_1", 59 | "genre_2", 60 | "genre_3", 61 | ] 62 | -------------------------------------------------------------------------------- /tests/utils_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | 7 | from libreco.data import TransformedEvalSet, TransformedSet 8 | 9 | SAVE_PATH = os.path.join(str(Path(os.path.realpath(__file__)).parent), "save_path") 10 | 11 | 12 | def remove_path(path): 13 | if os.path.exists(path) and os.path.isdir(path): 14 | shutil.rmtree(path) 15 | 16 | 17 | def set_ranking_labels(data): 18 | if isinstance(data, TransformedSet): 19 | original_labels = data._labels.copy() 20 | data._labels[original_labels >= 4] = 1 21 | data._labels[original_labels < 4] = 0 22 | elif isinstance(data, TransformedEvalSet): 23 | original_labels = np.copy(data.labels) 24 | data.labels[original_labels >= 4] = 1 25 | data.labels[original_labels < 4] = 0 26 | -------------------------------------------------------------------------------- /tests/utils_metrics.py: -------------------------------------------------------------------------------- 1 | def get_metrics(task): 2 | if task == "rating": 3 | return ["rmse", "mae", "r2"] 4 | else: 5 | return [ 6 | "loss", 7 | "balanced_accuracy", 8 | "roc_auc", 9 | "roc_gauc", 10 | "pr_auc", 11 | "precision", 12 | "recall", 13 | "map", 14 | "ndcg", 15 | "coverage", 16 | ] 17 | -------------------------------------------------------------------------------- /tests/utils_multi_sparse_models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from libreco.bases import TfBase 4 | 5 | 6 | def fit_multi_sparse(cls, train_data, eval_data, data_info, lr=None): 7 | if issubclass(cls, TfBase): 8 | tf.compat.v1.reset_default_graph() 9 | 10 | model = cls( 11 | task="ranking", 12 | data_info=data_info, 13 | loss_type="cross_entropy", 14 | embed_size=4, 15 | n_epochs=1, 16 | lr=1e-4 if not lr else lr, 17 | batch_size=100, 18 | ) 19 | model.fit( 20 | train_data, 21 | neg_sampling=True, 22 | verbose=2, 23 | shuffle=True, 24 | eval_data=eval_data, 25 | metrics=["roc_auc", "precision", "map", "ndcg"], 26 | eval_user_num=40, 27 | ) 28 | return model 29 | -------------------------------------------------------------------------------- /tests/utils_pred.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from libreco.prediction import predict_data_with_feats 4 | 5 | 6 | def ptest_preds(model, task, pd_data, with_feats): 7 | user = pd_data.user.iloc[0] 8 | item = pd_data.item.iloc[0] 9 | pred = model.predict(user=user, item=item) 10 | # prediction in range 11 | if task == "rating": 12 | assert 1 <= pred <= 5 13 | else: 14 | assert 0 <= pred <= 1 15 | 16 | popular_pred = model.predict( 17 | user="cold user2", item="cold item2", cold_start="popular" 18 | ) 19 | assert np.allclose(popular_pred, model.default_pred) 20 | 21 | cold_pred1 = model.predict(user="cold user1", item="cold item2") 22 | cold_pred2 = model.predict(user="cold user2", item="cold item2") 23 | assert cold_pred1 == cold_pred2 24 | 25 | if with_feats: 26 | assert len(predict_data_with_feats(model, pd_data[:5])) == 5 27 | model.predict(user=user, item=item, feats={"sex": "male", "genre_1": "crime"}) 28 | -------------------------------------------------------------------------------- /tests/utils_save_load.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from libreco.bases import TfBase 4 | from libreco.data import DataInfo 5 | from tests.utils_data import SAVE_PATH 6 | 7 | 8 | def save_load_model(cls, model, data_info): 9 | model_name = cls.__name__.lower() + "_model" 10 | data_info.save(path=SAVE_PATH, model_name=model_name) 11 | model.save(SAVE_PATH, model_name, manual=True, inference_only=True) 12 | 13 | if issubclass(cls, TfBase) or hasattr(model, "sess"): 14 | tf.compat.v1.reset_default_graph() 15 | loaded_data_info = DataInfo.load(path=SAVE_PATH, model_name=model_name) 16 | loaded_model = cls.load(SAVE_PATH, model_name, loaded_data_info, manual=True) 17 | return loaded_model, loaded_data_info 18 | --------------------------------------------------------------------------------