├── .editorconfig ├── .github ├── ISSUE_TEMPLATE │ ├── bug.yml │ ├── config.yml │ ├── feature_request.yml │ └── question.yml ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── publish.yml │ ├── publish_dev.yml │ └── test.yml ├── .gitignore ├── .gitmodules ├── .pylintrc ├── .readthedocs.yml ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.rst ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── bandit.yml ├── benchmark ├── __init__.py └── requirements.txt ├── datasets └── KION │ └── README.md ├── docker └── Dockerfile ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _static │ └── theme.css │ ├── _templates │ ├── custom-base-template.rst │ ├── custom-class-template.rst │ └── custom-module-template.rst │ ├── api.rst │ ├── conf.py │ ├── dataset.rst │ ├── examples.rst │ ├── faq.rst │ ├── features.rst │ ├── index.rst │ ├── metrics.rst │ ├── model_selection.rst │ ├── models.rst │ ├── support.rst │ ├── tools.rst │ ├── tutorials.rst │ └── visuals.rst ├── examples ├── 1_simple_example.ipynb ├── 2_cross_validation.ipynb ├── 3_metrics.ipynb ├── 4_dataset_with_features.ipynb ├── 5_benchmark_iALS_with_features.ipynb ├── 6_benchmark_lightfm_inference.ipynb ├── 7_visualization.ipynb ├── 8_debiased_metrics.ipynb ├── 9_model_configs_and_saving.ipynb └── tutorials │ ├── baselines_extended_tutorial.ipynb │ ├── transformers_advanced_training_guide.ipynb │ ├── transformers_customization_guide.ipynb │ └── transformers_tutorial.ipynb ├── poetry.lock ├── poetry.toml ├── pyproject.toml ├── rectools ├── __init__.py ├── columns.py ├── compat.py ├── dataset │ ├── __init__.py │ ├── dataset.py │ ├── features.py │ ├── identifiers.py │ ├── interactions.py │ └── torch_datasets.py ├── exceptions.py ├── metrics │ ├── __init__.py │ ├── auc.py │ ├── base.py │ ├── catalog.py │ ├── classification.py │ ├── debias.py │ ├── distances.py │ ├── diversity.py │ ├── dq.py │ ├── intersection.py │ ├── novelty.py │ ├── popularity.py │ ├── ranking.py │ ├── scoring.py │ └── serendipity.py ├── model_selection │ ├── __init__.py │ ├── cross_validate.py │ ├── last_n_split.py │ ├── random_split.py │ ├── splitter.py │ ├── time_split.py │ └── utils.py ├── models │ ├── __init__.py │ ├── base.py │ ├── ease.py │ ├── implicit_als.py │ ├── implicit_bpr.py │ ├── implicit_knn.py │ ├── lightfm.py │ ├── nn │ │ ├── __init__.py │ │ ├── dssm.py │ │ ├── item_net.py │ │ └── transformers │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── bert4rec.py │ │ │ ├── constants.py │ │ │ ├── data_preparator.py │ │ │ ├── lightning.py │ │ │ ├── negative_sampler.py │ │ │ ├── net_blocks.py │ │ │ ├── sasrec.py │ │ │ ├── similarity.py │ │ │ └── torch_backbone.py │ ├── popular.py │ ├── popular_in_category.py │ ├── pure_svd.py │ ├── random.py │ ├── rank │ │ ├── __init__.py │ │ ├── compat.py │ │ ├── rank.py │ │ ├── rank_implicit.py │ │ └── rank_torch.py │ ├── serialization.py │ ├── utils.py │ └── vector.py ├── tools │ ├── __init__.py │ └── ann.py ├── types.py ├── utils │ ├── __init__.py │ ├── array_set_ops.py │ ├── config.py │ ├── indexing.py │ ├── misc.py │ └── serialization.py ├── version.py └── visuals │ ├── __init__.py │ ├── metrics_app.py │ └── visual_app.py ├── scripts └── copyright.py ├── setup.cfg └── tests ├── __init__.py ├── dataset ├── __init__.py ├── test_dataset.py ├── test_features.py ├── test_identifiers.py ├── test_interactions.py └── test_torch_dataset.py ├── metrics ├── __init__.py ├── test_auc.py ├── test_base.py ├── test_catalog.py ├── test_classification.py ├── test_debias.py ├── test_distances.py ├── test_diversity.py ├── test_dq.py ├── test_intersection.py ├── test_novelty.py ├── test_popularity.py ├── test_ranking.py ├── test_scoring.py └── test_serendipity.py ├── model_selection ├── __init__.py ├── test_cross_validate.py ├── test_last_n_split.py ├── test_random_split.py ├── test_splitter.py ├── test_time_split.py └── test_utils.py ├── models ├── __init__.py ├── data.py ├── nn │ ├── __init__.py │ ├── test_dssm.py │ ├── test_item_net.py │ └── transformers │ │ ├── __init__.py │ │ ├── test_base.py │ │ ├── test_bert4rec.py │ │ ├── test_data_preparator.py │ │ ├── test_sasrec.py │ │ └── utils.py ├── rank │ ├── __init__.py │ ├── test_rank.py │ ├── test_rank_implicit.py │ └── test_rank_torch.py ├── test_base.py ├── test_ease.py ├── test_implicit_als.py ├── test_implicit_bpr.py ├── test_implicit_knn.py ├── test_lightfm.py ├── test_popular.py ├── test_popular_in_category.py ├── test_pure_svd.py ├── test_random.py ├── test_serialization.py ├── test_utils.py ├── test_vector.py └── utils.py ├── test_compat.py ├── testing_utils.py ├── tools ├── __init__.py └── test_ann.py ├── utils ├── __init__.py ├── test_array_set_ops.py ├── test_indexing.py └── test_misc.py └── visuals ├── __init__.py ├── test_metrics_app.py └── test_visual_app.py /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset = utf-8 5 | end_of_line = lf 6 | indent_size = 4 7 | indent_style = space 8 | insert_final_newline = true 9 | trim_trailing_whitespace = true 10 | max_line_length = 120 11 | tab_width = 4 12 | 13 | [{*.yml, *.yaml, *.json, *.xml}] 14 | indent_size = 2 15 | 16 | [Makefile] 17 | indent_style = tab 18 | 19 | [*.md] 20 | trim_trailing_whitespace = false 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Create a report to help us improve 3 | labels: [bug, good first issue] 4 | 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: 9 | Thanks for taking the time to report a bug! Before you start, please check if someone else has already reported the same issue. 10 | 11 | - type: textarea 12 | id: what-happened 13 | attributes: 14 | label: What happened? 15 | description: Tell us what happened. 16 | placeholder: A clear and concise description of what the bug is. 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: what-you-expected 22 | attributes: 23 | label: Expected behavior 24 | description: What did you expect to happen? 25 | validations: 26 | required: false 27 | 28 | - type: textarea 29 | id: logs 30 | attributes: 31 | label: Relevant logs and/or screenshots 32 | description: Please provide any relevant logs or screenshots. 33 | validations: 34 | required: false 35 | 36 | - type: input 37 | id: os-version 38 | attributes: 39 | label: Operating System 40 | description: What operating system are you using? 41 | validations: 42 | required: true 43 | 44 | - type: input 45 | id: python-version 46 | attributes: 47 | label: Python Version 48 | description: What version of Python are you using? 49 | validations: 50 | required: true 51 | 52 | - type: textarea 53 | id: rectools-version 54 | attributes: 55 | label: RecTools version 56 | description: What version of RecTools are you using? 57 | validations: 58 | required: true 59 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: RecTools Documentation 4 | url: https://rectools.readthedocs.io/ 5 | about: Please check the documentation before asking a question. 6 | - name: RecTools Telegram Channel 7 | url: https://t.me/RecTools_Support 8 | about: Feel free to ask questions in the telegram channel. 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest a new feature or enhancement 3 | labels: [feature, optimization, good first issue] 4 | 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to suggest a new feature! Before you start, please check if someone else has already suggested the same feature. 10 | 11 | - type: textarea 12 | id: feature-description 13 | attributes: 14 | label: Feature Description 15 | description: Tell us about the feature you would like to see. 16 | placeholder: A clear and concise description of what you want to happen. 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: why-this-feature 22 | attributes: 23 | label: Why this feature? 24 | description: Why do you want this feature? How will it help you? 25 | validations: 26 | required: true 27 | 28 | - type: textarea 29 | id: additional-context 30 | attributes: 31 | label: Additional context 32 | description: Any other context or screenshots about the feature request. 33 | validations: 34 | required: false 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.yml: -------------------------------------------------------------------------------- 1 | name: Question 2 | description: Ask a question about this project 3 | labels: [question] 4 | 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to ask a question! Before you start, please check if someone else has already asked the same question. 10 | 11 | - type: textarea 12 | id: question 13 | attributes: 14 | label: Your Question 15 | description: Please write your question here. 16 | placeholder: Write your question here. 17 | validations: 18 | required: true 19 | 20 | - type: input 21 | id: os-version 22 | attributes: 23 | label: Operating System 24 | description: What operating system are you using? 25 | validations: 26 | required: false 27 | 28 | - type: input 29 | id: python-version 30 | attributes: 31 | label: Python Version 32 | description: What version of Python are you using? 33 | validations: 34 | required: false 35 | 36 | - type: textarea 37 | id: rectools-version 38 | attributes: 39 | label: RecTools version 40 | description: What version of RecTools are you using? 41 | validations: 42 | required: false 43 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 9 | 10 | ## Description 11 | 12 | 16 | 17 | ## Type of change 18 | 19 | 22 | 23 | - [ ] Bug fix (non-breaking change which fixes an issue) 24 | - [ ] New feature (non-breaking change which adds functionality) 25 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 26 | - [ ] Optimization 27 | 28 | ## How Has This Been Tested? 29 | Before submitting a PR, please check yourself against the following list. It would save us quite a lot of time. 30 | - Have you read the contribution guide? 31 | - Have you updated the relevant docstrings? We're using Numpy format, please double-check yourself 32 | - Does your change require any new tests? 33 | - Have you updated the changelog file? 34 | 35 | 39 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | release: 5 | types: 6 | - created 7 | 8 | jobs: 9 | 10 | publish: 11 | runs-on: ubuntu-22.04 12 | permissions: 13 | id-token: write 14 | environment: 15 | name: production 16 | url: https://pypi.org/p/rectools 17 | 18 | steps: 19 | - name: Dump GitHub context 20 | env: 21 | GITHUB_CONTEXT: ${{ toJson(github) }} 22 | run: echo "$GITHUB_CONTEXT" 23 | 24 | - uses: actions/checkout@v4 25 | 26 | - name: Set up Python 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: "3.10" 30 | 31 | - name: Install poetry 32 | run: pip install urllib3==1.26.15 poetry==1.8.3 33 | 34 | - name: Install Dependencies 35 | run: poetry install --no-dev 36 | 37 | - name: Build 38 | run: poetry build 39 | 40 | - name: Publish 41 | uses: pypa/gh-action-pypi-publish@release/v1 42 | -------------------------------------------------------------------------------- /.github/workflows/publish_dev.yml: -------------------------------------------------------------------------------- 1 | name: Publish to test PyPI 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | 8 | publish: 9 | runs-on: ubuntu-22.04 10 | permissions: 11 | id-token: write 12 | environment: 13 | name: development 14 | url: https://test.pypi.org/p/rectools 15 | 16 | steps: 17 | - name: Dump GitHub context 18 | env: 19 | GITHUB_CONTEXT: ${{ toJson(github) }} 20 | run: echo "$GITHUB_CONTEXT" 21 | 22 | - uses: actions/checkout@v4 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: "3.10" 28 | 29 | - name: Install poetry 30 | run: pip install urllib3==1.26.15 poetry==1.4.0 31 | 32 | - name: Install Dependencies 33 | run: poetry install --no-dev 34 | 35 | - name: Build 36 | run: poetry build 37 | 38 | - name: Publish 39 | uses: pypa/gh-action-pypi-publish@release/v1 40 | with: 41 | repository-url: https://test.pypi.org/legacy/ 42 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | types: [opened, synchronize] 9 | 10 | 11 | jobs: 12 | lint: 13 | runs-on: ubuntu-22.04 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.10" 22 | 23 | - name: Install poetry 24 | run: pip install urllib3==1.26.15 poetry==1.8.5 25 | 26 | - name: Load cached venv 27 | id: cached-poetry-dependencies 28 | uses: actions/cache@v4 29 | with: 30 | path: .venv 31 | key: venv-lint-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} 32 | 33 | - name: Install dependencies 34 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 35 | run: make install 36 | 37 | - name: Static analysis 38 | run: make lint 39 | 40 | test: 41 | name: test ${{ matrix.python-version }} 42 | runs-on: ubuntu-22.04 43 | strategy: 44 | fail-fast: false 45 | matrix: 46 | python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ] 47 | 48 | steps: 49 | - uses: actions/checkout@v4 50 | 51 | - name: Set up Python 52 | uses: actions/setup-python@v5 53 | with: 54 | python-version: ${{ matrix.python-version }} 55 | 56 | - name: Install poetry 57 | run: pip install urllib3==1.26.15 poetry==1.8.5 58 | 59 | - name: Load cached venv 60 | id: cached-poetry-dependencies 61 | uses: actions/cache@v4 62 | with: 63 | path: .venv 64 | key: venv-test-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }} 65 | 66 | - name: Install dependencies 67 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 68 | run: make install 69 | 70 | - name: Run tests 71 | run: make test 72 | 73 | - name: Upload coverage 74 | if: matrix.python-version == '3.9' && ! startsWith(github.base_ref, 'experimental/') 75 | uses: codecov/codecov-action@v4 76 | with: 77 | fail_ci_if_error: true 78 | files: ./coverage.xml 79 | token: ${{ secrets.CODECOV_TOKEN }} 80 | verbose: true 81 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # MacOS 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # Unit test / coverage reports 33 | htmlcov/ 34 | .tox/ 35 | .nox/ 36 | .coverage 37 | .coverage.* 38 | .coverage_report 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | *.cover 43 | *.py,cover 44 | .hypothesis/ 45 | .pytest_cache/ 46 | cover/ 47 | .reports/ 48 | 49 | # PyBuilder 50 | .pybuilder/ 51 | target/ 52 | 53 | # Sphinx documentation 54 | docs/build/ 55 | docs/source/api/ 56 | 57 | # Jupyter checkpoints 58 | .ipynb_checkpoints 59 | 60 | # IPython 61 | profile_default/ 62 | ipython_config.py 63 | 64 | # Environments 65 | .env 66 | .venv 67 | env/ 68 | venv/ 69 | ENV/ 70 | env.bak/ 71 | venv.bak/ 72 | 73 | # mypy 74 | .mypy_cache/ 75 | .dmypy.json 76 | dmypy.json 77 | 78 | # Pycharm 79 | .idea/ 80 | 81 | # VS code 82 | settings.json 83 | .vscode/ 84 | .vs/ 85 | 86 | # Pytorch-lightning 87 | lightning_logs 88 | 89 | # benchmarks 90 | benchmark_results/ 91 | 92 | # Data 93 | *.zip 94 | *.csv 95 | *.dat 96 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "datasets/KION/data"] 2 | path = datasets/KION/data 3 | url = https://github.com/irsafilo/KION_DATASET.git 4 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.12" 7 | jobs: 8 | pre_build: 9 | - cp -r examples docs/source/ 10 | post_create_environment: 11 | - pip install poetry 12 | post_install: 13 | - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install -E all --no-root --with docs 14 | 15 | sphinx: 16 | builder: html 17 | configuration: docs/source/conf.py 18 | fail_on_warning: false 19 | 20 | formats: 21 | - pdf 22 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Conduct 4 | 5 | * We are commited to provide a friendly and welcoming environment for all 6 | * Be kind and resprectful, no need to be rude or mean 7 | * Any design decision or way of implementing carries a trade-off. People have different opinions, and that's ok. There is seldom a right answer. Please respect that 8 | * Keep your critique structured. If you still wish to be unstructured, keep it at minimum 9 | * Insulters will be excluded from interactions 10 | * Harrassment - public or private - is unacceptable, as well as trolling, flaming and baiting. That's all there is to it. No matter who you are 11 | 12 | ## Responsibilities 13 | 14 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 15 | 16 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 17 | 18 | ## Scope 19 | 20 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 21 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contributing Guide 2 | ================== 3 | 4 | Welcome! There are many ways to contribute, including submitting bug 5 | reports, improving documentation, submitting feature requests, reviewing 6 | new submissions, or contributing code that can be incorporated into the 7 | project. 8 | 9 | For support questions please use `Telegram Channel `_ 10 | or open an issue of type `Question` 11 | 12 | Feature Requests 13 | ---------------- 14 | 15 | Please create a new GitHub issue for any significant changes and 16 | enhancements that you wish to make. Provide the feature you would like 17 | to see, why you need it, and how it will work. Discuss your ideas 18 | transparently and get community feedback before proceeding. 19 | 20 | Significant changes that you wish to contribute to the project should be 21 | discussed first in a GitHub issue that clearly outlines the changes and 22 | benefits of the feature. 23 | 24 | Small Changes can directly be crafted and submitted to the GitHub 25 | Repository as a Pull Request. 26 | 27 | Pull Request Process 28 | -------------------- 29 | 30 | #. Fork RecTools `main repository `_ 31 | on GitHub. See `this guide `_ if you have questions. 32 | #. Clone your fork from GitHub to your local disk. 33 | #. Create a virtual environment and install dependencies including all 34 | extras and development dependencies. 35 | 36 | #. Make sure you have ``python>=3.9`` and ``poetry>=1.5.0`` installed 37 | #. Deactivate any active virtual environments. Deactivate conda ``base`` 38 | environment if applicable 39 | #. Run ``make install`` command which will create a virtual env and 40 | install everything that is needed with poetry. See `poetry usage details `_ 41 | 42 | #. Implement your changes. Check the following after you are done: 43 | 44 | #. Docstrings. Please ensure that all new public classes and methods 45 | have docstrings. We use numpy style. Also check that examples are 46 | provided 47 | #. Code styling. Autoformat with ``make format`` 48 | #. Linters. Run checks with ``make lint`` 49 | #. Tests. Make sure you've covered new features with tests. Check 50 | code correctness with ``make test`` 51 | #. Coverage. Check with ``make coverage`` 52 | #. Changelog. Please describe you changes in `CHANGELOG.MD `_ 53 | 54 | #. Create a pull request from your fork. See `instructions `_ 55 | 56 | 57 | You may merge the Pull Request in once you have the approval of one 58 | of the core developers, or if you do not have permission to do that, you 59 | may request the a reviewer to merge it for you. 60 | 61 | Review Process 62 | -------------- 63 | 64 | We keep pull requests open for a few days for multiple people to have 65 | the chance to review/comment. 66 | 67 | After feedback has been given, we expect responses within two weeks. 68 | After two weeks, we may close the pull request if it isn't showing any 69 | activity. 70 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VENV=.venv 2 | REPORTS=.reports 3 | 4 | BENCHMARK=benchmark 5 | SOURCES=rectools 6 | TESTS=tests 7 | SCRIPTS=scripts 8 | 9 | 10 | 11 | # Installation 12 | 13 | .reports: 14 | mkdir ${REPORTS} 15 | 16 | .venv: 17 | poetry install -E all --no-root 18 | 19 | install: .venv .reports 20 | 21 | 22 | # Linters 23 | 24 | .isort: 25 | poetry run isort --check ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 26 | 27 | .black: 28 | poetry run black --check --diff ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 29 | 30 | .pylint: 31 | poetry run pylint --jobs 4 ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 32 | 33 | .mypy: 34 | poetry run mypy ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 35 | 36 | .flake8: 37 | poetry run flake8 ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 38 | 39 | .bandit: 40 | poetry run bandit -q -c bandit.yml -r ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 41 | 42 | .codespell: 43 | poetry run codespell ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 44 | 45 | 46 | # Fixers & formatters 47 | 48 | .isort_fix: 49 | poetry run isort ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 50 | 51 | .autopep8_fix: 52 | poetry run autopep8 --in-place -r ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 53 | 54 | .black_fix: 55 | poetry run black -q ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 56 | 57 | 58 | # Tests 59 | 60 | .pytest: 61 | poetry run pytest ${TESTS} --cov=${SOURCES} --cov-report=xml 62 | 63 | .doctest: 64 | poetry run pytest --doctest-modules ${SOURCES} --ignore=rectools/tools/ann.py 65 | 66 | coverage: .venv .reports 67 | poetry run coverage run --source ${SOURCES} --module pytest 68 | poetry run coverage report 69 | poetry run coverage html -d ${REPORTS}/coverage_html 70 | poetry run coverage xml -o ${REPORTS}/coverage.xml -i 71 | 72 | 73 | # Generalization 74 | 75 | .format: .isort_fix .autopep8_fix .black_fix 76 | format: .venv .format 77 | 78 | .lint: .isort .black .flake8 .codespell .mypy .pylint .bandit 79 | lint: .venv .lint 80 | 81 | .test: .pytest .doctest 82 | test: .venv .test 83 | 84 | 85 | # Copyright 86 | 87 | copyright: 88 | poetry run python -m scripts.copyright --check ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 89 | 90 | copyright_fix: 91 | poetry run python -m scripts.copyright ${SOURCES} ${TESTS} ${SCRIPTS} ${BENCHMARK} 92 | 93 | 94 | # Cleaning 95 | 96 | clean: 97 | rm -rf build dist .eggs *.egg-info 98 | rm -rf ${VENV} 99 | rm -rf ${REPORTS} 100 | find . -type d -name '.mypy_cache' -exec rm -rf {} + 101 | find . -type d -name '*pytest_cache*' -exec rm -rf {} + 102 | 103 | reinstall: clean install 104 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | Security Policy 2 | ========== 3 | 4 | **Supported Python versions** 5 | 6 | 3.9 or above 7 | 8 | **Product development security recommendations** 9 | 10 | Update dependencies to last stable version. 11 | Build SBOM for the project. 12 | Perform SAST (Static Application Security Testing) where possible. 13 | 14 | **Product development security requirements** 15 | 16 | No binaries in repository. 17 | No passwords, keys, access tokens in source code. 18 | No "Critical" and/or "High" vulnerabilities in contributed source code. 19 | https://en.wikipedia.org/wiki/Common_Vulnerability_Scoring_System 20 | 21 | **Vulnerability reports** 22 | 23 | Please, use email rectools-team@mts.ru for reporting security issues or anything that can cause any 24 | consequences for security. Please avoid any public disclosure (including registering issues) at least until it is fixed. Thank you in advance for understanding. 25 | -------------------------------------------------------------------------------- /bandit.yml: -------------------------------------------------------------------------------- 1 | skips: ['B101', 'B301', 'B403'] 2 | -------------------------------------------------------------------------------- /benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Benchmarks for RecTools models 17 | =================================================== 18 | 19 | Benchmark RecTools models on a dataset chosen. 20 | One can choose either a built-in dataset or a custom dataset. 21 | Benchmark calculation includes fitting a model, evaluating recommender metrics and profiling 22 | time and memory required for calculations. 23 | 24 | Subpackages 25 | ----------- 26 | src - Tools for benchmark calculation 27 | """ 28 | -------------------------------------------------------------------------------- /benchmark/requirements.txt: -------------------------------------------------------------------------------- 1 | memory-profiler>=0.55.0 2 | pyyaml>=5.1 3 | requests>=2.20.0 4 | types-requests>=2.27.10 5 | types-PyYAML>=6.0.4 6 | -------------------------------------------------------------------------------- /datasets/KION/README.md: -------------------------------------------------------------------------------- 1 | # MTS Kion Implicit Contextualised Sequential Dataset for Movie Recommendation 2 | 3 | ## Dataset Description 4 | This is an official repository of the Kion Movies Recommendation Dataset. 5 | The data was gathered from the users of [MTS Kion](https://kion.ru/home) video streaming platform from 13.03.2021 to 22.08.2022. 6 | It is sample of anonymous data with adding random noise. 7 | 8 | The public part of the dataset includes 5,476,251 interactions of 962,179 users with 15,706 items. The dataset includes user_item interactions and their characteristics (such as temporal information, watch duration and watch percentage), as well as user demographics and rich movies meta-information. 9 | 10 | *Private* part of the dataset contains movies that the users watched within one week following the period covered by the public dataset. It is not released to general public, however there is a public the sandbox, where the researchers can measure MAP@10 metric on the private part of the data. Sandbox is accessible by the address https://ods.ai/competitions/competition-recsys-21/leaderboard/public_sandbox. 11 | 12 | To make a submission, you need to use the [sample_submission.csv](https://github.com/irsafilo/KION_DATASET/blob/main/sample_submission.csv) file, and replace the sample item ids with the ids of the recommended items according to your recommendation model. 13 | 14 | 15 | ### The dataset consists of three parts: 16 | 1. **Interactions.csv** - contains user-item implicit interactions, watch percentages, watch durations 17 | 2. **Users.cvs** - contains users demographics information (sex, age band, income level band, kids flag) 18 | 3. **Items.cvs** - contains items meta-information (title, original title, year, genres, keywords, descriptions, countries, studios, actors, directors) 19 | 20 | #### The users and items files have two versions: 21 | 22 | * **data_original** - original meta-information in Russian language 23 | * **data_en** - english version of the metadata translated with Facebook FAIR’s WMT19 Ru->En machine translation model. 24 | 25 | ## Comparison with MovieLens-25M and Netflix datasets 26 | 27 | ### Quantitative comparison: 28 | | | **Netflix** |**Movielens-25M** | **Kion** | 29 | |------------------------------|-------------|------------------|--------------------| 30 | | Users | 480,189 | 162,541 | 962,179 | 31 | | Items | 17,770 | 59,047 | 15,706 | 32 | | Interactions | 100,480,507 | 25,000,095 | 5,476,251 | 33 | | Avg. Sequence Length | 209.25 | 153.80 | 5.69 | 34 | | Sparsity | 98.82% | 99.73% | 99.9% | 35 | 36 | 37 | ### Qualitative comparison: 38 | | **Dataset Name** | **Netflix**. | **Movielens-25M** | **Kion** | 39 | |----------------------------------------|---------------------|------------------------------------------|---------------------------------| 40 | | Type | Explicit (Ratings) | Explicit (Rating) | Implicit (Interactions) | 41 | | Interaction registration time. | After watching | After watching | At watching | 42 | | Interaction features | Date, Rating | Date, Rating | Date, Duration, Watched Percent | 43 | | User features | None | None | Age, Income, Kids | 44 | | Item features | Release Year, Title |Release Year, Title, Genres, Tags | Content Type, Title, Original Title, Release Year, Genres, Countries, For Kids, Age Rating, Studios, Directors, Actors, Description, Keyword | 45 | 46 | # Kion challenge 47 | This dataset was used for the Kion challenge recommendation contest [ (Official website in Russian Language)](https://ods.ai/competitions/competition-recsys-21). 48 | 49 | This table contains results of the winners of the competition, measured on the private part of the dataset: 50 | 51 | | Position | Name | MAP@10 | Solution Type | 52 | |---------------------|-------------------|--------|--------------------------------| 53 | | 1 | Oleg Lashinin | 0.1221 | Neural and Non-Neural ensemble | 54 | | 2 | Aleksandr Petrov | 0.1214 | Neural and Non-Neural ensemble | 55 | | 3 | Stepan Zimin | 0.1213 | Non-Neural ensemble | 56 | | 4 | Daria Tikhonovich | 0.1148 | Gradient Boosting Trees | 57 | | 5 | Olga | 0.1135 | Gradient Boosting Trees | 58 | | Popularity baseline | | 0.0910 | | 59 | 60 | ## Acknowledgements 61 | 62 | [Igor Belkov](https://github.com/OzmundSedler), [Irina Elisova](https://github.com/ElisovaIra). 63 | 64 | We would would like to acknowledge Kion challenge participants Oleg Lashinin, Stepan Zimin, and Olga for providing descriptions of their Kion Challenge solutions, MTS Holding for providing the Kion dataset, ODS.ai international platform for hosting the competition. 65 | 66 | ## Citations 67 | 68 | If you use this dataset in your research, please cite our work: 69 | 70 | ``` 71 | @article{petrov2022mts, 72 | title={MTS Kion Implicit Contextualised Sequential Dataset for Movie Recommendation}, 73 | author={Aleksandr Petrov, Ildar Safilo, Daria Tikhonovich and Dmitry Ignatov}, 74 | year={2022}, 75 | booktitle={Proceedings of the ACM RecSys CARS Workshop 2022, September 23d, 2022 Seattle, WA, USA } 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9 2 | 3 | WORKDIR /usr/app 4 | 5 | RUN pip install rectools 6 | 7 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==5.1.1 2 | nbsphinx==0.8.9 3 | sphinx-rtd-theme==1.0.0 4 | -------------------------------------------------------------------------------- /docs/source/_static/theme.css: -------------------------------------------------------------------------------- 1 | .wy-side-nav-search { 2 | background-color:rgb(255, 255, 255); 3 | } -------------------------------------------------------------------------------- /docs/source/_templates/custom-base-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname.split(".")[-1] | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. auto{{ objtype }}:: {{ objname }} -------------------------------------------------------------------------------- /docs/source/_templates/custom-class-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname.split(".")[-1] | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | .. autoclass:: {{ objname }} 7 | :members: 8 | :show-inheritance: 9 | :exclude-members: __init__ 10 | {% set allow_inherited = "zero_grad" not in inherited_members %} {# no inheritance for torch.nn.Modules #} 11 | {%if allow_inherited %} 12 | :inherited-members: 13 | {% endif %} 14 | 15 | {% block methods %} 16 | {% set allowed_methods = [] %} 17 | {% for item in methods %}{% if not item.startswith("_") and (item not in inherited_members or allow_inherited) %} 18 | {% set a=allowed_methods.append(item) %} 19 | {% endif %}{%- endfor %} 20 | {% if allowed_methods %} 21 | .. rubric:: {{ _('Methods') }} 22 | 23 | .. autosummary:: 24 | {% for item in allowed_methods %} 25 | ~{{ name }}.{{ item }} 26 | {%- endfor %} 27 | {% endif %} 28 | {% endblock %} 29 | 30 | {% block attributes %} 31 | {% set allowed_attributes = [] %} 32 | {% for item in attributes %}{% if not item.startswith("_") and (item not in inherited_members or allow_inherited) %} 33 | {% set a=allowed_attributes.append(item) %} 34 | {% endif %}{%- endfor %} 35 | {% if allowed_attributes %} 36 | .. rubric:: {{ _('Attributes') }} 37 | 38 | .. autosummary:: 39 | {% for item in allowed_attributes %} 40 | ~{{ name }}.{{ item }} 41 | {%- endfor %} 42 | {% endif %} 43 | {% endblock %} -------------------------------------------------------------------------------- /docs/source/_templates/custom-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname.split(".")[-1] | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block attributes %} 6 | {% if attributes %} 7 | .. rubric:: Module Attributes 8 | 9 | .. autosummary:: 10 | :toctree: 11 | {% for item in attributes %} 12 | {{ item }} 13 | {%- endfor %} 14 | {% endif %} 15 | {% endblock %} 16 | 17 | {% block functions %} 18 | {% if functions %} 19 | .. rubric:: {{ _('Functions') }} 20 | 21 | .. autosummary:: 22 | :toctree: 23 | :template: custom-base-template.rst 24 | {% for item in functions %} 25 | {{ item }} 26 | {%- endfor %} 27 | {% endif %} 28 | {% endblock %} 29 | 30 | {% block classes %} 31 | {% if classes %} 32 | .. rubric:: {{ _('Classes') }} 33 | 34 | .. autosummary:: 35 | :toctree: 36 | :template: custom-class-template.rst 37 | {% for item in classes %} 38 | {{ item }} 39 | {%- endfor %} 40 | {% endif %} 41 | {% endblock %} 42 | 43 | {% block exceptions %} 44 | {% if exceptions %} 45 | .. rubric:: {{ _('Exceptions') }} 46 | 47 | .. autosummary:: 48 | :toctree: 49 | {% for item in exceptions %} 50 | {{ item }} 51 | {%- endfor %} 52 | {% endif %} 53 | {% endblock %} 54 | 55 | {% block modules %} 56 | {% if modules %} 57 | .. rubric:: Modules 58 | 59 | .. autosummary:: 60 | :toctree: 61 | :template: custom-module-template.rst 62 | :recursive: 63 | {% for item in modules %} 64 | {{ item }} 65 | {%- endfor %} 66 | {% endif %} 67 | {% endblock %} -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API 2 | ==== 3 | 4 | .. currentmodule:: rectools 5 | 6 | .. autosummary:: 7 | :toctree: api 8 | :template: custom-module-template.rst 9 | :recursive: 10 | 11 | rectools.dataset 12 | rectools.metrics 13 | rectools.model_selection 14 | rectools.models 15 | rectools.tools 16 | rectools.visuals 17 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | from pathlib import Path 15 | import shutil 16 | import sys 17 | from unittest.mock import Mock 18 | 19 | from sphinx.ext.autosummary import Autosummary 20 | from sphinx.application import Sphinx 21 | 22 | CURRENT_DIR = Path(__file__).parent.absolute() 23 | ROOT_DIR = CURRENT_DIR.parents[1] 24 | sys.path.insert(0, str(ROOT_DIR)) 25 | 26 | 27 | # -- Project information ----------------------------------------------------- 28 | 29 | project = "RecTools" 30 | copyright = """ 31 | 2022 MTS (Mobile Telesystems) 32 | 33 | Licensed under the Apache License, Version 2.0 (the "License"); 34 | you may not use this file except in compliance with the License. 35 | You may obtain a copy of the License at 36 | 37 | http://www.apache.org/licenses/LICENSE-2.0 38 | 39 | Unless required by applicable law or agreed to in writing, software 40 | distributed under the License is distributed on an "AS IS" BASIS, 41 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 42 | See the License for the specific language governing permissions and 43 | limitations under the License. 44 | """ 45 | author = "MTS Big Data" 46 | 47 | # The full version, including alpha/beta/rc tags 48 | # release = "0.2.0" 49 | 50 | 51 | # -- General configuration --------------------------------------------------- 52 | 53 | # Add any Sphinx extension module names here, as strings. They can be 54 | # extensions coming with Sphinx (named "sphinx.ext.*") or your custom 55 | # ones. 56 | extensions = [ 57 | "sphinx.ext.autodoc", 58 | "sphinx.ext.napoleon", 59 | "sphinx.ext.autosummary", 60 | "sphinx.ext.viewcode", 61 | "sphinx.ext.intersphinx", 62 | "sphinx.ext.mathjax", 63 | "nbsphinx" 64 | ] 65 | 66 | autodoc_typehints = "both" 67 | autodoc_typehints_description_target = "all" 68 | # add_module_names = False 69 | 70 | # PACKAGES = [rectools.__name__] 71 | 72 | 73 | # setup configuration 74 | def skip(app, what, name, obj, skip, options): 75 | """ 76 | Document __init__ methods 77 | """ 78 | if name == "__init__": 79 | return True 80 | if name.startswith("_") and what in ("function", "method"): 81 | return True 82 | return skip 83 | 84 | 85 | def get_by_name(string: str): 86 | """ 87 | Import by name and return imported module/function/class 88 | Args: 89 | string (str): module/function/class to import, e.g. "pandas.read_csv" will return read_csv function as 90 | defined by pandas 91 | Returns: 92 | imported object 93 | """ 94 | class_name = string.split(".")[-1] 95 | module_name = ".".join(string.split(".")[:-1]) 96 | 97 | if module_name == "": 98 | return getattr(sys.modules[__name__], class_name) 99 | 100 | mod = __import__(module_name, fromlist=[class_name]) 101 | return getattr(mod, class_name) 102 | 103 | 104 | class ModuleAutoSummary(Autosummary): 105 | def get_items(self, names): 106 | new_names = [] 107 | for name in names: 108 | mod = sys.modules[name] 109 | mod_items = getattr(mod, "__all__", mod.__dict__) 110 | for t in mod_items: 111 | if "." not in t and not t.startswith("_"): 112 | obj = get_by_name(f"{name}.{t}") 113 | if hasattr(obj, "__module__"): 114 | mod_name = obj.__module__ 115 | t = f"{mod_name}.{t}" 116 | if t.startswith("rectools"): 117 | new_names.append(t) 118 | new_items = super().get_items(sorted(new_names, key=lambda x: x.split(".")[-1])) 119 | return new_items 120 | 121 | 122 | def setup(app: Sphinx): 123 | app.connect("autodoc-skip-member", skip) 124 | app.add_directive("moduleautosummary", ModuleAutoSummary) 125 | 126 | 127 | autosummary_generate = True 128 | 129 | # Add any paths that contain templates here, relative to this directory. 130 | templates_path = ["_templates"] 131 | 132 | # List of patterns, relative to source directory, that match files and 133 | # directories to ignore when looking for source files. 134 | # This pattern also affects html_static_path and html_extra_path. 135 | exclude_patterns = ["_build", "_templates"] 136 | 137 | 138 | # -- Options for HTML output ------------------------------------------------- 139 | 140 | # The theme to use for HTML and HTML Help pages. See the documentation for 141 | # a list of builtin themes. 142 | # 143 | html_theme = "sphinx_rtd_theme" 144 | 145 | 146 | html_theme_options = { 147 | "collapse_navigation": False, 148 | "display_version": True, 149 | "logo_only": True, 150 | } 151 | 152 | # html_context = { 153 | # "css_files": [ 154 | # "_static/theme.css" 155 | # ], 156 | # } 157 | 158 | # The name of an image file (relative to this directory) to place at the top 159 | # of the sidebar. 160 | # html_logo = "_static/logo.jpeg" 161 | 162 | # The name of an image file (relative to this directory) to use as a favicon of 163 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 164 | # pixels large. 165 | #html_favicon = None 166 | 167 | # Add any paths that contain custom static files (such as style sheets) here, 168 | # relative to this directory. They are copied after the builtin static files, 169 | # so a file named "default.css" will overwrite the builtin "default.css". 170 | html_static_path = ["_static"] 171 | 172 | -------------------------------------------------------------------------------- /docs/source/dataset.rst: -------------------------------------------------------------------------------- 1 | Dataset 2 | ======= 3 | 4 | .. _dataset: 5 | 6 | .. currentmodule:: rectools 7 | 8 | Details of RecTools Dataset 9 | --------------------------- 10 | 11 | See the API documentation for further details on Dataset: 12 | 13 | .. currentmodule:: rectools 14 | 15 | .. moduleautosummary:: 16 | :toctree: api/ 17 | :template: custom-module-template.rst 18 | :recursive: 19 | 20 | rectools.dataset 21 | -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ========= 3 | 4 | See examples here: https://github.com/MobileTeleSystems/RecTools/tree/main/examples 5 | 6 | .. toctree:: 7 | :maxdepth: 4 8 | :glob: 9 | 10 | examples/1_simple_example 11 | examples/2_cross_validation 12 | examples/3_metrics 13 | examples/4_dataset_with_features 14 | examples/5_benchmark_iALS_with_features 15 | examples/6_benchmark_lightfm_inference 16 | examples/7_visualization 17 | examples/8_debiased_metrics 18 | examples/9_model_configs_and_saving 19 | -------------------------------------------------------------------------------- /docs/source/faq.rst: -------------------------------------------------------------------------------- 1 | FAQ 2 | === 3 | 4 | .. currentmodule:: rectools 5 | 6 | 1. What kind of features should I use: Dense or Sparse? 7 | It depends. In most cases you're better off using `SparseFeatures` because they are better suited 8 | to categorical features. Even if you have a feature with real numerical values you're often better off 9 | if you binarize or discretize it. But there are exceptions to this rule, e.g. ALS features. 10 | 11 | 2. How do I calculate several metrics at once? 12 | Use function `calc_metrics`. It allows to calculate a batch of metrics more efficiently. 13 | It's similar to `reports` from `sklearn`. 14 | 15 | 3. What is the benefit of model wrappers? 16 | They all have the same set of parameters allowing for easier usage. 17 | They also provide extension of existing functionality, such as allowing to filters to eliminate items 18 | that has already been seen, whitelist, features in ALS, I2I. Wrappers have unified interface of output 19 | that is easy to use as input to calculate metrics. They also allowed to speed up performance of some models. 20 | 21 | 4. What is the benefit of using `Dataset`? 22 | It's an easy-to-use wrapping of interactions, features and mapping between item and user ids in feature sets and 23 | those in interaction matrix. 24 | 25 | 5. Why do I need to pass `Dataset` object as an argument to method `recommend`? 26 | It conceals mapping between internal and external user and item ids. Additionally it allows to filter out items 27 | that users have already seen. Some models, such as `LightFM` or `DSSM`, require to pass features. 28 | 29 | 6. Should the same `Dataset` object be used for fitting of a model and for inference of recommendations? 30 | It almost always has to be exactly the same `Dataset` object. 31 | 32 | One of possible exceptions is if during the fitting stage you use both viewing and purchase of an item 33 | as a positive event but you want exempt an item from being recommended only if it was purchased. 34 | In this case you should pass all interactions to train a model and only purchases to infer recommendations. 35 | 36 | Another exception is if a model requires to pass features to infer recommendations and values of those features 37 | have changed. 38 | -------------------------------------------------------------------------------- /docs/source/features.rst: -------------------------------------------------------------------------------- 1 | Components 2 | ========== 3 | 4 | .. currentmodule:: rectools 5 | 6 | Basic Concepts 7 | -------------- 8 | 9 | Columns 10 | ~~~~~~~ 11 | Names of columns are fixed. They are `user_id`, `item_id`, `weight` (numerical value of interaction's importance), 12 | `datetime` (date and time of interaction), `rank` (rank of recommendation according to score) 13 | and `score` (numeric value estimating how good recommendation it is). 14 | Column names are fixed in order to not constantly require mapping of columns in data and their actual meaning. 15 | So you'll need to rename your columns. 16 | 17 | .. currentmodule:: rectools.columns 18 | 19 | .. moduleautosummary:: 20 | :toctree: api/ 21 | :template: custom-module-template.rst 22 | :recursive: 23 | 24 | rectools.columns 25 | 26 | Identifiers 27 | ~~~~~~~~~~~ 28 | Mappings of external identifiers of users or items to internal ones. 29 | Recommendation systems always require to have a mapping between external item ids in data sources 30 | and internal ids in interaction matrix. Managing such mapping requires a lot of diligence. RecTools does it for you. 31 | Every user and item must have a unique id. 32 | External ids may be any unique hashable values, internal - always integers from ``0`` to ``n_objects-1``. 33 | 34 | Interactions 35 | ~~~~~~~~~~~~ 36 | This table stores history of interactions between users and items. It carries the most importance. 37 | Interactions table might also contain column describing importance of an interaction. Also timestamp of interaction. 38 | If no such column is provided, all interactions are assumed to be of equal importance. 39 | 40 | User Features 41 | ~~~~~~~~~~~~~ 42 | This table stores data about users. 43 | It might include age, gender or any other features which may prove to be important for a recommender model. 44 | 45 | Item Features 46 | ~~~~~~~~~~~~~ 47 | This table stores data about items. 48 | It might include category, price or any other features which may prove to be important for a recommender model. 49 | 50 | Hot, warm, cold 51 | ~~~~~~~~~~~~~~~ 52 | There is a concept of a temperature we're using for users and items: 53 | 54 | * **hot** - the ones that are present in interactions used for training (they may or may not have features); 55 | * **warm** - the ones that are not in interactions, but have some features; 56 | * **cold** - the ones we don't know anything about (they are not in interactions and don't have any features). 57 | 58 | All the models are able to generate recommendations for the *hot* users (items). 59 | But as for warm and cold ones, there may be all possible combinations (neither of them, only cold, only warm, both). 60 | The important thing is that if model is able to recommend for cold users (items), but not for warm ones (see table below), 61 | it is still able to recommend for warm ones, but they will be considered as cold (no personalisation should be expected). 62 | 63 | All of the above concepts are combined in `Dataset`. 64 | `Dataset` is used to build recommendation models and infer recommendations. 65 | 66 | .. include:: dataset.rst 67 | 68 | .. include:: models.rst 69 | 70 | What are you waiting for? Train and apply them! 71 | 72 | Recommendation Table 73 | ~~~~~~~~~~~~~~~~~~~~ 74 | Recommendation table contains recommendations for each user. 75 | It has a fixed set of columns, though they are different for i2i and u2i recommendations. 76 | Recommendation table can also be used for calculation of metrics. 77 | 78 | 79 | .. include:: metrics.rst 80 | 81 | Oops, yeah, can't forget about them. 82 | 83 | 84 | .. include:: model_selection.rst 85 | 86 | 87 | .. include:: tools.rst 88 | 89 | 90 | .. include:: visuals.rst 91 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: console 2 | .. highlight:: python 3 | 4 | Welcome to RecTools's documentation! 5 | ==================================== 6 | 7 | RecTools is an easy-to-use Python library which makes the process of building recommendation systems easier, 8 | faster and more structured than ever before. The aim is to collect ready-to-use solutions and best practices in one place to make processes 9 | of creating your first MVP and deploying model to production as fast and easy as possible. 10 | The package also includes useful tools, such as ANN indexes for vector models and fast metric calculation. 11 | 12 | Quick Start 13 | ----------- 14 | 15 | Download data. 16 | 17 | .. code-block:: bash 18 | 19 | $ wget https://files.grouplens.org/datasets/movielens/ml-1m.zip 20 | $ unzip ml-1m.zip 21 | 22 | Train model and infer recommendations. 23 | 24 | .. code-block:: python 25 | 26 | import pandas as pd 27 | from implicit.nearest_neighbours import TFIDFRecommender 28 | 29 | from rectools import Columns 30 | from rectools.dataset import Dataset 31 | from rectools.models import ImplicitItemKNNWrapperModel 32 | 33 | # Read the data 34 | ratings = pd.read_csv( 35 | "ml-1m/ratings.dat", 36 | sep="::", 37 | engine="python", # Because of 2-chars separators 38 | header=None, 39 | names=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], 40 | ) 41 | 42 | # Create dataset 43 | dataset = Dataset.construct(ratings) 44 | 45 | # Fit model 46 | model = ImplicitItemKNNWrapperModel(TFIDFRecommender(K=10)) 47 | model.fit(dataset) 48 | 49 | # Make recommendations 50 | recos = model.recommend( 51 | users=ratings[Columns.User].unique(), 52 | dataset=dataset, 53 | k=10, 54 | filter_viewed=True, 55 | ) 56 | 57 | Installation 58 | ------------ 59 | PyPI 60 | ~~~~ 61 | Install from PyPi using pip 62 | 63 | .. code-block:: bash 64 | 65 | $ pip install rectools 66 | 67 | RecTools is compatible with all operating systems and with Python 3.9+. 68 | The default version doesn't contain all the dependencies. Optional dependencies are the following: 69 | 70 | lightfm: adds wrapper for LightFM model, 71 | torch: adds models based on neural nets, 72 | nmslib: adds fast ANN recommenders. 73 | all: all extra dependencies 74 | 75 | Install RecTools with selected dependencies: 76 | 77 | .. code-block:: bash 78 | 79 | $ pip install rectools[lightfm,torch] 80 | 81 | Why RecTools? 82 | ------------- 83 | The one, the only and the best. 84 | 85 | RecTools provides unified interface for most commonly used recommender models. They include Implicit ALS, Implicit KNN, 86 | LightFM, SVD and DSSM. Recommendations based on popularity and random are also possible. 87 | For model validation, RecTools contains implementation of time split methodology and numerous metrics 88 | to evaluate model's performance. As well as basic ones they also include Diversity, Novelty and Serendipity. 89 | The package also provides tools that allow to evaluate metrics as easy and as fast as possible. 90 | 91 | .. toctree:: 92 | :hidden: 93 | :caption: Table of Contents 94 | :titlesonly: 95 | :maxdepth: 2 96 | 97 | features 98 | api 99 | examples 100 | tutorials 101 | benchmarks 102 | faq 103 | support 104 | -------------------------------------------------------------------------------- /docs/source/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | ======== 3 | 4 | .. _metrics: 5 | 6 | .. currentmodule:: rectools 7 | 8 | Details of RecTools Metrics 9 | --------------------------- 10 | 11 | See the API documentation for further details on Dataset: 12 | 13 | .. currentmodule:: rectools 14 | 15 | .. moduleautosummary:: 16 | :toctree: api/ 17 | :template: custom-module-template.rst 18 | :recursive: 19 | 20 | rectools.metrics 21 | -------------------------------------------------------------------------------- /docs/source/model_selection.rst: -------------------------------------------------------------------------------- 1 | Model selection 2 | =============== 3 | 4 | .. _model_selection: 5 | 6 | .. currentmodule:: rectools 7 | 8 | Details of RecTools Model selection 9 | ----------------------------------- 10 | 11 | See the API documentation for further details on Model selection: 12 | 13 | .. currentmodule:: rectools 14 | 15 | .. moduleautosummary:: 16 | :toctree: api/ 17 | :template: custom-module-template.rst 18 | :recursive: 19 | 20 | rectools.model_selection 21 | -------------------------------------------------------------------------------- /docs/source/models.rst: -------------------------------------------------------------------------------- 1 | Models 2 | ====== 3 | 4 | .. _models: 5 | 6 | .. currentmodule:: rectools 7 | 8 | Details of RecTools Models 9 | -------------------------- 10 | 11 | 12 | +-----------------------------+-------------------+---------------------+---------------------+ 13 | | Model | Supports features | Recommends for warm | Recommends for cold | 14 | +=============================+===================+=====================+=====================+ 15 | | SASRecModel | Yes | No | No | 16 | +-----------------------------+-------------------+---------------------+---------------------+ 17 | | BERT4RecModel | Yes | No | No | 18 | +-----------------------------+-------------------+---------------------+---------------------+ 19 | | DSSMModel | Yes | Yes | No | 20 | +-----------------------------+-------------------+---------------------+---------------------+ 21 | | EASEModel | No | No | No | 22 | +-----------------------------+-------------------+---------------------+---------------------+ 23 | | ImplicitALSWrapperModel | Yes | No | No | 24 | +-----------------------------+-------------------+---------------------+---------------------+ 25 | | ImplicitBPRWrapperModel | No | No | No | 26 | +-----------------------------+-------------------+---------------------+---------------------+ 27 | | ImplicitItemKNNWrapperModel | No | No | No | 28 | +-----------------------------+-------------------+---------------------+---------------------+ 29 | | LightFMWrapperModel | Yes | Yes | Yes | 30 | +-----------------------------+-------------------+---------------------+---------------------+ 31 | | PopularModel | No | No | Yes | 32 | +-----------------------------+-------------------+---------------------+---------------------+ 33 | | PopularInCategoryModel | No | No | Yes | 34 | +-----------------------------+-------------------+---------------------+---------------------+ 35 | | PureSVDModel | No | No | No | 36 | +-----------------------------+-------------------+---------------------+---------------------+ 37 | | RandomModel | No | No | Yes | 38 | +-----------------------------+-------------------+---------------------+---------------------+ 39 | 40 | 41 | See the API documentation for further details on Models: 42 | 43 | .. currentmodule:: rectools 44 | 45 | .. moduleautosummary:: 46 | :toctree: api/ 47 | :template: custom-module-template.rst 48 | :recursive: 49 | 50 | rectools.models 51 | -------------------------------------------------------------------------------- /docs/source/support.rst: -------------------------------------------------------------------------------- 1 | Support 2 | ======= 3 | 4 | .. currentmodule:: rectools 5 | 6 | If something went wrong and you can't find a way to fix it please contact us at `Telegram Channel `__ 7 | -------------------------------------------------------------------------------- /docs/source/tools.rst: -------------------------------------------------------------------------------- 1 | Tools 2 | ===== 3 | 4 | .. _tools: 5 | 6 | .. currentmodule:: rectools 7 | 8 | Details of RecTools Tools 9 | ------------------------- 10 | 11 | See the API documentation for further details on Tools: 12 | 13 | .. currentmodule:: rectools 14 | 15 | .. moduleautosummary:: 16 | :toctree: api/ 17 | :template: custom-module-template.rst 18 | :recursive: 19 | 20 | rectools.tools 21 | -------------------------------------------------------------------------------- /docs/source/tutorials.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | See tutorials here: https://github.com/MobileTeleSystems/RecTools/tree/main/examples/tutorials 5 | 6 | .. toctree:: 7 | :maxdepth: 4 8 | :glob: 9 | 10 | examples/tutorials/baselines_extended_tutorial 11 | examples/tutorials/transformers_tutorial 12 | examples/tutorials/transformers_advanced_training_guide 13 | examples/tutorials/transformers_customization_guide 14 | -------------------------------------------------------------------------------- /docs/source/visuals.rst: -------------------------------------------------------------------------------- 1 | Visuals 2 | =============== 3 | 4 | .. _visuals: 5 | 6 | .. currentmodule:: rectools 7 | 8 | Details of RecTools Visuals 9 | ----------------------------------- 10 | 11 | See the API documentation for further details on Visuals: 12 | 13 | .. currentmodule:: rectools 14 | 15 | .. moduleautosummary:: 16 | :toctree: api/ 17 | :template: custom-module-template.rst 18 | :recursive: 19 | 20 | rectools.visuals 21 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | create = true 3 | in-project = true 4 | -------------------------------------------------------------------------------- /rectools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | MTS Advanced Recommender Systems package for Python 17 | =================================================== 18 | 19 | RecTools provides a convenient wrappers for popular recommendation 20 | algorithms (ItemKNN, ALS, LightFM, etc.) and offers its own 21 | realisations and optimizations. It also provides tools 22 | for metrics computation, easy data conversion, and preparing 23 | models for production-ready systems. 24 | 25 | See https://rectools.readthedocs.io for complete documentation. 26 | 27 | Subpackages 28 | ----------- 29 | dataset - Data and identifiers conversion 30 | metrics - Metrics calculation 31 | model_selection - Cross-validation 32 | models - Recommendation models 33 | tools - Useful instruments 34 | visuals - Visualization apps 35 | """ 36 | 37 | from .columns import Columns 38 | from .types import AnyIds, AnySequence, ExternalId, ExternalIds, InternalId, InternalIds 39 | from .version import VERSION 40 | 41 | __version__ = VERSION 42 | 43 | __all__ = ( 44 | "Columns", 45 | "AnyIds", 46 | "AnySequence", 47 | "ExternalId", 48 | "ExternalIds", 49 | "InternalId", 50 | "InternalIds", 51 | "__version__", 52 | ) 53 | -------------------------------------------------------------------------------- /rectools/columns.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Column names.""" 16 | 17 | 18 | class Columns: 19 | """Fixed column names for tables that contain interactions and recommendations.""" 20 | 21 | User = "user_id" 22 | Item = "item_id" 23 | TargetItem = "target_item_id" 24 | Weight = "weight" 25 | Datetime = "datetime" 26 | Rank = "rank" 27 | Score = "score" 28 | Model = "model" 29 | Split = "i_split" 30 | UserItem = [User, Item] 31 | Interactions = [User, Item, Weight, Datetime] 32 | Recommendations = [User, Item, Score, Rank] 33 | RecommendationsI2I = [TargetItem, Item, Score, Rank] 34 | -------------------------------------------------------------------------------- /rectools/compat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | 17 | 18 | class RequirementUnavailable: 19 | """Base class for dummy classes, which are returned if there are no dependencies required for the original class""" 20 | 21 | requirement: str = NotImplemented 22 | 23 | def __new__(cls, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: 24 | """Raise ImportError when an attempt to instantiate an unavailable model is made""" 25 | raise ImportError( 26 | f"Requirement `{cls.requirement}` is not satisfied. Run `pip install rectools[{cls.requirement}]` " 27 | f"to install extra requirements before accessing {cls.__name__}." 28 | ) 29 | 30 | 31 | class LightFMWrapperModel(RequirementUnavailable): 32 | """Dummy class, which is returned if there are no dependencies required for the model""" 33 | 34 | requirement = "lightfm" 35 | 36 | 37 | class DSSMModel(RequirementUnavailable): 38 | """Dummy class, which is returned if there are no dependencies required for the model""" 39 | 40 | requirement = "torch" 41 | 42 | 43 | class SASRecModel(RequirementUnavailable): 44 | """Dummy class, which is returned if there are no dependencies required for the model""" 45 | 46 | requirement = "torch" 47 | 48 | 49 | class BERT4RecModel(RequirementUnavailable): 50 | """Dummy class, which is returned if there are no dependencies required for the model""" 51 | 52 | requirement = "torch" 53 | 54 | 55 | class ItemToItemAnnRecommender(RequirementUnavailable): 56 | """Dummy class, which is returned if there are no dependencies required for the model""" 57 | 58 | requirement = "nmslib" 59 | 60 | 61 | class UserToItemAnnRecommender(RequirementUnavailable): 62 | """Dummy class, which is returned if there are no dependencies required for the model""" 63 | 64 | requirement = "nmslib" 65 | 66 | 67 | class VisualApp(RequirementUnavailable): 68 | """Dummy class, which is returned if there are no dependencies required for the model""" 69 | 70 | requirement = "visuals" 71 | 72 | 73 | class ItemToItemVisualApp(RequirementUnavailable): 74 | """Dummy class, which is returned if there are no dependencies required for the model""" 75 | 76 | requirement = "visuals" 77 | 78 | 79 | class MetricsApp(RequirementUnavailable): 80 | """Dummy class, which is returned if there are no dependencies required for the model""" 81 | 82 | requirement = "visuals" 83 | -------------------------------------------------------------------------------- /rectools/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Data conversion tools (:mod:`rectools.dataset`). 17 | ========================================================== 18 | 19 | Data and identifiers conversion tools for future working with 20 | models. 21 | 22 | 23 | Data Containers 24 | --------------- 25 | `dataset.IdMap` - Mapping between external and internal identifiers. 26 | `dataset.DenseFeatures` - Container for dense features. 27 | `dataset.SparseFeatures` - Container for sparse features. 28 | `dataset.Interactions` - Container for interactions. 29 | `dataset.Dataset` - Container for all data. 30 | 31 | """ 32 | 33 | 34 | from .dataset import Dataset 35 | from .features import DenseFeatures, Features, SparseFeatures 36 | from .identifiers import IdMap 37 | from .interactions import Interactions 38 | 39 | __all__ = ( 40 | "Dataset", 41 | "DenseFeatures", 42 | "SparseFeatures", 43 | "Features", 44 | "IdMap", 45 | "Interactions", 46 | ) 47 | -------------------------------------------------------------------------------- /rectools/exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Exceptions module""" 16 | 17 | 18 | class NotFittedError(Exception): 19 | """The error is raised when some fittable object is attempted to be used without fitting first.""" 20 | 21 | def __init__(self, obj_name: str) -> None: 22 | super().__init__() 23 | self.obj_name = obj_name 24 | 25 | def __str__(self) -> str: 26 | return f"{self.obj_name} isn't fitted, call method `fit` first." 27 | -------------------------------------------------------------------------------- /rectools/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Metrics calculation tools (:mod:`rectools.metrics`). 17 | =============================================================== 18 | 19 | Tools for fast and convenient calculation of 20 | different recommendation metrics. 21 | 22 | Metrics 23 | ------- 24 | `metrics.Precision` 25 | `metrics.Recall` 26 | `metrics.MAP` 27 | `metrics.NDCG` 28 | `metrics.MRR` 29 | `metrics.HitRate` 30 | `metrics.PartialAUC` 31 | `metrics.PAP` 32 | `metrics.F1Beta` 33 | `metrics.Accuracy` 34 | `metrics.MCC` 35 | `metrics.MeanInvUserFreq` 36 | `metrics.IntraListDiversity` 37 | `metrics.AvgRecPopularity` 38 | `metrics.Serendipity` 39 | `metrics.Intersection` 40 | `metrics.SufficientReco` 41 | `metrics.UnrepeatedReco` 42 | `metrics.CoveredUsers` 43 | `metrics.CatalogCoverage` 44 | 45 | Tools 46 | ----- 47 | `metrics.calc_metrics` - calculate a set of metrics efficiently 48 | `metrics.PairwiseDistanceCalculator` 49 | `metrics.PairwiseHammingDistanceCalculator` 50 | `metrics.SparsePairwiseHammingDistanceCalculator` 51 | `metrics.DebiasConfig` 52 | `metrics.debias_interactions` 53 | """ 54 | 55 | from .auc import PAP, PartialAUC 56 | from .catalog import CatalogCoverage 57 | from .classification import MCC, Accuracy, F1Beta, HitRate, Precision, Recall 58 | from .debias import DebiasConfig, debias_interactions 59 | from .distances import ( 60 | PairwiseDistanceCalculator, 61 | PairwiseHammingDistanceCalculator, 62 | SparsePairwiseHammingDistanceCalculator, 63 | ) 64 | from .diversity import IntraListDiversity 65 | from .dq import CoveredUsers, SufficientReco, UnrepeatedReco 66 | from .intersection import Intersection 67 | from .novelty import MeanInvUserFreq 68 | from .popularity import AvgRecPopularity 69 | from .ranking import MAP, MRR, NDCG 70 | from .scoring import calc_metrics 71 | from .serendipity import Serendipity 72 | 73 | __all__ = ( 74 | "Precision", 75 | "Recall", 76 | "F1Beta", 77 | "Accuracy", 78 | "MCC", 79 | "HitRate", 80 | "MAP", 81 | "NDCG", 82 | "PartialAUC", 83 | "PAP", 84 | "MRR", 85 | "CatalogCoverage", 86 | "MeanInvUserFreq", 87 | "IntraListDiversity", 88 | "AvgRecPopularity", 89 | "Serendipity", 90 | "calc_metrics", 91 | "PairwiseDistanceCalculator", 92 | "PairwiseHammingDistanceCalculator", 93 | "SparsePairwiseHammingDistanceCalculator", 94 | "Intersection", 95 | "SufficientReco", 96 | "UnrepeatedReco", 97 | "CoveredUsers", 98 | "DebiasConfig", 99 | "debias_interactions", 100 | ) 101 | -------------------------------------------------------------------------------- /rectools/metrics/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Base metric module.""" 16 | 17 | import typing as tp 18 | import warnings 19 | 20 | import attr 21 | import pandas as pd 22 | 23 | from rectools import Columns 24 | 25 | ExternalItemId = tp.Union[str, int] 26 | Catalog = tp.Collection[ExternalItemId] 27 | 28 | 29 | @attr.s(auto_attribs=True) 30 | class MetricAtK: 31 | """ 32 | Base class of metrics that depends on `k` - 33 | a number of top recommendations used to calculate a metric. 34 | 35 | Warning: This class should not be used directly. 36 | Use derived classes instead. 37 | 38 | Parameters 39 | ---------- 40 | k : int 41 | Number of items at the top of recommendations list that will be used to calculate metric. 42 | """ 43 | 44 | k: int 45 | 46 | @classmethod 47 | def _check( 48 | cls, 49 | reco: pd.DataFrame, 50 | interactions: tp.Optional[pd.DataFrame] = None, 51 | prev_interactions: tp.Optional[pd.DataFrame] = None, 52 | ref_reco: tp.Optional[pd.DataFrame] = None, 53 | ) -> None: 54 | cls._check_columns(reco, "reco", (Columns.User, Columns.Item, Columns.Rank)) 55 | cls._check_columns(interactions, "interactions", (Columns.User, Columns.Item)) 56 | cls._check_columns(prev_interactions, "prev_interactions", (Columns.User, Columns.Item)) 57 | cls._check_columns(ref_reco, "ref_reco", (Columns.User, Columns.Item, Columns.Rank)) 58 | 59 | cls._check_rank_column(reco, "reco") 60 | cls._check_rank_column(ref_reco, "ref_reco") 61 | 62 | @staticmethod 63 | def _check_columns(df: tp.Optional[pd.DataFrame], name: str, required_columns: tp.Iterable[str]) -> None: 64 | if df is None: 65 | return 66 | required_columns = set(required_columns) 67 | actual_columns = set(df.columns) 68 | if not actual_columns >= required_columns: 69 | raise KeyError(f"Missed columns {required_columns - actual_columns} in '{name}' dataframe") 70 | 71 | @staticmethod 72 | def _check_rank_column(reco: pd.DataFrame, df_name: str) -> None: 73 | if reco is None or reco.empty: 74 | return 75 | if reco[Columns.Rank].dtype.kind not in ("i", "u"): 76 | warnings.warn(f"Expected integer dtype of '{Columns.Rank}' column in '{df_name}' dataframe.") 77 | if int(round(reco[Columns.Rank].min())) != 1: 78 | warnings.warn(f"Expected min value of '{Columns.Rank}' column in '{df_name}' dataframe to be equal to 1.") 79 | 80 | 81 | def merge_reco(reco: pd.DataFrame, interactions: pd.DataFrame) -> pd.DataFrame: 82 | """ 83 | Merge recommendation table with interactions table. 84 | 85 | Parameters 86 | ---------- 87 | reco : pd.DataFrame 88 | Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. 89 | interactions : pd.DataFrame 90 | Interactions table with columns `Columns.User`, `Columns.Item`. 91 | 92 | Returns 93 | ------- 94 | pd.DataFrame 95 | Result of merging. 96 | """ 97 | merged = pd.merge( 98 | interactions.reindex(columns=Columns.UserItem), 99 | reco.reindex(columns=Columns.UserItem + [Columns.Rank]), 100 | on=Columns.UserItem, 101 | how="left", 102 | ) 103 | return merged 104 | 105 | 106 | def outer_merge_reco(reco: pd.DataFrame, interactions: pd.DataFrame) -> pd.DataFrame: 107 | """ 108 | Merge recommendation table with interactions table with outer join. All ranks for all users are 109 | present with no skipping. Null ranks will be specified for test interactions that were not 110 | predicted in recommendations. 111 | This method is useful for AUC based ranking metrics. 112 | 113 | Parameters 114 | ---------- 115 | reco : pd.DataFrame 116 | Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. 117 | interactions : pd.DataFrame 118 | Interactions table with columns `Columns.User`, `Columns.Item`. 119 | 120 | Returns 121 | ------- 122 | pd.DataFrame 123 | Result of merging with added `__test_positive` boolean column. 124 | """ 125 | prepared_interactions = interactions.reindex(columns=Columns.UserItem).drop_duplicates() 126 | prepared_interactions["__test_positive"] = True 127 | test_users = prepared_interactions[Columns.User].unique() 128 | prepared_reco = reco[reco[Columns.User].isin(test_users)].reindex(columns=Columns.UserItem + [Columns.Rank]) 129 | merged = pd.merge( 130 | prepared_interactions, 131 | prepared_reco, 132 | on=Columns.UserItem, 133 | how="outer", 134 | ) 135 | max_rank = prepared_reco.groupby(Columns.User)[Columns.Rank].max() 136 | full_ranks = max_rank.apply(lambda a: list(range(1, a + 1))).explode().rename(Columns.Rank) 137 | ranked_reco = merged.merge(full_ranks, on=[Columns.User, Columns.Rank], how="outer").sort_values( 138 | [Columns.User, Columns.Rank] 139 | ) 140 | ranked_reco["__test_positive"] = ranked_reco["__test_positive"].fillna(False) 141 | return ranked_reco 142 | -------------------------------------------------------------------------------- /rectools/metrics/catalog.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Catalog statistics recommendations metrics.""" 16 | 17 | import typing as tp 18 | 19 | import attr 20 | import pandas as pd 21 | 22 | from rectools import Columns 23 | 24 | from .base import Catalog, MetricAtK 25 | 26 | 27 | @attr.s 28 | class CatalogCoverage(MetricAtK): 29 | """ 30 | Count (or share) of items from catalog that is present in recommendations for all users. 31 | 32 | Parameters 33 | ---------- 34 | k : int 35 | Number of items at the top of recommendations list that will be used to calculate metric. 36 | normalize: bool, default ``False`` 37 | Flag, which says whether to normalize metric or not. 38 | """ 39 | 40 | normalize: bool = attr.ib(default=False) 41 | 42 | def calc(self, reco: pd.DataFrame, catalog: Catalog) -> float: 43 | """ 44 | Calculate metric value. 45 | 46 | Parameters 47 | ---------- 48 | reco : pd.DataFrame 49 | Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. 50 | catalog : collection 51 | Collection of unique item ids that could be used for recommendations. 52 | 53 | Returns 54 | ------- 55 | float 56 | Value of metric (aggregated for all users). 57 | """ 58 | res = reco.loc[reco[Columns.Rank] <= self.k, Columns.Item].nunique() 59 | if self.normalize: 60 | return res / len(catalog) 61 | return res 62 | 63 | 64 | CatalogMetric = CatalogCoverage 65 | 66 | 67 | def calc_catalog_metrics( 68 | metrics: tp.Dict[str, CatalogMetric], 69 | reco: pd.DataFrame, 70 | catalog: Catalog, 71 | ) -> tp.Dict[str, float]: 72 | """ 73 | Calculate metrics of catalog statistics for recommendations. 74 | 75 | Warning: It is not recommended to use this function directly. 76 | Use `calc_metrics` instead. 77 | 78 | Parameters 79 | ---------- 80 | metrics : dict(str -> CatalogMetric) 81 | Dict of metric objects to calculate, 82 | where key is a metric name and value is a metric object. 83 | reco : pd.DataFrame 84 | Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. 85 | catalog : collection 86 | Collection of unique item ids that could be used for recommendations. 87 | 88 | Returns 89 | ------- 90 | dict(str->float) 91 | Dictionary where keys are the same as keys in `metrics` 92 | and values are metric calculation results. 93 | """ 94 | return {metric_name: metric.calc(reco, catalog) for metric_name, metric in metrics.items()} 95 | -------------------------------------------------------------------------------- /rectools/metrics/intersection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Hashable 16 | from typing import Dict, Optional, Union 17 | 18 | import attr 19 | import numpy as np 20 | import pandas as pd 21 | 22 | from rectools import Columns 23 | from rectools.metrics.base import MetricAtK 24 | from rectools.metrics.classification import Recall 25 | 26 | 27 | @attr.s 28 | class Intersection(MetricAtK): 29 | """ 30 | Metric to measure intersection in user-item pairs between recommendation lists. 31 | 32 | The intersection@k equals the share of ``reco`` that is present in ``ref_reco``. 33 | 34 | This corresponds to the following algorithm: 35 | 1) filter ``reco`` by ``k`` 36 | 2) filter ``ref_reco`` by ``ref_k`` 37 | 3) calculate the proportion of items in ``reco`` that are also present in ``ref_reco`` 38 | The second and third steps are equivalent to computing Recall@ref_k when: 39 | - Interactions consists of ``reco`` without the `Columns.Rank` column. 40 | - Recommendation table is ``ref_reco`` 41 | 42 | Parameters 43 | ---------- 44 | k : int 45 | Number of items in top of recommendations list that will be used to calculate metric. 46 | ref_k : int, optional 47 | Number of items in top of reference recommendations list that will be used to calculate metric. 48 | If ``ref_k`` is None than ``ref_reco`` will be filtered with ``ref_k = k``. Default: None. 49 | """ 50 | 51 | ref_k: Optional[int] = attr.ib(default=None) 52 | 53 | def calc(self, reco: pd.DataFrame, ref_reco: pd.DataFrame) -> float: 54 | """ 55 | Calculate metric value. 56 | 57 | Parameters 58 | ---------- 59 | reco : pd.DataFrame 60 | Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. 61 | ref_reco : pd.DataFrame 62 | Reference recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. 63 | 64 | Returns 65 | ------- 66 | float 67 | Value of metric (average between users). 68 | """ 69 | per_user = self.calc_per_user(reco, ref_reco) 70 | return per_user.mean() 71 | 72 | def calc_per_user(self, reco: pd.DataFrame, ref_reco: pd.DataFrame) -> pd.Series: 73 | """ 74 | Calculate metric values for all users. 75 | 76 | Parameters 77 | ---------- 78 | reco : pd.DataFrame 79 | Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. 80 | ref_reco : pd.DataFrame 81 | Reference recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. 82 | 83 | Returns 84 | ------- 85 | pd.Series: 86 | Values of metric (index - user id, values - metric value for every user). 87 | """ 88 | self._check(reco, ref_reco=ref_reco) 89 | 90 | if ref_reco.shape[0] == 0: 91 | return pd.Series(index=pd.Series(name=Columns.User, dtype=int), dtype=np.float64) 92 | 93 | # Optimisation for cross_validate 94 | if ref_reco is reco: 95 | return pd.Series( 96 | data=1, 97 | index=pd.Series(data=reco[Columns.User].unique(), name=Columns.User, dtype=int), 98 | dtype=np.float64, 99 | ) 100 | 101 | filtered_reco = reco[reco[Columns.Rank] <= self.k] 102 | 103 | ref_k = self.ref_k if self.ref_k is not None else self.k 104 | recall = Recall(k=ref_k) 105 | 106 | return recall.calc_per_user(ref_reco, filtered_reco[Columns.UserItem]) 107 | 108 | 109 | IntersectionMetric = Intersection 110 | 111 | 112 | def calc_intersection_metrics( 113 | metrics: Dict[str, IntersectionMetric], 114 | reco: pd.DataFrame, 115 | ref_reco: Union[pd.DataFrame, Dict[Hashable, pd.DataFrame]], 116 | ) -> Dict[str, float]: 117 | """ 118 | Calculate intersection metrics. 119 | 120 | Warning: It is not recommended to use this function directly. 121 | Use `calc_metrics` instead. 122 | 123 | Parameters 124 | ---------- 125 | metrics : dict(str -> IntersectionMetric) 126 | Dict of metric objects to calculate, 127 | where key is metric name and value is metric object. 128 | reco : pd.DataFrame 129 | Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. 130 | ref_reco : Union[pd.DataFrame, Dict[Hashable, pd.DataFrame]] 131 | Reference recommendations table(s) with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. 132 | 133 | Returns 134 | ------- 135 | dict(str->float) 136 | Dictionary where keys are the same as keys in `metrics` 137 | and values are metric calculation results. 138 | """ 139 | results = {} 140 | 141 | for metric_name, metric in metrics.items(): 142 | if isinstance(ref_reco, pd.DataFrame): 143 | results[metric_name] = metric.calc(reco, ref_reco) 144 | else: 145 | for ref_reco_name, ref_reco_df in ref_reco.items(): 146 | results[f"{metric_name}_{ref_reco_name}"] = metric.calc(reco, ref_reco_df) 147 | 148 | return results 149 | -------------------------------------------------------------------------------- /rectools/model_selection/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Model selection tools (:mod:`rectools.model_selection`) 17 | ======================================================= 18 | 19 | Instruments to validate and compare models. 20 | 21 | Splitters 22 | --------- 23 | `model_selection.Splitter` - base class for all splitters 24 | 25 | `model_selection.KFoldSplitter` - split interactions randomly 26 | `model_selection.LastNSplitter` - split interactions by recent activity 27 | `model_selection.TimeRangeSplit` - split interactions by time 28 | 29 | Model selection tools 30 | --------------------- 31 | `model_selection.cross_validate` - run cross validation on multiple models with multiple metrics 32 | """ 33 | 34 | from .cross_validate import cross_validate 35 | from .last_n_split import LastNSplitter 36 | from .random_split import RandomSplitter 37 | from .splitter import Splitter 38 | from .time_split import TimeRangeSplitter 39 | 40 | __all__ = ( 41 | "Splitter", 42 | "RandomSplitter", 43 | "LastNSplitter", 44 | "TimeRangeSplitter", 45 | "cross_validate", 46 | ) 47 | -------------------------------------------------------------------------------- /rectools/model_selection/last_n_split.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """LastNSplitter.""" 16 | 17 | import typing as tp 18 | 19 | import numpy as np 20 | import pandas as pd 21 | 22 | from rectools import Columns 23 | from rectools.dataset import Interactions 24 | from rectools.model_selection.splitter import Splitter 25 | 26 | 27 | class LastNSplitter(Splitter): 28 | """ 29 | Splitter for cross-validation by leave-one-out / leave-k-out scheme (recent activity). 30 | Generate train and test putting last n interactions for each user in test 31 | and all of his previous interactions in train. 32 | Cross-validation is achieved with sliding window over each users interactions history. 33 | 34 | This technique may be used for sequential recommendation scenarios. 35 | It is common in research papers on sequential recommendations. 36 | But it doesn't fully prevent data leak from the future. 37 | 38 | It is also possible to exclude cold users and items and already seen items. 39 | 40 | Parameters 41 | ---------- 42 | n : int 43 | Number of interactions for each user that will be included in test. 44 | n_splits : int, default 1 45 | Number of test folds. 46 | filter_cold_users : bool, default ``True`` 47 | If `True`, users that are not present in train will be excluded from test. 48 | WARNING: both cold and warm users will be excluded from test. 49 | filter_cold_items : bool, default ``True`` 50 | If `True`, items that are not present in train will be excluded from test. 51 | WARNING: both cold and warm items will be excluded from test. 52 | filter_already_seen : bool, default ``True`` 53 | If ``True``, pairs (user, item) that are present in train will be excluded from test. 54 | 55 | Examples 56 | -------- 57 | >>> from rectools import Columns 58 | >>> df = pd.DataFrame( 59 | ... [ 60 | ... [1, 1, 1, "2021-09-01"], # 0 61 | ... [1, 2, 1, "2021-09-02"], # 1 62 | ... [1, 1, 1, "2021-09-03"], # 2 63 | ... [1, 2, 1, "2021-09-04"], # 3 64 | ... [1, 2, 1, "2021-09-05"], # 4 65 | ... [2, 1, 1, "2021-08-20"], # 5 66 | ... [2, 2, 1, "2021-08-21"], # 6 67 | ... [2, 2, 1, "2021-08-22"], # 7 68 | ... ], 69 | ... columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], 70 | ... ).astype({Columns.Datetime: "datetime64[ns]"}) 71 | >>> interactions = Interactions(df) 72 | >>> 73 | >>> splitter = LastNSplitter(2, 2, False, False, False) 74 | >>> for train_ids, test_ids, _ in splitter.split(interactions): 75 | ... print(train_ids, test_ids) 76 | [0] [1 2 5] 77 | [0 1 2 5] [3 4 6 7] 78 | >>> 79 | >>> splitter = LastNSplitter(2, 2, True, False, False) 80 | >>> for train_ids, test_ids, _ in splitter.split(interactions): 81 | ... print(train_ids, test_ids) 82 | [0] [1 2] 83 | [0 1 2 5] [3 4 6 7] 84 | """ 85 | 86 | def __init__( 87 | self, 88 | n: int, 89 | n_splits: int = 1, 90 | filter_cold_users: bool = True, 91 | filter_cold_items: bool = True, 92 | filter_already_seen: bool = True, 93 | ) -> None: 94 | super().__init__(filter_cold_users, filter_cold_items, filter_already_seen) 95 | self.n = n 96 | self.n_splits = n_splits 97 | 98 | def _split_without_filter( 99 | self, 100 | interactions: Interactions, 101 | collect_fold_stats: bool = False, 102 | ) -> tp.Iterator[tp.Tuple[np.ndarray, np.ndarray, tp.Dict[str, tp.Any]]]: 103 | df = interactions.df 104 | idx = pd.RangeIndex(0, len(df)) 105 | 106 | # last event - rank=1 107 | inv_ranks = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=False) 108 | 109 | for i_split in range(self.n_splits)[::-1]: 110 | min_rank = i_split * self.n # excluded 111 | max_rank = min_rank + self.n # included 112 | 113 | test_mask = (inv_ranks > min_rank) & (inv_ranks <= max_rank) 114 | train_mask = inv_ranks > max_rank 115 | 116 | test_idx = idx[test_mask].values 117 | train_idx = idx[train_mask].values 118 | 119 | split_info = {"i_split": self.n_splits - i_split - 1} 120 | 121 | yield train_idx, test_idx, split_info 122 | -------------------------------------------------------------------------------- /rectools/model_selection/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from scipy import sparse 17 | 18 | from rectools.utils import isin_2d_int 19 | 20 | 21 | def get_not_seen_mask( 22 | train_users: np.ndarray, 23 | train_items: np.ndarray, 24 | test_users: np.ndarray, 25 | test_items: np.ndarray, 26 | ) -> np.ndarray: 27 | """ 28 | Return mask for test interactions that is not in train interactions. 29 | 30 | Parameters 31 | ---------- 32 | train_users : np.ndarray 33 | Integer array of users in train interactions (it's not a unique users!). 34 | train_items : np.ndarray 35 | Integer array of items in train interactions. Has same length as `train_users`. 36 | test_users : np.ndarray 37 | Integer array of users in test interactions (it's not a unique users!). 38 | test_items : np.ndarray 39 | Integer array of items in test interactions. Has same length as `test_users`. 40 | 41 | Returns 42 | ------- 43 | np.ndarray 44 | Boolean mask of same length as `test_users` (`test_items`). 45 | ``True`` means interaction not present in train. 46 | """ 47 | if train_users.size != train_items.size: 48 | raise ValueError("Lengths of `train_users` and `train_items` must be the same") 49 | if test_users.size != test_items.size: 50 | raise ValueError("Lengths of `test_users` and `test_items` must be the same") 51 | 52 | if train_users.size == 0: 53 | return np.ones(test_users.size, dtype=bool) 54 | if test_users.size == 0: 55 | return np.array([], dtype=bool) 56 | 57 | n_users = max(train_users.max(), test_users.max()) + 1 58 | n_items = max(train_items.max(), test_items.max()) + 1 59 | 60 | cls = sparse.csr_matrix if n_users < n_items else sparse.csc_matrix 61 | 62 | def make_matrix(users: np.ndarray, items: np.ndarray) -> sparse.spmatrix: 63 | return cls((np.ones(len(users), dtype=bool), (users, items)), shape=(n_users, n_items)) 64 | 65 | train_mat = make_matrix(train_users, train_items) 66 | test_mat = make_matrix(test_users, test_items) 67 | 68 | already_seen_coo = test_mat.multiply(train_mat).tocoo(copy=False) 69 | del train_mat, test_mat 70 | already_seen_arr = np.vstack((already_seen_coo.row, already_seen_coo.col)).T.astype(test_users.dtype) 71 | del already_seen_coo 72 | 73 | test_ui = np.vstack((test_users, test_items)).T 74 | not_seen_mask = isin_2d_int(test_ui, already_seen_arr, invert=True) 75 | return not_seen_mask 76 | -------------------------------------------------------------------------------- /rectools/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=wrong-import-position 16 | 17 | """ 18 | Recommendation models (:mod:`rectools.models`) 19 | ============================================== 20 | 21 | Convenient wrappers for popular recommendation 22 | algorithms (ItemKNN, ALS, LightFM), also some custom 23 | implementations. 24 | 25 | 26 | Models 27 | ------ 28 | `models.DSSMModel` 29 | `models.EASEModel` 30 | `models.ImplicitALSWrapperModel` 31 | `models.ImplicitBPRWrapperModel` 32 | `models.ImplicitItemKNNWrapperModel` 33 | `models.LightFMWrapperModel` 34 | `models.PopularModel` 35 | `models.PopularInCategoryModel` 36 | `models.PureSVDModel` 37 | `models.RandomModel` 38 | `models.nn.bert4rec.BERT4RecModel` 39 | `models.nn.sasrec.SASRecModel` 40 | """ 41 | 42 | from .ease import EASEModel 43 | from .implicit_als import ImplicitALSWrapperModel 44 | from .implicit_bpr import ImplicitBPRWrapperModel 45 | from .implicit_knn import ImplicitItemKNNWrapperModel 46 | from .popular import PopularModel 47 | from .popular_in_category import PopularInCategoryModel 48 | from .pure_svd import PureSVDModel 49 | from .random import RandomModel 50 | from .serialization import load_model, model_from_config, model_from_params 51 | 52 | try: 53 | from .lightfm import LightFMWrapperModel 54 | except ImportError: # pragma: no cover 55 | from ..compat import LightFMWrapperModel # type: ignore 56 | 57 | try: 58 | from .nn.dssm import DSSMModel 59 | from .nn.transformers.bert4rec import BERT4RecModel 60 | from .nn.transformers.sasrec import SASRecModel 61 | except ImportError: # pragma: no cover 62 | from ..compat import BERT4RecModel, DSSMModel, SASRecModel # type: ignore 63 | 64 | 65 | __all__ = ( 66 | "SASRecModel", 67 | "BERT4RecModel", 68 | "EASEModel", 69 | "ImplicitALSWrapperModel", 70 | "ImplicitBPRWrapperModel", 71 | "ImplicitItemKNNWrapperModel", 72 | "LightFMWrapperModel", 73 | "PopularModel", 74 | "PopularInCategoryModel", 75 | "PureSVDModel", 76 | "RandomModel", 77 | "DSSMModel", 78 | "load_model", 79 | "model_from_config", 80 | "model_from_params", 81 | ) 82 | -------------------------------------------------------------------------------- /rectools/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Recommendation models based on neural nets.""" 16 | -------------------------------------------------------------------------------- /rectools/models/nn/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Recommendation models based on transformers.""" 16 | -------------------------------------------------------------------------------- /rectools/models/nn/transformers/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | PADDING_VALUE = "PAD" 16 | MASKING_VALUE = "MASK" 17 | -------------------------------------------------------------------------------- /rectools/models/nn/transformers/negative_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | 17 | import torch 18 | 19 | 20 | class TransformerNegativeSamplerBase: 21 | """Base class for negative sampler. To create custom sampling logic inherit 22 | from this class and pass your custom negative sampler to your model parameters. 23 | 24 | Parameters 25 | ---------- 26 | n_negatives : int 27 | Number of negatives. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | n_negatives: int, 33 | **kwargs: tp.Any, 34 | ) -> None: 35 | self.n_negatives = n_negatives 36 | 37 | def get_negatives( 38 | self, 39 | batch_dict: tp.Dict, 40 | lowest_id: int, 41 | highest_id: int, 42 | session_len_limit: tp.Optional[int] = None, 43 | **kwargs: tp.Any, 44 | ) -> torch.Tensor: 45 | """Return sampled negatives.""" 46 | raise NotImplementedError() 47 | 48 | 49 | class CatalogUniformSampler(TransformerNegativeSamplerBase): 50 | """Class to sample negatives uniformly from all catalog items. 51 | 52 | Parameters 53 | ---------- 54 | n_negatives : int 55 | Number of negatives. 56 | """ 57 | 58 | def get_negatives( 59 | self, 60 | batch_dict: tp.Dict, 61 | lowest_id: int, 62 | highest_id: int, 63 | session_len_limit: tp.Optional[int] = None, 64 | **kwargs: tp.Any, 65 | ) -> torch.Tensor: 66 | """Return sampled negatives.""" 67 | session_len = session_len_limit if session_len_limit is not None else batch_dict["x"].shape[1] 68 | negatives = torch.randint( 69 | low=lowest_id, 70 | high=highest_id, 71 | size=(batch_dict["x"].shape[0], session_len, self.n_negatives), 72 | ) # [batch_size, session_max_len, n_negatives] 73 | return negatives 74 | -------------------------------------------------------------------------------- /rectools/models/nn/transformers/similarity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | 17 | import numpy as np 18 | import torch 19 | from scipy import sparse 20 | 21 | from rectools.models.base import InternalRecoTriplet 22 | from rectools.models.rank import Distance, TorchRanker 23 | from rectools.types import InternalIdsArray 24 | 25 | 26 | class SimilarityModuleBase(torch.nn.Module): 27 | """Base class for similarity module.""" 28 | 29 | def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: 30 | raise NotImplementedError() 31 | 32 | def _get_pos_neg_logits( 33 | self, session_embs: torch.Tensor, item_embs: torch.Tensor, candidate_item_ids: torch.Tensor 34 | ) -> torch.Tensor: 35 | raise NotImplementedError() 36 | 37 | def forward( 38 | self, 39 | session_embs: torch.Tensor, 40 | item_embs: torch.Tensor, 41 | candidate_item_ids: tp.Optional[torch.Tensor] = None, 42 | ) -> torch.Tensor: 43 | """Forward pass to get logits.""" 44 | raise NotImplementedError() 45 | 46 | def _recommend_u2i( 47 | self, 48 | user_embs: torch.Tensor, 49 | item_embs: torch.Tensor, 50 | user_ids: InternalIdsArray, 51 | k: int, 52 | sorted_item_ids_to_recommend: InternalIdsArray, 53 | ui_csr_for_filter: tp.Optional[sparse.csr_matrix], 54 | ) -> InternalRecoTriplet: 55 | """Recommend to users.""" 56 | raise NotImplementedError() 57 | 58 | 59 | class DistanceSimilarityModule(SimilarityModuleBase): 60 | """Distance similarity module.""" 61 | 62 | dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE] 63 | epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8]) 64 | 65 | def __init__(self, distance: str = "dot") -> None: 66 | super().__init__() 67 | if distance not in self.dist_available: 68 | raise ValueError("`dist` can only be either `dot` or `cosine`.") 69 | 70 | self.distance = Distance(distance) 71 | 72 | def _get_full_catalog_logits(self, session_embs: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: 73 | logits = session_embs @ item_embs.T 74 | return logits 75 | 76 | def _get_pos_neg_logits( 77 | self, session_embs: torch.Tensor, item_embs: torch.Tensor, candidate_item_ids: torch.Tensor 78 | ) -> torch.Tensor: 79 | # [batch_size, session_max_len, len(candidate_item_ids), n_factors] 80 | pos_neg_embs = item_embs[candidate_item_ids] 81 | # [batch_size, session_max_len,len(item_ids)] 82 | logits = (pos_neg_embs @ session_embs.unsqueeze(-1)).squeeze(-1) 83 | return logits 84 | 85 | def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: 86 | embedding_norm = torch.norm(embeddings, p=2, dim=1).unsqueeze(dim=1) 87 | embeddings = embeddings / torch.max(embedding_norm, self.epsilon_cosine_dist.to(embeddings)) 88 | return embeddings 89 | 90 | def forward( 91 | self, 92 | session_embs: torch.Tensor, 93 | item_embs: torch.Tensor, 94 | candidate_item_ids: tp.Optional[torch.Tensor] = None, 95 | ) -> torch.Tensor: 96 | """Forward pass to get logits.""" 97 | if self.distance == Distance.COSINE: 98 | session_embs = self._get_embeddings_norm(session_embs) 99 | item_embs = self._get_embeddings_norm(item_embs) 100 | 101 | if candidate_item_ids is None: 102 | return self._get_full_catalog_logits(session_embs, item_embs) 103 | return self._get_pos_neg_logits(session_embs, item_embs, candidate_item_ids) 104 | 105 | def _recommend_u2i( 106 | self, 107 | user_embs: torch.Tensor, 108 | item_embs: torch.Tensor, 109 | user_ids: InternalIdsArray, 110 | k: int, 111 | sorted_item_ids_to_recommend: InternalIdsArray, 112 | ui_csr_for_filter: tp.Optional[sparse.csr_matrix], 113 | ) -> InternalRecoTriplet: 114 | """Recommend to users.""" 115 | ranker = TorchRanker( 116 | distance=self.distance, 117 | device=item_embs.device, 118 | subjects_factors=user_embs[user_ids], 119 | objects_factors=item_embs, 120 | ) 121 | user_ids_indices, all_reco_ids, all_scores = ranker.rank( 122 | subject_ids=np.arange(len(user_ids)), # n_rec_users 123 | k=k, 124 | filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + n_item_extra_tokens] 125 | sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal 126 | ) 127 | all_user_ids = user_ids[user_ids_indices] 128 | return all_user_ids, all_reco_ids, all_scores 129 | -------------------------------------------------------------------------------- /rectools/models/rank/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=wrong-import-position 16 | 17 | """ 18 | Recommendation models (:mod:`rectools.models.rank`) 19 | ============================================== 20 | 21 | Rankers to build recs from embeddings. 22 | 23 | 24 | Rankers 25 | ------ 26 | `rank.ImplicitRanker` 27 | `rank.TorchRanker` 28 | """ 29 | 30 | try: 31 | from .rank_torch import TorchRanker 32 | except ImportError: # pragma: no cover 33 | from .compat import TorchRanker # type: ignore 34 | 35 | from rectools.models.rank.rank import Distance, Ranker 36 | from rectools.models.rank.rank_implicit import ImplicitRanker 37 | 38 | __all__ = [ 39 | "TorchRanker", 40 | "ImplicitRanker", 41 | "Distance", 42 | "Ranker", 43 | ] 44 | -------------------------------------------------------------------------------- /rectools/models/rank/compat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from rectools.compat import RequirementUnavailable 16 | 17 | 18 | class TorchRanker(RequirementUnavailable): 19 | """Dummy class, which is returned if there are no dependencies required for the model""" 20 | 21 | requirement = "torch" 22 | -------------------------------------------------------------------------------- /rectools/models/rank/rank.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | from enum import Enum 17 | 18 | from scipy import sparse 19 | 20 | from rectools import InternalIds 21 | from rectools.models.base import Scores 22 | from rectools.types import InternalIdsArray 23 | 24 | 25 | class Distance(str, Enum): 26 | """Distance metric""" 27 | 28 | DOT = "dot" # Bigger value means closer vectors 29 | COSINE = "cosine" # Bigger value means closer vectors 30 | EUCLIDEAN = "euclidean" # Smaller value means closer vectors 31 | 32 | 33 | class Ranker(tp.Protocol): 34 | """Protocol for all rankers""" 35 | 36 | def rank( 37 | self, 38 | subject_ids: InternalIds, 39 | k: tp.Optional[int] = None, 40 | filter_pairs_csr: tp.Optional[sparse.csr_matrix] = None, 41 | sorted_object_whitelist: tp.Optional[InternalIdsArray] = None, 42 | ) -> tp.Tuple[InternalIds, InternalIds, Scores]: # pragma: no cover 43 | """Rank objects by corresponding embeddings. 44 | 45 | Parameters 46 | ---------- 47 | subject_ids : InternalIds 48 | Array of ids to recommend for. 49 | k : int, optional, default ``None`` 50 | Derived number of recommendations for every subject id. 51 | Return all recs if None. 52 | filter_pairs_csr : sparse.csr_matrix, optional, default ``None`` 53 | Subject-object interactions that should be filtered from recommendations. 54 | This is relevant for u2i case. 55 | sorted_object_whitelist : sparse.csr_matrix, optional, default ``None`` 56 | Whitelist of object ids. 57 | If given, only these items will be used for recommendations. 58 | Otherwise all items from dataset will be used. 59 | 60 | Returns 61 | ------- 62 | (InternalIds, InternalIds, Scores) 63 | Array of subject ids, array of recommended items, sorted by score descending and array of scores. 64 | """ 65 | -------------------------------------------------------------------------------- /rectools/models/serialization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pickle 16 | import typing as tp 17 | 18 | from pydantic import TypeAdapter 19 | 20 | from rectools.models.base import ModelBase, ModelClass, ModelConfig 21 | from rectools.utils.misc import unflatten_dict 22 | from rectools.utils.serialization import FileLike, read_bytes 23 | 24 | 25 | def load_model(f: FileLike) -> ModelBase: 26 | """ 27 | Load model from file. 28 | 29 | Parameters 30 | ---------- 31 | f : str or Path or file-like object 32 | Path to file or file-like object. 33 | 34 | Returns 35 | ------- 36 | model 37 | Model instance. 38 | """ 39 | data = read_bytes(f) 40 | loaded = pickle.loads(data) 41 | return loaded 42 | 43 | 44 | def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase: 45 | """ 46 | Create model from config. 47 | 48 | Parameters 49 | ---------- 50 | config : dict or ModelConfig 51 | Model config. 52 | 53 | Returns 54 | ------- 55 | model 56 | Model instance. 57 | """ 58 | if isinstance(config, dict): 59 | model_cls = config.get("cls") 60 | model_cls = TypeAdapter(tp.Optional[ModelClass]).validate_python(model_cls) 61 | else: 62 | model_cls = config.cls 63 | 64 | if model_cls is None: 65 | raise ValueError("`cls` must be provided in the config to load the model") 66 | 67 | return model_cls.from_config(config) 68 | 69 | 70 | def model_from_params(params: dict, sep: str = ".") -> ModelBase: 71 | """ 72 | Create model from dict of parameters. 73 | Same as `from_config` but accepts flat dict. 74 | 75 | Parameters 76 | ---------- 77 | params : dict 78 | Model parameters as a flat dict with keys separated by `sep`. 79 | sep : str, default "." 80 | Separator for nested keys. 81 | 82 | Returns 83 | ------- 84 | model 85 | Model instance. 86 | """ 87 | config_dict = unflatten_dict(params, sep=sep) 88 | return model_from_config(config_dict) 89 | -------------------------------------------------------------------------------- /rectools/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Useful functions.""" 16 | 17 | import typing as tp 18 | 19 | import implicit.gpu 20 | import numpy as np 21 | from scipy import sparse 22 | 23 | from rectools.models.base import ScoresArray 24 | from rectools.types import InternalId, InternalIdsArray 25 | from rectools.utils import fast_isin_for_sorted_test_elements 26 | 27 | 28 | def get_viewed_item_ids(user_items: sparse.csr_matrix, user_id: InternalId) -> InternalIdsArray: 29 | """ 30 | Return indices of items that user has interacted with. 31 | 32 | Parameters 33 | ---------- 34 | user_items : csr_matrix 35 | Matrix of interactions. 36 | user_id : int 37 | Internal user id. 38 | 39 | Returns 40 | ------- 41 | np.ndarray 42 | Internal item indices that user has interacted with. 43 | """ 44 | return user_items.indices[user_items.indptr[user_id] : user_items.indptr[user_id + 1]] 45 | 46 | 47 | def recommend_from_scores( 48 | scores: ScoresArray, 49 | k: int, 50 | sorted_blacklist: tp.Optional[InternalIdsArray] = None, 51 | sorted_whitelist: tp.Optional[InternalIdsArray] = None, 52 | ascending: bool = False, 53 | ) -> tp.Tuple[InternalIdsArray, ScoresArray]: 54 | """ 55 | Prepare top-k recommendations for a user. 56 | 57 | Recommendations are sorted by item scores for this particular user. 58 | Recommendations can be filtered according to whitelist and blacklist. 59 | 60 | If `I` - set of all items, `B` - set of blacklist items, `W` - set of whitelist items, then: 61 | - if `W` is ``None``, then for recommendations will be used `I - B` set of items 62 | - if `W` is not ``None``, then for recommendations will be used `W - B` set of items 63 | 64 | Parameters 65 | ---------- 66 | scores : np.ndarray 67 | Array of floats. Scores of relevance of all items for this user. Shape ``(n_items,)``. 68 | k : int 69 | Desired number of final recommendations. 70 | If, after applying white- and blacklist, number of available items `n_available` is less than `k`, 71 | then `n_available` items will be returned without warning. 72 | sorted_blacklist : np.ndarray, optional, default ``None`` 73 | Array of unique ints. Sorted inner item ids to exclude from recommendations. 74 | sorted_whitelist : np.ndarray, optional, default ``None`` 75 | Array of unique ints. Sorted inner item ids to use in recommendations. 76 | ascending : bool, default False 77 | If False, sorting by descending of score, use when score are metric of similarity. 78 | If True, sorting by ascending of score, use when score are distance. 79 | 80 | Returns 81 | ------- 82 | np.ndarray 83 | Array of recommended items, sorted by score descending. 84 | """ 85 | if k <= 0: 86 | raise ValueError("`k` must be positive") 87 | 88 | items_to_recommend = None 89 | 90 | if sorted_blacklist is not None: 91 | if sorted_whitelist is None: 92 | sorted_whitelist = np.arange(scores.size) 93 | items_to_recommend = sorted_whitelist[~fast_isin_for_sorted_test_elements(sorted_whitelist, sorted_blacklist)] 94 | elif sorted_whitelist is not None: 95 | items_to_recommend = sorted_whitelist 96 | 97 | if items_to_recommend is not None: 98 | scores = scores[items_to_recommend] 99 | 100 | if ascending: 101 | scores = -scores 102 | 103 | n_reco = min(k, scores.size) 104 | unsorted_reco_positions = scores.argpartition(-n_reco)[-n_reco:] 105 | unsorted_reco_scores = scores[unsorted_reco_positions] 106 | sorted_reco_positions = unsorted_reco_positions[unsorted_reco_scores.argsort()[::-1]] 107 | 108 | if items_to_recommend is not None: 109 | reco_ids = items_to_recommend[sorted_reco_positions] 110 | else: 111 | reco_ids = sorted_reco_positions 112 | reco_scores = scores[sorted_reco_positions] 113 | 114 | if ascending: 115 | reco_scores = -reco_scores 116 | 117 | return reco_ids, reco_scores 118 | 119 | 120 | def convert_arr_to_implicit_gpu_matrix(arr: np.ndarray) -> tp.Any: 121 | """ 122 | Safely convert numpy array to implicit.gpu.Matrix. 123 | 124 | Parameters 125 | ---------- 126 | arr : np.ndarray 127 | Array to be converted. 128 | 129 | Returns 130 | ------- 131 | np.ndarray 132 | implicit.gpu.Matrix from array. 133 | """ 134 | # We need to explicitly create copy to handle transposed and sliced arrays correctly 135 | # since Matrix is created from a direct copy of the underlying memory block, and `.T` is just a view 136 | return implicit.gpu.Matrix(arr.astype(np.float32).copy()) # pragma: no cover 137 | -------------------------------------------------------------------------------- /rectools/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Tools (:mod:`rectools.tools`) 17 | ============================= 18 | 19 | Various useful instruments to make recommendations better. 20 | 21 | 22 | Tools 23 | ----- 24 | `tools.ItemToItemAnnRecommender` 25 | `tools.UserToItemAnnRecommender` 26 | """ 27 | 28 | try: 29 | from .ann import ItemToItemAnnRecommender, UserToItemAnnRecommender 30 | except ImportError: # pragma: no cover 31 | from ..compat import ItemToItemAnnRecommender, UserToItemAnnRecommender # type: ignore 32 | 33 | __all__ = ( 34 | "ItemToItemAnnRecommender", 35 | "UserToItemAnnRecommender", 36 | ) 37 | -------------------------------------------------------------------------------- /rectools/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | 17 | import numpy as np 18 | 19 | ExternalId = tp.Hashable 20 | ExternalIdsArray = np.ndarray 21 | ExternalIds = tp.Union[tp.Sequence[ExternalId], ExternalIdsArray] 22 | InternalId = int 23 | InternalIdsArray = np.ndarray 24 | InternalIds = tp.Union[tp.Sequence[InternalId], InternalIdsArray] 25 | AnyIdsArray = tp.Union[ExternalIdsArray, InternalIdsArray] 26 | AnyIds = tp.Union[ExternalIds, InternalIds] 27 | AnySequence = tp.Union[tp.Sequence[tp.Any], np.ndarray] 28 | -------------------------------------------------------------------------------- /rectools/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Utils (:mod:`rectools.utils`) 17 | ============================= 18 | 19 | Inner helpers. 20 | 21 | 22 | Tools 23 | ----- 24 | `utils.fast_2d_int_unique` 25 | `utils.fast_isin` 26 | `utils.fast_isin_for_sorted_test_elements` 27 | `utils.get_element_ids` 28 | `utils.get_from_series_by_index` 29 | `utils.pairwise` 30 | `utils.log_at_base` 31 | `utils.is_instance` 32 | `utils.select_by_type` 33 | """ 34 | 35 | from .array_set_ops import ( 36 | fast_2d_2col_int_unique, 37 | fast_2d_int_unique, 38 | fast_isin, 39 | fast_isin_for_sorted_test_elements, 40 | isin_2d_int, 41 | ) 42 | from .indexing import get_element_ids, get_from_series_by_index 43 | from .misc import is_instance, log_at_base, pairwise, select_by_type 44 | 45 | __all__ = ( 46 | "fast_2d_int_unique", 47 | "fast_2d_2col_int_unique", 48 | "fast_isin", 49 | "fast_isin_for_sorted_test_elements", 50 | "isin_2d_int", 51 | "get_element_ids", 52 | "get_from_series_by_index", 53 | "pairwise", 54 | "log_at_base", 55 | "is_instance", 56 | "select_by_type", 57 | ) 58 | -------------------------------------------------------------------------------- /rectools/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pydantic import BaseModel 16 | 17 | 18 | class BaseConfig(BaseModel, extra="forbid"): 19 | """Base config class for rectools.""" 20 | -------------------------------------------------------------------------------- /rectools/utils/indexing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Tuple, Union 16 | 17 | import numpy as np 18 | import pandas as pd 19 | 20 | from rectools import AnySequence 21 | 22 | 23 | def get_element_ids(elements: np.ndarray, test_elements: np.ndarray) -> np.ndarray: 24 | """ 25 | Find index of each element of `elements` in `test_elements`. 26 | 27 | Similar to `np.searchsorted` but works for any arrays (not only sorted). 28 | 29 | All `elements` must be present in `test_elements`. 30 | 31 | Parameters 32 | ---------- 33 | elements : np.ndarray 34 | Elements that indices you want to get. 35 | test_elements : np.ndarray 36 | Array in which you want to find indices. 37 | 38 | Returns 39 | ------- 40 | np.ndarray 41 | Integer array with same shape as `elements`. 42 | 43 | Raises 44 | ------ 45 | ValueError 46 | If there are elements from `elements` which are not in `test_elements`. 47 | 48 | Examples 49 | -------- 50 | >>> get_element_ids(np.array([50, 20, 30]), np.array([10, 30, 40, 50, 60, 20])) 51 | array([3, 5, 1]) 52 | 53 | """ 54 | sort_test_element_ids = np.argsort(test_elements) 55 | sorted_test_elements = test_elements[sort_test_element_ids] 56 | ids_in_sorted_test_elements = np.searchsorted(sorted_test_elements, elements) 57 | try: 58 | ids = sort_test_element_ids[ids_in_sorted_test_elements] 59 | except IndexError: 60 | raise ValueError("All `elements` must be in `test_elements`") 61 | if not (test_elements[ids] == elements).all(): 62 | raise ValueError("All `elements` must be in `test_elements`") 63 | return ids 64 | 65 | 66 | def get_from_series_by_index( 67 | series: pd.Series, ids: AnySequence, strict: bool = True, return_missing: bool = False 68 | ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 69 | """ 70 | Get values from pd.Series by index. 71 | 72 | Analogue to `s[ids]` but it can process cases when ids are not in index. 73 | Processing is possible in 2 different ways: raise an error or skip. 74 | `s[ids]` returns NaN for nonexistent values. 75 | 76 | Parameters 77 | ---------- 78 | series : pd.Series 79 | `pd.Series` from which values are extracted. 80 | ids : sequence(int) 81 | Sequence of indices. 82 | strict : bool, default True 83 | - if True, raise KeyError if at least one element of `ids` not in `s.index`; 84 | - if False, skip nonexistent `ids` and return values only for existent. 85 | return_missing : bool, default False 86 | If True, return a tuple of 2 arrays: values and missing indices. 87 | Works only if `strict` is False. 88 | 89 | Returns 90 | ------- 91 | np.ndarray 92 | Array of values. 93 | np.ndarray, np.ndarray 94 | Tuple of 2 arrays: values and missing indices. 95 | Only if `strict` is False and `return_missing` is True. 96 | 97 | Raises 98 | ------ 99 | KeyError 100 | If `strict` is ``True`` and at least one element of `ids` not in `s.index`. 101 | ValueError 102 | If `strict` and `return_missing` are both ``True``. 103 | 104 | Examples 105 | -------- 106 | >>> s = pd.Series([10, 20, 30, 40, 50], index=[1, 2, 3, 4, 5]) 107 | >>> get_from_series_by_index(s, [3, 1, 4]) 108 | array([30, 10, 40]) 109 | 110 | >>> get_from_series_by_index(s, [3, 7, 4]) 111 | Traceback (most recent call last): 112 | ... 113 | KeyError: 'Some indices do not exist' 114 | 115 | >>> get_from_series_by_index(s, [3, 7, 4], strict=False) 116 | array([30, 40]) 117 | 118 | >>> get_from_series_by_index(s, [3, 7, 4], strict=False, return_missing=True) 119 | (array([30, 40]), array([7])) 120 | """ 121 | if strict and return_missing: 122 | raise ValueError("You can't use `strict` and `return_missing` together") 123 | 124 | r = series.reindex(ids) 125 | if strict: 126 | if r.isna().any(): 127 | raise KeyError("Some indices do not exist") 128 | else: 129 | if return_missing: 130 | missing = r[r.isna()].index.values 131 | r.dropna(inplace=True) 132 | selected = r.astype(series.dtype).values 133 | 134 | if return_missing: 135 | return selected, missing 136 | return selected 137 | -------------------------------------------------------------------------------- /rectools/utils/serialization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | from pathlib import Path 17 | 18 | import numpy as np 19 | import typing_extensions as tpe 20 | from pydantic import BeforeValidator, PlainSerializer 21 | 22 | FileLike = tp.Union[str, Path, tp.IO[bytes]] 23 | 24 | PICKLE_PROTOCOL = 5 25 | 26 | 27 | def _serialize_random_state(rs: tp.Optional[tp.Union[None, int, np.random.RandomState]]) -> tp.Union[None, int]: 28 | if rs is None or isinstance(rs, int): 29 | return rs 30 | 31 | # NOBUG: We can add serialization using get/set_state, but it's not human readable 32 | raise TypeError("`random_state` must be ``None`` or have ``int`` type to convert it to simple type") 33 | 34 | 35 | RandomState = tpe.Annotated[ 36 | tp.Union[None, int, np.random.RandomState], 37 | PlainSerializer(func=_serialize_random_state, when_used="json"), 38 | ] 39 | 40 | DType = tpe.Annotated[ 41 | np.dtype, BeforeValidator(func=np.dtype), PlainSerializer(func=lambda dtp: dtp.name, when_used="json") 42 | ] 43 | 44 | 45 | def read_bytes(f: FileLike) -> bytes: 46 | """Read bytes from a file.""" 47 | if isinstance(f, (str, Path)): 48 | data = Path(f).read_bytes() 49 | else: 50 | data = f.read() 51 | return data 52 | -------------------------------------------------------------------------------- /rectools/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | VERSION = "0.14.0" 16 | -------------------------------------------------------------------------------- /rectools/visuals/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Visualization tools (:mod:`rectools.visuals`) 17 | ======================================================= 18 | 19 | Instruments to visualize recommender models performance 20 | 21 | Recommendations visualization 22 | --------- 23 | `visuals.VisualApp` - Jupyter app for visual comparison of recommendations 24 | `visuals.ItemToItemVisualApp` - Jupyter app for visual comparison of item-to-item recommendations 25 | `visuals.MetricsApp` - Jupyter app for visual metrics comparison 26 | """ 27 | 28 | try: 29 | from .visual_app import ItemToItemVisualApp, VisualApp 30 | except ImportError: # pragma: no cover 31 | from ..compat import ItemToItemVisualApp, VisualApp # type: ignore 32 | 33 | try: 34 | from .metrics_app import MetricsApp 35 | except ImportError: # pragma: no cover 36 | from ..compat import MetricsApp # type: ignore 37 | 38 | __all__ = ("VisualApp", "ItemToItemVisualApp", "MetricsApp") 39 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | console_output_style = progress 3 | testpaths = tests 4 | junit_family = xunit2 5 | doctest_optionflags = DONT_ACCEPT_TRUE_FOR_1 NORMALIZE_WHITESPACE 6 | filterwarnings = 7 | ignore:LightFM was compiled without OpenMP support 8 | ignore:distutils Version classes are deprecated 9 | ignore:Converting sparse features to dense array may cause MemoryError 10 | ignore:OpenBLAS is configured to use 11 | 12 | [coverage:run] 13 | # the name of the data file to use for storing or reporting coverage. 14 | data_file = .coverage.data 15 | 16 | [coverage:report] 17 | # Any line of your source code that matches one of these 18 | # regexes is excluded from being reported as missing. 19 | exclude_lines = 20 | # Have to re-enable the standard pragma 21 | pragma: no cover 22 | 23 | # Don't complain about missing debug-only code: 24 | def __repr__ 25 | def __str__ 26 | 27 | # Don't complain if tests don't hit defensive assertion code: 28 | raise NotImplemented 29 | raise NotImplementedError 30 | @abstractmethod 31 | 32 | # Don't complain if non-runnable code isn't run: 33 | if __name__ == .__main__.: 34 | 35 | # ignore source code that can’t be found, emitting a warning instead of an exception. 36 | ignore_errors = False 37 | 38 | [flake8] 39 | max-complexity = 10 40 | max-line-length = 120 41 | max-doc-length = 120 42 | exclude = .venv 43 | docstring-convention = numpy 44 | ignore = D205,D400,D105,D100,E203,W503 45 | per-file-ignores = 46 | tests/*: D100,D101,D102,D103,D104 47 | rectools/models/nn/dssm.py: D101,D102,N812 48 | rectools/dataset/torch_datasets.py: D101,D102 49 | rectools/models/implicit_als.py: N806 50 | 51 | [mypy] 52 | python_version = 3.9 53 | no_incremental = True 54 | ignore_missing_imports = True 55 | disallow_untyped_defs = True 56 | disallow_incomplete_defs = True 57 | disallow_subclassing_any = False 58 | disallow_any_generics = True 59 | no_implicit_optional = True 60 | warn_redundant_casts = True 61 | warn_unused_ignores = True 62 | warn_unreachable = True 63 | allow_untyped_decorators = True 64 | show_error_codes = True 65 | show_error_context = True 66 | show_column_numbers = True 67 | disable_error_code = type-arg 68 | 69 | [isort] 70 | profile = black 71 | line_length = 120 72 | wrap_length = 120 73 | multi_line_output = 3 74 | indent = 4 75 | force_grid_wrap = false 76 | atomic = True 77 | combine_star = True 78 | verbose = false 79 | include_trailing_comma = True 80 | use_parentheses = True 81 | case_sensitive = True 82 | 83 | [pycodestyle] 84 | max_line_length = 120 85 | 86 | [codespell] 87 | count = 88 | quiet-level = 3 89 | builtin = clear,rare,names,code 90 | check-filenames = 91 | ignore-words-list = als, uint, coo, arange, jupyter 92 | skip = *.ipynb 93 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/dataset/test_identifiers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=attribute-defined-outside-init 16 | 17 | import typing as tp 18 | 19 | import numpy as np 20 | import pandas as pd 21 | import pytest 22 | 23 | from rectools.dataset import IdMap 24 | 25 | 26 | class TestIdMap: 27 | def setup_method(self) -> None: 28 | self.external_ids = np.array(["b", "c", "a"]) 29 | self.id_map = IdMap(self.external_ids) 30 | 31 | def test_creation(self) -> None: 32 | np.testing.assert_equal(self.id_map.external_ids, self.external_ids) 33 | 34 | def test_from_values_creation(self) -> None: 35 | values = ["b", "c", "c", "a"] 36 | id_map = IdMap.from_values(values) 37 | np.testing.assert_equal(id_map.external_ids, self.external_ids) 38 | 39 | def test_from_dict_creation(self) -> None: 40 | existing_mapping: tp.Dict[tp.Hashable, int] = {"a": 2, "b": 0, "c": 1} 41 | id_map = IdMap.from_dict(existing_mapping) 42 | np.testing.assert_equal(id_map.external_ids, self.external_ids) 43 | 44 | @pytest.mark.parametrize("existing_mapping", ({"a": "0", "b": "1"}, {"a": 1, "b": 2}, {"a": 0, "b": 2})) 45 | def test_from_dict_creation_with_incorrect_internal_ids(self, existing_mapping: tp.Dict[tp.Hashable, int]) -> None: 46 | with pytest.raises(ValueError): 47 | IdMap.from_dict(existing_mapping) 48 | 49 | def test_size(self) -> None: 50 | assert self.id_map.size == 3 51 | 52 | @pytest.mark.parametrize("external_ids", (np.array(["a", "b"]), np.array([1, 2]), np.array([1, 2], dtype="O"))) 53 | def test_external_dtype(self, external_ids: np.ndarray) -> None: 54 | id_map = IdMap(external_ids) 55 | assert id_map.external_dtype == external_ids.dtype 56 | 57 | id_map = IdMap.from_values(external_ids) 58 | assert id_map.external_dtype == external_ids.dtype 59 | 60 | def test_to_internal(self) -> None: 61 | actual = self.id_map.to_internal 62 | expected = pd.Series([0, 1, 2], index=self.external_ids) 63 | pd.testing.assert_series_equal(actual, expected) 64 | 65 | def test_to_external(self) -> None: 66 | actual = self.id_map.to_external 67 | expected = pd.Series(self.external_ids, index=pd.RangeIndex(0, 3)) 68 | pd.testing.assert_series_equal(actual, expected, check_index_type=True) 69 | 70 | def test_internal_ids(self) -> None: 71 | actual = self.id_map.internal_ids 72 | expected = np.array([0, 1, 2]) 73 | np.testing.assert_equal(actual, expected) 74 | 75 | def test_get_sorted_inner(self) -> None: 76 | actual = self.id_map.get_sorted_internal() 77 | expected = np.array([0, 1, 2]) 78 | np.testing.assert_equal(actual, expected) 79 | 80 | def test_get_extern_sorted_by_inner(self) -> None: 81 | actual = self.id_map.get_external_sorted_by_internal() 82 | np.testing.assert_equal(actual, self.external_ids) 83 | 84 | def test_convert_to_internal(self) -> None: 85 | with pytest.raises(KeyError): 86 | self.id_map.convert_to_internal(["b", "a", "e", "a"]) 87 | 88 | def test_convert_to_internal_not_strict(self) -> None: 89 | actual = self.id_map.convert_to_internal(["b", "a", "e", "a"], strict=False) 90 | expected = np.array([0, 2, 2]) 91 | np.testing.assert_equal(actual, expected) 92 | 93 | def test_convert_to_internal_with_return_missing(self) -> None: 94 | # pylint: disable=unpacking-non-sequence 95 | values, missing = self.id_map.convert_to_internal(["b", "a", "e", "a"], strict=False, return_missing=True) 96 | np.testing.assert_equal(values, np.array([0, 2, 2])) 97 | np.testing.assert_equal(missing, np.array(["e"])) 98 | 99 | def test_convert_to_external(self) -> None: 100 | with pytest.raises(KeyError): 101 | self.id_map.convert_to_external([0, 2, 4, 2]) 102 | 103 | def test_convert_to_external_not_strict(self) -> None: 104 | actual = self.id_map.convert_to_external([0, 2, 4, 2], strict=False) 105 | expected = np.array(["b", "a", "a"]) 106 | np.testing.assert_equal(actual, expected) 107 | 108 | def test_convert_to_external_with_return_missing(self) -> None: 109 | # pylint: disable=unpacking-non-sequence 110 | values, missing = self.id_map.convert_to_external([0, 2, 4, 2], strict=False, return_missing=True) 111 | np.testing.assert_equal(values, np.array(["b", "a", "a"])) 112 | np.testing.assert_equal(missing, np.array([4])) 113 | 114 | def test_add_ids(self) -> None: 115 | new_id_map = self.id_map.add_ids(["d", "e", "c", "d"]) 116 | actual = new_id_map.external_ids 117 | expected = np.array(["b", "c", "a", "d", "e"]) 118 | np.testing.assert_equal(actual, expected) 119 | 120 | def test_add_ids_with_raising_on_repeating_ids(self) -> None: 121 | with pytest.raises(ValueError): 122 | self.id_map.add_ids(["d", "e", "c", "d"], raise_if_already_present=True) 123 | -------------------------------------------------------------------------------- /tests/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/metrics/test_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | import warnings 17 | 18 | import pandas as pd 19 | import pytest 20 | 21 | from rectools import Columns 22 | from rectools.metrics.base import MetricAtK 23 | 24 | 25 | class SomeMetric(MetricAtK): 26 | def calc( 27 | self, reco: pd.DataFrame, interactions: pd.DataFrame, prev_interactions: pd.DataFrame, ref_reco: pd.DataFrame 28 | ) -> None: 29 | self._check(reco, interactions, prev_interactions, ref_reco) 30 | 31 | 32 | class TestMetricAtK: 33 | @pytest.fixture 34 | def data(self) -> tp.Dict[str, pd.DataFrame]: 35 | reco = pd.DataFrame( 36 | [[10, 100, 1], [10, 200, 2], [20, 200, 1]], 37 | columns=[Columns.User, Columns.Item, Columns.Rank], 38 | ) 39 | interactions = pd.DataFrame( 40 | [[10, 100], [10, 200], [20, 200]], 41 | columns=[Columns.User, Columns.Item], 42 | ) 43 | prev_interactions = pd.DataFrame( 44 | [[10, 100], [10, 200], [20, 200]], 45 | columns=[Columns.User, Columns.Item], 46 | ) 47 | ref_reco = pd.DataFrame( 48 | [[10, 100, 1], [10, 300, 2], [20, 200, 1]], 49 | columns=[Columns.User, Columns.Item, Columns.Rank], 50 | ) 51 | return { 52 | "reco": reco, 53 | "interactions": interactions, 54 | "prev_interactions": prev_interactions, 55 | "ref_reco": ref_reco, 56 | } 57 | 58 | @pytest.mark.parametrize("table", ("reco", "interactions", "prev_interactions", "ref_reco")) 59 | @pytest.mark.parametrize("column", (Columns.User, Columns.Item, Columns.Rank)) 60 | def test_check_columns(self, data: tp.Dict[str, pd.DataFrame], table: str, column: str) -> None: 61 | if column not in data[table]: 62 | return 63 | metric = SomeMetric(1) 64 | data[table].drop(columns=column, inplace=True) 65 | with pytest.raises(KeyError) as e: 66 | metric.calc(**data) 67 | err_text = e.value.args[0] 68 | assert table in err_text.lower() 69 | assert column in err_text.lower() 70 | 71 | @pytest.mark.parametrize("table", ("reco", "ref_reco")) 72 | def test_check_rank_type(self, data: tp.Dict[str, pd.DataFrame], table: str) -> None: 73 | data[table][Columns.Rank] = data[table][Columns.Rank].astype(float) 74 | metric = SomeMetric(1) 75 | with warnings.catch_warnings(record=True) as w: 76 | metric.calc(**data) 77 | assert len(w) == 1 78 | for phrase in (Columns.Rank, table, "dtype", "integer"): 79 | assert phrase in str(w[-1].message) 80 | 81 | @pytest.mark.parametrize("table", ("reco", "ref_reco")) 82 | def test_check_min_rank(self, data: tp.Dict[str, pd.DataFrame], table: str) -> None: 83 | data[table][Columns.Rank] = data[table][Columns.Rank].map({1: 3, 2: 2}) 84 | metric = SomeMetric(1) 85 | with warnings.catch_warnings(record=True) as w: 86 | metric.calc(**data) 87 | assert len(w) == 1 88 | for phrase in (Columns.Rank, table, "min value", "1"): 89 | assert phrase in str(w[-1].message) 90 | -------------------------------------------------------------------------------- /tests/metrics/test_catalog.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=attribute-defined-outside-init 16 | 17 | import numpy as np 18 | import pandas as pd 19 | import pytest 20 | 21 | from rectools import Columns 22 | from rectools.metrics import CatalogCoverage 23 | 24 | 25 | class TestCatalogCoverage: 26 | def setup_method(self) -> None: 27 | self.reco = pd.DataFrame( 28 | { 29 | Columns.User: [1, 1, 1, 2, 2, 3, 4], 30 | Columns.Item: [1, 2, 3, 1, 2, 1, 1], 31 | Columns.Rank: [1, 2, 3, 1, 1, 3, 2], 32 | } 33 | ) 34 | 35 | @pytest.mark.parametrize("normalize,expected", ((True, 0.4), (False, 2.0))) 36 | def test_calc(self, normalize: bool, expected: float) -> None: 37 | catalog = np.arange(5) 38 | metric = CatalogCoverage(k=2, normalize=normalize) 39 | assert metric.calc(self.reco, catalog) == expected 40 | -------------------------------------------------------------------------------- /tests/metrics/test_distances.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | 17 | import numpy as np 18 | import pandas as pd 19 | import pytest 20 | from scipy.sparse import csr_matrix 21 | 22 | from rectools import Columns, ExternalIds 23 | from rectools.dataset import IdMap, SparseFeatures 24 | from rectools.metrics import ( 25 | PairwiseDistanceCalculator, 26 | PairwiseHammingDistanceCalculator, 27 | SparsePairwiseHammingDistanceCalculator, 28 | ) 29 | 30 | 31 | class TestPairwiseHammingDistanceCalculator: 32 | def test_correct_distance_values(self) -> None: 33 | features_df = pd.DataFrame( 34 | [ 35 | ["i1", 0, 0], 36 | ["i2", 0, 1], 37 | ["i3", 1, 1], 38 | ["i4", 0, np.nan], 39 | ], 40 | columns=[Columns.Item, "feature_1", "feature_2"], 41 | ).set_index(Columns.Item) 42 | distance_calculator = PairwiseHammingDistanceCalculator(features_df) 43 | 44 | expected = np.array([0, 1, 2, np.nan, np.nan]) 45 | 46 | with pytest.warns(UserWarning, match="Some items has absent feature values"): 47 | actual = distance_calculator[["i1", "i1", "i1", "i1", "i1"], ["i1", "i2", "i3", "i4", "i5"]] 48 | assert np.array_equal(actual, expected, equal_nan=True) 49 | 50 | 51 | @pytest.mark.filterwarnings("ignore:Some items absent in mapper") 52 | @pytest.mark.filterwarnings("ignore:Some items has absent feature values") 53 | class TestSparsePairwiseHammingDistanceCalculator: 54 | @pytest.mark.parametrize( 55 | "left,right,expected", 56 | ( 57 | # Correct features, mapper, item case 58 | (["i1", "i1", "i1"], ["i1", "i2", "i3"], [0, 1, 2]), 59 | # Features contain and not contain nan case 60 | (["i1", "i1", "i1", "i1"], ["i1", "i2", "i3", "i4"], [0, 1, 2, np.nan]), 61 | # Comparison absence item case 62 | (["i1", "i1", "i1", "i1"], ["i1", "i2", "i3", "i6"], [0, 1, 2, np.nan]), 63 | # Comparison empty items lists case 64 | ([], [], []), 65 | # IndexError case 66 | (["i1"], ["i5"], []), 67 | ), 68 | ) 69 | def test_correct_distance_values(self, left: tp.List[str], right: tp.List[str], expected: tp.List[float]) -> None: 70 | dense_features = [ 71 | [0, 0], # i1 72 | [0, 1], # i2 73 | [1, 1], # i3 74 | [0, np.nan], # i4 75 | ] 76 | mapper = IdMap.from_values(["i1", "i2", "i3", "i4", "i5"]) 77 | sparse_features = SparseFeatures(values=csr_matrix(dense_features), names=(("f1", "v1"), ("f2", "v2"))) 78 | distance_calculator = SparsePairwiseHammingDistanceCalculator(sparse_features, mapper) 79 | if "i5" not in right: 80 | actual = distance_calculator[left, right] 81 | assert np.array_equal(actual, expected, equal_nan=True) 82 | else: 83 | with pytest.raises(IndexError): 84 | distance_calculator[left, right] # pylint: disable=pointless-statement 85 | 86 | 87 | class DummyPairwiseDistanceCalculator(PairwiseDistanceCalculator): 88 | def _get_distances_for_item_pairs(self, items_0: ExternalIds, items_1: ExternalIds) -> np.ndarray: 89 | return np.zeros(len(items_0)) 90 | 91 | 92 | # pylint: disable=expression-not-assigned 93 | class TestPairwiseDistanceCalculatorBase: 94 | def test_raises_when_get_distance_for_not_a_pairs_of_items(self) -> None: 95 | with pytest.raises(IndexError): 96 | DummyPairwiseDistanceCalculator()[["i1"], ["i2"], ["i3"]] # type: ignore 97 | 98 | def test_raises_when_get_distance_for_not_a_sequence_of_items(self) -> None: 99 | with pytest.raises(TypeError): 100 | DummyPairwiseDistanceCalculator()["i1", "i2"] 101 | 102 | def test_raises_when_different_lengths_of_indices_lists_for_item_pairs(self) -> None: 103 | with pytest.raises(IndexError): 104 | DummyPairwiseDistanceCalculator()[["i1", "i2"], ["i3"]] 105 | -------------------------------------------------------------------------------- /tests/metrics/test_diversity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import pytest 18 | 19 | from rectools import Columns 20 | from rectools.metrics import PairwiseHammingDistanceCalculator 21 | from rectools.metrics.diversity import IntraListDiversity 22 | 23 | 24 | class TestIntraListDiversity: 25 | @pytest.fixture 26 | def distance_calculator(self) -> PairwiseHammingDistanceCalculator: 27 | features_df = pd.DataFrame( 28 | [ 29 | ["i1", 0, 0], 30 | ["i2", 0, 1], 31 | ["i3", 1, 1], 32 | ], 33 | columns=[Columns.Item, "feature_1", "feature_2"], 34 | ).set_index(Columns.Item) 35 | return PairwiseHammingDistanceCalculator(features_df) 36 | 37 | @pytest.fixture 38 | def recommendations(self) -> pd.DataFrame: 39 | recommendations = pd.DataFrame( 40 | [["u1", "i1", 1], ["u1", "i2", 2], ["u1", "i3", 3], ["u2", "i1", 1], ["u2", "i4", 2], ["u3", "i1", 1]], 41 | columns=[Columns.User, Columns.Item, Columns.Rank], 42 | ) 43 | return recommendations 44 | 45 | @pytest.mark.parametrize( 46 | "k,expected", 47 | ( 48 | (1, pd.Series(index=["u1", "u2", "u3"], data=[0, 0, 0])), 49 | (2, pd.Series(index=["u1", "u2", "u3"], data=[1, np.nan, 0])), 50 | (3, pd.Series(index=["u1", "u2", "u3"], data=[4 / 3, np.nan, 0])), 51 | ), 52 | ) 53 | @pytest.mark.filterwarnings("ignore:Some items has absent feature values") 54 | def test_correct_ild_values( 55 | self, 56 | distance_calculator: PairwiseHammingDistanceCalculator, 57 | recommendations: pd.DataFrame, 58 | k: int, 59 | expected: pd.Series, 60 | ) -> None: 61 | ild = IntraListDiversity(k, distance_calculator) 62 | 63 | actual = ild.calc_per_user(recommendations) 64 | pd.testing.assert_series_equal(actual, expected, check_names=False) 65 | 66 | actual_mean = ild.calc(recommendations) 67 | assert actual_mean == expected.mean() 68 | -------------------------------------------------------------------------------- /tests/metrics/test_dq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: disable=attribute-defined-outside-init 16 | 17 | import pandas as pd 18 | 19 | from rectools import Columns 20 | from rectools.metrics import CoveredUsers, SufficientReco, UnrepeatedReco 21 | 22 | 23 | class TestSufficientReco: 24 | def setup_method(self) -> None: 25 | self.metric = SufficientReco(k=2, deep=False) 26 | self.deep_metric = SufficientReco(k=2, deep=True) 27 | self.reco = pd.DataFrame( 28 | { 29 | Columns.User: [1, 1, 1, 2, 2, 3, 4], 30 | Columns.Item: [1, 2, 3, 1, 2, 1, 1], 31 | Columns.Rank: [1, 2, 3, 1, 1, 3, 2], 32 | } 33 | ) 34 | 35 | def test_calc_deep(self) -> None: 36 | expected_metric_per_user = pd.Series( 37 | [1.0, 1.0, 0.0, 0.5], 38 | index=pd.Series([1, 2, 3, 4], name=Columns.User), 39 | ) 40 | pd.testing.assert_series_equal(self.deep_metric.calc_per_user(self.reco), expected_metric_per_user) 41 | assert self.deep_metric.calc(self.reco) == expected_metric_per_user.mean() 42 | 43 | def test_calc_default(self) -> None: 44 | expected_metric_per_user = pd.Series( 45 | [1, 1, 0, 0], 46 | index=pd.Series([1, 2, 3, 4], name=Columns.User), 47 | ) 48 | pd.testing.assert_series_equal(self.metric.calc_per_user(self.reco), expected_metric_per_user) 49 | assert self.metric.calc(self.reco) == expected_metric_per_user.mean() 50 | 51 | 52 | class TestUnrepeatedReco: 53 | def setup_method(self) -> None: 54 | self.metric = UnrepeatedReco(k=4, deep=False) 55 | self.deep_metric = UnrepeatedReco(k=4, deep=True) 56 | self.reco = pd.DataFrame( 57 | { 58 | Columns.User: [1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3], 59 | Columns.Item: [1, 2, 1, 1, 3, 4, 1, 2, 2, 1, 5], 60 | Columns.Rank: [1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5], 61 | } 62 | ) 63 | 64 | def test_calc_deep(self) -> None: 65 | expected_metric_per_user = pd.Series( 66 | [1.0, 0.75, 0.5], 67 | index=pd.Series([1, 2, 3], name=Columns.User), 68 | ) 69 | pd.testing.assert_series_equal(self.deep_metric.calc_per_user(self.reco), expected_metric_per_user) 70 | assert self.deep_metric.calc(self.reco) == expected_metric_per_user.mean() 71 | 72 | def test_calc_default(self) -> None: 73 | expected_metric_per_user = pd.Series( 74 | [1, 0, 0], 75 | index=pd.Series([1, 2, 3], name=Columns.User), 76 | ) 77 | pd.testing.assert_series_equal(self.metric.calc_per_user(self.reco), expected_metric_per_user) 78 | assert self.metric.calc(self.reco) == expected_metric_per_user.mean() 79 | 80 | 81 | class TestCoveredUsers: 82 | def setup_method(self) -> None: 83 | self.metric = CoveredUsers(k=4) 84 | self.reco = pd.DataFrame( 85 | { 86 | Columns.User: [1, 1, 2, 2, 2, 3, 3, 3, 3, 3], 87 | Columns.Item: [1, 2, 1, 1, 3, 1, 2, 2, 1, 5], 88 | Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3, 4, 5], 89 | } 90 | ) 91 | self.interactions = pd.DataFrame( 92 | { 93 | Columns.User: [1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 10, 11], 94 | Columns.Item: [1, 2, 1, 1, 3, 1, 2, 2, 1, 5, 5, 5], 95 | } 96 | ) 97 | 98 | def test_calc(self) -> None: 99 | expected_metric_per_user = pd.Series([1, 1, 1, 0, 0], index=pd.Series([1, 2, 3, 10, 11], name=Columns.User)) 100 | pd.testing.assert_series_equal( 101 | self.metric.calc_per_user(self.reco, self.interactions), expected_metric_per_user 102 | ) 103 | assert self.metric.calc(self.reco, self.interactions) == expected_metric_per_user.mean() 104 | -------------------------------------------------------------------------------- /tests/metrics/test_intersection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | import pytest 20 | 21 | from rectools import Columns 22 | from rectools.metrics.intersection import Intersection, calc_intersection_metrics 23 | 24 | 25 | class TestIntersection: 26 | 27 | @pytest.mark.parametrize( 28 | "k,ref_k,expected_users,expected_intersection", 29 | ( 30 | (2, 2, [1, 2, 4, 5], [0.0, 1.0, 1 / 2, 1 / 3]), 31 | (3, None, [1, 2, 4, 5], [1 / 2, 1.0, 1 / 2, 2 / 3]), 32 | (3, 6, [1, 2, 4, 5], [1 / 2, 1.0, 1.0, 1.0]), 33 | ), 34 | ) 35 | def test_calc(self, k: int, ref_k: int, expected_users: List[int], expected_intersection: List[float]) -> None: 36 | reco = pd.DataFrame( 37 | { 38 | Columns.User: [1, 1, 2, 3, 3, 4, 4, 4, 5, 5, 5, 5], 39 | Columns.Item: [1, 2, 1, 1, 2, 1, 2, 3, 1, 2, 3, 4], 40 | Columns.Rank: [3, 1, 1, 7, 5, 2, 1, 8, 1, 2, 2, 9], 41 | } 42 | ) 43 | ref_reco = pd.DataFrame( 44 | { 45 | Columns.User: [1, 1, 2, 3, 3, 4, 4, 4, 5, 5, 5, 5], 46 | Columns.Item: [1, 3, 1, 1, 2, 1, 2, 3, 1, 2, 3, 4], 47 | Columns.Rank: [3, 2, 1, 1, 4, 5, 1, 2, 3, 4, 1, 7], 48 | } 49 | ) 50 | 51 | intersection_metric = Intersection(k=k, ref_k=ref_k) 52 | 53 | metric_per_user = intersection_metric.calc_per_user(reco, ref_reco) 54 | expected_metric_per_user = pd.Series( 55 | expected_intersection, 56 | index=pd.Series(expected_users, name=Columns.User), 57 | dtype=float, 58 | ) 59 | pd.testing.assert_series_equal(metric_per_user, expected_metric_per_user) 60 | 61 | metric = intersection_metric.calc(reco, ref_reco) 62 | assert np.allclose(metric, expected_metric_per_user.mean()) 63 | 64 | def test_when_no_ref_reco(self) -> None: 65 | reco = pd.DataFrame( 66 | { 67 | Columns.User: [1, 1, 1, 2, 2, 3, 4], 68 | Columns.Item: [1, 2, 3, 1, 2, 1, 1], 69 | Columns.Rank: [1, 2, 3, 1, 2, 1, 1], 70 | } 71 | ) 72 | empty_ref_reco = pd.DataFrame(columns=[Columns.User, Columns.Item, Columns.Rank], dtype=int) 73 | 74 | intersection_metric = Intersection(k=2) 75 | 76 | metric_per_user = intersection_metric.calc_per_user(reco, empty_ref_reco) 77 | expected_metric_per_user = pd.Series(index=pd.Series(name=Columns.User, dtype=int), dtype=np.float64) 78 | pd.testing.assert_series_equal(metric_per_user, expected_metric_per_user) 79 | 80 | metric = intersection_metric.calc(reco, empty_ref_reco) 81 | assert np.isnan(metric) 82 | 83 | 84 | class TestCalcIntersectionMetrics: 85 | 86 | @pytest.fixture 87 | def reco(self) -> pd.DataFrame: 88 | return pd.DataFrame( 89 | { 90 | Columns.User: [1, 1, 2], 91 | Columns.Item: [3, 2, 1], 92 | Columns.Rank: [1, 2, 1], 93 | } 94 | ) 95 | 96 | @pytest.fixture 97 | def ref_reco(self) -> pd.DataFrame: 98 | return pd.DataFrame( 99 | { 100 | Columns.User: [1, 1, 2], 101 | Columns.Item: [3, 5, 1], 102 | Columns.Rank: [1, 2, 1], 103 | } 104 | ) 105 | 106 | def test_single_ref_reco(self, reco: pd.DataFrame, ref_reco: pd.DataFrame) -> None: 107 | actual = calc_intersection_metrics( 108 | metrics={"int": Intersection(k=2, ref_k=1)}, 109 | reco=reco, 110 | ref_reco=ref_reco, 111 | ) 112 | assert actual == {"int": 0.75} 113 | 114 | def test_multiple_ref_reco(self, reco: pd.DataFrame, ref_reco: pd.DataFrame) -> None: 115 | actual = calc_intersection_metrics( 116 | metrics={"int": Intersection(k=2, ref_k=1)}, 117 | reco=reco, 118 | ref_reco={"one": ref_reco, "two": ref_reco}, 119 | ) 120 | assert actual == {"int_one": 0.75, "int_two": 0.75} 121 | -------------------------------------------------------------------------------- /tests/metrics/test_novelty.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import pytest 18 | 19 | from rectools import Columns 20 | from rectools.metrics.novelty import MeanInvUserFreq 21 | 22 | 23 | class TestMeanInvUserFreq: 24 | @pytest.fixture 25 | def interactions(self) -> pd.DataFrame: 26 | interactions = pd.DataFrame( 27 | [ 28 | ["u1", "i1"], 29 | ["u1", "i2"], 30 | ["u2", "i1"], 31 | ["u3", "i1"], 32 | ], 33 | columns=[Columns.User, Columns.Item], 34 | ) 35 | return interactions 36 | 37 | @pytest.fixture 38 | def recommendations(self) -> pd.DataFrame: 39 | recommendations = pd.DataFrame( 40 | [ 41 | ["u1", "i3", 1], 42 | ["u2", "i2", 1], 43 | ["u2", "i3", 2], 44 | ["u3", "i1", 1], 45 | ["u3", "i2", 2], 46 | ], 47 | columns=[Columns.User, Columns.Item, Columns.Rank], 48 | ) 49 | return recommendations 50 | 51 | @pytest.mark.parametrize( 52 | "k,expected", 53 | ( 54 | (1, pd.Series(index=["u1", "u2", "u3"], data=[-np.log2(1 / 3), -np.log2(1 / 3), 0])), 55 | (2, pd.Series(index=["u1", "u2", "u3"], data=[-np.log2(1 / 3), -np.log2(1 / 3), -np.log2(1 / 3) / 2])), 56 | ), 57 | ) 58 | def test_correct_miuf_values( 59 | self, recommendations: pd.DataFrame, interactions: pd.DataFrame, k: int, expected: pd.Series 60 | ) -> None: 61 | miuf = MeanInvUserFreq(k) 62 | 63 | actual = miuf.calc_per_user(recommendations, interactions) 64 | pd.testing.assert_series_equal(actual, expected, check_names=False) 65 | 66 | actual_mean = miuf.calc(recommendations, interactions) 67 | assert actual_mean == expected.mean() 68 | -------------------------------------------------------------------------------- /tests/metrics/test_popularity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import pytest 18 | 19 | from rectools import Columns 20 | from rectools.metrics.popularity import AvgRecPopularity 21 | 22 | 23 | class TestAvgRecPopularity: 24 | @pytest.fixture 25 | def interactions(self) -> pd.DataFrame: 26 | interactions = pd.DataFrame( 27 | [["u1", "i1"], ["u1", "i2"], ["u2", "i1"], ["u2", "i3"], ["u3", "i1"], ["u3", "i2"]], 28 | columns=[Columns.User, Columns.Item], 29 | ) 30 | return interactions 31 | 32 | @pytest.fixture 33 | def recommendations(self) -> pd.DataFrame: 34 | recommendations = pd.DataFrame( 35 | [ 36 | ["u1", "i1", 1], 37 | ["u1", "i2", 2], 38 | ["u2", "i3", 1], 39 | ["u2", "i1", 2], 40 | ["u2", "i2", 3], 41 | ["u3", "i3", 1], 42 | ["u3", "i2", 2], 43 | ], 44 | columns=[Columns.User, Columns.Item, Columns.Rank], 45 | ) 46 | return recommendations 47 | 48 | @pytest.mark.parametrize( 49 | "k,normalize,expected", 50 | ( 51 | (1, False, pd.Series(index=["u1", "u2", "u3"], data=[3.0, 1.0, 1.0])), 52 | (3, False, pd.Series(index=["u1", "u2", "u3"], data=[2.5, 2.0, 1.5])), 53 | (1, True, pd.Series(index=["u1", "u2", "u3"], data=[0.5, np.divide(1, 6), np.divide(1, 6)])), 54 | (3, True, pd.Series(index=["u1", "u2", "u3"], data=[np.divide(5, 12), np.divide(1, 3), 0.25])), 55 | ), 56 | ) 57 | def test_correct_arp_values( 58 | self, recommendations: pd.DataFrame, interactions: pd.DataFrame, k: int, normalize: bool, expected: pd.Series 59 | ) -> None: 60 | arp = AvgRecPopularity(k, normalize) 61 | 62 | actual = arp.calc_per_user(recommendations, interactions) 63 | pd.testing.assert_series_equal(actual, expected, check_names=False) 64 | 65 | actual_mean = arp.calc(recommendations, interactions) 66 | assert actual_mean == expected.mean() 67 | 68 | def test_when_no_interactions( 69 | self, 70 | recommendations: pd.DataFrame, 71 | ) -> None: 72 | expected = pd.Series(index=recommendations[Columns.User].unique(), data=[0.0, 0.0, 0.0]) 73 | empty_interactions = pd.DataFrame(columns=[Columns.User, Columns.Item], dtype=int) 74 | arp = AvgRecPopularity(k=2) 75 | 76 | actual = arp.calc_per_user(recommendations, empty_interactions) 77 | pd.testing.assert_series_equal(actual, expected, check_names=False) 78 | 79 | actual_mean = arp.calc(recommendations, empty_interactions) 80 | assert actual_mean == expected.mean() 81 | 82 | @pytest.mark.parametrize( 83 | "k,expected", 84 | ( 85 | (1, pd.Series(index=["u1", "u2", "u3"], data=[3.0, 1.0, 1.0])), 86 | (3, pd.Series(index=["u1", "u2", "u3"], data=[2.5, np.divide(4, 3), 1.5])), 87 | ), 88 | ) 89 | def test_when_new_item_in_reco(self, interactions: pd.DataFrame, k: int, expected: pd.Series) -> None: 90 | reco = pd.DataFrame( 91 | [ 92 | ["u1", "i1", 1], 93 | ["u1", "i2", 2], 94 | ["u2", "i3", 1], 95 | ["u2", "i1", 2], 96 | ["u2", "i4", 3], 97 | ["u3", "i3", 1], 98 | ["u3", "i2", 2], 99 | ], 100 | columns=[Columns.User, Columns.Item, Columns.Rank], 101 | ) 102 | arp = AvgRecPopularity(k) 103 | 104 | actual = arp.calc_per_user(reco, interactions) 105 | pd.testing.assert_series_equal(actual, expected, check_names=False) 106 | 107 | actual_mean = arp.calc(reco, interactions) 108 | assert actual_mean == expected.mean() 109 | -------------------------------------------------------------------------------- /tests/metrics/test_serendipity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | import pytest 17 | 18 | from rectools import Columns 19 | from rectools.metrics.serendipity import Serendipity 20 | 21 | 22 | class TestSerendipityCalculator: 23 | @pytest.fixture 24 | def interactions_train(self) -> pd.DataFrame: 25 | interactions_train = pd.DataFrame( 26 | [ 27 | ["u1", "i1"], 28 | ["u1", "i2"], 29 | ["u2", "i1"], 30 | ["u2", "i2"], 31 | ["u3", "i1"], 32 | ], 33 | columns=[Columns.User, Columns.Item], 34 | ) 35 | return interactions_train 36 | 37 | @pytest.fixture 38 | def interactions_test(self) -> pd.DataFrame: 39 | interactions_test = pd.DataFrame( 40 | [ 41 | ["u1", "i1"], 42 | ["u1", "i2"], 43 | ["u2", "i2"], 44 | ["u2", "i3"], 45 | ["u3", "i2"], 46 | ["u4", "i2"], 47 | ], 48 | columns=[Columns.User, Columns.Item], 49 | ) 50 | return interactions_test 51 | 52 | @pytest.fixture 53 | def recommendations(self) -> pd.DataFrame: 54 | recommendations = pd.DataFrame( 55 | [ 56 | ["u1", "i1", 1], 57 | ["u1", "i2", 2], 58 | ["u2", "i2", 1], 59 | ["u2", "i3", 2], 60 | ["u3", "i3", 1], 61 | ["u4", "i2", 1], 62 | ["u4", "i3", 2], 63 | ], 64 | columns=[Columns.User, Columns.Item, Columns.Rank], 65 | ) 66 | return recommendations 67 | 68 | @pytest.mark.parametrize( 69 | "k,expected", 70 | ( 71 | (1, pd.Series(index=["u1", "u2", "u3", "u4"], data=[0, 0.25, 0, 0.25])), 72 | (2, pd.Series(index=["u1", "u2", "u3", "u4"], data=[0, 0.5, 0, 0.125])), 73 | ), 74 | ) 75 | def test_correct_serendipity_values( 76 | self, 77 | recommendations: pd.DataFrame, 78 | interactions_train: pd.DataFrame, 79 | interactions_test: pd.DataFrame, 80 | k: int, 81 | expected: pd.Series, 82 | ) -> None: 83 | serendipity = Serendipity(k) 84 | 85 | actual = serendipity.calc_per_user( 86 | reco=recommendations, 87 | interactions=interactions_test, 88 | prev_interactions=interactions_train, 89 | catalog=["i1", "i2", "i3", "i4"], 90 | ) 91 | pd.testing.assert_series_equal(actual, expected, check_names=False) 92 | 93 | actual_mean = serendipity.calc( 94 | reco=recommendations, 95 | interactions=interactions_test, 96 | prev_interactions=interactions_train, 97 | catalog=["i1", "i2", "i3", "i4"], 98 | ) 99 | assert actual_mean == expected.mean() 100 | -------------------------------------------------------------------------------- /tests/model_selection/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/model_selection/test_splitter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import pytest 18 | 19 | from rectools import Columns 20 | from rectools.dataset import Interactions 21 | from rectools.model_selection import Splitter 22 | 23 | 24 | class TestSplitter: 25 | @pytest.fixture 26 | def interactions(self) -> Interactions: 27 | df = pd.DataFrame( 28 | [ 29 | [1, 1, 1, "2021-09-01"], 30 | [1, 2, 1, "2021-09-02"], 31 | [2, 1, 1, "2021-09-02"], 32 | [2, 2, 1, "2021-09-03"], 33 | [3, 2, 1, "2021-09-03"], 34 | [3, 3, 1, "2021-09-03"], 35 | [3, 4, 1, "2021-09-04"], 36 | [1, 2, 1, "2021-09-04"], 37 | [3, 1, 1, "2021-09-05"], 38 | [4, 2, 1, "2021-09-05"], 39 | [3, 3, 1, "2021-09-06"], 40 | ], 41 | columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], 42 | ).astype({Columns.Datetime: "datetime64[ns]"}) 43 | return Interactions(df) 44 | 45 | def test_not_implemented(self, interactions: Interactions) -> None: 46 | s = Splitter() 47 | with pytest.raises(NotImplementedError): 48 | for _, _, _ in s.split(interactions): 49 | pass 50 | 51 | @pytest.mark.parametrize("collect_fold_stats", [False, True]) 52 | def test_not_defined_fields(self, interactions: Interactions, collect_fold_stats: bool) -> None: 53 | s = Splitter() 54 | train_idx = np.array([1, 2, 3, 5, 7, 8]) 55 | test_idx = np.array([4, 6, 9, 10]) 56 | fold_info = {"info_from_split": 123} 57 | train_idx_new, test_idx_new, _ = s.filter(interactions, collect_fold_stats, train_idx, test_idx, fold_info) 58 | 59 | assert np.array_equal(train_idx, train_idx_new) 60 | assert sorted(test_idx_new) == [4] 61 | -------------------------------------------------------------------------------- /tests/model_selection/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | 17 | import numpy as np 18 | import pytest 19 | 20 | from rectools.model_selection.utils import get_not_seen_mask 21 | 22 | 23 | class TestGetNotSeenMask: 24 | @pytest.mark.parametrize( 25 | "train_users,train_items,test_users,test_items,expected", 26 | ( 27 | ([], [], [], [], []), 28 | ([1, 2], [10, 20], [], [], []), 29 | ([], [], [1, 2], [10, 20], [True, True]), 30 | ([1, 2, 3, 4, 2, 3], [10, 20, 30, 40, 22, 30], [1, 2, 3, 2], [10, 20, 33, 20], [False, False, True, False]), 31 | ), 32 | ) 33 | def test_correct( 34 | self, 35 | train_users: tp.List[int], 36 | train_items: tp.List[int], 37 | test_users: tp.List[int], 38 | test_items: tp.List[int], 39 | expected: tp.List[bool], 40 | ) -> None: 41 | actual = get_not_seen_mask(*(np.array(a) for a in (train_users, train_items, test_users, test_items))) 42 | np.testing.assert_equal(actual, expected) 43 | 44 | @pytest.mark.parametrize( 45 | "train_users,train_items,test_users,test_items,expected_error_type", 46 | ( 47 | ([1], [10, 20], [1], [10], ValueError), 48 | ([1], [10], [1, 2], [10], ValueError), 49 | ), 50 | ) 51 | def test_with_incorrect_arrays( 52 | self, 53 | train_users: tp.List[int], 54 | train_items: tp.List[int], 55 | test_users: tp.List[int], 56 | test_items: tp.List[int], 57 | expected_error_type: tp.Type[Exception], 58 | ) -> None: 59 | with pytest.raises(expected_error_type): 60 | get_not_seen_mask(*(np.array(a) for a in (train_users, train_items, test_users, test_items))) 61 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/models/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | 17 | from rectools import Columns 18 | from rectools.dataset import Dataset 19 | 20 | INTERACTIONS = pd.DataFrame( 21 | [ 22 | [10, 11], 23 | [10, 12], 24 | [10, 14], 25 | [20, 11], 26 | [20, 12], 27 | [20, 13], 28 | [30, 11], 29 | [30, 12], 30 | [30, 14], 31 | [30, 15], 32 | [40, 11], 33 | [40, 15], 34 | [40, 17], 35 | ], 36 | columns=Columns.UserItem, 37 | ) 38 | INTERACTIONS[Columns.Weight] = 1 39 | INTERACTIONS[Columns.Datetime] = "2021-09-09" 40 | 41 | DATASET = Dataset.construct(INTERACTIONS) 42 | -------------------------------------------------------------------------------- /tests/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/models/nn/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/models/nn/transformers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | from pytorch_lightning import Trainer 17 | from pytorch_lightning.callbacks import ModelCheckpoint 18 | 19 | from rectools import Columns 20 | 21 | 22 | def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series: 23 | rank = ( 24 | interactions.sort_values(Columns.Datetime, ascending=False, kind="stable") 25 | .groupby(Columns.User, sort=False) 26 | .cumcount() 27 | ) 28 | return rank == 0 29 | 30 | 31 | def custom_trainer() -> Trainer: 32 | return Trainer( 33 | max_epochs=3, 34 | min_epochs=3, 35 | deterministic=True, 36 | accelerator="cpu", 37 | enable_checkpointing=False, 38 | devices=1, 39 | ) 40 | 41 | 42 | def custom_trainer_ckpt() -> Trainer: 43 | return Trainer( 44 | max_epochs=3, 45 | min_epochs=3, 46 | deterministic=True, 47 | accelerator="cpu", 48 | devices=1, 49 | callbacks=ModelCheckpoint(filename="last_epoch"), 50 | ) 51 | 52 | 53 | def custom_trainer_multiple_ckpt() -> Trainer: 54 | return Trainer( 55 | max_epochs=3, 56 | min_epochs=3, 57 | deterministic=True, 58 | accelerator="cpu", 59 | devices=1, 60 | callbacks=ModelCheckpoint( 61 | monitor="train_loss", 62 | save_top_k=3, 63 | every_n_epochs=1, 64 | filename="{epoch}", 65 | ), 66 | ) 67 | -------------------------------------------------------------------------------- /tests/models/rank/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/models/rank/test_rank_implicit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | 17 | import implicit.cpu 18 | import numpy as np 19 | import pytest 20 | from scipy import sparse 21 | 22 | from rectools.models.rank import Distance, ImplicitRanker 23 | 24 | T = tp.TypeVar("T") 25 | 26 | pytestmark = pytest.mark.filterwarnings("ignore:invalid value encountered in true_divide") 27 | 28 | 29 | class TestImplicitRanker: # pylint: disable=protected-access 30 | @pytest.fixture 31 | def subject_factors(self) -> np.ndarray: 32 | return np.array([[-4, 0, 3], [0, 1, 2]]) 33 | 34 | @pytest.fixture 35 | def object_factors(self) -> np.ndarray: 36 | return np.array( 37 | [ 38 | [-4, 0, 3], 39 | [0, 2, 4], 40 | [1, 10, 100], 41 | ] 42 | ) 43 | 44 | @pytest.mark.parametrize( 45 | "dense", 46 | ( 47 | (True), 48 | (False), 49 | ), 50 | ) 51 | def test_neginf_score( 52 | self, 53 | subject_factors: np.ndarray, 54 | object_factors: np.ndarray, 55 | dense: bool, 56 | ) -> None: 57 | if not dense: 58 | subject_factors = sparse.csr_matrix(subject_factors) 59 | implicit_ranker = ImplicitRanker( 60 | Distance.DOT, 61 | subjects_factors=subject_factors, 62 | objects_factors=object_factors, 63 | ) 64 | dummy_factors: np.ndarray = np.array([[1, 2]], dtype=np.float32) 65 | neginf = implicit.cpu.topk.topk( # pylint: disable=c-extension-no-member 66 | items=dummy_factors, 67 | query=dummy_factors, 68 | k=1, 69 | filter_items=np.array([0]), 70 | )[1][0][0] 71 | assert neginf <= implicit_ranker._get_neginf_score() <= -1e38 72 | 73 | @pytest.mark.parametrize( 74 | "dense", 75 | ( 76 | (True), 77 | (False), 78 | ), 79 | ) 80 | def test_mask_for_correct_scores( 81 | self, subject_factors: np.ndarray, object_factors: np.ndarray, dense: bool 82 | ) -> None: 83 | if not dense: 84 | subject_factors = sparse.csr_matrix(subject_factors) 85 | 86 | implicit_ranker = ImplicitRanker( 87 | Distance.DOT, 88 | subjects_factors=subject_factors, 89 | objects_factors=object_factors, 90 | ) 91 | neginf = implicit_ranker._get_neginf_score() 92 | scores: np.ndarray = np.array([7, 6, 0, 0], dtype=np.float32) 93 | 94 | actual = implicit_ranker._get_mask_for_correct_scores(scores) 95 | assert actual == [True] * 4 96 | 97 | actual = implicit_ranker._get_mask_for_correct_scores(np.append(scores, [neginf] * 2)) 98 | assert actual == [True] * 4 + [False] * 2 99 | 100 | actual = implicit_ranker._get_mask_for_correct_scores(np.append(scores, [neginf * 0.99] * 2)) 101 | assert actual == [True] * 6 102 | 103 | actual = implicit_ranker._get_mask_for_correct_scores(np.insert(scores, 0, neginf)) 104 | assert actual == [True] * 5 105 | 106 | @pytest.mark.parametrize("distance", (Distance.COSINE, Distance.EUCLIDEAN)) 107 | def test_raises( 108 | self, 109 | subject_factors: np.ndarray, 110 | object_factors: np.ndarray, 111 | distance: Distance, 112 | ) -> None: 113 | subject_factors = sparse.csr_matrix(subject_factors) 114 | with pytest.raises(ValueError): 115 | ImplicitRanker( 116 | distance=distance, 117 | subjects_factors=subject_factors, 118 | objects_factors=object_factors, 119 | ) 120 | -------------------------------------------------------------------------------- /tests/models/rank/test_rank_torch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | from itertools import product 17 | 18 | import numpy as np 19 | import pytest 20 | import torch 21 | from scipy import sparse 22 | 23 | from rectools.models.rank import Distance, Ranker, TorchRanker 24 | 25 | T = tp.TypeVar("T") 26 | EPS_DIGITS = 5 27 | pytestmark = pytest.mark.filterwarnings("ignore:invalid value encountered in true_divide") 28 | 29 | 30 | def gen_rankers() -> tp.List[tp.Tuple[tp.Any, tp.Dict[str, tp.Any]]]: 31 | keys = ["device", "batch_size"] 32 | vals = list( 33 | product( 34 | ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"], 35 | [128, 1], 36 | ) 37 | ) 38 | torch_ranker_args = [(TorchRanker, dict(zip(keys, v))) for v in vals] 39 | 40 | return torch_ranker_args 41 | 42 | 43 | class TestTorchRanker: # pylint: disable=protected-access 44 | @pytest.fixture 45 | def subject_factors(self) -> torch.Tensor: 46 | return torch.from_numpy(np.array([[-4, 0, 3], [0, 1, 2]])) 47 | 48 | @pytest.fixture 49 | def object_factors(self) -> torch.Tensor: 50 | return torch.from_numpy( 51 | np.array( 52 | [ 53 | [-4, 0, 3], 54 | [0, 2, 4], 55 | [1, 10, 100], 56 | ] 57 | ) 58 | ) 59 | 60 | @pytest.mark.parametrize( 61 | "distance, expected_recs, expected_scores, dense", 62 | ( 63 | ( 64 | Distance.DOT, 65 | [2, 0, 1, 2, 1, 0], 66 | [296, 25, 12, 210, 10, 6], 67 | True, 68 | ), 69 | ( 70 | Distance.COSINE, 71 | [0, 2, 1, 1, 2, 0], 72 | [1, 0.5890328, 0.5366563, 1, 0.9344414, 0.5366563], 73 | True, 74 | ), 75 | ( 76 | Distance.EUCLIDEAN, 77 | [0, 1, 2, 1, 0, 2], 78 | [0, 4.58257569, 97.64220399, 2.23606798, 4.24264069, 98.41747812], 79 | True, 80 | ), 81 | ( 82 | Distance.DOT, 83 | [2, 0, 1, 2, 1, 0], 84 | [296, 25, 12, 210, 10, 6], 85 | False, 86 | ), 87 | ), 88 | ) 89 | @pytest.mark.parametrize("ranker_cls, ranker_args", gen_rankers()) 90 | def test_rank( 91 | self, 92 | ranker_cls: tp.Type[TorchRanker], 93 | ranker_args: tp.Dict[str, tp.Any], 94 | distance: Distance, 95 | expected_recs: tp.List[int], 96 | expected_scores: tp.List[float], 97 | subject_factors: np.ndarray, 98 | object_factors: np.ndarray, 99 | dense: bool, 100 | ) -> None: 101 | if not dense: 102 | subject_factors = sparse.csr_matrix(subject_factors) 103 | 104 | ranker: Ranker = ranker_cls( 105 | **ranker_args, 106 | distance=distance, 107 | subjects_factors=subject_factors, 108 | objects_factors=object_factors, 109 | ) 110 | 111 | _, actual_recs, actual_scores = ranker.rank( 112 | subject_ids=[0, 1], 113 | k=3, 114 | ) 115 | 116 | np.testing.assert_equal(actual_recs, expected_recs) 117 | np.testing.assert_almost_equal( 118 | actual_scores, 119 | expected_scores, 120 | decimal=EPS_DIGITS, 121 | ) 122 | -------------------------------------------------------------------------------- /tests/models/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | 17 | import numpy as np 18 | import pytest 19 | from scipy import sparse 20 | 21 | from rectools.models.utils import get_viewed_item_ids, recommend_from_scores 22 | 23 | _ui = [ 24 | [0, 1, 3], 25 | [1, 0, 1], 26 | [0, 0, 0], 27 | [3, 1, 2], 28 | ] 29 | 30 | 31 | @pytest.mark.parametrize( 32 | "user_items,user_id,expected", 33 | ( 34 | (_ui, 0, [1, 2]), 35 | (_ui, 1, [0, 2]), 36 | (_ui, 2, []), 37 | (_ui, 3, [0, 1, 2]), 38 | ), 39 | ) 40 | def test_get_viewed_item_ids(user_items: tp.List[tp.List[int]], user_id: int, expected: tp.List[int]) -> None: 41 | actual = get_viewed_item_ids(sparse.csr_matrix(user_items), user_id) 42 | np.testing.assert_equal(expected, actual) 43 | 44 | 45 | class TestRecommendFromScores: 46 | @pytest.mark.parametrize( 47 | "blacklist,whitelist,all_expected_ids", 48 | ( 49 | (None, None, np.array([6, 0, 2, 4, 1, 3, 5])), 50 | (np.array([0, 1, 5, 6]), None, np.array([2, 4, 3])), 51 | (None, np.array([0, 2, 5, 6]), np.array([6, 0, 2, 5])), 52 | (np.array([0, 1, 5, 6]), np.array([0, 2, 5, 6]), np.array([2])), 53 | (np.array([0, 1, 2, 3]), np.array([1, 2, 3]), np.array([], dtype=int)), 54 | ), 55 | ) 56 | @pytest.mark.parametrize("ascending", (True, False)) 57 | def test_valid_cases( 58 | self, blacklist: np.ndarray, whitelist: np.ndarray, all_expected_ids: np.ndarray, ascending: bool 59 | ) -> None: 60 | if ascending: 61 | all_expected_ids = all_expected_ids[::-1] 62 | expected_ids = all_expected_ids[:5] 63 | input_scores = np.array([10.5, 2, 7, 0, 5, -3, 100]) 64 | actual_ids, actual_scores = recommend_from_scores(input_scores, 5, blacklist, whitelist, ascending) 65 | np.testing.assert_equal(actual_ids, expected_ids) 66 | expected_scores = input_scores[expected_ids] 67 | np.testing.assert_equal(actual_scores, expected_scores) 68 | 69 | @pytest.mark.parametrize("k", (-5, 0)) 70 | def test_raises_when_k_is_not_positive(self, k: int) -> None: 71 | with pytest.raises(ValueError): 72 | recommend_from_scores(np.array([1, 2, 3]), k=k) 73 | -------------------------------------------------------------------------------- /tests/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | from copy import deepcopy 17 | from tempfile import NamedTemporaryFile 18 | 19 | import numpy as np 20 | import pandas as pd 21 | 22 | from rectools.dataset import Dataset 23 | from rectools.models.base import ModelBase 24 | from rectools.models.serialization import load_model 25 | 26 | 27 | def _dummy_func() -> None: 28 | pass 29 | 30 | 31 | def assert_second_fit_refits_model( 32 | model: ModelBase, dataset: Dataset, pre_fit_callback: tp.Optional[tp.Callable[[], None]] = None 33 | ) -> None: 34 | pre_fit_callback = pre_fit_callback or _dummy_func 35 | 36 | pre_fit_callback() 37 | model_1 = deepcopy(model) 38 | pre_fit_callback() 39 | model_1.fit(dataset) 40 | 41 | pre_fit_callback() 42 | model_2 = deepcopy(model) 43 | pre_fit_callback() 44 | model_2.fit(dataset) 45 | pre_fit_callback() 46 | model_2.fit(dataset) 47 | 48 | k = dataset.item_id_map.external_ids.size 49 | 50 | reco_u2i_1 = model_1.recommend(dataset.user_id_map.external_ids, dataset, k, False) 51 | reco_u2i_2 = model_2.recommend(dataset.user_id_map.external_ids, dataset, k, False) 52 | pd.testing.assert_frame_equal(reco_u2i_1, reco_u2i_2, atol=0.001) 53 | 54 | reco_i2i_1 = model_1.recommend_to_items(dataset.item_id_map.external_ids, dataset, k, False) 55 | reco_i2i_2 = model_2.recommend_to_items(dataset.item_id_map.external_ids, dataset, k, False) 56 | pd.testing.assert_frame_equal(reco_i2i_1, reco_i2i_2, atol=0.001) 57 | 58 | 59 | def assert_dumps_loads_do_not_change_model( 60 | model: ModelBase, 61 | dataset: Dataset, 62 | check_configs: bool = True, 63 | ) -> None: 64 | def get_reco(model: ModelBase) -> pd.DataFrame: 65 | users = dataset.user_id_map.external_ids[:2] 66 | return model.recommend(users=users, dataset=dataset, k=2, filter_viewed=False) 67 | 68 | dumped = model.dumps() 69 | recovered_model = model.__class__.loads(dumped) 70 | 71 | original_model_reco = get_reco(model) 72 | recovered_model_reco = get_reco(recovered_model) 73 | pd.testing.assert_frame_equal(recovered_model_reco, original_model_reco) 74 | 75 | if check_configs: 76 | original_model_config = model.get_config() 77 | recovered_model_config = recovered_model.get_config() 78 | assert recovered_model_config == original_model_config 79 | 80 | 81 | def assert_save_load_do_not_change_model( 82 | model: ModelBase, 83 | dataset: Dataset, 84 | check_configs: bool = True, 85 | ) -> None: 86 | 87 | def get_reco(model: ModelBase) -> pd.DataFrame: 88 | users = dataset.user_id_map.external_ids[:2] 89 | return model.recommend(users=users, dataset=dataset, k=2, filter_viewed=False) 90 | 91 | with NamedTemporaryFile() as f: 92 | model.save(f.name) 93 | recovered_model = load_model(f.name) 94 | 95 | assert isinstance(recovered_model, model.__class__) 96 | 97 | original_model_reco = get_reco(model) 98 | recovered_model_reco = get_reco(recovered_model) 99 | pd.testing.assert_frame_equal(recovered_model_reco, original_model_reco) 100 | 101 | if check_configs: 102 | original_model_config = model.get_config() 103 | recovered_model_config = recovered_model.get_config() 104 | assert recovered_model_config == original_model_config 105 | 106 | 107 | def assert_default_config_and_default_model_params_are_the_same( 108 | model: ModelBase, default_config: tp.Dict[str, tp.Any] 109 | ) -> None: 110 | model_from_config = model.from_config(default_config) 111 | assert model_from_config.get_config() == model.get_config() 112 | 113 | 114 | def assert_get_config_and_from_config_compatibility( 115 | model: tp.Type[ModelBase], dataset: Dataset, initial_config: tp.Dict[str, tp.Any], simple_types: bool 116 | ) -> None: 117 | def get_reco(model: ModelBase) -> pd.DataFrame: 118 | return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) 119 | 120 | model_1 = model.from_config(initial_config) 121 | reco_1 = get_reco(model_1) 122 | config_1 = model_1.get_config(simple_types=simple_types) 123 | 124 | model_2 = model.from_config(config_1) 125 | reco_2 = get_reco(model_2) 126 | config_2 = model_2.get_config(simple_types=simple_types) 127 | 128 | assert config_1 == config_2 129 | pd.testing.assert_frame_equal(reco_1, reco_2) 130 | 131 | 132 | def get_successors(cls: tp.Type) -> tp.List[tp.Type]: 133 | successors = [] 134 | subclasses = cls.__subclasses__() 135 | for subclass in subclasses: 136 | successors.append(subclass) 137 | successors.extend(get_successors(subclass)) 138 | return successors 139 | -------------------------------------------------------------------------------- /tests/test_compat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | 17 | import pytest 18 | 19 | from rectools.compat import ( 20 | BERT4RecModel, 21 | DSSMModel, 22 | ItemToItemAnnRecommender, 23 | ItemToItemVisualApp, 24 | LightFMWrapperModel, 25 | MetricsApp, 26 | SASRecModel, 27 | UserToItemAnnRecommender, 28 | VisualApp, 29 | ) 30 | from rectools.models.rank.compat import TorchRanker 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "model", 35 | ( 36 | DSSMModel, 37 | SASRecModel, 38 | BERT4RecModel, 39 | ItemToItemAnnRecommender, 40 | UserToItemAnnRecommender, 41 | LightFMWrapperModel, 42 | VisualApp, 43 | ItemToItemVisualApp, 44 | MetricsApp, 45 | TorchRanker, 46 | ), 47 | ) 48 | def test_raise_when_model_not_available( 49 | model: tp.Any, 50 | ) -> None: 51 | with pytest.raises(ImportError): 52 | model() 53 | -------------------------------------------------------------------------------- /tests/testing_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing as tp 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from scipy import sparse 20 | 21 | from rectools.dataset import DenseFeatures, Features, IdMap, Interactions, SparseFeatures 22 | 23 | 24 | def assert_sparse_matrix_equal(actual: sparse.spmatrix, expected: sparse.spmatrix) -> None: 25 | assert isinstance(actual, type(expected)) 26 | np.testing.assert_equal(actual.toarray(), expected.toarray()) 27 | 28 | 29 | def assert_id_map_equal(actual: IdMap, expected: IdMap) -> None: 30 | assert isinstance(actual, type(expected)) 31 | pd.testing.assert_series_equal(actual.to_internal, expected.to_internal) 32 | 33 | 34 | def assert_interactions_set_equal(actual: Interactions, expected: Interactions) -> None: 35 | assert isinstance(actual, type(expected)) 36 | pd.testing.assert_frame_equal(actual.df, expected.df) 37 | 38 | 39 | def assert_feature_set_equal(actual: tp.Optional[Features], expected: tp.Optional[Features]) -> None: 40 | if actual is None and expected is None: 41 | return 42 | 43 | assert isinstance(actual, type(expected)) 44 | 45 | if isinstance(actual, DenseFeatures) and isinstance(expected, DenseFeatures): 46 | np.testing.assert_equal(actual.values, expected.values) 47 | assert actual.names == expected.names 48 | 49 | if isinstance(actual, SparseFeatures) and isinstance(expected, SparseFeatures): 50 | assert_sparse_matrix_equal(actual.values, expected.values) 51 | assert actual.names == expected.names 52 | -------------------------------------------------------------------------------- /tests/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/utils/test_indexing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | import pytest 20 | 21 | from rectools.utils import get_element_ids, get_from_series_by_index 22 | 23 | 24 | class TestGetElementIds: 25 | def test_when_elements_present(self) -> None: 26 | actual = get_element_ids(np.array([2, 5, 3, 8]), np.array([2, 3, 4, 8, 1, 5])) 27 | np.testing.assert_equal(actual, np.array([0, 5, 1, 3])) 28 | 29 | def test_raises_when_elements_not_present(self) -> None: 30 | with pytest.raises(ValueError): 31 | get_element_ids(np.array([2, 5, 3, 8]), np.array([3, 4, 8, 1, 5])) 32 | 33 | def test_when_elements_empty(self) -> None: 34 | actual = get_element_ids(np.array([]), np.array([2, 3, 4, 8, 1, 5])) 35 | np.testing.assert_equal(actual, np.array([])) 36 | 37 | def test_raises_when_test_elements_empty(self) -> None: 38 | with pytest.raises(ValueError): 39 | get_element_ids(np.array([2, 5, 3, 8]), np.array([])) 40 | 41 | 42 | @pytest.mark.parametrize("index_type", ("int64", "str")) 43 | @pytest.mark.parametrize("value_type", ("int64", "str")) 44 | class TestGetFromSeriesByIndex: 45 | def test_normal(self, index_type: str, value_type: str) -> None: 46 | s = pd.Series([40, 20, 40, 10, 30], index=np.array([4, 2, 1, 3, 0], dtype=index_type), dtype=value_type) 47 | ids = np.array([1, 3, 4], dtype=index_type) 48 | actual = get_from_series_by_index(s, ids) 49 | expected = np.array([40, 10, 40], dtype=value_type) 50 | np.testing.assert_equal(actual, expected) 51 | 52 | @pytest.mark.parametrize("s_index, s_values", (([4, 2], [40, 20]), ([], []))) 53 | def test_raises_when_unknown_object( 54 | self, index_type: str, value_type: str, s_index: List[int], s_values: List[int] 55 | ) -> None: 56 | s = pd.Series(s_values, index=np.array(s_index, dtype=index_type), dtype=value_type) 57 | ids = np.array([1, 2, 4], dtype=index_type) 58 | with pytest.raises(KeyError): 59 | get_from_series_by_index(s, ids) 60 | 61 | def test_selects_known_objects(self, index_type: str, value_type: str) -> None: 62 | s = pd.Series([40, 20], index=np.array([4, 2], dtype=index_type), dtype=value_type) 63 | ids = np.array([2, 4, 1], dtype=index_type) 64 | actual = get_from_series_by_index(s, ids, strict=False) 65 | expected = np.array([20, 40], dtype=value_type) 66 | np.testing.assert_equal(actual, expected) 67 | 68 | def test_with_return_missing(self, index_type: str, value_type: str) -> None: 69 | s = pd.Series([40, 20], index=np.array([4, 2], dtype=index_type), dtype=value_type) 70 | ids = np.array([2, 4, 1], dtype=index_type) 71 | values, missing = get_from_series_by_index(s, ids, strict=False, return_missing=True) 72 | expected_values = np.array([20, 40], dtype=value_type) 73 | np.testing.assert_equal(values, expected_values) 74 | expected_missing = np.array([1], dtype=index_type) 75 | np.testing.assert_equal(missing, expected_missing) 76 | 77 | def test_raises_when_return_missing_and_strict(self, index_type: str, value_type: str) -> None: 78 | s = pd.Series([40, 20], index=np.array([4, 2], dtype=index_type), dtype=value_type) 79 | ids = np.array([2, 4, 1], dtype=index_type) 80 | with pytest.raises(ValueError): 81 | get_from_series_by_index(s, ids, return_missing=True) 82 | 83 | @pytest.mark.parametrize("s_index, s_values", (([4, 2], [40, 20]), ([], []))) 84 | def test_with_empty_ids(self, index_type: str, value_type: str, s_index: List[int], s_values: List[int]) -> None: 85 | s = pd.Series(s_values, index=np.array(s_index, dtype=index_type), dtype=value_type) 86 | ids = np.array([], dtype=index_type) 87 | actual = get_from_series_by_index(s, ids) 88 | expected = np.array([], dtype=value_type) 89 | np.testing.assert_equal(actual, expected) 90 | -------------------------------------------------------------------------------- /tests/utils/test_misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from rectools.utils.misc import unflatten_dict 16 | 17 | 18 | class TestUnflattenDict: 19 | def test_empty(self) -> None: 20 | assert unflatten_dict({}) == {} 21 | 22 | def test_complex(self) -> None: 23 | flattened = { 24 | "a.b": 1, 25 | "a.c": 2, 26 | "d": 3, 27 | "a.e.f": [10, 20], 28 | } 29 | excepted = { 30 | "a": {"b": 1, "c": 2, "e": {"f": [10, 20]}}, 31 | "d": 3, 32 | } 33 | assert unflatten_dict(flattened) == excepted 34 | 35 | def test_simple(self) -> None: 36 | flattened = { 37 | "a": 1, 38 | "b": 2, 39 | } 40 | excepted = { 41 | "a": 1, 42 | "b": 2, 43 | } 44 | assert unflatten_dict(flattened) == excepted 45 | 46 | def test_non_default_sep(self) -> None: 47 | flattened = { 48 | "a_b": 1, 49 | "a_c": 2, 50 | "d": 3, 51 | } 52 | excepted = { 53 | "a": {"b": 1, "c": 2}, 54 | "d": 3, 55 | } 56 | assert unflatten_dict(flattened, sep="_") == excepted 57 | -------------------------------------------------------------------------------- /tests/visuals/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MTS (Mobile Telesystems) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | --------------------------------------------------------------------------------