├── tests ├── __init__.py ├── loss │ ├── __init__.py │ └── test_pairwise_additive.py ├── utils │ ├── __init__.py │ ├── test_tensor_operations.py │ ├── test_file.py │ └── test_downloader.py ├── datasets │ ├── __init__.py │ ├── svmrank │ │ ├── __init__.py │ │ ├── test_istella.py │ │ ├── test_example3.py │ │ ├── test_istella_x.py │ │ ├── test_istella_s.py │ │ ├── test_mslr10k.py │ │ ├── test_mslr30k.py │ │ └── test_svmrank.py │ └── test_list_sampler.py ├── evaluation │ ├── __init__.py │ ├── test_arp.py │ ├── test_trec.py │ └── test_dcg.py ├── click_simulation │ ├── __init__.py │ └── test_pbm.py └── test_integration.py ├── pytorchltr ├── click_simulation │ ├── __init__.py │ └── pbm.py ├── datasets │ ├── svmrank │ │ ├── __init__.py │ │ ├── parser │ │ │ ├── .gitignore │ │ │ ├── __init__.py │ │ │ └── svmrank_parser.pyx │ │ ├── istella.py │ │ ├── istella_x.py │ │ ├── example3.py │ │ ├── istella_s.py │ │ ├── mslr30k.py │ │ ├── mslr10k.py │ │ └── svmrank.py │ ├── __init__.py │ └── list_sampler.py ├── __init__.py ├── evaluation │ ├── __init__.py │ ├── arp.py │ ├── trec.py │ └── dcg.py ├── utils │ ├── __init__.py │ ├── progress.py │ ├── tensor_operations.py │ ├── file.py │ └── downloader.py └── loss │ ├── __init__.py │ ├── pairwise_additive.py │ └── pairwise_lambda.py ├── requirements.in ├── pyproject.toml ├── codecov.yml ├── MANIFEST.in ├── dev-requirements.in ├── docs ├── source │ ├── index.rst │ ├── evaluation.rst │ ├── loss.rst │ ├── datasets.rst │ ├── conf.py │ ├── references.bib │ └── getting-started.rst └── Makefile ├── .readthedocs.yml ├── requirements.txt ├── LICENSE.md ├── .github └── workflows │ ├── python-test.yml │ └── python-publish.yml ├── .gitignore ├── setup.py ├── dev-requirements.txt ├── README.md └── examples └── 01-basic-usage.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/click_simulation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/datasets/svmrank/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorchltr/click_simulation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorchltr/datasets/svmrank/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | scipy 4 | scikit-learn -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "numpy"] 3 | 4 | -------------------------------------------------------------------------------- /pytorchltr/datasets/svmrank/parser/.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by cython 2 | svmrank_parser.c -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | disable_default_path_fixes: true 3 | fixes: 4 | - "::pytorchltr/" -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include pyproject.toml 2 | include pytorchltr/datasets/svmrank/parser/svmrank_parser.h 3 | -------------------------------------------------------------------------------- /pytorchltr/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 3 | .. include::../README.md 4 | :start-line: 1 5 | 6 | """ 7 | -------------------------------------------------------------------------------- /dev-requirements.in: -------------------------------------------------------------------------------- 1 | pytest 2 | sphinx 3 | pytest-cov 4 | flake8 5 | sphinx-rtd-theme 6 | sphinxcontrib-bibtex 7 | sphinx-autodoc-typehints 8 | cython 9 | -------------------------------------------------------------------------------- /pytorchltr/datasets/svmrank/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorchltr.datasets.svmrank.parser.svmrank_parser import \ 2 | parse_svmrank_file # noqa: F401 3 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | PyTorchLTR 2 | ========== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Contents: 7 | 8 | getting-started 9 | datasets 10 | loss 11 | evaluation 12 | -------------------------------------------------------------------------------- /pytorchltr/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorchltr.evaluation.arp import arp # noqa: F401 2 | from pytorchltr.evaluation.dcg import ndcg # noqa: F401 3 | from pytorchltr.evaluation.dcg import dcg # noqa: F401 4 | from pytorchltr.evaluation.trec import generate_pytrec_eval # noqa: F401 5 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | builder: html 5 | configuration: docs/source/conf.py 6 | 7 | python: 8 | version: 3.7 9 | install: 10 | - requirements: requirements.txt 11 | - requirements: dev-requirements.txt 12 | - method: setuptools 13 | path: . 14 | -------------------------------------------------------------------------------- /pytorchltr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorchltr.utils.tensor_operations import mask_padded_values # noqa: F401 2 | from pytorchltr.utils.tensor_operations import tiebreak_argsort # noqa: F401 3 | from pytorchltr.utils.tensor_operations import rank_by_score # noqa: F401 4 | from pytorchltr.utils.tensor_operations import batch_pairs # noqa: F401 5 | from pytorchltr.utils.tensor_operations import rank_by_plackettluce # noqa: F401,E501 6 | -------------------------------------------------------------------------------- /pytorchltr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorchltr.datasets.svmrank.example3 import Example3 # noqa: F401 2 | from pytorchltr.datasets.svmrank.istella import Istella # noqa: F401 3 | from pytorchltr.datasets.svmrank.istella_s import IstellaS # noqa: F401 4 | from pytorchltr.datasets.svmrank.istella_x import IstellaX # noqa: F401 5 | from pytorchltr.datasets.svmrank.mslr10k import MSLR10K # noqa: F401 6 | from pytorchltr.datasets.svmrank.mslr30k import MSLR30K # noqa: F401 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile 3 | # To update, run: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | future==0.18.2 # via torch 8 | joblib==0.14.1 # via scikit-learn 9 | numpy==1.18.3 # via -r requirements.in, scikit-learn, scipy, torch 10 | scikit-learn==0.22.2.post1 # via -r requirements.in 11 | scipy==1.4.1 # via -r requirements.in, scikit-learn 12 | torch==1.5.0 # via -r requirements.in 13 | -------------------------------------------------------------------------------- /pytorchltr/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorchltr.loss.pairwise_additive import PairwiseHingeLoss # noqa: F401 2 | from pytorchltr.loss.pairwise_additive import PairwiseDCGHingeLoss # noqa: F401,E501 3 | from pytorchltr.loss.pairwise_additive import PairwiseLogisticLoss # noqa: F401,E501 4 | from pytorchltr.loss.pairwise_lambda import LambdaARPLoss1 # noqa: F401,E501 5 | from pytorchltr.loss.pairwise_lambda import LambdaARPLoss2 # noqa: F401,E501 6 | from pytorchltr.loss.pairwise_lambda import LambdaNDCGLoss1 # noqa: F401,E501 7 | from pytorchltr.loss.pairwise_lambda import LambdaNDCGLoss2 # noqa: F401,E501 8 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Rolf Jagerman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /docs/source/evaluation.rst: -------------------------------------------------------------------------------- 1 | .. _evaluation: 2 | 3 | Evaluation 4 | ========== 5 | 6 | PyTorchLTR provides several built-in evaluation metrics including ARP 7 | :cite:`evaluation-joachims2017unbiased` and DCG 8 | :cite:`evaluation-kalervo2002cumulated`. Furthermore, the library has support 9 | for creating `pytrec_eval `_ 10 | :cite:`evaluation-gysel2018pytreceval` compatible output. 11 | 12 | Example 13 | ------- 14 | 15 | .. doctest:: 16 | 17 | >>> import torch 18 | >>> from pytorchltr.evaluation import ndcg 19 | >>> scores = torch.tensor([[1.0, 0.0, 1.5], [1.5, 0.2, 0.5]]) 20 | >>> relevance = torch.tensor([[0, 1, 0], [0, 1, 1]]) 21 | >>> n = torch.tensor([3, 3]) 22 | >>> ndcg(scores, relevance, n, k=10) 23 | tensor([0.5000, 0.6934]) 24 | 25 | Built-in metrics 26 | ---------------- 27 | 28 | .. autofunction:: pytorchltr.evaluation.arp 29 | 30 | .. autofunction:: pytorchltr.evaluation.dcg 31 | 32 | .. autofunction:: pytorchltr.evaluation.ndcg 33 | 34 | Integration with pytrec_eval 35 | ---------------------------- 36 | 37 | .. autofunction:: pytorchltr.evaluation.generate_pytrec_eval 38 | 39 | .. rubric:: References 40 | 41 | .. bibliography:: references.bib 42 | :cited: 43 | :style: authoryearstyle 44 | :keyprefix: evaluation- 45 | -------------------------------------------------------------------------------- /.github/workflows/python-test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | schedule: 9 | - cron: '0 0 * * 2' # every Tuesday at 00:00 10 | 11 | jobs: 12 | test: 13 | 14 | strategy: 15 | matrix: 16 | os: [macos-latest, windows-latest, ubuntu-latest] 17 | python-version: [ '3.5', '3.6', '3.7', '3.8' ] 18 | 19 | runs-on: ${{ matrix.os }} 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | architecture: 'x64' 28 | - name: Setup build environment 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install setuptools wheel twine 32 | - name: Install windows-specific dependencies 33 | if: matrix.os == 'windows-latest' 34 | run: | 35 | pip install torch==1.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 36 | - name: Install dependencies 37 | run: | 38 | pip install -r requirements.txt 39 | pip install -r dev-requirements.txt 40 | python setup.py build_ext --inplace 41 | - name: Test 42 | run: | 43 | flake8 tests/ pytorchltr/ 44 | pytest tests --cov-report=xml --cov=pytorchltr 45 | - name: Report coverage 46 | uses: codecov/codecov-action@v1 47 | with: 48 | token: ${{ secrets.CODECOV_TOKEN }} 49 | file: ./coverage.xml 50 | -------------------------------------------------------------------------------- /tests/evaluation/test_arp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorchltr.evaluation import arp 3 | from pytest import approx 4 | 5 | 6 | def test_arp(): 7 | torch.manual_seed(42) 8 | scores = torch.FloatTensor([ 9 | [10.0, 5.0, 2.0, 3.0, 4.0], 10 | [5.0, 6.0, 4.0, 2.0, 5.5] 11 | ]) 12 | ys = torch.LongTensor([ 13 | [0, 1, 1, 0, 1], 14 | [1, 1, 0, 0, 0] 15 | ]) 16 | n = torch.LongTensor([5, 4]) 17 | 18 | out = arp(scores, ys, n) 19 | expected = torch.FloatTensor([3.333333333, 1.5]) 20 | 21 | assert out.numpy() == approx(expected.numpy()) 22 | 23 | 24 | def test_arp_all_relevant(): 25 | torch.manual_seed(42) 26 | scores = torch.FloatTensor([ 27 | [10.0, 5.0, 2.0, 3.0, 4.0], 28 | [5.0, 6.0, 4.0, 2.0, 5.5] 29 | ]) 30 | ys = torch.LongTensor([ 31 | [1, 1, 1, 1, 1], 32 | [1, 1, 1, 1, 1] 33 | ]) 34 | n = torch.LongTensor([5, 4]) 35 | 36 | out = arp(scores, ys, n) 37 | expected = torch.FloatTensor([3.0, 2.5]) 38 | 39 | assert out.numpy() == approx(expected.numpy()) 40 | 41 | 42 | def test_arp_no_relevant(): 43 | torch.manual_seed(42) 44 | scores = torch.FloatTensor([ 45 | [10.0, 5.0, 2.0, 3.0, 4.0], 46 | [5.0, 6.0, 4.0, 2.0, 5.5] 47 | ]) 48 | ys = torch.LongTensor([ 49 | [0, 0, 0, 0, 0], 50 | [0, 0, 0, 0, 0] 51 | ]) 52 | n = torch.LongTensor([5, 4]) 53 | 54 | out = arp(scores, ys, n) 55 | expected = torch.FloatTensor([0.0, 0.0]) 56 | 57 | assert out.numpy() == approx(expected.numpy()) 58 | -------------------------------------------------------------------------------- /tests/datasets/svmrank/test_istella.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from pytorchltr.datasets.svmrank.istella import Istella 5 | from tests.datasets.svmrank.test_svmrank import mock_svmrank_dataset 6 | 7 | 8 | pkg = "pytorchltr.datasets.svmrank.istella" 9 | 10 | 11 | def test_wrong_split_raises_error(): 12 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 13 | with pytest.raises(ValueError): 14 | Istella(tmpdir, split="nonexisting") 15 | 16 | 17 | def test_call_validate_download(): 18 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 19 | Istella(tmpdir, split="train") 20 | mock_vali.called_once() 21 | args, kwargs = mock_vali.call_args 22 | assert kwargs["location"] == tmpdir 23 | assert kwargs["validate_checksums"] 24 | assert isinstance(kwargs["expected_files"], list) 25 | 26 | 27 | def test_call_super_train(): 28 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 29 | Istella(tmpdir, split="train") 30 | mock_super.called_once() 31 | args, kwargs = mock_super.call_args 32 | assert kwargs["file"] == os.path.join(tmpdir, "full", "train.txt") 33 | assert kwargs["normalize"] 34 | assert not kwargs["filter_queries"] 35 | 36 | 37 | def test_call_super_test(): 38 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 39 | Istella(tmpdir, split="test") 40 | mock_super.called_once() 41 | args, kwargs = mock_super.call_args 42 | assert kwargs["file"] == os.path.join(tmpdir, "full", "test.txt") 43 | assert kwargs["normalize"] 44 | assert kwargs["filter_queries"] 45 | -------------------------------------------------------------------------------- /tests/datasets/svmrank/test_example3.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from pytorchltr.datasets.svmrank.example3 import Example3 5 | from tests.datasets.svmrank.test_svmrank import mock_svmrank_dataset 6 | 7 | 8 | pkg = "pytorchltr.datasets.svmrank.example3" 9 | 10 | 11 | def test_wrong_split_raises_error(): 12 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 13 | with pytest.raises(ValueError): 14 | Example3(tmpdir, split="nonexisting") 15 | 16 | 17 | def test_call_validate_download(): 18 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 19 | Example3(tmpdir, split="train") 20 | mock_vali.called_once() 21 | args, kwargs = mock_vali.call_args 22 | assert kwargs["location"] == tmpdir 23 | assert kwargs["validate_checksums"] 24 | assert isinstance(kwargs["expected_files"], list) 25 | 26 | 27 | def test_call_super_train(): 28 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 29 | Example3(tmpdir, split="train") 30 | mock_super.called_once() 31 | args, kwargs = mock_super.call_args 32 | assert kwargs["file"] == os.path.join(tmpdir, "example3", "train.dat") 33 | assert kwargs["normalize"] 34 | assert not kwargs["filter_queries"] 35 | 36 | 37 | def test_call_super_test(): 38 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 39 | Example3(tmpdir, split="test") 40 | mock_super.called_once() 41 | args, kwargs = mock_super.call_args 42 | assert kwargs["file"] == os.path.join(tmpdir, "example3", "test.dat") 43 | assert kwargs["normalize"] 44 | assert kwargs["filter_queries"] 45 | -------------------------------------------------------------------------------- /pytorchltr/evaluation/arp.py: -------------------------------------------------------------------------------- 1 | """Average Relevant Position.""" 2 | import torch as _torch 3 | from pytorchltr.utils import rank_by_score as _rank_by_score 4 | from pytorchltr.utils import mask_padded_values as _mask_padded_values 5 | 6 | 7 | def arp(scores: _torch.FloatTensor, relevance: _torch.LongTensor, 8 | n: _torch.LongTensor) -> _torch.FloatTensor: 9 | r"""Average Relevant Position (ARP) 10 | 11 | .. math:: 12 | 13 | \text{arp}(\mathbf{s}, \mathbf{y}) 14 | = \frac{1}{\sum_{i=1}^n y_i} \sum_{i=1}^n y_{\pi_i} \cdot i 15 | 16 | where :math:`\pi_i` is the index of the item at rank :math:`i` after 17 | sorting the scores. 18 | 19 | Args: 20 | scores: A tensor of size (batch_size, list_size, 1) or 21 | (batch_size, list_size), indicating the scores per document per 22 | query. 23 | relevance: A tensor of size (batch_size, list_size), indicating the 24 | relevance judgements per document per query. 25 | n: A tensor of size (batch_size) indicating the number of docs per 26 | query. 27 | 28 | Returns: 29 | A tensor of size (batch_size) indicating the ARP of each query. 30 | """ 31 | # Compute relevance per rank 32 | rel_sort = _torch.gather(relevance, 1, _rank_by_score(scores, n)).float() 33 | arange = 1.0 + _torch.repeat_interleave( 34 | _torch.arange( 35 | rel_sort.shape[1], device=rel_sort.device, 36 | dtype=_torch.float)[None, :], 37 | rel_sort.shape[0], dim=0) 38 | _mask_padded_values(rel_sort, n, mask_value=0.0, mutate=True) 39 | srp = _torch.sum(arange * rel_sort, dim=1) 40 | nrp = _torch.sum(rel_sort, dim=1) 41 | nrp[nrp == 0.0] = 1.0 42 | return srp / nrp 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # mkdocs documentation 96 | /site 97 | 98 | # mypy 99 | .mypy_cache/ 100 | 101 | # IDEA project settings 102 | .idea/ 103 | 104 | # VSCode project settings 105 | .vscode/ 106 | 107 | # OSX 108 | .DS_Store 109 | 110 | # Cover code analysis 111 | cover/ 112 | -------------------------------------------------------------------------------- /pytorchltr/datasets/svmrank/parser/svmrank_parser.pyx: -------------------------------------------------------------------------------- 1 | cimport numpy as np 2 | import numpy as np 3 | 4 | 5 | cdef extern from "errno.h": 6 | int errno 7 | 8 | 9 | cdef extern from "svmrank_parser.h": 10 | cdef struct shape: 11 | unsigned long cols 12 | unsigned long rows 13 | int PARSE_OK 14 | int PARSE_FILE_ERROR 15 | int PARSE_FORMAT_ERROR 16 | int PARSE_MEMORY_ERROR 17 | int c_parse_svmrank_file "parse_svmrank_file" (char* path, double** xs, shape* xs_shape, int** ys, long** qids) nogil 18 | void init_svmrank_parser() 19 | 20 | 21 | def parse_svmrank_file(path): 22 | global errno 23 | 24 | # Initialize pointers 25 | cdef int* ys 26 | cdef long* qids 27 | cdef double* xs 28 | cdef shape xs_shape 29 | 30 | # Initialize path to read file from 31 | py_path_bytes = path.encode('UTF-8') 32 | cdef char* c_path = py_path_bytes 33 | cdef int result = 0 34 | 35 | # Initialize output array views 36 | cdef int[:] ys_view 37 | cdef long[:] qids_view 38 | cdef double[:,:] xs_view 39 | 40 | # Init parser and parse file 41 | init_svmrank_parser() 42 | with nogil: 43 | result = c_parse_svmrank_file(c_path, &xs, &xs_shape, &ys, &qids) 44 | 45 | if result == PARSE_OK: 46 | ys_view = ys 47 | ys_np = np.asarray(ys_view) 48 | qids_view = qids 49 | qids_np = np.asarray(qids_view) 50 | xs_view = xs 51 | xs_np = np.asarray(xs_view) 52 | 53 | return xs_np, ys_np, qids_np 54 | elif result == PARSE_FILE_ERROR: 55 | raise OSError(errno, "could not open file %s" % path) 56 | elif result == PARSE_FORMAT_ERROR: 57 | raise ValueError("could not parse file %s, not in SVMrank format" % path) 58 | elif result == PARSE_MEMORY_ERROR: 59 | raise OSError(errno, "could not allocate memory") 60 | -------------------------------------------------------------------------------- /tests/datasets/svmrank/test_istella_x.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from pytorchltr.datasets.svmrank.istella_x import IstellaX 5 | from tests.datasets.svmrank.test_svmrank import mock_svmrank_dataset 6 | 7 | 8 | pkg = "pytorchltr.datasets.svmrank.istella_x" 9 | 10 | 11 | def test_wrong_split_raises_error(): 12 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 13 | with pytest.raises(ValueError): 14 | IstellaX(tmpdir, split="nonexisting") 15 | 16 | 17 | def test_call_validate_download(): 18 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 19 | IstellaX(tmpdir, split="train") 20 | mock_vali.called_once() 21 | args, kwargs = mock_vali.call_args 22 | assert kwargs["location"] == tmpdir 23 | assert kwargs["validate_checksums"] 24 | assert isinstance(kwargs["expected_files"], list) 25 | 26 | 27 | def test_call_super_train(): 28 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 29 | IstellaX(tmpdir, split="train") 30 | mock_super.called_once() 31 | args, kwargs = mock_super.call_args 32 | assert kwargs["file"] == os.path.join(tmpdir, "train.txt") 33 | assert kwargs["normalize"] 34 | assert not kwargs["filter_queries"] 35 | 36 | 37 | def test_call_super_vali(): 38 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 39 | IstellaX(tmpdir, split="vali") 40 | mock_super.called_once() 41 | args, kwargs = mock_super.call_args 42 | assert kwargs["file"] == os.path.join(tmpdir, "vali.txt") 43 | assert kwargs["normalize"] 44 | assert kwargs["filter_queries"] 45 | 46 | 47 | def test_call_super_test(): 48 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 49 | IstellaX(tmpdir, split="test") 50 | mock_super.called_once() 51 | args, kwargs = mock_super.call_args 52 | assert kwargs["file"] == os.path.join(tmpdir, "test.txt") 53 | assert kwargs["normalize"] 54 | assert kwargs["filter_queries"] 55 | -------------------------------------------------------------------------------- /tests/datasets/svmrank/test_istella_s.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from pytorchltr.datasets.svmrank.istella_s import IstellaS 5 | from tests.datasets.svmrank.test_svmrank import mock_svmrank_dataset 6 | 7 | 8 | pkg = "pytorchltr.datasets.svmrank.istella_s" 9 | 10 | 11 | def test_wrong_split_raises_error(): 12 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 13 | with pytest.raises(ValueError): 14 | IstellaS(tmpdir, split="nonexisting") 15 | 16 | 17 | def test_call_validate_download(): 18 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 19 | IstellaS(tmpdir, split="train") 20 | mock_vali.called_once() 21 | args, kwargs = mock_vali.call_args 22 | assert kwargs["location"] == tmpdir 23 | assert kwargs["validate_checksums"] 24 | assert isinstance(kwargs["expected_files"], list) 25 | 26 | 27 | def test_call_super_train(): 28 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 29 | IstellaS(tmpdir, split="train") 30 | mock_super.called_once() 31 | args, kwargs = mock_super.call_args 32 | assert kwargs["file"] == os.path.join(tmpdir, "sample", "train.txt") 33 | assert kwargs["normalize"] 34 | assert not kwargs["filter_queries"] 35 | 36 | 37 | def test_call_super_vali(): 38 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 39 | IstellaS(tmpdir, split="vali") 40 | mock_super.called_once() 41 | args, kwargs = mock_super.call_args 42 | assert kwargs["file"] == os.path.join(tmpdir, "sample", "vali.txt") 43 | assert kwargs["normalize"] 44 | assert kwargs["filter_queries"] 45 | 46 | 47 | def test_call_super_test(): 48 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 49 | IstellaS(tmpdir, split="test") 50 | mock_super.called_once() 51 | args, kwargs = mock_super.call_args 52 | assert kwargs["file"] == os.path.join(tmpdir, "sample", "test.txt") 53 | assert kwargs["normalize"] 54 | assert kwargs["filter_queries"] 55 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from setuptools import find_packages 4 | from setuptools import Extension 5 | 6 | import numpy 7 | 8 | try: 9 | from Cython.Build import cythonize 10 | CAN_CYTHONIZE = True 11 | except ImportError: 12 | CAN_CYTHONIZE = False 13 | 14 | 15 | def get_svmrank_parser_ext(): 16 | """ 17 | Gets the svmrank parser extension. 18 | 19 | This uses cython if possible when building from source, otherwise uses the 20 | packaged .c files to compile directly. 21 | """ 22 | path = "pytorchltr/datasets/svmrank/parser" 23 | pyx_path = os.path.join(path, "svmrank_parser.pyx") 24 | c_path = os.path.join(path, "svmrank_parser.c") 25 | if CAN_CYTHONIZE and os.path.exists(pyx_path): 26 | return cythonize([Extension( 27 | "pytorchltr.datasets.svmrank.parser.svmrank_parser", [pyx_path], 28 | define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")])]) 29 | else: 30 | return [Extension("pytorchltr.datasets.svmrank.parser.svmrank_parser", 31 | [c_path])] 32 | 33 | 34 | with open("README.md", "rt") as f: 35 | long_description = f.read() 36 | 37 | 38 | setup( 39 | name="pytorchltr", 40 | version="0.2.1", 41 | description="Learning to Rank with PyTorch", 42 | long_description=long_description, 43 | long_description_content_type="text/markdown", 44 | url="https://github.com/rjagerman/pytorchltr", 45 | author="Rolf Jagerman", 46 | author_email="rjagerman@gmail.com", 47 | license="MIT", 48 | packages=find_packages(exclude=("tests", "tests.*",)), 49 | python_requires='>=3.5', 50 | ext_modules=get_svmrank_parser_ext(), 51 | include_dirs=[numpy.get_include()], 52 | install_requires=["numpy", 53 | "scikit-learn", 54 | "scipy", 55 | "torch"], 56 | tests_require=["pytest"], 57 | classifiers=[ 58 | "Programming Language :: Python :: 3", 59 | "License :: OSI Approved :: MIT License", 60 | "Operating System :: OS Independent", 61 | ] 62 | ) 63 | -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tests.datasets.svmrank.test_svmrank import get_sample_dataset 3 | from pytorchltr.datasets.svmrank.svmrank import SVMRankDataset 4 | from pytorchltr.datasets.list_sampler import UniformSampler 5 | from pytorchltr.loss import PairwiseHingeLoss 6 | from pytorchltr.evaluation.arp import arp 7 | 8 | 9 | class Model(torch.nn.Module): 10 | def __init__(self, in_features): 11 | super().__init__() 12 | self.ff = torch.nn.Linear(in_features, 1) 13 | 14 | def forward(self, xs): 15 | return self.ff(xs) 16 | 17 | 18 | def test_basic_sgd_learning(): 19 | torch.manual_seed(42) 20 | 21 | dataset = get_sample_dataset() 22 | 23 | input_dim = dataset[0].features.shape[1] 24 | collate_fn = SVMRankDataset.collate_fn(UniformSampler(max_list_size=50)) 25 | model = Model(input_dim) 26 | optimizer = torch.optim.SGD(model.parameters(), lr=0.001) 27 | loss_fn = PairwiseHingeLoss() 28 | arp_per_epoch = torch.zeros(100) 29 | 30 | # Perform 100 epochs 31 | for epoch in range(100): 32 | 33 | # Load and iterate over dataset 34 | avg_arp = 0.0 35 | loader = torch.utils.data.DataLoader( 36 | dataset, batch_size=2, shuffle=True, collate_fn=collate_fn) 37 | for i, batch in enumerate(loader): 38 | # Get batch of samples 39 | xs, ys, n = batch.features, batch.relevance, batch.n 40 | 41 | # Compute loss 42 | loss = loss_fn(model(xs), ys, n) 43 | loss = loss.mean() 44 | 45 | # Perform SGD step 46 | optimizer.zero_grad() 47 | loss.backward() 48 | optimizer.step() 49 | 50 | # Evaluate ARP on train data 51 | model.eval() 52 | arp_score = torch.mean(arp(model(xs), ys, n)) 53 | model.train() 54 | 55 | # Keep track of average ARP in this epoch 56 | avg_arp = avg_arp + (float(arp_score) - avg_arp) / (i + 1) 57 | 58 | # Record the average ARP. 59 | arp_per_epoch[epoch] = avg_arp 60 | 61 | # Assert that the ARP was decreased by a significant amount from start to 62 | # finish. 63 | assert arp_per_epoch[-1] - arp_per_epoch[0] <= -0.40 64 | -------------------------------------------------------------------------------- /tests/datasets/svmrank/test_mslr10k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from pytorchltr.datasets.svmrank.mslr10k import MSLR10K 5 | from tests.datasets.svmrank.test_svmrank import mock_svmrank_dataset 6 | 7 | 8 | pkg = "pytorchltr.datasets.svmrank.mslr10k" 9 | 10 | 11 | def test_wrong_split_raises_error(): 12 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 13 | with pytest.raises(ValueError): 14 | MSLR10K(tmpdir, split="nonexisting") 15 | 16 | 17 | def test_wrong_fold_raises_error(): 18 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 19 | with pytest.raises(ValueError): 20 | MSLR10K(tmpdir, split="train", fold=99) 21 | 22 | 23 | def test_call_validate_download(): 24 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 25 | MSLR10K(tmpdir, split="train") 26 | mock_vali.called_once() 27 | args, kwargs = mock_vali.call_args 28 | assert kwargs["location"] == tmpdir 29 | assert kwargs["validate_checksums"] 30 | assert isinstance(kwargs["expected_files"], list) 31 | 32 | 33 | def test_call_super_train(): 34 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 35 | MSLR10K(tmpdir, split="train", fold=1) 36 | mock_super.called_once() 37 | args, kwargs = mock_super.call_args 38 | assert kwargs["file"] == os.path.join(tmpdir, "Fold1", "train.txt") 39 | assert kwargs["normalize"] 40 | assert not kwargs["filter_queries"] 41 | 42 | 43 | def test_call_super_vali(): 44 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 45 | MSLR10K(tmpdir, split="vali", fold=2) 46 | mock_super.called_once() 47 | args, kwargs = mock_super.call_args 48 | assert kwargs["file"] == os.path.join(tmpdir, "Fold2", "vali.txt") 49 | assert kwargs["normalize"] 50 | assert kwargs["filter_queries"] 51 | 52 | 53 | def test_call_super_test(): 54 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 55 | MSLR10K(tmpdir, split="test", fold=5) 56 | mock_super.called_once() 57 | args, kwargs = mock_super.call_args 58 | assert kwargs["file"] == os.path.join(tmpdir, "Fold5", "test.txt") 59 | assert kwargs["normalize"] 60 | assert kwargs["filter_queries"] 61 | -------------------------------------------------------------------------------- /tests/datasets/svmrank/test_mslr30k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from pytorchltr.datasets.svmrank.mslr30k import MSLR30K 5 | from tests.datasets.svmrank.test_svmrank import mock_svmrank_dataset 6 | 7 | 8 | pkg = "pytorchltr.datasets.svmrank.mslr30k" 9 | 10 | 11 | def test_wrong_split_raises_error(): 12 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 13 | with pytest.raises(ValueError): 14 | MSLR30K(tmpdir, split="nonexisting") 15 | 16 | 17 | def test_wrong_fold_raises_error(): 18 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 19 | with pytest.raises(ValueError): 20 | MSLR30K(tmpdir, split="train", fold=99) 21 | 22 | 23 | def test_call_validate_download(): 24 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 25 | MSLR30K(tmpdir, split="train") 26 | mock_vali.called_once() 27 | args, kwargs = mock_vali.call_args 28 | assert kwargs["location"] == tmpdir 29 | assert kwargs["validate_checksums"] 30 | assert isinstance(kwargs["expected_files"], list) 31 | 32 | 33 | def test_call_super_train(): 34 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 35 | MSLR30K(tmpdir, split="train", fold=1) 36 | mock_super.called_once() 37 | args, kwargs = mock_super.call_args 38 | assert kwargs["file"] == os.path.join(tmpdir, "Fold1", "train.txt") 39 | assert kwargs["normalize"] 40 | assert not kwargs["filter_queries"] 41 | 42 | 43 | def test_call_super_vali(): 44 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 45 | MSLR30K(tmpdir, split="vali", fold=2) 46 | mock_super.called_once() 47 | args, kwargs = mock_super.call_args 48 | assert kwargs["file"] == os.path.join(tmpdir, "Fold2", "vali.txt") 49 | assert kwargs["normalize"] 50 | assert kwargs["filter_queries"] 51 | 52 | 53 | def test_call_super_test(): 54 | with mock_svmrank_dataset(pkg) as (tmpdir, mock_super, mock_vali): 55 | MSLR30K(tmpdir, split="test", fold=5) 56 | mock_super.called_once() 57 | args, kwargs = mock_super.call_args 58 | assert kwargs["file"] == os.path.join(tmpdir, "Fold5", "test.txt") 59 | assert kwargs["normalize"] 60 | assert kwargs["filter_queries"] 61 | -------------------------------------------------------------------------------- /tests/evaluation/test_trec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorchltr.evaluation import generate_pytrec_eval 3 | 4 | 5 | def test_pytrec_eval_input(): 6 | scores = torch.FloatTensor([ 7 | [10.0, 5.0, 2.0, 3.0, 4.0], 8 | [5.0, 6.0, 4.0, 2.0, 5.5] 9 | ]) 10 | ys = torch.LongTensor([ 11 | [0, 1, 1, 0, 1], 12 | [3, 1, 0, 1, 0] 13 | ]) 14 | n = torch.LongTensor([5, 4]) 15 | 16 | qrels, run = generate_pytrec_eval(scores, ys, n) 17 | qrels_expected = { 18 | 'q0': {'d0': 0, 'd1': 1, 'd2': 1, 'd3': 0, 'd4': 1}, 19 | 'q1': {'d0': 3, 'd1': 1, 'd2': 0, 'd3': 1} 20 | } 21 | run_expected = { 22 | 'q0': {'d0': 10.0, 'd1': 5.0, 'd2': 2.0, 'd3': 3.0, 'd4': 4.0}, 23 | 'q1': {'d0': 5.0, 'd1': 6.0, 'd2': 4.0, 'd3': 2.0} 24 | } 25 | assert qrels_expected == qrels 26 | assert run_expected == run 27 | 28 | 29 | def test_pytrec_eval_input_noprefix(): 30 | scores = torch.FloatTensor([ 31 | [10.0, 5.0, 2.0, 3.0, 4.0], 32 | [5.0, 6.0, 4.0, 2.0, 5.5] 33 | ]) 34 | ys = torch.LongTensor([ 35 | [0, 1, 1, 0, 1], 36 | [3, 1, 0, 1, 0] 37 | ]) 38 | n = torch.LongTensor([5, 4]) 39 | 40 | qrels, run = generate_pytrec_eval(scores, ys, n, q_prefix="", d_prefix="") 41 | qrels_expected = { 42 | '0': {'0': 0, '1': 1, '2': 1, '3': 0, '4': 1}, 43 | '1': {'0': 3, '1': 1, '2': 0, '3': 1} 44 | } 45 | run_expected = { 46 | '0': {'0': 10.0, '1': 5.0, '2': 2.0, '3': 3.0, '4': 4.0}, 47 | '1': {'0': 5.0, '1': 6.0, '2': 4.0, '3': 2.0} 48 | } 49 | assert qrels_expected == qrels 50 | assert run_expected == run 51 | 52 | 53 | def test_pytrec_eval_input_qids(): 54 | scores = torch.FloatTensor([ 55 | [10.0, 5.0, 2.0, 3.0, 4.0], 56 | [5.0, 6.0, 4.0, 2.0, 5.5] 57 | ]) 58 | ys = torch.LongTensor([ 59 | [0, 1, 1, 0, 1], 60 | [3, 1, 0, 1, 0] 61 | ]) 62 | n = torch.LongTensor([5, 4]) 63 | qid = torch.LongTensor([15623, 49998]) 64 | 65 | qrels, run = generate_pytrec_eval(scores, ys, n, qid, q_prefix="") 66 | qrels_expected = { 67 | '15623': {'d0': 0, 'd1': 1, 'd2': 1, 'd3': 0, 'd4': 1}, 68 | '49998': {'d0': 3, 'd1': 1, 'd2': 0, 'd3': 1} 69 | } 70 | run_expected = { 71 | '15623': {'d0': 10.0, 'd1': 5.0, 'd2': 2.0, 'd3': 3.0, 'd4': 4.0}, 72 | '49998': {'d0': 5.0, 'd1': 6.0, 'd2': 4.0, 'd3': 2.0} 73 | } 74 | assert qrels_expected == qrels 75 | assert run_expected == run 76 | -------------------------------------------------------------------------------- /pytorchltr/utils/progress.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | from typing import Callable 4 | from typing import Optional 5 | 6 | 7 | _PROGRESS_FN_TYPE = Callable[[int, Optional[int], bool], None] 8 | 9 | 10 | def _default_progress_str(progress: int, total: Optional[int], final: bool): 11 | """ 12 | Default progress string 13 | 14 | Args: 15 | progress: The progress so far as an integer. 16 | total: The total progress. 17 | final: Whether this is the final call. 18 | 19 | Returns: 20 | A formatted string representing the progress so far. 21 | """ 22 | prefix = "completed " if final else "" 23 | if total is not None: 24 | percent = (100.0 * progress) / total 25 | return "%s%d / %d (%3d%%)" % (prefix, progress, total, int(percent)) 26 | else: 27 | return "%s%d / %d" % (prefix, progress, progress) 28 | 29 | 30 | class IntervalProgress: 31 | """ 32 | A progress hook function that reports to output at a specified interval. 33 | """ 34 | def __init__(self, interval: float = 1.0, 35 | progress_str: _PROGRESS_FN_TYPE = _default_progress_str): 36 | self.interval = interval 37 | self.progress_str = progress_str 38 | self.last_update = time.time() - interval 39 | 40 | def __call__(self, progress: int, total: Optional[int], final: bool): 41 | if final or time.time() - self.last_update >= self.interval: 42 | self.progress(progress, total, final) 43 | self.last_update = time.time() 44 | 45 | def progress(self, progress: int, total: Optional[int], final: bool): 46 | """Processes the progress so far. Called only once per interval. 47 | 48 | Args: 49 | progress: The progress so far. 50 | total: The total to reach. 51 | final: Whether this is the final progress call. 52 | """ 53 | raise NotImplementedError 54 | 55 | 56 | class LoggingProgress(IntervalProgress): 57 | """ 58 | An interval progress hook that reports to logging.info. 59 | """ 60 | def progress(self, progress, total, final): 61 | logging.info(self.progress_str(progress, total, final)) 62 | 63 | 64 | class TerminalProgress(IntervalProgress): 65 | """ 66 | An interval progress hook that writes to the terminal via print. 67 | """ 68 | def progress(self, progress, total, final): 69 | print("\033[K" + self.progress_str(progress, total, final), 70 | end="\n" if final else "\r") 71 | -------------------------------------------------------------------------------- /pytorchltr/datasets/list_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch as _torch 3 | 4 | 5 | class ListSampler: 6 | def __init__(self, max_list_size: Optional[int] = None): 7 | self._max_list_size = max_list_size 8 | 9 | def max_list_size(self, relevance): 10 | size = relevance.shape[0] 11 | if self._max_list_size is not None: 12 | size = min(self._max_list_size, size) 13 | return size 14 | 15 | def __call__(self, relevance: _torch.LongTensor) -> _torch.LongTensor: 16 | return _torch.arange(self.max_list_size(relevance), dtype=_torch.long) 17 | 18 | 19 | class UniformSampler(ListSampler): 20 | def __init__(self, max_list_size: Optional[int] = None, 21 | generator: Optional[_torch.Generator] = None): 22 | super().__init__(max_list_size) 23 | self.rng_kw = {"generator": generator} if generator is not None else {} 24 | 25 | def __call__(self, relevance: _torch.LongTensor) -> _torch.LongTensor: 26 | perm = _torch.randperm(relevance.shape[0], **self.rng_kw) 27 | return perm[0:self.max_list_size(relevance)] 28 | 29 | 30 | class BalancedRelevanceSampler(UniformSampler): 31 | def __init__(self, max_list_size: Optional[int] = None, 32 | generator: Optional[_torch.Generator] = None): 33 | super().__init__(max_list_size, generator) 34 | 35 | def __call__(self, relevance: _torch.LongTensor) -> _torch.LongTensor: 36 | # Get the unique relevance grades and randomly permute them 37 | unique_rel = _torch.unique(relevance) 38 | unique_rel = unique_rel[_torch.randperm( 39 | unique_rel.shape[0], **self.rng_kw)] 40 | 41 | # Randomly shuffle the relevance grades 42 | perm = _torch.randperm(relevance.shape[0], **self.rng_kw) 43 | rel_shuffled = relevance[perm] 44 | 45 | # Get maximum size and create a stack for output 46 | max_size = self.max_list_size(relevance) 47 | stacked_rel = -_torch.ones((unique_rel.shape[0], relevance.shape[0]), 48 | dtype=_torch.long, device=relevance.device) 49 | 50 | # For each unique relevance grade fill a row in the stacked_rel matrix 51 | # with the indices of that relevance grade 52 | for j, rel in enumerate(unique_rel): 53 | idxs = _torch.where(rel_shuffled == rel)[0] 54 | size = min(max_size, idxs.shape[0]) 55 | stacked_rel[j, 0:size] = idxs[0:size] 56 | 57 | # Create output by removing unused indices and undoing the permutation 58 | out = stacked_rel.T.reshape(relevance.shape[0] * unique_rel.shape[0]) 59 | out = out[out != -1] 60 | out = perm[out] 61 | return out[0:max_size] 62 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile 3 | # To update, run: 4 | # 5 | # pip-compile dev-requirements.in 6 | # 7 | alabaster==0.7.12 8 | # via sphinx 9 | attrs==19.3.0 10 | # via pytest 11 | babel==2.8.0 12 | # via sphinx 13 | certifi==2020.4.5.1 14 | # via requests 15 | chardet==3.0.4 16 | # via requests 17 | coverage==5.1 18 | # via pytest-cov 19 | cython==0.29.21 20 | # via -r dev-requirements.in 21 | docutils==0.16 22 | # via 23 | # pybtex-docutils 24 | # sphinx 25 | flake8==3.8.2 26 | # via -r dev-requirements.in 27 | idna==2.9 28 | # via requests 29 | imagesize==1.2.0 30 | # via sphinx 31 | jinja2==2.11.3 32 | # via sphinx 33 | latexcodec==2.0.1 34 | # via pybtex 35 | markupsafe==1.1.1 36 | # via jinja2 37 | mccabe==0.6.1 38 | # via flake8 39 | more-itertools==8.2.0 40 | # via pytest 41 | oset==0.1.3 42 | # via sphinxcontrib-bibtex 43 | packaging==20.3 44 | # via 45 | # pytest 46 | # sphinx 47 | pluggy==0.13.1 48 | # via pytest 49 | py==1.8.1 50 | # via pytest 51 | pybtex-docutils==0.2.2 52 | # via sphinxcontrib-bibtex 53 | pybtex==0.22.2 54 | # via 55 | # pybtex-docutils 56 | # sphinxcontrib-bibtex 57 | pycodestyle==2.6.0 58 | # via flake8 59 | pyflakes==2.2.0 60 | # via flake8 61 | pygments==2.6.1 62 | # via sphinx 63 | pyparsing==2.4.7 64 | # via packaging 65 | pytest-cov==2.9.0 66 | # via -r dev-requirements.in 67 | pytest==5.4.1 68 | # via 69 | # -r dev-requirements.in 70 | # pytest-cov 71 | pytz==2020.1 72 | # via babel 73 | pyyaml==5.3.1 74 | # via pybtex 75 | requests==2.23.0 76 | # via sphinx 77 | six==1.14.0 78 | # via 79 | # latexcodec 80 | # packaging 81 | # pybtex 82 | # pybtex-docutils 83 | snowballstemmer==2.0.0 84 | # via sphinx 85 | sphinx-autodoc-typehints==1.11.0 86 | # via -r dev-requirements.in 87 | sphinx-rtd-theme==0.5.0 88 | # via -r dev-requirements.in 89 | sphinx==3.0.3 90 | # via 91 | # -r dev-requirements.in 92 | # sphinx-autodoc-typehints 93 | # sphinx-rtd-theme 94 | # sphinxcontrib-bibtex 95 | sphinxcontrib-applehelp==1.0.2 96 | # via sphinx 97 | sphinxcontrib-bibtex==1.0.0 98 | # via -r dev-requirements.in 99 | sphinxcontrib-devhelp==1.0.2 100 | # via sphinx 101 | sphinxcontrib-htmlhelp==1.0.3 102 | # via sphinx 103 | sphinxcontrib-jsmath==1.0.1 104 | # via sphinx 105 | sphinxcontrib-qthelp==1.0.3 106 | # via sphinx 107 | sphinxcontrib-serializinghtml==1.1.4 108 | # via sphinx 109 | urllib3==1.25.9 110 | # via requests 111 | wcwidth==0.1.9 112 | # via pytest 113 | 114 | # The following packages are considered to be unsafe in a requirements file: 115 | # setuptools 116 | -------------------------------------------------------------------------------- /docs/source/loss.rst: -------------------------------------------------------------------------------- 1 | .. _loss: 2 | 3 | Loss functions 4 | ============== 5 | 6 | PyTorchLTR provides serveral common loss functions for LTR. Each loss function 7 | operates on a batch of query-document lists with corresponding relevance 8 | labels. 9 | 10 | The input to an LTR loss function comprises three tensors: 11 | 12 | - scores: A tensor of size :math:`(N, \texttt{list_size})`: the item scores 13 | - relevance: A tensor of size :math:`(N, \texttt{list_size})`: the relevance labels 14 | - n: A tensor of size :math:`(N)`: the number of docs per learning instance. 15 | 16 | And produces the following output: 17 | 18 | - output: A tensor of size :math:`(N)`: the loss per learning instance in the batch. 19 | 20 | Example 21 | ------- 22 | 23 | The following is a usage example for the pairwise hinge loss but the same usage 24 | pattern holds for all the other losses. 25 | 26 | .. doctest:: 27 | 28 | >>> import torch 29 | >>> from pytorchltr.loss import PairwiseHingeLoss 30 | >>> scores = torch.tensor([[0.5, 2.0, 1.0], [0.9, -1.2, 0.0]]) 31 | >>> relevance = torch.tensor([[2, 0, 1], [0, 1, 0]]) 32 | >>> n = torch.tensor([3, 2]) 33 | >>> loss_fn = PairwiseHingeLoss() 34 | >>> loss_fn(scores, relevance, n) 35 | tensor([6.0000, 3.1000]) 36 | >>> loss_fn(scores, relevance, n).mean() 37 | tensor(4.5500) 38 | 39 | Additive ranking losses 40 | ----------------------- 41 | 42 | Additive ranking losses optimize linearly decomposible ranking metrics 43 | :cite:`loss-joachims2002optimizing,loss-agarwal2019general`. These loss 44 | functions optimize an upper bound on the rank of relevant documents via either 45 | a hinge or logistic formulation. 46 | 47 | .. autoclass:: pytorchltr.loss.PairwiseHingeLoss 48 | :members: 49 | 50 | .. automethod:: forward 51 | 52 | .. autoclass:: pytorchltr.loss.PairwiseDCGHingeLoss 53 | :members: 54 | 55 | .. automethod:: forward 56 | 57 | .. autoclass:: pytorchltr.loss.PairwiseLogisticLoss 58 | :members: 59 | 60 | .. automethod:: __init__ 61 | .. automethod:: forward 62 | 63 | 64 | LambdaLoss 65 | ---------- 66 | 67 | LambdaLoss :cite:`loss-wang2018lambdaloss` is a probabilistic framework for 68 | ranking metric optimization. We provide implementations for ARPLoss1, ARPLoss2, 69 | NDCGLoss1 and NDCGLoss2. 70 | 71 | .. autoclass:: pytorchltr.loss.LambdaARPLoss1 72 | :members: 73 | 74 | .. automethod:: __init__ 75 | .. automethod:: forward 76 | 77 | .. autoclass:: pytorchltr.loss.LambdaARPLoss2 78 | :members: 79 | 80 | .. automethod:: __init__ 81 | .. automethod:: forward 82 | 83 | .. autoclass:: pytorchltr.loss.LambdaNDCGLoss1 84 | :members: 85 | 86 | .. automethod:: __init__ 87 | .. automethod:: forward 88 | 89 | .. autoclass:: pytorchltr.loss.LambdaNDCGLoss2 90 | :members: 91 | 92 | .. automethod:: __init__ 93 | .. automethod:: forward 94 | 95 | .. rubric:: References 96 | 97 | .. bibliography:: references.bib 98 | :cited: 99 | :style: authoryearstyle 100 | :keyprefix: loss- 101 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Learning to Rank (LTR) 2 | 3 | [![Build](https://img.shields.io/github/workflow/status/rjagerman/pytorchltr/Test/master)](https://github.com/rjagerman/pytorchltr/actions?query=branch%3Amaster+workflow%3ATest) 4 | [![Documentation](https://img.shields.io/readthedocs/pytorchltr)](https://pytorchltr.readthedocs.io/) 5 | [![Coverage](https://img.shields.io/codecov/c/github/rjagerman/pytorchltr/master)](https://codecov.io/gh/rjagerman/pytorchltr) 6 | [![CodeFactor](https://img.shields.io/codefactor/grade/github/rjagerman/pytorchltr/master)](https://www.codefactor.io/repository/github/rjagerman/pytorchltr) 7 | [![License](https://img.shields.io/github/license/rjagerman/pytorchltr)](https://github.com/rjagerman/pytorchltr/blob/master/LICENSE.md) 8 | 9 | This is a library for Learning to Rank (LTR) with PyTorch. 10 | The goal of this library is to support the infrastructure necessary for performing LTR experiments in PyTorch. 11 | 12 | ## Installation 13 | 14 | In your virtualenv simply run: 15 | 16 | pip install pytorchltr 17 | 18 | Note that this library requires Python 3.5 or higher. 19 | 20 | ## Documentation 21 | 22 | Documentation is available [here](https://pytorchltr.readthedocs.io/). 23 | 24 | ## Example 25 | 26 | See `examples/01-basic-usage.py` for a more complete example including evaluation 27 | 28 | ```python 29 | import torch 30 | from pytorchltr.datasets import Example3 31 | from pytorchltr.loss import PairwiseHingeLoss 32 | 33 | # Load dataset 34 | train = Example3(split="train") 35 | collate_fn = train.collate_fn() 36 | 37 | # Setup model, optimizer and loss 38 | model = torch.nn.Linear(train[0].features.shape[1], 1) 39 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 40 | loss = PairwiseHingeLoss() 41 | 42 | # Train for 3 epochs 43 | for epoch in range(3): 44 | loader = torch.utils.data.DataLoader(train, batch_size=2, collate_fn=collate_fn) 45 | for batch in loader: 46 | xs, ys, n = batch.features, batch.relevance, batch.n 47 | l = loss(model(xs), ys, n).mean() 48 | optimizer.zero_grad() 49 | l.backward() 50 | optimizer.step() 51 | ``` 52 | 53 | ## Dataset Disclaimer 54 | This library provides utilities to automatically download and prepare several public LTR datasets. 55 | We cannot vouch for the quality, correctness or usefulness of these datasets. 56 | We do not host or distribute these datasets and it is ultimately **your responsibility** to determine whether you have permission to use each dataset under its respective license. 57 | 58 | ## Citing 59 | If you find this software useful for your research, we kindly ask you to cite the following publication: 60 | 61 | @inproceedings{jagerman2020accelerated, 62 | author = {Jagerman, Rolf and de Rijke, Maarten}, 63 | title = {Accelerated Convergence for Counterfactual Learning to Rank}, 64 | year = {2020}, 65 | publisher = {Association for Computing Machinery}, 66 | address = {New York, NY, USA}, 67 | booktitle = {Proceedings of the 43rd International ACM SIGIR Conference on Research and Development in Information Retrieval}, 68 | doi = {10.1145/3397271.3401069}, 69 | series = {SIGIR’20} 70 | } 71 | -------------------------------------------------------------------------------- /pytorchltr/evaluation/trec.py: -------------------------------------------------------------------------------- 1 | """Generate pytrec_eval runs from model and labels.""" 2 | from typing import Optional 3 | from typing import Tuple 4 | from typing import Dict 5 | import torch as _torch 6 | 7 | 8 | _PYTREC_RETURN_TYPE = Tuple[ 9 | Dict[str, Dict[str, int]], 10 | Dict[str, Dict[str, float]]] 11 | 12 | 13 | def generate_pytrec_eval(scores: _torch.FloatTensor, 14 | relevance: _torch.LongTensor, 15 | n: _torch.LongTensor, 16 | qids: Optional[_torch.LongTensor] = None, 17 | qid_offset: int = 0, 18 | q_prefix: str = "q", 19 | d_prefix: str = "d") -> _PYTREC_RETURN_TYPE: 20 | """Generates `pytrec_eval `_ 21 | qrels and runs from given batch. 22 | 23 | Example usage: 24 | 25 | .. doctest:: 26 | 27 | >>> import json 28 | >>> import torch 29 | >>> import pytrec_eval 30 | >>> from pytorchltr.evaluation.trec import generate_pytrec_eval 31 | >>> scores = torch.tensor([[1.0, 0.0, 1.5], [1.5, 0.2, 0.5]]) 32 | >>> relevance = torch.tensor([[0, 1, 0], [0, 1, 1]]) 33 | >>> n = torch.tensor([3, 3]) 34 | >>> qrel, run = generate_pytrec_eval(scores, relevance, n) 35 | >>> evaluator = pytrec_eval.RelevanceEvaluator(qrel, {'map', 'ndcg'}) 36 | >>> print(json.dumps(evaluator.evaluate(run), indent=1)) 37 | { 38 | "q0": { 39 | "map": 0.3333333333333333, 40 | "ndcg": 0.5 41 | }, 42 | "q1": { 43 | "map": 0.5833333333333333, 44 | "ndcg": 0.6934264036172708 45 | } 46 | } 47 | 48 | Args: 49 | scores: A FloatTensor of size (batch_size, list_size) indicating the 50 | scores of each document. 51 | relevance: A LongTensor of size (batch_size, list_size) indicating the 52 | relevance of each document. 53 | n: A LongTensor of size (batch_size) indicating the number of docs per 54 | query. 55 | qids: A LongTensor of size (batch_size) indicating the qid of each 56 | query. 57 | qid_offset: An offset to increment all qids in this batch with. Only 58 | used if `qids` is None. 59 | q_prefix: A string prefix to add for query identifiers. 60 | d_prefix: A string prefix to add for doc identifiers. 61 | 62 | Returns: 63 | A tuple containing a qrel dict and a run dict. 64 | """ 65 | qrel = {} 66 | run = {} 67 | for i in range(scores.shape[0]): 68 | 69 | # Store under correct qid. 70 | if qids is not None: 71 | qid = "{:s}{:d}".format(q_prefix, int(qids[i])) 72 | else: 73 | qid = "{:s}{:d}".format(q_prefix, i + qid_offset) 74 | qrel[qid] = {} 75 | run[qid] = {} 76 | 77 | # Iterate documents and get relevance and scores. 78 | for d in range(n[i]): 79 | docid = "{:s}{:d}".format(d_prefix, d) 80 | qrel[qid][docid] = int(relevance[i, d]) 81 | run[qid][docid] = float(scores[i, d]) 82 | 83 | return qrel, run 84 | -------------------------------------------------------------------------------- /examples/01-basic-usage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # This script trains and evaluates a linear ranker on the Example3 toy dataset. 4 | # 5 | # Usage with expected output: 6 | # 7 | # $ PYTHONPATH=. python examples/01-basic-usage.py 8 | # [INFO, file] checking dataset files in './datasets/example3' 9 | # [WARNING, file] dataset file(s) in './datasets/example3' are missing or corrupt 10 | # [INFO, downloader] starting download from 'http://download.joachims.org/svm_light/examples/example3.tar.gz' to './datasets/example3/example3.tar.gz' 11 | # finished downloading [307B] 12 | # [INFO, file] extracting tar file at './datasets/example3/example3.tar.gz' to './datasets/example3' 13 | # [INFO, file] successfully checked all dataset files 14 | # [INFO, svmrank] loading svmrank dataset from ./datasets/example3/example3/train.dat 15 | # [INFO, file] checking dataset files in './datasets/example3' 16 | # [INFO, file] successfully checked all dataset files 17 | # [INFO, svmrank] loading svmrank dataset from ./datasets/example3/example3/test.dat 18 | # [INFO, 01-basic-usage] Test nDCG at start: 0.8617 19 | # [INFO, 01-basic-usage] Test nDCG after epoch 1: 0.8617 20 | # [INFO, 01-basic-usage] Test nDCG after epoch 2: 1.0000 21 | # [INFO, 01-basic-usage] Test nDCG after epoch 3: 1.0000 22 | 23 | import torch 24 | from pytorchltr.datasets import Example3 25 | from pytorchltr.evaluation import ndcg 26 | from pytorchltr.loss import PairwiseHingeLoss 27 | import logging 28 | 29 | 30 | # Setup logging 31 | logging.basicConfig( 32 | format="[%(levelname)s, %(module)s] %(message)s", 33 | level=logging.INFO) 34 | 35 | # Seed randomness 36 | torch.manual_seed(42) 37 | 38 | # Load the example3 toy dataset 39 | train = Example3(split="train") 40 | test = Example3(split="test") 41 | collate_fn = train.collate_fn() 42 | 43 | # Create model, optimizer and loss to optimize 44 | model = torch.nn.Linear(train[0].features.shape[1], 1) 45 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 46 | loss_fn = PairwiseHingeLoss() 47 | 48 | 49 | # Function to evaluate the model on the test split of the dataset 50 | def evaluate(): 51 | model.eval() 52 | loader = torch.utils.data.DataLoader(test, batch_size=2, shuffle=True, 53 | collate_fn=test.collate_fn()) 54 | ndcg_score = 0.0 55 | for batch in loader: 56 | xs, ys, n = batch.features, batch.relevance, batch.n 57 | ndcg_score += float(torch.sum(ndcg(model(xs), ys, n, k=10))) 58 | 59 | ndcg_score /= len(test) 60 | model.train() 61 | return ndcg_score 62 | 63 | 64 | logging.info("Test nDCG at start: %.4f" % evaluate()) 65 | 66 | # Train model for 3 epochs 67 | for epoch in range(3): 68 | loader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True, 69 | collate_fn=train.collate_fn()) 70 | for batch in loader: 71 | xs, ys, n = batch.features, batch.relevance, batch.n 72 | loss = loss_fn(model(xs), ys, n).mean() 73 | optimizer.zero_grad() 74 | loss.backward() 75 | optimizer.step() 76 | 77 | logging.info("Test nDCG after epoch %d: %.4f" % (epoch + 1, evaluate())) 78 | -------------------------------------------------------------------------------- /pytorchltr/datasets/svmrank/istella.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from pytorchltr.utils.downloader import DefaultDownloadProgress 5 | from pytorchltr.utils.downloader import Downloader 6 | from pytorchltr.utils.file import validate_and_download 7 | from pytorchltr.utils.file import extract_tar 8 | from pytorchltr.utils.file import dataset_dir 9 | from pytorchltr.datasets.svmrank.svmrank import SVMRankDataset 10 | 11 | 12 | class Istella(SVMRankDataset): 13 | """ 14 | Utility class for downloading and using the istella dataset: 15 | http://quickrank.isti.cnr.it/istella-dataset/. 16 | """ 17 | 18 | downloader = Downloader( 19 | url="http://library.istella.it/dataset/istella-letor.tar.gz", 20 | target="istella-letor.tar.gz", 21 | sha256_checksum="d45899d9a6a0e48afb250aac7ee3dc50e73e263687f15761d754515cd8284e0b", # noqa: E501 22 | progress_fn=DefaultDownloadProgress(), 23 | postprocess_fn=extract_tar) 24 | 25 | expected_files = [ 26 | {"path": "full/train.txt", "sha256": "8b08ca4c36281e408bfa026303461a6162398a1eec22fb79db0709c76bf6d189"}, # noqa: E501 27 | {"path": "full/test.txt", "sha256": "4b34523af8e030718f0216bcae0a35caee1e34e5198ff3e76c08a75db61aac65"}, # noqa: E501 28 | ] 29 | 30 | splits = { 31 | "train": "train.txt", 32 | "test": "test.txt" 33 | } 34 | 35 | def __init__(self, location: str = dataset_dir("istella"), 36 | split: str = "train", normalize: bool = True, 37 | filter_queries: Optional[bool] = None, download: bool = True, 38 | validate_checksums: bool = True): 39 | """ 40 | Args: 41 | location: Directory where the dataset is located. 42 | split: The data split to load ("train" or "test") 43 | normalize: Whether to perform query-level feature 44 | normalization. 45 | filter_queries: Whether to filter out queries that 46 | have no relevant items. If not given this will filter queries 47 | for the test set but not the train set. 48 | download: Whether to download the dataset if it does not 49 | exist. 50 | validate_checksums: Whether to validate the dataset files 51 | via sha256. 52 | """ 53 | # Check if specified split exists. 54 | if split not in Istella.splits.keys(): 55 | raise ValueError("unrecognized data split '%s'" % str(split)) 56 | 57 | # Validate dataset exists and is correct, or download it. 58 | validate_and_download( 59 | location=location, 60 | expected_files=Istella.expected_files, 61 | downloader=Istella.downloader if download else None, 62 | validate_checksums=validate_checksums) 63 | 64 | # Only filter queries on non-train splits. 65 | if filter_queries is None: 66 | filter_queries = False if split == "train" else True 67 | 68 | # Initialize the dataset. 69 | datafile = os.path.join(location, "full", Istella.splits[split]) 70 | super().__init__(file=datafile, sparse=False, normalize=normalize, 71 | filter_queries=filter_queries, zero_based="auto") 72 | -------------------------------------------------------------------------------- /pytorchltr/datasets/svmrank/istella_x.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from pytorchltr.utils.downloader import DefaultDownloadProgress 5 | from pytorchltr.utils.downloader import Downloader 6 | from pytorchltr.utils.file import validate_and_download 7 | from pytorchltr.utils.file import extract_tar 8 | from pytorchltr.utils.file import dataset_dir 9 | from pytorchltr.datasets.svmrank.svmrank import SVMRankDataset 10 | 11 | 12 | class IstellaX(SVMRankDataset): 13 | """ 14 | Utility class for downloading and using the istella-X dataset: 15 | http://quickrank.isti.cnr.it/istella-dataset/. 16 | """ 17 | 18 | downloader = Downloader( 19 | url="http://quickrank.isti.cnr.it/istella-datasets-mirror/istella-X.tar.gz", # noqa: E501 20 | target="istella-X.tar.gz", 21 | sha256_checksum="e67be60d1c6a68983a669e4de9df2c395914717b2017cb3b68a97eb89f2ea763", # noqa: E501 22 | progress_fn=DefaultDownloadProgress(), 23 | postprocess_fn=extract_tar) 24 | 25 | expected_files = [ 26 | {"path": "train.txt", "sha256": "710378b3536a2156ae0747f0683e2e811e992e7a12014408ff93fbb1ea5a3340"}, # noqa: E501 27 | {"path": "test.txt", "sha256": "c2b808bae0ccbc40df9519b6a3906a297c8207eb273304c22d506a8612fa9eef"}, # noqa: E501 28 | {"path": "vali.txt", "sha256": "9de516e2e0dfd0e0a29cf4b233df8e22f2b547f260f7881c7dc1c7237101ee7e"} # noqa: E501 29 | ] 30 | 31 | splits = { 32 | "train": "train.txt", 33 | "test": "test.txt", 34 | "vali": "vali.txt" 35 | } 36 | 37 | def __init__(self, location: str = dataset_dir("istella_x"), 38 | split: str = "train", normalize: bool = True, 39 | filter_queries: Optional[bool] = None, download: bool = True, 40 | validate_checksums: bool = True): 41 | """ 42 | Args: 43 | location: Directory where the dataset is located. 44 | split: The data split to load ("train", "test" or "vali") 45 | normalize: Whether to perform query-level feature 46 | normalization. 47 | filter_queries: Whether to filter out queries that 48 | have no relevant items. If not given this will filter queries 49 | for the test set but not the train set. 50 | download: Whether to download the dataset if it does not 51 | exist. 52 | validate_checksums: Whether to validate the dataset files 53 | via sha256. 54 | """ 55 | # Check if specified split exists. 56 | if split not in IstellaX.splits.keys(): 57 | raise ValueError("unrecognized data split '%s'" % str(split)) 58 | 59 | # Validate dataset exists and is correct, or download it. 60 | validate_and_download( 61 | location=location, 62 | expected_files=IstellaX.expected_files, 63 | downloader=IstellaX.downloader if download else None, 64 | validate_checksums=validate_checksums) 65 | 66 | # Only filter queries on non-train splits. 67 | if filter_queries is None: 68 | filter_queries = False if split == "train" else True 69 | 70 | # Initialize the dataset. 71 | datafile = os.path.join(location, IstellaX.splits[split]) 72 | super().__init__(file=datafile, sparse=False, normalize=normalize, 73 | filter_queries=filter_queries, zero_based="auto") 74 | -------------------------------------------------------------------------------- /pytorchltr/datasets/svmrank/example3.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from pytorchltr.utils.downloader import DefaultDownloadProgress 5 | from pytorchltr.utils.downloader import Downloader 6 | from pytorchltr.utils.file import validate_and_download 7 | from pytorchltr.utils.file import extract_tar 8 | from pytorchltr.utils.file import dataset_dir 9 | from pytorchltr.datasets.svmrank.svmrank import SVMRankDataset 10 | 11 | 12 | class Example3(SVMRankDataset): 13 | """ 14 | Utility class for loading and using the Example3 dataset: 15 | http://www.cs.cornell.edu/people/tj/svm_light/svm_rank.html 16 | 17 | This dataset is a very small toy sample which is useful as a sanity check 18 | for testing your code. 19 | """ 20 | 21 | downloader = Downloader( 22 | url="http://download.joachims.org/svm_light/examples/example3.tar.gz", 23 | target="example3.tar.gz", 24 | sha256_checksum="c46e97b66d3c9d5f37f7c3a2201aa2c4ea2a4e8a768f8794b10152c22648106b", # noqa: E501 25 | progress_fn=DefaultDownloadProgress(), 26 | postprocess_fn=extract_tar) 27 | 28 | expected_files = [ 29 | {"path": os.path.join("example3", "train.dat"), "sha256": "503aa66c6a1b1bb8a86b14e52163dcdb5bcffc017981afdff4cf026eacc592cf"}, # noqa: E501 30 | {"path": os.path.join("example3", "test.dat"), "sha256": "81aaac13dfc5180edce38a588cec80ee00b5d85662e00d1b7ac1d3f98242698e"} # noqa: E501 31 | ] 32 | 33 | splits = { 34 | "train": os.path.join("example3", "train.dat"), 35 | "test": os.path.join("example3", "test.dat") 36 | } 37 | 38 | def __init__(self, location: str = dataset_dir("example3"), 39 | split: str = "train", 40 | normalize: bool = True, filter_queries: Optional[bool] = None, 41 | download: bool = True, validate_checksums: bool = True): 42 | """ 43 | Args: 44 | location: Directory where the dataset is located. 45 | split: The data split to load ("train" or "test") 46 | normalize: Whether to perform query-level feature 47 | normalization. 48 | filter_queries: Whether to filter out queries that 49 | have no relevant items. If not given this will filter queries 50 | for the test set but not the train set. 51 | download: Whether to download the dataset if it does not 52 | exist. 53 | validate_checksums: Whether to validate the dataset files 54 | via sha256. 55 | """ 56 | # Check if specified split exists. 57 | if split not in Example3.splits.keys(): 58 | raise ValueError("unrecognized data split '%s'" % split) 59 | 60 | # Validate dataset exists and is correct, or download it. 61 | validate_and_download( 62 | location=location, 63 | expected_files=Example3.expected_files, 64 | downloader=Example3.downloader if download else None, 65 | validate_checksums=validate_checksums) 66 | 67 | # Only filter queries on non-train splits. 68 | if filter_queries is None: 69 | filter_queries = False if split == "train" else True 70 | 71 | # Initialize the dataset. 72 | super().__init__(file=os.path.join(location, Example3.splits[split]), 73 | sparse=False, normalize=normalize, 74 | filter_queries=filter_queries, zero_based="auto") 75 | -------------------------------------------------------------------------------- /pytorchltr/datasets/svmrank/istella_s.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from pytorchltr.utils.downloader import DefaultDownloadProgress 5 | from pytorchltr.utils.downloader import Downloader 6 | from pytorchltr.utils.file import validate_and_download 7 | from pytorchltr.utils.file import extract_tar 8 | from pytorchltr.utils.file import dataset_dir 9 | from pytorchltr.datasets.svmrank.svmrank import SVMRankDataset 10 | 11 | 12 | class IstellaS(SVMRankDataset): 13 | """ 14 | Utility class for downloading and using the istella-s dataset: 15 | http://quickrank.isti.cnr.it/istella-dataset/. 16 | 17 | This dataset is a smaller sampled version of the Istella dataset. 18 | """ 19 | 20 | downloader = Downloader( 21 | url="http://library.istella.it/dataset/istella-s-letor.tar.gz", 22 | target="istella-s-letor.tar.gz", 23 | sha256_checksum="41b21116a3650cc043dbe16f02ee39f4467f9405b37fdbcc9a6a05e230a38981", # noqa: E501 24 | progress_fn=DefaultDownloadProgress(), 25 | postprocess_fn=extract_tar) 26 | 27 | expected_files = [ 28 | {"path": "sample/train.txt", "sha256": "5cda4187b88b597ca6fa98d6cd3a6d7551ba69ffff700eae63a59fbc8af385b8"}, # noqa: E501 29 | {"path": "sample/test.txt", "sha256": "2cd1a4f46fa21ea2489073979b5e5913146d1e451f1d6e268bc4e472d39da5d7"}, # noqa: E501 30 | {"path": "sample/vali.txt", "sha256": "f225dd95772fa65dc685351ff2643945fbd9a9c0e874aa1c538d485595e7c890"} # noqa: E501 31 | ] 32 | 33 | splits = { 34 | "train": "train.txt", 35 | "test": "test.txt", 36 | "vali": "vali.txt" 37 | } 38 | 39 | def __init__(self, location: str = dataset_dir("istella_s"), 40 | split: str = "train", normalize: bool = True, 41 | filter_queries: Optional[bool] = None, download: bool = True, 42 | validate_checksums: bool = True): 43 | """ 44 | Args: 45 | location: Directory where the dataset is located. 46 | split: The data split to load ("train", "test" or "vali") 47 | normalize: Whether to perform query-level feature 48 | normalization. 49 | filter_queries: Whether to filter out queries that 50 | have no relevant items. If not given this will filter queries 51 | for the test set but not the train set. 52 | download: Whether to download the dataset if it does not 53 | exist. 54 | validate_checksums: Whether to validate the dataset files 55 | via sha256. 56 | """ 57 | # Check if specified split exists. 58 | if split not in IstellaS.splits.keys(): 59 | raise ValueError("unrecognized data split '%s'" % str(split)) 60 | 61 | # Validate dataset exists and is correct, or download it. 62 | validate_and_download( 63 | location=location, 64 | expected_files=IstellaS.expected_files, 65 | downloader=IstellaS.downloader if download else None, 66 | validate_checksums=validate_checksums) 67 | 68 | # Only filter queries on non-train splits. 69 | if filter_queries is None: 70 | filter_queries = False if split == "train" else True 71 | 72 | # Initialize the dataset. 73 | datafile = os.path.join(location, "sample", IstellaS.splits[split]) 74 | super().__init__(file=datafile, sparse=False, normalize=normalize, 75 | filter_queries=filter_queries, zero_based="auto") 76 | -------------------------------------------------------------------------------- /docs/source/datasets.rst: -------------------------------------------------------------------------------- 1 | .. _datasets: 2 | 3 | Datasets 4 | ======== 5 | 6 | PyTorchLTR provides several LTR datasets utility classes that can be used to 7 | automatically process and/or download the dataset files. 8 | 9 | .. warning:: 10 | 11 | PyTorchLTR provides utilities to automatically download and prepare several 12 | public LTR datasets. We cannot vouch for the quality, correctness or 13 | usefulness of these datasets. We do not host or distribute any datasets and 14 | it is ultimately **your responsibility** to determine whether you have 15 | permission to use each dataset under its respective license. 16 | 17 | Example 18 | ------- 19 | 20 | The following is a usage example for the small Example3 dataset. 21 | 22 | .. code-block:: python 23 | 24 | >>> from pytorchltr.datasets import Example3 25 | >>> train = Example3(split="train") 26 | >>> test = Example3(split="test") 27 | >>> print(len(train)) 28 | 3 29 | >>> print(len(test)) 30 | 1 31 | >>> sample = train[0] 32 | >>> print(sample["features"]) 33 | tensor([[1.0000, 1.0000, 0.0000, 0.3333, 0.0000], 34 | [0.0000, 0.0000, 1.0000, 0.0000, 1.0000], 35 | [0.0000, 1.0000, 0.0000, 1.0000, 0.0000], 36 | [0.0000, 0.0000, 1.0000, 0.6667, 0.0000]]) 37 | >>> print(sample["relevance"]) 38 | tensor([3, 2, 1, 1]) 39 | >>> print(sample["n"]) 40 | 4 41 | 42 | .. note:: 43 | 44 | PyTorchLTR looks for dataset files in (and downloads them to) the following 45 | locations: 46 | 47 | * The :code:`location` arg if it is specified in the constructor of each 48 | respective Dataset class. 49 | * :code:`$PYTORCHLTR_DATASET_PATH/{dataset_name}` if 50 | :code:`$PYTORCHLTR_DATASET_PATH` is a defined environment variable. 51 | * :code:`$DATASET_PATH/{dataset_name}` if :code:`$DATASET_PATH` is a defined 52 | environment variable. 53 | * :code:`$HOME/.pytorchltr_datasets/{dataset_name}` if all the above fail. 54 | 55 | 56 | SVMRank datasets 57 | ---------------- 58 | 59 | Example3 60 | ^^^^^^^^ 61 | .. autoclass:: pytorchltr.datasets.Example3 62 | :members: 63 | 64 | .. automethod:: __init__ 65 | .. automethod:: collate_fn 66 | .. automethod:: __getitem__ 67 | .. automethod:: __len__ 68 | 69 | Istella 70 | ^^^^^^^^^ 71 | .. autoclass:: pytorchltr.datasets.Istella 72 | :members: 73 | 74 | .. automethod:: __init__ 75 | .. automethod:: collate_fn 76 | .. automethod:: __getitem__ 77 | .. automethod:: __len__ 78 | 79 | Istella-S 80 | ^^^^^^^^^ 81 | .. autoclass:: pytorchltr.datasets.IstellaS 82 | :members: 83 | 84 | .. automethod:: __init__ 85 | .. automethod:: collate_fn 86 | .. automethod:: __getitem__ 87 | .. automethod:: __len__ 88 | 89 | Istella-X 90 | ^^^^^^^^^ 91 | .. autoclass:: pytorchltr.datasets.IstellaX 92 | :members: 93 | 94 | .. automethod:: __init__ 95 | .. automethod:: collate_fn 96 | .. automethod:: __getitem__ 97 | .. automethod:: __len__ 98 | 99 | MSLR-WEB10K 100 | ^^^^^^^^^^^ 101 | .. autoclass:: pytorchltr.datasets.MSLR10K 102 | :members: 103 | 104 | .. automethod:: __init__ 105 | .. automethod:: collate_fn 106 | .. automethod:: __getitem__ 107 | .. automethod:: __len__ 108 | 109 | MSLR-WEB30K 110 | ^^^^^^^^^^^ 111 | .. autoclass:: pytorchltr.datasets.MSLR30K 112 | :members: 113 | 114 | .. automethod:: __init__ 115 | .. automethod:: collate_fn 116 | .. automethod:: __getitem__ 117 | .. automethod:: __len__ 118 | -------------------------------------------------------------------------------- /pytorchltr/evaluation/dcg.py: -------------------------------------------------------------------------------- 1 | """DCG and NDCG evaluation metrics.""" 2 | from typing import Optional 3 | 4 | import torch as _torch 5 | from pytorchltr.utils import rank_by_score as _rank_by_score 6 | 7 | 8 | def ndcg(scores: _torch.FloatTensor, relevance: _torch.LongTensor, 9 | n: _torch.LongTensor, k: Optional[int] = None, 10 | exp: Optional[bool] = True) -> _torch.FloatTensor: 11 | r"""Normalized Discounted Cumulative Gain (NDCG) 12 | 13 | .. math:: 14 | 15 | \text{ndcg}(\mathbf{s}, \mathbf{y}) 16 | = \frac{\text{dcg}(\mathbf{s}, \mathbf{y})} 17 | {\text{dcg}(\mathbf{y}, \mathbf{y})} 18 | 19 | Args: 20 | scores: A tensor of size (batch_size, list_size, 1) or 21 | (batch_size, list_size), indicating the scores per document per 22 | query. 23 | relevance: A tensor of size (batch_size, list_size), indicating the 24 | relevance judgements per document per query. 25 | n: A tensor of size (batch_size) indicating the number of docs per 26 | query. 27 | k: An integer indicating the cutoff for ndcg. 28 | exp: A boolean indicating whether to use the exponential notation of 29 | DCG. 30 | 31 | Returns: 32 | A tensor of size (batch_size, list_size) indicating the NDCG of each 33 | query at every rank. If k is not None, then this returns a tensor of 34 | size (batch_size), indicating the NDCG@k of each query. 35 | """ 36 | idcg = dcg(relevance.float(), relevance, n, k, exp) 37 | idcg[idcg == 0.0] = 1.0 38 | return dcg(scores, relevance, n, k, exp) / idcg 39 | 40 | 41 | def dcg(scores: _torch.FloatTensor, relevance: _torch.LongTensor, 42 | n: _torch.LongTensor, k: Optional[int] = None, 43 | exp: Optional[bool] = True) -> _torch.FloatTensor: 44 | r"""Discounted Cumulative Gain (DCG) 45 | 46 | .. math:: 47 | 48 | \text{dcg}(\mathbf{s}, \mathbf{y}) 49 | = \sum_{i=1}^n \frac{\text{gain}(y_{\pi_i})}{\log_2(1 + i)} 50 | 51 | where :math:`\pi_i` is the index of the item at rank :math:`i` after 52 | sorting the scores, and: 53 | 54 | .. math:: 55 | :nowrap: 56 | 57 | \[ 58 | \text{gain}(y_i) = \left\{ 59 | \begin{array}{ll} 60 | 2^{y_i} - 1 & \text{if } \texttt{exp=True} \\ 61 | y_i & \text{otherwise} 62 | \end{array} 63 | \right. 64 | \] 65 | 66 | 67 | Args: 68 | scores: A tensor of size (batch_size, list_size, 1) or 69 | (batch_size, list_size), indicating the scores per document per 70 | query. 71 | relevance: A tensor of size (batch_size, list_size), indicating the 72 | relevance judgements per document per query. 73 | n: A tensor of size (batch_size) indicating the number of docs per 74 | query. 75 | k: An integer indicating the cutoff for ndcg. 76 | exp: A boolean indicating whether to use the exponential notation of 77 | DCG. 78 | 79 | Returns: 80 | A tensor of size (batch_size, list_size) indicating the DCG of each 81 | query at every rank. If k is not None, then this returns a tensor of 82 | size (batch_size), indicating the DCG@k of each query. 83 | """ 84 | # Compute relevance per rank 85 | rel_sort = _torch.gather(relevance, 1, _rank_by_score(scores, n)).float() 86 | arange = _torch.repeat_interleave( 87 | _torch.arange(scores.shape[1], dtype=_torch.float, 88 | device=scores.device).reshape( 89 | (1, scores.shape[1])), 90 | scores.shape[0], dim=0) 91 | if exp: 92 | rel_sort = (2.0 ** rel_sort - 1.0) 93 | per_rank_dcg = rel_sort / _torch.log2(arange + 2.0) 94 | dcg = _torch.cumsum(per_rank_dcg, dim=1) 95 | 96 | # Do cutoff at k (or return all dcg@k as an array) 97 | if k is not None: 98 | dcg = dcg[:, :k][:, -1] 99 | return dcg 100 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | from pybtex.style.formatting.unsrt import Style as UnsrtStyle 17 | from pybtex.style.labels import BaseLabelStyle 18 | from pybtex.plugin import register_plugin 19 | from collections import Counter 20 | 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = 'pytorchltr' 25 | copyright = '2020, Rolf Jagerman' 26 | author = 'Rolf Jagerman' 27 | 28 | 29 | # -- General configuration --------------------------------------------------- 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = [ 35 | 'sphinx.ext.mathjax', 36 | 'sphinx.ext.githubpages', 37 | 'sphinx.ext.autodoc', 38 | 'sphinx.ext.doctest', 39 | 'sphinx.ext.napoleon', 40 | 'sphinx_rtd_theme', 41 | 'sphinxcontrib.bibtex', 42 | 'sphinx_autodoc_typehints', 43 | ] 44 | 45 | # Add any paths that contain templates here, relative to this directory. 46 | templates_path = ['_templates'] 47 | 48 | 49 | # -- Options for HTML output ------------------------------------------------- 50 | 51 | # The theme to use for HTML and HTML Help pages. See the documentation for 52 | # a list of builtin themes. 53 | # 54 | html_theme = 'sphinx_rtd_theme' 55 | 56 | # Add any paths that contain custom static files (such as style sheets) here, 57 | # relative to this directory. They are copied after the builtin static files, 58 | # so a file named "default.css" will overwrite the builtin "default.css". 59 | html_static_path = [] 60 | 61 | 62 | # -- Extension configuration ------------------------------------------------- 63 | 64 | class AuthorYearLabelStyle(BaseLabelStyle): 65 | def format_labels(self, sorted_entries): 66 | # First generate processed entries 67 | processed_entries = [] 68 | for entry in sorted_entries: 69 | year = "" 70 | if "year" in entry.fields: 71 | year = entry.fields["year"] 72 | if len(year) == 4 and year.startswith("20"): 73 | year = year[-2:] 74 | authors = "" 75 | author_limit = 3 76 | for author in entry.persons["author"][:author_limit]: 77 | authors += author.last_names[0][0] 78 | if len(entry.persons["author"]) > author_limit: 79 | authors += "+" 80 | processed_entries.append("%s%s" % (authors, year)) 81 | 82 | # Mark duplicates with incremental alphabet 83 | counts = Counter(processed_entries) 84 | marked = Counter() 85 | out_entries = [] 86 | for entry in processed_entries: 87 | if counts[entry] > 1: 88 | marked[entry] += 1 89 | entry += _number_to_alphabet(marked[entry] - 1) 90 | out_entries.append(entry) 91 | 92 | return out_entries 93 | 94 | 95 | def _number_to_alphabet(number): 96 | out = chr((number % 26) + 97) 97 | if number >= 26: 98 | out = _number_to_alphabet(number // 26 - 1) + out 99 | return out 100 | 101 | 102 | class AuthorYearStyle(UnsrtStyle): 103 | default_label_style = AuthorYearLabelStyle 104 | 105 | 106 | register_plugin('pybtex.style.formatting', 'authoryearstyle', AuthorYearStyle) 107 | -------------------------------------------------------------------------------- /docs/source/references.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{wang2018lambdaloss, 2 | author = {Wang, Xuanhui and Li, Cheng and Golbandi, Nadav and Bendersky, Michael and Najork, Marc}, 3 | title = {The LambdaLoss Framework for Ranking Metric Optimization}, 4 | year = {2018}, 5 | isbn = {9781450360142}, 6 | publisher = {Association for Computing Machinery}, 7 | address = {New York, NY, USA}, 8 | doi = {10.1145/3269206.3271784}, 9 | booktitle = {Proceedings of the 27th ACM International Conference on Information and Knowledge Management}, 10 | pages = {1313–1322}, 11 | numpages = {10}, 12 | keywords = {lambdarank, ranking metric optimization, lambdaloss}, 13 | location = {Torino, Italy}, 14 | series = {CIKM ’18} 15 | } 16 | 17 | @inproceedings{agarwal2019general, 18 | author = {Agarwal, Aman and Takatsu, Kenta and Zaitsev, Ivan and Joachims, Thorsten}, 19 | title = {A General Framework for Counterfactual Learning-to-Rank}, 20 | year = {2019}, 21 | isbn = {9781450361729}, 22 | publisher = {Association for Computing Machinery}, 23 | address = {New York, NY, USA}, 24 | doi = {10.1145/3331184.3331202}, 25 | booktitle = {Proceedings of the 42nd International ACM SIGIR Conference on Research and Development in Information Retrieval}, 26 | pages = {5–14}, 27 | numpages = {10}, 28 | keywords = {learning to rank, presentation bias, counterfactual inference}, 29 | location = {Paris, France}, 30 | series = {SIGIR’19} 31 | } 32 | 33 | @inproceedings{joachims2002optimizing, 34 | author = {Joachims, Thorsten}, 35 | title = {Optimizing Search Engines Using Clickthrough Data}, 36 | year = {2002}, 37 | isbn = {158113567X}, 38 | publisher = {Association for Computing Machinery}, 39 | address = {New York, NY, USA}, 40 | doi = {10.1145/775047.775067}, 41 | booktitle = {Proceedings of the Eighth ACM SIGKDD International Conference on Knowledge Discovery and Data Mining}, 42 | pages = {133–142}, 43 | numpages = {10}, 44 | location = {Edmonton, Alberta, Canada}, 45 | series = {KDD ’02} 46 | } 47 | 48 | @article{kalervo2002cumulated, 49 | author = {J\"{a}rvelin, Kalervo and Kek\"{a}l\"{a}inen, Jaana}, 50 | title = {Cumulated Gain-Based Evaluation of IR Techniques}, 51 | year = {2002}, 52 | issue_date = {October 2002}, 53 | publisher = {Association for Computing Machinery}, 54 | address = {New York, NY, USA}, 55 | volume = {20}, 56 | number = {4}, 57 | issn = {1046-8188}, 58 | doi = {10.1145/582415.582418}, 59 | journal = {ACM Trans. Inf. Syst.}, 60 | month = oct, 61 | pages = {422–446}, 62 | numpages = {25}, 63 | keywords = {cumulated gain, Graded relevance judgments} 64 | } 65 | 66 | @inproceedings{joachims2017unbiased, 67 | author = {Joachims, Thorsten and Swaminathan, Adith and Schnabel, Tobias}, 68 | title = {Unbiased Learning-to-Rank with Biased Feedback}, 69 | year = {2017}, 70 | isbn = {9781450346757}, 71 | publisher = {Association for Computing Machinery}, 72 | address = {New York, NY, USA}, 73 | doi = {10.1145/3018661.3018699}, 74 | booktitle = {Proceedings of the Tenth ACM International Conference on Web Search and Data Mining}, 75 | pages = {781–789}, 76 | numpages = {9}, 77 | keywords = {implicit feedback, learning to rank, click models, ranking svm, propensity weighting}, 78 | location = {Cambridge, United Kingdom}, 79 | series = {WSDM ’17} 80 | } 81 | 82 | @inproceedings{gysel2018pytreceval, 83 | author = {Van Gysel, Christophe and de Rijke, Maarten}, 84 | title = {Pytrec_eval: An Extremely Fast Python Interface to Trec_eval}, 85 | year = {2018}, 86 | isbn = {9781450356572}, 87 | publisher = {Association for Computing Machinery}, 88 | address = {New York, NY, USA}, 89 | doi = {10.1145/3209978.3210065}, 90 | booktitle = {The 41st International ACM SIGIR Conference on Research & Development in Information Retrieval}, 91 | pages = {873–876}, 92 | numpages = {4}, 93 | keywords = {ir evaluation, toolkits}, 94 | location = {Ann Arbor, MI, USA}, 95 | series = {SIGIR ’18} 96 | } 97 | 98 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy-sdist: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: '3.6' 18 | architecture: 'x64' 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install --upgrade setuptools wheel twine 23 | pip install -r requirements.txt 24 | pip install -r dev-requirements.txt 25 | python setup.py build_ext 26 | - name: Build source wheel 27 | run: | 28 | python setup.py sdist 29 | - name: Upload wheel 30 | env: 31 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 32 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 33 | run: | 34 | twine upload dist/* 35 | 36 | deploy-macos: 37 | 38 | strategy: 39 | fail-fast: false 40 | matrix: 41 | python-version: [ '3.5', '3.6', '3.7', '3.8' ] 42 | 43 | runs-on: macos-latest 44 | 45 | steps: 46 | - uses: actions/checkout@v2 47 | - name: Set up Python 48 | uses: actions/setup-python@v2 49 | with: 50 | python-version: ${{ matrix.python-version }} 51 | architecture: 'x64' 52 | - name: Install dependencies 53 | run: | 54 | python -m pip install --upgrade pip 55 | pip install --upgrade setuptools wheel twine 56 | pip install -r requirements.txt 57 | pip install -r dev-requirements.txt 58 | python setup.py build_ext 59 | - name: Build macos wheel 60 | run: | 61 | python setup.py bdist_wheel 62 | - name: Upload wheel 63 | env: 64 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 65 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 66 | run: | 67 | twine upload dist/* 68 | 69 | deploy-windows: 70 | 71 | strategy: 72 | fail-fast: false 73 | matrix: 74 | python-version: [ '3.5', '3.6', '3.7', '3.8' ] 75 | 76 | runs-on: windows-latest 77 | 78 | steps: 79 | - uses: actions/checkout@v2 80 | - name: Set up Python 81 | uses: actions/setup-python@v2 82 | with: 83 | python-version: ${{ matrix.python-version }} 84 | architecture: 'x64' 85 | - name: Install dependencies 86 | run: | 87 | python -m pip install --upgrade pip 88 | pip install --upgrade setuptools wheel twine 89 | pip install torch==1.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 90 | pip install -r requirements.txt 91 | pip install -r dev-requirements.txt 92 | python setup.py build_ext 93 | - name: Build windows wheel 94 | run: | 95 | python setup.py bdist_wheel 96 | - name: Upload wheel 97 | env: 98 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 99 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 100 | run: | 101 | twine upload dist/* 102 | 103 | deploy-manylinux: 104 | 105 | strategy: 106 | fail-fast: false 107 | matrix: 108 | python-version: [ 'cp35-cp35m', 'cp36-cp36m', 'cp37-cp37m', 'cp38-cp38' ] 109 | 110 | runs-on: ubuntu-latest 111 | container: quay.io/pypa/manylinux1_x86_64 112 | 113 | steps: 114 | - uses: actions/checkout@v1 115 | - name: Install dependencies 116 | run: | 117 | /opt/python/${{ matrix.python-version }}/bin/python -m pip install --upgrade pip setuptools wheel twine 118 | /opt/python/${{ matrix.python-version }}/bin/python -m pip install -r requirements.txt 119 | /opt/python/${{ matrix.python-version }}/bin/python -m pip install -r dev-requirements.txt 120 | - name: Build manylinux wheel 121 | env: 122 | CFLAGS: -std=c99 123 | run: | 124 | /opt/python/${{ matrix.python-version }}/bin/python setup.py build_ext 125 | /opt/python/${{ matrix.python-version }}/bin/pip wheel . --no-deps -w wheelhouse/ 126 | for whl in wheelhouse/*.whl; do if auditwheel show "$whl"; then auditwheel repair "$whl" --plat "manylinux1_x86_64" -w ./dist/; fi; done 127 | - name: Upload wheel 128 | env: 129 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 130 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 131 | run: | 132 | /opt/python/${{ matrix.python-version }}/bin/twine upload dist/* 133 | -------------------------------------------------------------------------------- /tests/evaluation/test_dcg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorchltr.evaluation import dcg 3 | from pytorchltr.evaluation import ndcg 4 | from pytest import approx 5 | 6 | 7 | def _generate_data(): 8 | scores = torch.FloatTensor([ 9 | [10.0, 5.0, 2.0, 3.0, 4.0], 10 | [5.0, 6.0, 4.0, 2.0, 5.5] 11 | ]) 12 | ys = torch.LongTensor([ 13 | [0, 1, 1, 0, 1], 14 | [3, 1, 0, 1, 0] 15 | ]) 16 | n = torch.LongTensor([5, 4]) 17 | return scores, ys, n 18 | 19 | 20 | def test_dcg3(): 21 | torch.manual_seed(42) 22 | scores, ys, n = _generate_data() 23 | out = dcg(scores, ys, n, k=3, exp=True) 24 | 25 | expected = torch.FloatTensor([ 26 | 1.1309297535714575, 27 | 5.4165082750002025]) 28 | 29 | assert out.numpy() == approx(expected.numpy()) 30 | 31 | 32 | def test_ndcg3(): 33 | torch.manual_seed(42) 34 | scores, ys, n = _generate_data() 35 | out = ndcg(scores, ys, n, k=3, exp=True) 36 | 37 | expected = torch.FloatTensor([ 38 | 1.1309297535714575 / 2.1309297535714578, 39 | 5.4165082750002025 / 8.130929753571458]) 40 | 41 | assert out.numpy() == approx(expected.numpy()) 42 | 43 | 44 | def test_dcg5_exp(): 45 | torch.manual_seed(42) 46 | scores, ys, n = _generate_data() 47 | out = dcg(scores, ys, n, k=5, exp=True) 48 | 49 | expected = torch.FloatTensor([ 50 | 1.5177825608059992, 51 | 5.847184833073595]) 52 | 53 | assert out.numpy() == approx(expected.numpy()) 54 | 55 | 56 | def test_ndcg5_exp(): 57 | torch.manual_seed(42) 58 | scores, ys, n = _generate_data() 59 | out = ndcg(scores, ys, n, k=5, exp=True) 60 | 61 | expected = torch.FloatTensor([ 62 | 1.5177825608059992 / 2.1309297535714578, 63 | 5.847184833073595 / 8.130929753571458]) 64 | 65 | assert out.numpy() == approx(expected.numpy()) 66 | 67 | 68 | def test_dcg5_nonexp(): 69 | torch.manual_seed(42) 70 | scores, ys, n = _generate_data() 71 | out = dcg(scores, ys, n, k=5, exp=False) 72 | 73 | expected = torch.FloatTensor([ 74 | 1.5177825608059992, 75 | 3.3234658187877653]) 76 | 77 | assert out.numpy() == approx(expected.numpy()) 78 | 79 | 80 | def test_ndcg5_nonexp(): 81 | torch.manual_seed(42) 82 | scores, ys, n = _generate_data() 83 | out = ndcg(scores, ys, n, k=5, exp=False) 84 | 85 | expected = torch.FloatTensor([ 86 | 1.5177825608059992 / 2.1309297535714578, 87 | 3.3234658187877653 / 4.130929753571458]) 88 | 89 | assert out.numpy() == approx(expected.numpy()) 90 | 91 | 92 | def test_dcg_all_relevant(): 93 | torch.manual_seed(42) 94 | scores = torch.FloatTensor([ 95 | [10.0, 5.0, 2.0, 3.0, 4.0], 96 | [5.0, 6.0, 4.0, 2.0, 5.5] 97 | ]) 98 | ys = torch.LongTensor([ 99 | [1, 1, 1, 1, 1], 100 | [1, 1, 1, 1, 1] 101 | ]) 102 | n = torch.LongTensor([5, 4]) 103 | out = dcg(scores, ys, n, k=5, exp=False) 104 | 105 | expected = torch.sum(torch.repeat_interleave( 106 | 1.0 / torch.log2(2.0 + torch.arange(5, dtype=torch.float))[None, :], 107 | 2, dim=0), dim=1) 108 | 109 | assert out.numpy() == approx(expected.numpy()) 110 | 111 | 112 | def test_ndcg_all_relevant(): 113 | torch.manual_seed(42) 114 | scores = torch.FloatTensor([ 115 | [10.0, 5.0, 2.0, 3.0, 4.0], 116 | [5.0, 6.0, 4.0, 2.0, 5.5] 117 | ]) 118 | ys = torch.LongTensor([ 119 | [1, 1, 1, 1, 1], 120 | [1, 1, 1, 1, 1] 121 | ]) 122 | n = torch.LongTensor([5, 4]) 123 | out = ndcg(scores, ys, n, k=5, exp=False) 124 | expected = torch.ones(2) 125 | 126 | assert out.numpy() == approx(expected.numpy()) 127 | 128 | 129 | def test_dcg_no_relevant(): 130 | torch.manual_seed(42) 131 | scores = torch.FloatTensor([ 132 | [10.0, 5.0, 2.0, 3.0, 4.0], 133 | [5.0, 6.0, 4.0, 2.0, 5.5] 134 | ]) 135 | ys = torch.LongTensor([ 136 | [0, 0, 0, 0, 0], 137 | [0, 0, 0, 0, 0] 138 | ]) 139 | n = torch.LongTensor([5, 4]) 140 | 141 | out = dcg(scores, ys, n, k=5, exp=False) 142 | expected = torch.zeros(2) 143 | 144 | assert out.numpy() == approx(expected.numpy()) 145 | 146 | 147 | def test_ndcg_no_relevant(): 148 | torch.manual_seed(42) 149 | scores = torch.FloatTensor([ 150 | [10.0, 5.0, 2.0, 3.0, 4.0], 151 | [5.0, 6.0, 4.0, 2.0, 5.5] 152 | ]) 153 | ys = torch.LongTensor([ 154 | [0, 0, 0, 0, 0], 155 | [0, 0, 0, 0, 0] 156 | ]) 157 | n = torch.LongTensor([5, 4]) 158 | 159 | out = ndcg(scores, ys, n, k=5, exp=False) 160 | expected = torch.zeros(2) 161 | 162 | assert out.numpy() == approx(expected.numpy()) 163 | -------------------------------------------------------------------------------- /tests/utils/test_tensor_operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorchltr.utils.tensor_operations import rank_by_plackettluce 3 | from pytest import approx 4 | 5 | 6 | def repeat_rank_stats(fn, runs=100): 7 | """Repeats a ranking function multiple times and records how often each 8 | item gets placed at each rank. 9 | 10 | Args: 11 | fn: The function that generates rankings. 12 | runs: How many runs to repeat. 13 | """ 14 | shape = fn().shape 15 | out = torch.zeros((shape[0], shape[1], shape[1])) 16 | for i in range(runs): 17 | ranking = fn() 18 | for batch in range(shape[0]): 19 | for j, r_j in enumerate(ranking[batch, :]): 20 | out[batch, j, r_j] += 1 21 | return out / runs 22 | 23 | 24 | def test_plackettluce_3d_input(): 25 | torch.manual_seed(42) 26 | scores = torch.FloatTensor([[[5.0], [3.0], [2.0]]]) 27 | n = torch.IntTensor([3]) 28 | fn = lambda: rank_by_plackettluce(scores, n) # noqa: E731 29 | out = repeat_rank_stats(fn, runs=100) 30 | expected = torch.nn.Softmax(dim=1)(scores.reshape((1, 3))) 31 | 32 | # Assert that rank 1 has a proportional amount of each doc 33 | out = out[0, 0, :].numpy() 34 | expected = expected[0, :].numpy() 35 | assert out == approx(expected, abs=0.1) 36 | 37 | 38 | def test_plackettluce_rank_1(): 39 | torch.manual_seed(42) 40 | scores = torch.FloatTensor([[5.0, 3.0, 2.0, 1.0]]) 41 | n = torch.IntTensor([4]) 42 | fn = lambda: rank_by_plackettluce(scores, n) # noqa: E731 43 | out = repeat_rank_stats(fn, runs=100) 44 | expected = torch.nn.Softmax(dim=1)(scores) 45 | 46 | # Assert that rank 1 has a proportional amount of each doc 47 | out = out[0, 0, :].numpy() 48 | expected = expected[0, :].numpy() 49 | assert out == approx(expected, abs=0.1) 50 | 51 | 52 | def test_plackettluce_rank_2(): 53 | torch.manual_seed(42) 54 | scores = torch.FloatTensor([[5.0, 3.0, 2.0, 1.0]]) 55 | n = torch.IntTensor([4]) 56 | fn = lambda: rank_by_plackettluce(scores, n) # noqa: E731 57 | out = repeat_rank_stats(fn, runs=100) 58 | softmax = torch.nn.Softmax(dim=1)(scores) 59 | 60 | # Compute: 61 | # P(doc in rank 2) = 62 | # sum_{i} P(doc_i in rank 1) * P(doc in rank 2 | doc_i in rank 1) 63 | expected = torch.zeros(softmax.shape) 64 | for i in range(4): 65 | # P(doc_i in rank1): 66 | p_rank_1 = softmax[0, i] 67 | 68 | # P(doc in rank 2 | doc_i in rank 1): 69 | r_scores = scores.clone() 70 | r_scores[scores == scores[0, i]] = 0.0 71 | r_softmax = torch.exp(r_scores) / torch.sum(torch.exp(r_scores)) 72 | 73 | # P(doc in rank 2): 74 | expected += p_rank_1 * r_softmax 75 | 76 | # Assert that rank 2 has a proportional amount of each doc 77 | out = out[0, 1, :].numpy() 78 | expected = expected[0, :].numpy() 79 | assert out == approx(expected, abs=0.1) 80 | 81 | 82 | def test_plackettluce_place_padded_docs_last(): 83 | torch.manual_seed(42) 84 | scores = torch.FloatTensor([[5.0, 3.0, 2.0, 1.0, 10.0]]) 85 | n = torch.IntTensor([4]) 86 | fn = lambda: rank_by_plackettluce(scores, n) # noqa: E731 87 | out = repeat_rank_stats(fn, runs=100) 88 | 89 | # Assert that 4th doc is placed last always 90 | assert out[0, 4, 0].numpy() == approx(0.0) 91 | assert out[0, 4, 1].numpy() == approx(0.0) 92 | assert out[0, 4, 2].numpy() == approx(0.0) 93 | assert out[0, 4, 3].numpy() == approx(0.0) 94 | assert out[0, 4, 4].numpy() == approx(1.0) 95 | 96 | 97 | def test_plackettluce_batch(): 98 | torch.manual_seed(42) 99 | scores = torch.FloatTensor([[5.0, 3.0, 2.0, 1.0], 100 | [10.0, 3.0, 10.0, 100.0]]) 101 | n = torch.IntTensor([4, 4]) 102 | fn = lambda: rank_by_plackettluce(scores, n) # noqa: E731 103 | out = repeat_rank_stats(fn, runs=100) 104 | expected = torch.nn.Softmax(dim=1)(scores) 105 | 106 | # Assert that both rows in the batch are correct 107 | out_1 = out[0, 0, :].numpy() 108 | expected_1 = expected[0, :].numpy() 109 | assert out_1 == approx(expected_1, abs=0.1) 110 | 111 | out_2 = out[1, 0, :].numpy() 112 | expected_2 = expected[1, :].numpy() 113 | assert out_2 == approx(expected_2, abs=0.1) 114 | 115 | 116 | def test_plackettluce_negative_input(): 117 | torch.manual_seed(42) 118 | scores = torch.FloatTensor([[-1.0, -2.0, 0.0]]) 119 | n = torch.IntTensor([3]) 120 | fn = lambda: rank_by_plackettluce(scores, n) # noqa: E731 121 | out = repeat_rank_stats(fn, runs=100) 122 | expected = torch.nn.Softmax(dim=1)(scores) 123 | 124 | # Assert that rank 1 has a proportional amount of each doc 125 | out = out[0, 0, :].numpy() 126 | expected = expected[0, :].numpy() 127 | assert out == approx(expected, abs=0.1) 128 | -------------------------------------------------------------------------------- /pytorchltr/utils/tensor_operations.py: -------------------------------------------------------------------------------- 1 | """Common utils for the library.""" 2 | from typing import Optional 3 | import torch as _torch 4 | 5 | 6 | def mask_padded_values(xs: _torch.FloatTensor, n: _torch.LongTensor, 7 | mask_value: float = -float('inf'), 8 | mutate: bool = False): 9 | """Turns padded values into given mask value. 10 | 11 | Args: 12 | xs: A tensor of size (batch_size, list_size, 1) containing padded 13 | values. 14 | n: A tensor of size (batch_size) containing list size of each query. 15 | mask_value: The value to mask with (default: -inf). 16 | mutate: Whether to mutate the values of xs or return a copy. 17 | """ 18 | mask = _torch.repeat_interleave( 19 | _torch.arange(xs.shape[1], device=xs.device).reshape((1, xs.shape[1])), 20 | xs.shape[0], dim=0) 21 | n_mask = _torch.repeat_interleave( 22 | n.reshape((n.shape[0], 1)), xs.shape[1], dim=1) 23 | if not mutate: 24 | xs = xs.clone() 25 | xs[mask >= n_mask] = mask_value 26 | return xs 27 | 28 | 29 | def tiebreak_argsort( 30 | x: _torch.FloatTensor, 31 | descending: bool = True, 32 | generator: Optional[_torch.Generator] = None) -> _torch.LongTensor: 33 | """Computes a per-row argsort of matrix x with random tiebreaks. 34 | 35 | Args: 36 | x: A 2D tensor where each row will be argsorted. 37 | descending: Whether to sort in descending order. 38 | 39 | Returns: 40 | A 2D tensor of the same size as x, where each row is the argsort of x, 41 | with ties broken randomly. 42 | """ 43 | rng_kwargs = {"generator": generator} if generator is not None else {} 44 | p = _torch.randperm(x.shape[1], device=x.device, **rng_kwargs) 45 | return p[_torch.argsort(x[:, p], descending=descending)] 46 | 47 | 48 | def rank_by_score( 49 | scores: _torch.FloatTensor, 50 | n: _torch.LongTensor, 51 | generator: Optional[_torch.Generator] = None) -> _torch.LongTensor: 52 | """Sorts scores in decreasing order. 53 | 54 | This method ensures that padded documents are placed last and ties are 55 | broken randomly. 56 | 57 | Args: 58 | scores: A tensor of size (batch_size, list_size, 1) or 59 | (batch_size, list_size) containing scores. 60 | n: A tensor of size (batch_size) containing list size of each query. 61 | """ 62 | if scores.dim() == 3: 63 | scores = scores.reshape((scores.shape[0], scores.shape[1])) 64 | return tiebreak_argsort(mask_padded_values(scores, n), generator=generator) 65 | 66 | 67 | def rank_by_plackettluce( 68 | scores: _torch.FloatTensor, n: _torch.LongTensor, 69 | generator: Optional[_torch.Generator] = None) -> _torch.LongTensor: 70 | """Samples a ranking from a plackett luce distribution. 71 | 72 | This method ensures that padded documents are placed last. 73 | 74 | Args: 75 | scores: A tensor of size (batch_size, list_size, 1) or 76 | (batch_size, list_size) containing scores. 77 | n: A tensor of size (batch_size) containing list size of each query. 78 | """ 79 | if scores.dim() == 3: 80 | scores = scores.reshape((scores.shape[0], scores.shape[1])) 81 | masked_scores = mask_padded_values(scores, n) 82 | 83 | # This implementation uses reservoir sampling, which comes down to doing 84 | # Uniform(0, 1) ^ (1 / p) and then sorting by the resulting values. The 85 | # following implementation is a numerically stable variant that operates in 86 | # log-space. 87 | log_p = _torch.nn.LogSoftmax(dim=1)(masked_scores) 88 | rng_kwargs = {"generator": generator} if generator is not None else {} 89 | u = _torch.rand(log_p.shape, device=scores.device, **rng_kwargs) 90 | r = _torch.log(-_torch.log(u)) - log_p 91 | return tiebreak_argsort(r, descending=False, generator=generator) 92 | 93 | 94 | def batch_pairs(x: _torch.Tensor) -> _torch.Tensor: 95 | """Returns a pair matrix 96 | 97 | This matrix contains all pairs (i, j) as follows: 98 | p[_, i, j, 0] = x[_, i] 99 | p[_, i, j, 1] = x[_, j] 100 | 101 | Args: 102 | x: The input batch of dimension (batch_size, list_size) or 103 | (batch_size, list_size, 1). 104 | 105 | Returns: 106 | Two tensors of size (batch_size, list_size ^ 2, 2) containing 107 | all pairs. 108 | """ 109 | 110 | if x.dim() == 2: 111 | x = x.reshape((x.shape[0], x.shape[1], 1)) 112 | 113 | # Construct broadcasted x_{:,i,0...list_size} 114 | x_ij = _torch.repeat_interleave(x, x.shape[1], dim=2) 115 | 116 | # Construct broadcasted x_{:,0...list_size,i} 117 | x_ji = _torch.repeat_interleave(x.permute(0, 2, 1), x.shape[1], dim=1) 118 | 119 | return _torch.stack([x_ij, x_ji], dim=3) 120 | -------------------------------------------------------------------------------- /pytorchltr/datasets/svmrank/mslr30k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from pytorchltr.utils.downloader import DefaultDownloadProgress 5 | from pytorchltr.utils.downloader import Downloader 6 | from pytorchltr.utils.file import validate_and_download 7 | from pytorchltr.utils.file import extract_zip 8 | from pytorchltr.utils.file import dataset_dir 9 | from pytorchltr.datasets.svmrank.svmrank import SVMRankDataset 10 | 11 | 12 | class MSLR30K(SVMRankDataset): 13 | """ 14 | Utility class for downloading and using the MSLR-WEB30K dataset: 15 | https://www.microsoft.com/en-us/research/project/mslr/. 16 | """ 17 | 18 | downloader = Downloader( 19 | url="https://api.onedrive.com/v1.0/shares/s!AtsMfWUz5l8nbXGPBlwD1rnFdBY/root/content", # noqa: E501 20 | target="MSLR-WEB30K.zip", 21 | sha256_checksum="08cb7977e1d5cbdeb57a9a2537a0923dbca6d46a76db9a6afc69e043c85341ae", # noqa: E501 22 | progress_fn=DefaultDownloadProgress(), 23 | postprocess_fn=extract_zip) 24 | 25 | per_fold_expected_files = { 26 | 1: [ 27 | {"path": "Fold1/train.txt", "sha256": "40b8eee4d1221cf8205e81603441c1757dd024a3baac25a47756210b03c031d6"}, # noqa: E501 28 | {"path": "Fold1/test.txt", "sha256": "9a4668fd2615e6772d2e5c4d558d084b2daaf2405571eaf3e4d0526f4da096c7"}, # noqa: E501 29 | {"path": "Fold1/vali.txt", "sha256": "7647834b84c849a61e5cf3c999a2f72a4785613286fd972a5615e9fcb58f94d8"} # noqa: E501 30 | ], 31 | 2: [ 32 | {"path": "Fold2/train.txt", "sha256": "a6d12dc4cc8c2dd0743b58c49cad12bd0f6f1cbeda54b54a91e5b54e0b96e7ca"}, # noqa: E501 33 | {"path": "Fold2/test.txt", "sha256": "d192c023baebdae148d902d716aeebf3df2f2f1ce5aee12be4f0b8bb76a8c04a"}, # noqa: E501 34 | {"path": "Fold2/vali.txt", "sha256": "9a4668fd2615e6772d2e5c4d558d084b2daaf2405571eaf3e4d0526f4da096c7"} # noqa: E501 35 | ], 36 | 3: [ 37 | {"path": "Fold3/train.txt", "sha256": "3d0eb52a6702b2c48750a6de89e757cfac499c8ec1d38e3843cb531c059f2f74"}, # noqa: E501 38 | {"path": "Fold3/test.txt", "sha256": "ba2487a27c21dceea0b9afaadbb6b392d5f652eb0fb8649cbd201bb894c47c12"}, # noqa: E501 39 | {"path": "Fold3/vali.txt", "sha256": "d192c023baebdae148d902d716aeebf3df2f2f1ce5aee12be4f0b8bb76a8c04a"} # noqa: E501 40 | ], 41 | 4: [ 42 | {"path": "Fold4/train.txt", "sha256": "21e389656be3c2bfe92eb4fb898d2f30dc24990fc524aada7471b949f176778d"}, # noqa: E501 43 | {"path": "Fold4/test.txt", "sha256": "a7ba03d708ae6b21556a8a6859d52e37cc0dcb2cbff992c630ef6173bb02122a"}, # noqa: E501 44 | {"path": "Fold4/vali.txt", "sha256": "ba2487a27c21dceea0b9afaadbb6b392d5f652eb0fb8649cbd201bb894c47c12"} # noqa: E501 45 | ], 46 | 5: [ 47 | {"path": "Fold5/train.txt", "sha256": "a19da1799650b08c0dc3baaec6c590469ef1773148082e1a10b8504f2f5e9a8b"}, # noqa: E501 48 | {"path": "Fold5/test.txt", "sha256": "7647834b84c849a61e5cf3c999a2f72a4785613286fd972a5615e9fcb58f94d8"}, # noqa: E501 49 | {"path": "Fold5/vali.txt", "sha256": "a7ba03d708ae6b21556a8a6859d52e37cc0dcb2cbff992c630ef6173bb02122a"} # noqa: E501 50 | ] 51 | } 52 | 53 | splits = { 54 | "train": "train.txt", 55 | "test": "test.txt", 56 | "vali": "vali.txt" 57 | } 58 | 59 | def __init__(self, location: str = dataset_dir("MSLR30K"), 60 | split: str = "train", fold: int = 1, normalize: bool = True, 61 | filter_queries: Optional[bool] = None, download: bool = True, 62 | validate_checksums: bool = True): 63 | """ 64 | Args: 65 | location: Directory where the dataset is located. 66 | split: The data split to load ("train", "test" or "vali") 67 | fold: Which data fold to load (1...5) 68 | normalize: Whether to perform query-level feature 69 | normalization. 70 | filter_queries: Whether to filter out queries that 71 | have no relevant items. If not given this will filter queries 72 | for the test set but not the train set. 73 | download: Whether to download the dataset if it does not 74 | exist. 75 | validate_checksums: Whether to validate the dataset files 76 | via sha256. 77 | """ 78 | # Check if specified split and fold exists. 79 | if split not in MSLR30K.splits.keys(): 80 | raise ValueError("unrecognized data split '%s'" % str(split)) 81 | 82 | if fold not in MSLR30K.per_fold_expected_files.keys(): 83 | raise ValueError("unrecognized data fold '%s'" % str(fold)) 84 | 85 | # Validate dataset exists and is correct, or download it. 86 | validate_and_download( 87 | location=location, 88 | expected_files=MSLR30K.per_fold_expected_files[fold], 89 | downloader=MSLR30K.downloader if download else None, 90 | validate_checksums=validate_checksums) 91 | 92 | # Only filter queries on non-train splits. 93 | if filter_queries is None: 94 | filter_queries = False if split == "train" else True 95 | 96 | # Initialize the dataset. 97 | datafile = os.path.join(location, "Fold%d" % fold, 98 | MSLR30K.splits[split]) 99 | super().__init__(file=datafile, sparse=False, normalize=normalize, 100 | filter_queries=filter_queries, zero_based="auto") 101 | -------------------------------------------------------------------------------- /pytorchltr/datasets/svmrank/mslr10k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from pytorchltr.utils.downloader import DefaultDownloadProgress 5 | from pytorchltr.utils.downloader import Downloader 6 | from pytorchltr.utils.file import validate_and_download 7 | from pytorchltr.utils.file import extract_zip 8 | from pytorchltr.utils.file import dataset_dir 9 | from pytorchltr.datasets.svmrank.svmrank import SVMRankDataset 10 | 11 | 12 | class MSLR10K(SVMRankDataset): 13 | """ 14 | Utility class for downloading and using the MSLR-WEB10K dataset: 15 | https://www.microsoft.com/en-us/research/project/mslr/. 16 | 17 | This dataset is a smaller sampled version of the MSLR-WEB30K dataset. 18 | """ 19 | 20 | downloader = Downloader( 21 | url="https://api.onedrive.com/v1.0/shares/s!AtsMfWUz5l8nbOIoJ6Ks0bEMp78/root/content", # noqa: E501 22 | target="MSLR-WEB10K.zip", 23 | sha256_checksum="2902142ea33f18c59414f654212de5063033b707d5c3939556124b1120d3a0ba", # noqa: E501 24 | progress_fn=DefaultDownloadProgress(), 25 | postprocess_fn=extract_zip) 26 | 27 | per_fold_expected_files = { 28 | 1: [ 29 | {"path": "Fold1/train.txt", "sha256": "6eb3fae4e1186e1242a6520f53a98abdbcde5b926dd19a28e51239284b1d55dc"}, # noqa: E501 30 | {"path": "Fold1/test.txt", "sha256": "33fe002374a4fce58c4e12863e4eee74745d5672a26f3e4ddacc20ccfe7d6ba0"}, # noqa: E501 31 | {"path": "Fold1/vali.txt", "sha256": "e86fb3fe7e8a5f16479da7ce04f783ae85735f17f66016786c3ffc797dd9d4db"} # noqa: E501 32 | ], 33 | 2: [ 34 | {"path": "Fold2/train.txt", "sha256": "40e4a2fcc237d9c164cbb6a3f2fa91fe6cf7d46a419d2f73e21cf090285659eb"}, # noqa: E501 35 | {"path": "Fold2/test.txt", "sha256": "44add582ccd674cf63af24d3bf6e1074e87a678db77f00b44c37980a3010917a"}, # noqa: E501 36 | {"path": "Fold2/vali.txt", "sha256": "33fe002374a4fce58c4e12863e4eee74745d5672a26f3e4ddacc20ccfe7d6ba0"} # noqa: E501 37 | ], 38 | 3: [ 39 | {"path": "Fold3/train.txt", "sha256": "f13005ceb8de0db76c93b02ee4b2bded6f925097d3ab7938931e8d07aa72acd7"}, # noqa: E501 40 | {"path": "Fold3/test.txt", "sha256": "c0a5a3c6bd7790d0b4ff3d5e961d0c8c5f8ff149089ce492540fa63035801b7a"}, # noqa: E501 41 | {"path": "Fold3/vali.txt", "sha256": "44add582ccd674cf63af24d3bf6e1074e87a678db77f00b44c37980a3010917a"} # noqa: E501 42 | ], 43 | 4: [ 44 | {"path": "Fold4/train.txt", "sha256": "6c1677cf9b2ed491e26ac6b8c8ca7dfae9c1a375e2bce8cba6df36ab67ce5836"}, # noqa: E501 45 | {"path": "Fold4/test.txt", "sha256": "dc6083c24a5f0c03df3c91ad3eed7542694115b998acf046e51432cb7a22b848"}, # noqa: E501 46 | {"path": "Fold4/vali.txt", "sha256": "c0a5a3c6bd7790d0b4ff3d5e961d0c8c5f8ff149089ce492540fa63035801b7a"} # noqa: E501 47 | ], 48 | 5: [ 49 | {"path": "Fold5/train.txt", "sha256": "4249797a2f0f46bff279973f0fb055d4a78f67f337769eabd56e82332c044794"}, # noqa: E501 50 | {"path": "Fold5/test.txt", "sha256": "e86fb3fe7e8a5f16479da7ce04f783ae85735f17f66016786c3ffc797dd9d4db"}, # noqa: E501 51 | {"path": "Fold5/vali.txt", "sha256": "dc6083c24a5f0c03df3c91ad3eed7542694115b998acf046e51432cb7a22b848"} # noqa: E501 52 | ] 53 | } 54 | 55 | splits = { 56 | "train": "train.txt", 57 | "test": "test.txt", 58 | "vali": "vali.txt" 59 | } 60 | 61 | def __init__(self, location: str = dataset_dir("MSLR10K"), 62 | split: str = "train", fold: int = 1, normalize: bool = True, 63 | filter_queries: Optional[bool] = None, download: bool = True, 64 | validate_checksums: bool = True): 65 | """ 66 | Args: 67 | location: Directory where the dataset is located. 68 | split: The data split to load ("train", "test" or "vali") 69 | fold: Which data fold to load (1...5) 70 | normalize: Whether to perform query-level feature 71 | normalization. 72 | filter_queries: Whether to filter out queries that 73 | have no relevant items. If not given this will filter queries 74 | for the test set but not the train set. 75 | download: Whether to download the dataset if it does not 76 | exist. 77 | validate_checksums: Whether to validate the dataset files 78 | via sha256. 79 | """ 80 | # Check if specified split and fold exists. 81 | if split not in MSLR10K.splits.keys(): 82 | raise ValueError("unrecognized data split '%s'" % str(split)) 83 | 84 | if fold not in MSLR10K.per_fold_expected_files.keys(): 85 | raise ValueError("unrecognized data fold '%s'" % str(fold)) 86 | 87 | # Validate dataset exists and is correct, or download it. 88 | validate_and_download( 89 | location=location, 90 | expected_files=MSLR10K.per_fold_expected_files[fold], 91 | downloader=MSLR10K.downloader if download else None, 92 | validate_checksums=validate_checksums) 93 | 94 | # Only filter queries on non-train splits. 95 | if filter_queries is None: 96 | filter_queries = False if split == "train" else True 97 | 98 | # Initialize the dataset. 99 | datafile = os.path.join(location, "Fold%d" % fold, 100 | MSLR10K.splits[split]) 101 | super().__init__(file=datafile, sparse=False, normalize=normalize, 102 | filter_queries=filter_queries, zero_based="auto") 103 | -------------------------------------------------------------------------------- /docs/source/getting-started.rst: -------------------------------------------------------------------------------- 1 | Getting started guide 2 | ===================== 3 | 4 | This guide describes step-by-step instructions for running a full example of 5 | training a neural network on a LTR task. 6 | 7 | Data loading 8 | ------------ 9 | 10 | .. warning:: 11 | 12 | PyTorchLTR provides utilities to automatically download and prepare several 13 | public LTR datasets. We do not host or distribute these datasets and it is 14 | ultimately **your responsibility** to determine whether you have permission 15 | to use each dataset under its respective license. 16 | 17 | The first step is loading the dataset. For this guide we will use the 18 | MSLR-WEB10K dataset which is a learning to rank dataset containing 10,000 19 | queries split across a training (60%), validation (20%) and test (20%) split. 20 | We will use the first fold of the dataset. The following code will 21 | automatically download the dataset if it has not yet been downloaded. 22 | Downloading and processing the data can take a few minutes. 23 | 24 | .. code-block:: python 25 | 26 | >>> from pytorchltr.datasets import MSLR10K 27 | >>> train = MSLR10K(split="train", fold=1) 28 | >>> test = MSLR10K(split="test", fold=1) 29 | 30 | A complete list of available datasets is available :ref:`here `. 31 | 32 | Building a scoring function 33 | --------------------------- 34 | 35 | Next, we will set up a scoring function. For this guide we will use a simple 36 | feedforward neural network with ReLU activation functions. This network will, 37 | given a query-document feature vector, predict a score as output. 38 | 39 | .. code-block:: python 40 | 41 | >>> import torch 42 | >>> class Model(torch.nn.Module): 43 | >>> def __init__(self, in_features): 44 | >>> super().__init__() 45 | >>> self.l1 = torch.nn.Linear(in_features, 50) 46 | >>> self.l2 = torch.nn.Linear(50, 10) 47 | >>> self.l3 = torch.nn.Linear(10, 1) 48 | >>> def forward(self, x): 49 | >>> o1 = torch.nn.functional.relu(self.l1(x)) 50 | >>> o2 = torch.nn.functional.relu(self.l2(o1)) 51 | >>> return self.l3(o2) 52 | 53 | With our model class defined, we can now create an instance. We can use the 54 | train dataset to extract the dimensionality of the input feature vectors and 55 | create a model instance. 56 | 57 | .. code-block:: python 58 | 59 | >>> torch.manual_seed(42) 60 | >>> dimensionality = train[0].features.shape[1] 61 | >>> model = Model(dimensionality) 62 | 63 | Training the model 64 | ------------------ 65 | 66 | Next, we will train the model using a basic training loop. First we set up the 67 | loss function and optimizer. For this example we will use a simple 68 | pairwise hinge loss. More information about the available loss functions can be 69 | found :ref:`here `. 70 | 71 | .. code-block:: python 72 | 73 | >>> from pytorchltr.loss import PairwiseHingeLoss 74 | >>> optimizer = torch.optim.Adagrad(model.parameters(), lr=0.1) 75 | >>> loss_fn = PairwiseHingeLoss() 76 | 77 | Next, we will implement the actual training loop which will run for 20 epochs. 78 | We use a collate function with a maximum list size of 20 to truncate the list 79 | of each training instance to a maximum of 20 items. This will significantly 80 | speed up the training process for pairwise losses which have a computational 81 | complexity that is quadratic in the list size. 82 | 83 | .. code-block:: python 84 | 85 | >>> from pytorchltr.datasets.list_sampler import UniformSampler 86 | >>> for epoch in range(1, 21): 87 | >>> loader = torch.utils.data.DataLoader( 88 | >>> train, batch_size=16, shuffle=True, 89 | >>> collate_fn=train.collate_fn(UniformSampler(max_list_size=20))) 90 | >>> for batch in loader: 91 | >>> xs, ys, n = batch.features, batch.relevance, batch.n 92 | >>> loss = loss_fn(model(xs), ys, n).mean() 93 | >>> optimizer.zero_grad() 94 | >>> loss.backward() 95 | >>> optimizer.step() 96 | >>> print("Finished epoch %d" % epoch) 97 | Finished epoch 1 98 | Finished epoch 2 99 | Finished epoch 3 100 | Finished epoch 4 101 | Finished epoch 5 102 | Finished epoch 6 103 | Finished epoch 7 104 | Finished epoch 8 105 | Finished epoch 9 106 | Finished epoch 10 107 | Finished epoch 11 108 | Finished epoch 12 109 | Finished epoch 13 110 | Finished epoch 14 111 | Finished epoch 15 112 | Finished epoch 16 113 | Finished epoch 17 114 | Finished epoch 18 115 | Finished epoch 19 116 | Finished epoch 20 117 | 118 | Evaluating the trained model 119 | ---------------------------- 120 | 121 | Finally we will evaluate the model using :math:`ndcg@10` on the test set. To do 122 | so we iterate over the test set in batches and compute :math:`ndcg@10` on each 123 | batch. To compute the average :math:`ndcg@10` on the test set we take the sum 124 | of all scores and finally divide by the length of the test set. 125 | 126 | .. code-block:: python 127 | 128 | >>> from pytorchltr.evaluation import ndcg 129 | >>> loader = torch.utils.data.DataLoader( 130 | >>> test, batch_size=16, collate_fn=test.collate_fn()) 131 | >>> final_score = 0.0 132 | >>> for batch in loader: 133 | >>> xs, ys, n = batch.features, batch.relevance, batch.n 134 | >>> ndcg_score = ndcg(model(xs), ys, n, k=10) 135 | >>> final_score += float(torch.sum(ndcg_score)) 136 | 137 | >>> print("ndcg@10 on test set: %f" % (final_score / len(test))) 138 | ndcg@10 on test set: 0.445652 139 | 140 | Additional information about available evaluation metrics and how to integrate 141 | with :code:`pytrec_eval` can be found :ref:`here `. 142 | -------------------------------------------------------------------------------- /pytorchltr/click_simulation/pbm.py: -------------------------------------------------------------------------------- 1 | """Collection of PBM-based click simulators.""" 2 | from typing import Optional 3 | from typing import Tuple 4 | 5 | import torch as _torch 6 | from pytorchltr.utils import mask_padded_values as _mask_padded_values 7 | 8 | 9 | _SIM_RETURN_TYPE = Tuple[_torch.LongTensor, _torch.FloatTensor] 10 | 11 | 12 | def simulate_pbm(rankings: _torch.LongTensor, ys: _torch.LongTensor, 13 | n: _torch.LongTensor, relevance_probs: _torch.FloatTensor, 14 | cutoff: Optional[int] = None, 15 | eta: float = 1.0) -> _SIM_RETURN_TYPE: 16 | """Simulates clicks according to a position-biased user model. 17 | 18 | Args: 19 | rankings: A tensor of size (batch_size, list_size) of rankings. 20 | ys: A tensor of size (batch_size, list_size) of relevance labels. 21 | n: A tensor of size (batch_size) indicating the nr docs per query. 22 | relevance_prob: A tensor of size (max_relevance) where the entry at 23 | index "i" indicates the probability of clicking a document with 24 | relevance label "i" (given that it is observed). 25 | cutoff: The maximum list size to simulate. 26 | eta: The severity of position bias (0.0 = no bias) 27 | 28 | Returns: 29 | A tuple of two tensors of size (batch_size, list_size), where the first 30 | indicates the clicks with 0.0 and 1.0 and the second indicates the 31 | propensity of observing each document. 32 | """ 33 | # Cutoff at n for observation probabilities. 34 | if cutoff is not None: 35 | n = _torch.min(_torch.ones_like(n) * cutoff, n) 36 | 37 | # Compute position-biased observation probabilities. 38 | ranks = 1.0 + _torch.arange( 39 | rankings.shape[1], device=rankings.device, dtype=_torch.float) 40 | obs_probs = 1.0 / (1.0 + ranks) ** eta 41 | obs_probs = _torch.repeat_interleave( 42 | obs_probs[None, :], rankings.shape[0], dim=0) 43 | obs_probs = _mask_padded_values(obs_probs, n, mask_value=0.0, mutate=True) 44 | 45 | # Compute relevance labels at every rank. 46 | ranked_ys = _torch.gather(ys, 1, rankings) 47 | 48 | # Compute click probabilities (given observed). 49 | relevance_probs = _torch.repeat_interleave( 50 | relevance_probs[None, :], rankings.shape[0], dim=0) 51 | click_probs = _torch.gather(relevance_probs, 1, ranked_ys) 52 | 53 | # Sample clicks from bernoulli distribution with probabilities. 54 | clicks = _torch.bernoulli(click_probs * obs_probs) 55 | 56 | # Invert back to regular ranking. 57 | invert_ranking = _torch.argsort(rankings, dim=1) 58 | 59 | # Return click realization and propensities. 60 | return ( 61 | _torch.gather(clicks, 1, invert_ranking).to(dtype=_torch.long), 62 | _torch.gather(obs_probs, 1, invert_ranking) 63 | ) 64 | 65 | 66 | def simulate_perfect(rankings: _torch.LongTensor, ys: _torch.LongTensor, 67 | n: _torch.LongTensor, cutoff: Optional[int] = None): 68 | """Simulates clicks according to a perfect user model. 69 | 70 | Args: 71 | rankings: A tensor of size (batch_size, list_size) of rankings. 72 | ys: A tensor of size (batch_size, list_size) of relevance labels. 73 | n: A tensor of size (batch_size) indicating the nr docs per query. 74 | cutoff: The maximum list size to simulate. 75 | 76 | Returns: 77 | A tuple of two tensors of size (batch_size, list_size), where the first 78 | indicates the clicks with 0.0 and 1.0 and the second indicates the 79 | propensity of observing each document. 80 | """ 81 | rel_probs = _torch.FloatTensor( 82 | [0.0, 0.2, 0.4, 0.8, 1.0], device=rankings.device) 83 | return simulate_pbm(rankings, ys, n, rel_probs, cutoff, 0.0) 84 | 85 | 86 | def simulate_position(rankings: _torch.LongTensor, ys: _torch.LongTensor, 87 | n: _torch.LongTensor, cutoff: Optional[int] = None, 88 | eta: float = 1.0) -> _SIM_RETURN_TYPE: 89 | """Simulates clicks according to a binary position-biased user model. 90 | 91 | Args: 92 | rankings: A tensor of size (batch_size, list_size) of rankings. 93 | ys: A tensor of size (batch_size, list_size) of relevance labels. 94 | n: A tensor of size (batch_size) indicating the nr docs per query. 95 | cutoff: The maximum list size to simulate. 96 | eta: The severity of position bias (0.0 = no bias) 97 | 98 | Returns: 99 | A tuple of two tensors of size (batch_size, list_size), where the first 100 | indicates the clicks with 0.0 and 1.0 and the second indicates the 101 | propensity of observing each document. 102 | """ 103 | rel_probs = _torch.FloatTensor( 104 | [0.1, 0.1, 0.1, 1.0, 1.0], device=rankings.device) 105 | return simulate_pbm(rankings, ys, n, rel_probs, cutoff, eta) 106 | 107 | 108 | def simulate_nearrandom(rankings: _torch.LongTensor, ys: _torch.LongTensor, 109 | n: _torch.LongTensor, cutoff: Optional[int] = None, 110 | eta: float = 1.0) -> _SIM_RETURN_TYPE: 111 | """Simulates clicks according to a near-random user model. 112 | 113 | Args: 114 | rankings: A tensor of size (batch_size, list_size) of rankings. 115 | ys: A tensor of size (batch_size, list_size) of relevance labels. 116 | n: A tensor of size (batch_size) indicating the nr docs per query. 117 | cutoff: The maximum list size to simulate. 118 | eta: The severity of position bias (0.0 = no bias) 119 | 120 | Returns: 121 | A tuple of two tensors of size (batch_size, list_size), where the first 122 | indicates the clicks with 0.0 and 1.0 and the second indicates the 123 | propensity of observing each document. 124 | """ 125 | rel_probs = _torch.FloatTensor( 126 | [0.4, 0.45, 0.5, 0.55, 0.6], device=rankings.device) 127 | return simulate_pbm(rankings, ys, n, rel_probs, cutoff, eta) 128 | -------------------------------------------------------------------------------- /tests/loss/test_pairwise_additive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorchltr.loss import PairwiseHingeLoss 3 | from pytorchltr.loss import PairwiseDCGHingeLoss 4 | from pytorchltr.loss import PairwiseLogisticLoss 5 | from math import log 6 | from math import log2 7 | from math import exp 8 | from pytest import approx 9 | 10 | 11 | def test_pairwise_hinge_autoreshape_scores(): 12 | loss_fn = PairwiseHingeLoss() 13 | scores = torch.FloatTensor([[0.0, 0.0, 1.0, 2.0, 1.0]]) 14 | ys = torch.LongTensor([[[0], [0], [1], [2], [1]]]) 15 | n = torch.LongTensor([5]) 16 | loss = loss_fn(scores, ys, n) 17 | 18 | # Hinge loss: 1.0 + \sum_{(d_i, d_j) : y_i > y_j} hinge(f(d_j) - f(d_i))) 19 | assert loss.item() == approx(0.0) 20 | 21 | 22 | def test_pairwise_hinge_autoreshape_relevance(): 23 | loss_fn = PairwiseHingeLoss() 24 | scores = torch.FloatTensor([[[0.0], [0.0], [1.0], [2.0], [1.0]]]) 25 | ys = torch.LongTensor([[0, 0, 1, 2, 1]]) 26 | n = torch.LongTensor([5]) 27 | loss = loss_fn(scores, ys, n) 28 | 29 | # Hinge loss: 1.0 + \sum_{(d_i, d_j) : y_i > y_j} hinge(f(d_j) - f(d_i))) 30 | assert loss.item() == approx(0.0) 31 | 32 | 33 | def test_pairwise_hinge_perfect(): 34 | loss_fn = PairwiseHingeLoss() 35 | scores = torch.FloatTensor([[0.0, 0.0, 1.0, 2.0, 1.0]]) 36 | ys = torch.LongTensor([[0, 0, 1, 2, 1]]) 37 | n = torch.LongTensor([5]) 38 | loss = loss_fn(scores, ys, n) 39 | 40 | # Hinge loss: 1.0 + \sum_{(d_i, d_j) : y_i > y_j} hinge(f(d_j) - f(d_i))) 41 | assert loss.item() == approx(0.0) 42 | 43 | 44 | def test_pairwise_hinge_2(): 45 | loss_fn = PairwiseHingeLoss() 46 | scores = torch.FloatTensor([[[0.0], [0.0], [1.0], [1.0], [1.0]]]) 47 | ys = torch.LongTensor([[0, 0, 1, 2, 1]]) 48 | n = torch.LongTensor([5]) 49 | loss = loss_fn(scores, ys, n) 50 | 51 | # Hinge loss: 1.0 + \sum_{(d_i, d_j) : y_i > y_j} hinge(f(d_j) - f(d_i))) 52 | assert loss.item() == approx(2.0) 53 | 54 | 55 | def test_pairwise_hinge_3(): 56 | loss_fn = PairwiseHingeLoss() 57 | scores = torch.FloatTensor([[[0.0], [0.0], [1.0], [-5.0], [1.0]]]) 58 | ys = torch.LongTensor([[0, 0, 1, 2, 1]]) 59 | n = torch.LongTensor([5]) 60 | loss = loss_fn(scores, ys, n) 61 | 62 | # Hinge loss: 1.0 + \sum_{(d_i, d_j) : y_i > y_j} hinge(f(d_j) - f(d_i))) 63 | assert loss.item() == approx(7.0 + 7.0 + 6.0 + 6.0) 64 | 65 | 66 | def test_pairwise_hinge_batch(): 67 | loss_fn = PairwiseHingeLoss() 68 | scores = torch.FloatTensor([ 69 | [[0.0], [10.0], [1.0], [0.5], [1.0]], 70 | [[1.0], [3.5], [6.0], [4.3], [10.0]]]) 71 | ys = torch.LongTensor([ 72 | [0, 2, 1, 2, 1], 73 | [1, 2, 2, 1, 0]]) 74 | n = torch.LongTensor([5, 4]) 75 | loss = loss_fn(scores, ys, n) 76 | 77 | # Hinge loss: 1.0 + \sum_{(d_i, d_j) : y_i > y_j} hinge(f(d_j) - f(d_i))) 78 | assert loss[0].item() == approx(0.5 + 1.5 + 1.5) 79 | assert loss[1].item() == approx(5.3 - 3.5) 80 | 81 | 82 | def test_pairwise_hinge_cutoff(): 83 | loss_fn = PairwiseHingeLoss() 84 | scores = torch.FloatTensor([[[1.0], [3.5], [6.0], [4.3], [8.0]]]) 85 | ys = torch.LongTensor([[1, 2, 2, 1, 0]]) 86 | n1 = torch.LongTensor([3]) 87 | n2 = torch.LongTensor([4]) 88 | n3 = torch.LongTensor([5]) 89 | loss1 = loss_fn(scores, ys, n1) 90 | loss2 = loss_fn(scores, ys, n2) 91 | loss3 = loss_fn(scores, ys, n3) 92 | 93 | # Hinge loss: 1.0 + \sum_{(d_i, d_j) : y_i > y_j} hinge(f(d_j) - f(d_i))) 94 | assert loss1.item() == approx(0.0) 95 | assert loss2.item() == approx(5.3 - 3.5) 96 | assert loss3.item() == approx( 97 | (5.3 - 3.5) + (9.0 - 4.3) + (9.0 - 6.0) + (9.0 - 3.5) + (9.0 - 1.0)) 98 | 99 | 100 | def test_pairwise_dcghinge_perfect(): 101 | loss_fn = PairwiseDCGHingeLoss() 102 | scores = torch.FloatTensor([[[0.0], [0.0], [1.0], [2.0], [1.0]]]) 103 | ys = torch.LongTensor([[0, 0, 1, 2, 1]]) 104 | n = torch.LongTensor([5]) 105 | loss = loss_fn(scores, ys, n) 106 | 107 | # DCG-modified Hinge loss: 108 | # -1.0 / log2(1.0 + \sum_{(d_i, d_j) : y_i > y_j} hinge(f(d_j) - f(d_i)))) 109 | assert loss.item() == approx(-1.0 / log(2.0 + 0.0)) 110 | 111 | 112 | def test_pairwise_dcghinge_worst(): 113 | loss_fn = PairwiseDCGHingeLoss() 114 | scores = torch.FloatTensor([[[3.0], [3.0], [1.0], [0.0], [1.0]]]) 115 | ys = torch.LongTensor([[0, 0, 1, 2, 1]]) 116 | n = torch.LongTensor([5]) 117 | loss = loss_fn(scores, ys, n) 118 | 119 | # DCG-modified Hinge loss: 120 | # -1.0 / log2(1.0 + \sum_{(d_i, d_j) : y_i > y_j} hinge(f(d_j) - f(d_i)))) 121 | assert loss.item() == approx(-1.0 / log(2.0 + 24.0)) 122 | 123 | 124 | def test_pairwise_logistic_perfect(): 125 | loss_fn = PairwiseLogisticLoss() 126 | scores = torch.FloatTensor([[[0.0], [0.0], [1.0], [2.0], [1.0]]]) 127 | ys = torch.LongTensor([[0, 0, 1, 2, 1]]) 128 | n = torch.LongTensor([5]) 129 | loss = loss_fn(scores, ys, n) 130 | 131 | # DCG-modified Hinge loss: 132 | # \sum_{(d_i, d_j) : y_i > y_j} log2(1.0 + exp(-1.0 * (f(d_j) - f(d_i))))) 133 | d1 = (2.0 - 1.0) # 2 times 134 | d2 = (2.0 - 0.0) # 2 times 135 | d3 = (1.0 - 0.0) # 4 times 136 | assert loss.item() == approx( 137 | log2(1.0 + exp(-1.0 * d1)) * 2 + 138 | log2(1.0 + exp(-1.0 * d2)) * 2 + 139 | log2(1.0 + exp(-1.0 * d3)) * 4) 140 | 141 | 142 | def test_pairwise_logistic_worst(): 143 | loss_fn = PairwiseLogisticLoss() 144 | scores = torch.FloatTensor([[[3.0], [3.0], [1.0], [0.0], [1.0]]]) 145 | ys = torch.LongTensor([[0, 0, 1, 2, 1]]) 146 | n = torch.LongTensor([5]) 147 | loss = loss_fn(scores, ys, n) 148 | 149 | # Logistic loss: 150 | # \sum_{(d_i, d_j) : y_i > y_j} log2(1.0 + exp(-1.0 * (f(d_j) - f(d_i))))) 151 | d1 = (0.0 - 3.0) # 2 times 152 | d2 = (0.0 - 1.0) # 2 times 153 | d3 = (1.0 - 3.0) # 4 times 154 | assert loss.item() == approx( 155 | log2(1.0 + exp(-1.0 * d1)) * 2 + 156 | log2(1.0 + exp(-1.0 * d2)) * 2 + 157 | log2(1.0 + exp(-1.0 * d3)) * 4) 158 | -------------------------------------------------------------------------------- /pytorchltr/loss/pairwise_additive.py: -------------------------------------------------------------------------------- 1 | import torch as _torch 2 | from pytorchltr.utils import batch_pairs 3 | 4 | 5 | class _PairwiseAdditiveLoss(_torch.nn.Module): 6 | """Pairwise additive ranking losses. 7 | 8 | Implementation of linearly decomposible additive pairwise ranking losses. 9 | This includes RankSVM hinge loss and variations. 10 | """ 11 | def __init__(self): 12 | r"""""" 13 | super().__init__() 14 | 15 | def _loss_per_doc_pair(self, score_pairs: _torch.FloatTensor, 16 | rel_pairs: _torch.LongTensor) -> _torch.FloatTensor: 17 | """Computes a loss on given score pairs and relevance pairs. 18 | 19 | Args: 20 | score_pairs: A tensor of shape (batch_size, list_size, 21 | list_size, 2), where each entry (:, i, j, :) indicates a pair 22 | of scores for doc i and j. 23 | rel_pairs: A tensor of shape (batch_size, list_size, list_size, 2), 24 | where each entry (:, i, j, :) indicates the relevance 25 | for doc i and j. 26 | 27 | Returns: 28 | A tensor of shape (batch_size, list_size, list_size) with a loss 29 | per document pair. 30 | """ 31 | raise NotImplementedError 32 | 33 | def _loss_reduction(self, 34 | loss_pairs: _torch.FloatTensor) -> _torch.FloatTensor: 35 | """Reduces the paired loss to a per sample loss. 36 | 37 | Args: 38 | loss_pairs: A tensor of shape (batch_size, list_size, list_size) 39 | where each entry indicates the loss of doc pair i and j. 40 | 41 | Returns: 42 | A tensor of shape (batch_size) indicating the loss per training 43 | sample. 44 | """ 45 | return loss_pairs.view(loss_pairs.shape[0], -1).sum(1) 46 | 47 | def _loss_modifier(self, loss: _torch.FloatTensor) -> _torch.FloatTensor: 48 | """A modifier to apply to the loss.""" 49 | return loss 50 | 51 | def forward(self, scores: _torch.FloatTensor, relevance: _torch.LongTensor, 52 | n: _torch.LongTensor) -> _torch.FloatTensor: 53 | """Computes the loss for given batch of samples. 54 | 55 | Args: 56 | scores: A batch of per-query-document scores. 57 | relevance: A batch of per-query-document relevance labels. 58 | n: A batch of per-query number of documents (for padding purposes). 59 | """ 60 | # Reshape relevance if necessary. 61 | if relevance.ndimension() == 2: 62 | relevance = relevance.reshape( 63 | (relevance.shape[0], relevance.shape[1], 1)) 64 | if scores.ndimension() == 2: 65 | scores = scores.reshape((scores.shape[0], scores.shape[1], 1)) 66 | 67 | # Compute pairwise differences for scores and relevances. 68 | score_pairs = batch_pairs(scores) 69 | rel_pairs = batch_pairs(relevance) 70 | 71 | # Compute loss per doc pair. 72 | loss_pairs = self._loss_per_doc_pair(score_pairs, rel_pairs) 73 | 74 | # Mask out padded documents per query in the batch 75 | n_grid = n[:, None, None].repeat(1, score_pairs.shape[1], 76 | score_pairs.shape[2]) 77 | arange = _torch.arange(score_pairs.shape[1], 78 | device=score_pairs.device) 79 | range_grid = _torch.max(*_torch.meshgrid([arange, arange])) 80 | range_grid = range_grid[None, :, :].repeat(n.shape[0], 1, 1) 81 | loss_pairs[n_grid <= range_grid] = 0.0 82 | 83 | # Reduce final list loss from per doc pair loss to a per query loss. 84 | loss = self._loss_reduction(loss_pairs) 85 | 86 | # Apply a loss modifier. 87 | loss = self._loss_modifier(loss) 88 | 89 | # Return loss 90 | return loss 91 | 92 | 93 | class PairwiseHingeLoss(_PairwiseAdditiveLoss): 94 | r"""Pairwise hinge loss formulation of SVMRank: 95 | 96 | .. math:: 97 | l(\mathbf{s}, \mathbf{y}) = \sum_{y_i > y _j} max\left( 98 | 0, 1 - (s_i - s_j) 99 | \right) 100 | 101 | Shape: 102 | - input scores: :math:`(N, \texttt{list_size})` 103 | - input relevance: :math:`(N, \texttt{list_size})` 104 | - input n: :math:`(N)` 105 | - output: :math:`(N)` 106 | """ 107 | def _loss_per_doc_pair(self, score_pairs, rel_pairs): 108 | score_pair_diffs = score_pairs[:, :, :, 0] - score_pairs[:, :, :, 1] 109 | rel_pair_diffs = rel_pairs[:, :, :, 0] - rel_pairs[:, :, :, 1] 110 | loss = 1.0 - score_pair_diffs 111 | loss[rel_pair_diffs <= 0.0] = 0.0 112 | loss[loss < 0.0] = 0.0 113 | return loss 114 | 115 | 116 | class PairwiseDCGHingeLoss(PairwiseHingeLoss): 117 | r"""Pairwise DCG-modified hinge loss: 118 | 119 | .. math:: 120 | l(\mathbf{s}, \mathbf{y}) = 121 | \frac{-1}{\log\left( 122 | 2 + \sum_{y_i > y_j} 123 | max\left(0, 1 - (s_i - s_j)\right) 124 | \right)} 125 | 126 | Shape: 127 | - input scores: :math:`(N, \texttt{list_size})` 128 | - input relevance: :math:`(N, \texttt{list_size})` 129 | - input n: :math:`(N)` 130 | - output: :math:`(N)` 131 | """ 132 | def _loss_modifier(self, loss): 133 | return -1.0 / _torch.log(2.0 + loss) 134 | 135 | 136 | class PairwiseLogisticLoss(_PairwiseAdditiveLoss): 137 | r"""Pairwise logistic loss formulation of RankNet: 138 | 139 | .. math:: 140 | l(\mathbf{s}, \mathbf{y}) = \sum_{y_i > y_j} \log_2\left(1 + e^{ 141 | -\sigma \left(s_i - s_j\right) 142 | }\right) 143 | 144 | Shape: 145 | - input scores: :math:`(N, \texttt{list_size})` 146 | - input relevance: :math:`(N, \texttt{list_size})` 147 | - input n: :math:`(N)` 148 | - output: :math:`(N)` 149 | """ 150 | def __init__(self, sigma: float = 1.0): 151 | """ 152 | Args: 153 | sigma: Steepness of the logistic curve. 154 | """ 155 | super().__init__() 156 | self.sigma = sigma 157 | 158 | def _loss_per_doc_pair(self, score_pairs, rel_pairs): 159 | score_pair_diffs = score_pairs[:, :, :, 0] - score_pairs[:, :, :, 1] 160 | rel_pair_diffs = rel_pairs[:, :, :, 0] - rel_pairs[:, :, :, 1] 161 | loss = _torch.log2(1.0 + _torch.exp(-self.sigma * score_pair_diffs)) 162 | loss[rel_pair_diffs <= 0.0] = 0.0 163 | return loss 164 | -------------------------------------------------------------------------------- /tests/datasets/test_list_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorchltr.datasets.list_sampler import ListSampler 3 | from pytorchltr.datasets.list_sampler import UniformSampler 4 | from pytorchltr.datasets.list_sampler import BalancedRelevanceSampler 5 | 6 | from pytest import approx 7 | 8 | 9 | def rng(seed=1608637542): 10 | gen = torch.Generator() 11 | gen.manual_seed(seed) 12 | return gen 13 | 14 | 15 | def test_list_sampler(): 16 | sampler = ListSampler(max_list_size=5) 17 | relevance = torch.tensor([0, 0, 1, 0, 0, 0, 2, 1], dtype=torch.long) 18 | idxs = sampler(relevance) 19 | expected = torch.arange(5) 20 | assert idxs.equal(expected) 21 | 22 | 23 | def test_list_sampler_unlimited(): 24 | sampler = ListSampler(max_list_size=None) 25 | relevance = torch.tensor([0, 0, 1, 0, 0, 0, 2, 1], dtype=torch.long) 26 | idxs = sampler(relevance) 27 | expected = torch.arange(8) 28 | assert idxs.equal(expected) 29 | 30 | 31 | def test_list_sampler_single(): 32 | sampler = ListSampler(max_list_size=1) 33 | relevance = torch.tensor([0], dtype=torch.long) 34 | idxs = sampler(relevance) 35 | expected = torch.arange(1) 36 | assert idxs.equal(expected) 37 | 38 | 39 | def test_uniform_no_generator(): 40 | torch.manual_seed(1608637542) 41 | sampler = UniformSampler(max_list_size=5) 42 | relevance = torch.tensor([0, 1, 0, 0, 2, 0, 1, 0, 0, 0], dtype=torch.long) 43 | idxs = sampler(relevance) 44 | assert idxs.shape == (5,) 45 | 46 | 47 | def test_uniform_single(): 48 | sampler = UniformSampler(max_list_size=1, generator=rng()) 49 | relevance = torch.tensor([0], dtype=torch.long) 50 | idxs = sampler(relevance) 51 | expected = torch.tensor([0], dtype=torch.long) 52 | assert idxs.equal(expected) 53 | 54 | 55 | def test_uniform_single_2(): 56 | sampler = UniformSampler(max_list_size=1, generator=rng()) 57 | relevance = torch.tensor([0, 1, 2, 0, 0, 1, 0, 0, 2], dtype=torch.long) 58 | idxs = sampler(relevance) 59 | assert idxs.shape == (1,) 60 | 61 | 62 | def test_uniform_multiple(): 63 | sampler = UniformSampler(max_list_size=3, generator=rng()) 64 | relevance = torch.tensor([0, 0, 0, 1, 0, 0, 2], dtype=torch.long) 65 | idxs = sampler(relevance) 66 | assert idxs.shape == (3,) 67 | 68 | 69 | def test_uniform_multiple_2(): 70 | sampler = UniformSampler(max_list_size=9, generator=rng()) 71 | relevance = torch.tensor([0, 1], dtype=torch.long) 72 | idxs = sampler(relevance) 73 | assert idxs.shape == (2,) 74 | 75 | 76 | def test_uniform_unlimited(): 77 | sampler = UniformSampler(max_list_size=None, generator=rng()) 78 | relevance = torch.tensor([0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 0, 1], 79 | dtype=torch.long) 80 | idxs = sampler(relevance) 81 | assert idxs.shape == (13,) 82 | 83 | 84 | def test_uniform_stat(): 85 | sampler = UniformSampler(max_list_size=1, generator=rng()) 86 | relevance = torch.tensor([0, 0, 1, 0, 0, 0, 2, 1], dtype=torch.long) 87 | n = 1000 88 | hist = torch.zeros(3) 89 | for i in range(n): 90 | idxs = sampler(relevance) 91 | hist[relevance[idxs].item()] += 1 92 | hist /= n 93 | assert hist[0].item() == approx(5.0 / 8.0, abs=0.05) 94 | assert hist[1].item() == approx(2.0 / 8.0, abs=0.05) 95 | assert hist[2].item() == approx(1.0 / 8.0, abs=0.05) 96 | 97 | 98 | def test_uniform_stat_large(): 99 | sampler = UniformSampler(max_list_size=5, generator=rng()) 100 | relevance = torch.tensor([0, 0, 1, 0, 0, 0, 2, 1], dtype=torch.long) 101 | n = 1000 102 | for idx in range(5): 103 | hist = torch.zeros(3) 104 | for i in range(n): 105 | idxs = sampler(relevance) 106 | hist[relevance[idxs][idx].item()] += 1 107 | hist /= n 108 | assert hist[0].item() == approx(5.0 / 8.0, abs=0.05) 109 | assert hist[1].item() == approx(2.0 / 8.0, abs=0.05) 110 | assert hist[2].item() == approx(1.0 / 8.0, abs=0.05) 111 | 112 | 113 | def test_balanced_no_generator(): 114 | torch.manual_seed(1608637542) 115 | sampler = BalancedRelevanceSampler(max_list_size=5) 116 | relevance = torch.tensor([0, 1, 0, 0, 2, 0, 1, 0, 0, 0], dtype=torch.long) 117 | idxs = sampler(relevance) 118 | assert idxs.shape == (5,) 119 | 120 | 121 | def test_balanced_single(): 122 | sampler = BalancedRelevanceSampler(max_list_size=1, generator=rng()) 123 | relevance = torch.tensor([0], dtype=torch.long) 124 | idxs = sampler(relevance) 125 | expected = torch.tensor([0], dtype=torch.long) 126 | assert idxs.equal(expected) 127 | 128 | 129 | def test_balanced_single_2(): 130 | sampler = BalancedRelevanceSampler(max_list_size=1, generator=rng()) 131 | relevance = torch.tensor([0, 1, 2, 0, 0, 1, 0, 0, 2], dtype=torch.long) 132 | idxs = sampler(relevance) 133 | assert idxs.shape == (1,) 134 | 135 | 136 | def test_balanced_multiple(): 137 | sampler = BalancedRelevanceSampler(max_list_size=3, generator=rng()) 138 | relevance = torch.tensor([0, 0, 0, 1, 0, 0, 2], dtype=torch.long) 139 | idxs = sampler(relevance) 140 | assert idxs.shape == (3,) 141 | 142 | 143 | def test_balanced_multiple_2(): 144 | sampler = BalancedRelevanceSampler(max_list_size=9, generator=rng()) 145 | relevance = torch.tensor([0, 1], dtype=torch.long) 146 | idxs = sampler(relevance) 147 | assert idxs.shape == (2,) 148 | 149 | 150 | def test_balanced_unlimited(): 151 | sampler = BalancedRelevanceSampler(max_list_size=None, generator=rng()) 152 | relevance = torch.tensor([0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 0, 1], 153 | dtype=torch.long) 154 | idxs = sampler(relevance) 155 | assert idxs.shape == (13,) 156 | 157 | 158 | def test_balanced_stat(): 159 | sampler = BalancedRelevanceSampler(max_list_size=1, generator=rng()) 160 | relevance = torch.tensor([0, 0, 1, 0, 0, 0, 2, 1], dtype=torch.long) 161 | n = 1000 162 | hist = torch.zeros(3) 163 | for i in range(n): 164 | idxs = sampler(relevance) 165 | hist[relevance[idxs].item()] += 1 166 | hist /= n 167 | assert hist[0].item() == approx(1.0 / 3.0, abs=0.05) 168 | assert hist[1].item() == approx(1.0 / 3.0, abs=0.05) 169 | assert hist[2].item() == approx(1.0 / 3.0, abs=0.05) 170 | 171 | 172 | def test_balanced_stat_large(): 173 | sampler = BalancedRelevanceSampler(max_list_size=7, generator=rng()) 174 | relevance = torch.tensor([0, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0], 175 | dtype=torch.long) 176 | n = 1000 177 | expected = torch.tensor([ 178 | [1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], 179 | [1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], 180 | [1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], 181 | [1.0 / 2.0, 1.0 / 2.0, 0.0], 182 | [1.0 / 2.0, 1.0 / 2.0, 0.0], 183 | [1.0, 0.0, 0.0], 184 | [1.0, 0.0, 0.0], 185 | ]) 186 | for idx in range(7): 187 | hist = torch.zeros(3) 188 | for i in range(n): 189 | idxs = sampler(relevance) 190 | hist[relevance[idxs][idx].item()] += 1 191 | hist /= n 192 | assert hist[0].item() == approx(expected[idx, 0].item(), abs=0.05) 193 | assert hist[1].item() == approx(expected[idx, 1].item(), abs=0.05) 194 | assert hist[2].item() == approx(expected[idx, 2].item(), abs=0.05) 195 | -------------------------------------------------------------------------------- /pytorchltr/utils/file.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import logging 3 | import os 4 | import tarfile 5 | import zipfile 6 | from typing import Dict 7 | from typing import List 8 | from typing import Optional 9 | from typing import Union 10 | 11 | 12 | class ChecksumError(Exception): 13 | """A ChecksumError occurs when a checksum validation fails. 14 | 15 | Attributes: 16 | msg (str): Human readable string describing the exception 17 | origin (str): The origin of the checksum error. 18 | expected (str): The expected checksum. 19 | actual (str): The checksum that occurre 20 | """ 21 | def __init__(self, origin: str, expected: str, actual: str): 22 | """ 23 | Args: 24 | origin: The origin of the checksum error. 25 | expected: The expected checksum. 26 | actual: The actual checksum that occurred. 27 | """ 28 | self.origin = origin 29 | self.expected = expected 30 | self.actual = actual 31 | self.msg = "'%s' checksum error: expected '%s', but got '%s'" % ( 32 | origin, expected, actual) 33 | 34 | 35 | def sha256_checksum(path: str, chunk_size: int = 4*1024*1024): 36 | """ 37 | Computes the sha256 checksum on the file in given path. 38 | 39 | Args: 40 | path: The file to compute the sha256 checksum for. 41 | chunk_size: Chunk size to read per time, prevents loading the full file 42 | into memory (default 4MiB). 43 | 44 | Returns: 45 | The sha256 checksum hex digest as a string. 46 | """ 47 | hash_sha256 = hashlib.sha256() 48 | with open(path, "rb") as f: 49 | for chunk in iter(lambda: f.read(chunk_size), b""): 50 | hash_sha256.update(chunk) 51 | return hash_sha256.hexdigest() 52 | 53 | 54 | def validate_file(path: str, sha256: Optional[str] = None): 55 | """ 56 | Performs a validation check on the file in given path: 57 | * Checks if the file exists. 58 | * (Optional) Checks if the file's sha256 matches given sha256. 59 | 60 | This function will raise a FileNotFoundError if the given file does not 61 | exist or it will raise an AssertionError if the sha256 checksum fails. If 62 | no exception is raised, all checks have passed. 63 | 64 | Args: 65 | path: The path to the file to check. 66 | sha256: (Optional) the sha256 checksum to compare against. 67 | """ 68 | logging.debug("checking if '%s' exists and is a file", path) 69 | if not (os.path.exists(path) and os.path.isfile(path)): 70 | raise FileNotFoundError("could not find expected file '%s'" % path) 71 | 72 | if sha256 is not None: 73 | logging.debug("checking sha256 checksum of '%s'", path) 74 | actual = sha256_checksum(path) 75 | if actual != sha256: 76 | raise ChecksumError(path, sha256, actual) 77 | 78 | 79 | def validate_expected_files(location: str, 80 | expected_files: List[Union[str, Dict[str, str]]], 81 | validate_checksums: bool = False): 82 | """ 83 | Performs a validation check on a set of expected files: 84 | * Checks if all files exists 85 | * (Optional) Checks if each file's sha256 matches given sha256. 86 | 87 | This function will raise a FileNotFoundError or a ChecksumError if 88 | something is missing or if a checksum fails. A RuntimeError may be raised 89 | if expected_files has an incorrect format. 90 | 91 | Args: 92 | location: The directory to check 93 | expected_files: A list of expected files for this resource. Each entry 94 | in the list should be either a string indicating the path of the 95 | file or a dict containing a 'path' and 'sha256' key for the path 96 | and sha256 checksum of the file. 97 | validate_checksums: Whether to validate checksums or skip them. 98 | """ 99 | if expected_files is not None: 100 | for f in expected_files: 101 | if isinstance(f, str): 102 | validate_file(os.path.join(location, f)) 103 | elif isinstance(f, dict) and "path" in f and "sha256" in f: 104 | validate_file( 105 | os.path.join(location, f["path"]), 106 | f["sha256"] if validate_checksums else None) 107 | else: 108 | raise ValueError( 109 | "entries in expected_files should be either of type " 110 | "str or a dict containing 'path' and 'sha256' keys.") 111 | 112 | 113 | _DOWNLOADER_TYPE = "pytorchltr.utils.downloader.Downloader" 114 | 115 | 116 | def validate_and_download(location: str, 117 | expected_files: List[Union[str, Dict[str, str]]], 118 | downloader: Optional[_DOWNLOADER_TYPE] = None, 119 | validate_checksums: bool = False): 120 | """Validates expected files at given location and attempts to download 121 | if validation fails. 122 | 123 | Args: 124 | location: The location to check. 125 | expected_files: (Optional) a list of expected files for this resource. 126 | Each entry in the list should be either a string indicating the 127 | path of the file or a dict containing a 'path' and 'sha256' key for 128 | the path and sha256 checksum of the file. 129 | downloader: The downloader to use when downloading files. 130 | validate_checksums: Whether to validate checksums. 131 | """ 132 | try: 133 | logging.info("checking dataset files in '%s'", location) 134 | validate_expected_files( 135 | location, expected_files, validate_checksums) 136 | logging.info("successfully checked all dataset files") 137 | except (FileNotFoundError, ChecksumError): 138 | logging.warning("dataset file(s) in '%s' are missing or corrupt", 139 | location) 140 | if downloader is not None: 141 | downloader.download(location) 142 | validate_expected_files( 143 | location, expected_files, validate_checksums) 144 | logging.info("successfully checked all dataset files") 145 | else: 146 | raise 147 | 148 | 149 | def extract_tar(path: str, destination: str): 150 | """ 151 | Extracts the .tar[.gz] file at given path to given destination. 152 | 153 | Args: 154 | path: The location of the .tar[.gz] file. 155 | destination: The destination to extract to. 156 | """ 157 | logging.info("extracting tar file at '%s' to '%s'", path, destination) 158 | with tarfile.open(path) as f: 159 | f.extractall(destination) 160 | 161 | 162 | def extract_zip(path: str, destination: str): 163 | """ 164 | Extracts the .zip file at given path to given destination. 165 | 166 | Args: 167 | path: The location of the .zip file. 168 | destination: The destination to extract to. 169 | """ 170 | logging.info("extracting zip file at '%s' to '%s'", path, destination) 171 | with zipfile.ZipFile(path, "r") as f: 172 | f.extractall(destination) 173 | 174 | 175 | def dataset_dir(name: str) -> str: 176 | """ 177 | Returns the location of the dataset directory. 178 | 179 | Args: 180 | name: The name of the dataset. 181 | 182 | Returns: 183 | The path to the dataset directory. 184 | """ 185 | dataset_path = os.path.join(os.environ.get("HOME", "."), 186 | ".pytorchltr_datasets") 187 | dataset_path = os.environ.get("DATASET_PATH", dataset_path) 188 | dataset_path = os.environ.get("PYTORCHLTR_DATASET_PATH", dataset_path) 189 | return os.path.join(dataset_path, name) 190 | -------------------------------------------------------------------------------- /tests/utils/test_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from unittest import mock 4 | from pathlib import Path 5 | 6 | import pytest 7 | from pytorchltr.utils.file import validate_expected_files 8 | from pytorchltr.utils.file import validate_and_download 9 | from pytorchltr.utils.file import ChecksumError 10 | 11 | 12 | def test_validate_expected_files_simple_succeeds(): 13 | expected_files = ["file1.txt", "file2.txt"] 14 | actual_files = ["file1.txt", "file2.txt"] 15 | 16 | with tempfile.TemporaryDirectory() as tmpdir: 17 | 18 | # Create actual files 19 | for actual_file in actual_files: 20 | Path(os.path.join(tmpdir, actual_file)).touch() 21 | 22 | # Assert that validation succeeds 23 | validate_expected_files(tmpdir, expected_files) 24 | 25 | 26 | def test_validate_expected_files_simple_fails(): 27 | expected_files = ["file1.txt", "file2.txt", "file3.txt"] 28 | actual_files = ["file1.txt", "file2.txt"] 29 | 30 | with tempfile.TemporaryDirectory() as tmpdir: 31 | 32 | # Create actual files 33 | for actual_file in actual_files: 34 | Path(os.path.join(tmpdir, actual_file)).touch() 35 | 36 | # Assert that validation raises file not found error 37 | with pytest.raises(FileNotFoundError): 38 | validate_expected_files(tmpdir, expected_files) 39 | 40 | 41 | def test_validate_expected_files_checksum_succeeds(): 42 | expected_files = [ 43 | {"path": "file1.txt", "sha256": "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"}, # noqa: E501 44 | {"path": "file2.txt", "sha256": "486ea46224d1bb4fb680f34f7c9ad96a8f24ec88be73ea8e5a6c65260e9cb8a7"} # noqa: E501 45 | ] 46 | actual_files = [ 47 | {"path": "file1.txt", "contents": "hello"}, 48 | {"path": "file2.txt", "contents": "world"} 49 | ] 50 | 51 | with tempfile.TemporaryDirectory() as tmpdir: 52 | 53 | # Create actual files 54 | for actual_file in actual_files: 55 | path = os.path.join(tmpdir, actual_file["path"]) 56 | with open(path, "wt") as file_handle: 57 | file_handle.write(actual_file["contents"]) 58 | 59 | # Assert that validation succeeds 60 | validate_expected_files( 61 | tmpdir, expected_files, validate_checksums=True) 62 | 63 | 64 | def test_validate_expected_files_checksum_fails(): 65 | expected_files = [ 66 | {"path": "file1.txt", "sha256": "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"}, # noqa: E501 67 | {"path": "file2.txt", "sha256": "486ea46224d1bb4fb680f34f7c9ad96a8f24ec88be73ea8e5a6c65260e9cb8a7"} # noqa: E501 68 | ] 69 | actual_files = [ 70 | {"path": "file1.txt", "contents": "goodbye"}, 71 | {"path": "file2.txt", "contents": "world"} 72 | ] 73 | 74 | with tempfile.TemporaryDirectory() as tmpdir: 75 | 76 | # Create actual files 77 | for actual_file in actual_files: 78 | path = os.path.join(tmpdir, actual_file["path"]) 79 | with open(path, "wt") as file_handle: 80 | file_handle.write(actual_file["contents"]) 81 | 82 | # Assert that validation fails with checksum error 83 | with pytest.raises(ChecksumError): 84 | validate_expected_files( 85 | tmpdir, expected_files, validate_checksums=True) 86 | 87 | 88 | def test_validate_expected_files_format_fails(): 89 | expected_files = [ 90 | {"filepath": "file1.txt", "sha256": "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"}, # noqa: E501 91 | {"path": "file2.txt", "checksum": "486ea46224d1bb4fb680f34f7c9ad96a8f24ec88be73ea8e5a6c65260e9cb8a7"} # noqa: E501 92 | ] 93 | actual_files = [ 94 | {"path": "file1.txt", "contents": "hello"}, 95 | {"path": "file2.txt", "contents": "world"} 96 | ] 97 | 98 | with tempfile.TemporaryDirectory() as tmpdir: 99 | # Create actual files 100 | for actual_file in actual_files: 101 | path = os.path.join(tmpdir, actual_file["path"]) 102 | with open(path, "wt") as file_handle: 103 | file_handle.write(actual_file["contents"]) 104 | 105 | # Assert that validation fails with checksum error 106 | with pytest.raises(ValueError): 107 | validate_expected_files( 108 | tmpdir, expected_files, validate_checksums=True) 109 | 110 | 111 | def test_validate_and_download_calls_download(): 112 | downloader = mock.MagicMock() 113 | expected_files = ["file1.txt"] 114 | actual_files = ["file1.txt"] 115 | 116 | def create_files_side_effect(location): 117 | for actual_file in actual_files: 118 | Path(os.path.join(location, actual_file)).touch() 119 | 120 | downloader.download.side_effect = create_files_side_effect 121 | 122 | with tempfile.TemporaryDirectory() as tmpdir: 123 | # Call validate_and_download and assert download gets triggered. 124 | validate_and_download( 125 | tmpdir, expected_files, downloader=downloader, 126 | validate_checksums=False) 127 | downloader.download.assert_called_once_with(tmpdir) 128 | 129 | 130 | def test_validate_and_download_skips_download(): 131 | downloader = mock.MagicMock() 132 | expected_files = ["file1.txt"] 133 | actual_files = ["file1.txt"] 134 | 135 | def create_files_side_effect(location): 136 | for actual_file in actual_files: 137 | Path(os.path.join(location, actual_file)).touch() 138 | 139 | downloader.download.side_effect = create_files_side_effect 140 | 141 | with tempfile.TemporaryDirectory() as tmpdir: 142 | # Create files already, so that validate_and_download can skip the 143 | # download call. 144 | create_files_side_effect(tmpdir) 145 | 146 | # Call validate_and_download and assert download was not triggered. 147 | validate_and_download( 148 | tmpdir, expected_files, downloader=downloader, 149 | validate_checksums=False) 150 | downloader.download.assert_not_called() 151 | 152 | 153 | def test_validate_and_download_fails_after_download_fails(): 154 | downloader = mock.MagicMock() 155 | expected_files = ["file1.txt"] 156 | 157 | with tempfile.TemporaryDirectory() as tmpdir: 158 | # Call validate_and_download and assert that file not found errors 159 | # is raised when no files are actually downloaded. 160 | with pytest.raises(FileNotFoundError): 161 | validate_and_download( 162 | tmpdir, expected_files, downloader=downloader, 163 | validate_checksums=False) 164 | 165 | # Assert download was called 166 | downloader.download.assert_called_once_with(tmpdir) 167 | 168 | 169 | def test_validate_and_download_fails_without_downloader(): 170 | expected_files = ["file1.txt"] 171 | 172 | with tempfile.TemporaryDirectory() as tmpdir: 173 | # Call validate_and_download and assert that file not found error 174 | # is raised when no files are actually downloaded. 175 | with pytest.raises(FileNotFoundError): 176 | validate_and_download( 177 | tmpdir, expected_files, downloader=None, 178 | validate_checksums=False) 179 | 180 | 181 | def test_validate_and_download_succeeds_without_downloader(): 182 | expected_files = ["file1.txt"] 183 | actual_files = ["file1.txt"] 184 | 185 | with tempfile.TemporaryDirectory() as tmpdir: 186 | # Create actual files. 187 | for actual_file in actual_files: 188 | Path(os.path.join(tmpdir, actual_file)).touch() 189 | 190 | # Call validate_and_download and make sure nothing gets raised. 191 | validate_and_download( 192 | tmpdir, expected_files, downloader=None, 193 | validate_checksums=False) 194 | -------------------------------------------------------------------------------- /pytorchltr/utils/downloader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from collections import deque 5 | from urllib.request import urlopen 6 | from typing import Callable 7 | from typing import Optional 8 | 9 | from pytorchltr.utils.file import ChecksumError 10 | from pytorchltr.utils.file import validate_file 11 | from pytorchltr.utils.progress import LoggingProgress 12 | from pytorchltr.utils.progress import TerminalProgress 13 | 14 | 15 | _PROGRESS_FN_TYPE = Callable[[int, Optional[int], bool], None] 16 | _POSTPROCESS_FN_TYPE = Callable[[str, str], None] 17 | 18 | 19 | class Downloader: 20 | """ 21 | Downloader allows downloading from a given url to a given target file. 22 | 23 | Attributes: 24 | url (str): The URL to download from. 25 | target (str): The target file name to save as. 26 | sha256_checksum (str, optional): The SHA256 checksum of the file which 27 | will be checked to see if a download and is needed and to validate 28 | the downloaded contents. 29 | force_download (bool): If set to True, a call to download will always 30 | download, regardless of whether the target file already exists. 31 | chunk_size (int): The download chunk size, which is the maximum amount 32 | of bytes read per file-write. Default is 32*1024 (32KB) 33 | create_dirs (bool): Whether to create directories automatically if the 34 | target file should be downloaded in a non-existing directory. 35 | progress_fn (callable, optional): A callable function that reports the 36 | download progress after every download chunk. 37 | postprocess_fn (callable, optional): A callable function that is called 38 | when the download finished. Typical use cases include extracting or 39 | unzipping a downloaded archive. 40 | """ 41 | def __init__(self, url: str, target: str, 42 | sha256_checksum: Optional[str] = None, 43 | force_download: bool = False, chunk_size: int = 32*1024, 44 | create_dirs: bool = True, 45 | progress_fn: _PROGRESS_FN_TYPE = None, 46 | postprocess_fn: _POSTPROCESS_FN_TYPE = None, 47 | expected_files=None): 48 | """ 49 | Creates a downloader that downloads from a url. 50 | 51 | Unless force_download is set to True, this class will skip the download 52 | if destination already exists and its (optional) sha256 checksum 53 | matches. 54 | 55 | Args: 56 | url: The url to download from. 57 | target: The target name to save. 58 | sha256_checksum: sha256 checksum to validate against. 59 | force_download: Forces a re-dowload always. 60 | chunk_size: Chunk size to download in. 61 | create_dirs: Whether to create directories. 62 | progress_fn: A callable progress function. 63 | postprocess_fn: Function to call after successfull download. 64 | expected_files: A list of expected files for this downloader. Each 65 | entry in the list should be either a string indicating the path 66 | of the file or a dict containing a 'path' and 'sha256' key for 67 | the path and sha256 checksum of the file. 68 | """ 69 | self.url = url 70 | self.target = target 71 | self.sha256_checksum = sha256_checksum 72 | self.force_download = force_download 73 | self.chunk_size = chunk_size 74 | self.progress_fn = progress_fn 75 | self.create_dirs = create_dirs 76 | self.postprocess_fn = postprocess_fn 77 | self.expected_files = expected_files 78 | 79 | def download(self, destination: str): 80 | """ 81 | Downloads to given destination. Unless force_download is True, this 82 | will skip the download if the file already exists and its (optional) 83 | sha256 checksum matches. 84 | 85 | Args: 86 | destination: The destination to download to 87 | """ 88 | path = os.path.join(destination, self.target) 89 | if self.force_download or self._should_download(path): 90 | self._download(path) 91 | validate_file(path, self.sha256_checksum) 92 | if self.postprocess_fn is not None: 93 | self.postprocess_fn(path, destination) 94 | 95 | def _should_download(self, path: str): 96 | """ 97 | Checks if the file should be downloaded. 98 | 99 | Args: 100 | path: The path to check. 101 | """ 102 | try: 103 | validate_file(path, self.sha256_checksum) 104 | return False 105 | except (FileNotFoundError, ChecksumError): 106 | logging.debug("could not validate file at '%s'", path) 107 | return True 108 | 109 | def _download(self, path: str): 110 | """ 111 | Downloads the file from self.url to path. 112 | 113 | Args: 114 | path: The path to download the file to. 115 | """ 116 | logging.info("starting download from '%s' to '%s'", self.url, path) 117 | 118 | # Create directories 119 | if self.create_dirs: 120 | os.makedirs(os.path.dirname(path), exist_ok=True) 121 | 122 | # Open URL 123 | with urlopen(self.url) as response: 124 | 125 | # Get total size if response provides it (otherwise None) 126 | total_size = response.info().get("Content-Length") 127 | bytes_read = 0 128 | if total_size: 129 | total_size = int(total_size.strip()) 130 | 131 | # Download to destination 132 | with open(path, "wb") as f: 133 | self._progress(bytes_read, total_size, False) 134 | for chunk in iter(lambda: response.read(self.chunk_size), b""): 135 | f.write(chunk) 136 | bytes_read += len(chunk) 137 | self._progress(bytes_read, total_size, False) 138 | self._progress(bytes_read, total_size, True) 139 | 140 | def _progress(self, bytes_read: int, total_size: Optional[int], 141 | final: bool): 142 | """ 143 | Reports download progress to the progress hook if it exists. 144 | 145 | Args: 146 | bytes_read: How many bytes have been read so far. 147 | total_size: The total size of the download (None if unknown). 148 | final: True when the download finishes, False otherwise. 149 | """ 150 | if self.progress_fn is not None: 151 | self.progress_fn(bytes_read, total_size, final) 152 | 153 | 154 | class LoggingDownloadProgress(LoggingProgress): 155 | def __init__(self, interval=1.0): 156 | super().__init__(interval=interval, progress_str=_progress_string) 157 | 158 | 159 | class TerminalDownloadProgress(TerminalProgress): 160 | def __init__(self, interval=1.0): 161 | super().__init__(interval=interval, progress_str=_progress_string) 162 | 163 | 164 | def _progress_string(bytes_read: int, total_size: Optional[int], final: bool): 165 | """ 166 | Returns a human-readable string representing the download progress. 167 | 168 | Args: 169 | bytes_read: The number of bytes read so far. 170 | total_size: The total number of bytes to read or None if unknown. 171 | """ 172 | if final: 173 | return "finished downloading [%s]" % _to_human_readable(bytes_read) 174 | if total_size is None: 175 | return "downloading [%s / ?]" % _to_human_readable(bytes_read) 176 | else: 177 | percent = int((100.0 * bytes_read) / total_size) 178 | return "downloading %3d%% [%s / %s]" % ( 179 | percent, 180 | _to_human_readable(bytes_read), 181 | _to_human_readable(total_size)) 182 | 183 | 184 | def _to_human_readable(nr_of_bytes: int): 185 | """ 186 | Returns a human-readable string representation of given bytes. 187 | 188 | Args: 189 | nr_of_bytes: The number bytes 190 | """ 191 | # Convert to human readble byte format 192 | byte_unit = deque(["B", "KB", "MB", "GB", "TB"]) 193 | while len(byte_unit) > 1 and nr_of_bytes > 1024.0: 194 | byte_unit.popleft() 195 | nr_of_bytes /= 1024.0 196 | 197 | byte_unit = byte_unit.popleft() 198 | if nr_of_bytes < 10.0 and byte_unit != "B": 199 | return "%.1f%s" % (nr_of_bytes, byte_unit) 200 | return "%d%s" % (nr_of_bytes, byte_unit) 201 | 202 | 203 | # Set default progress hook depending on whether the stdout is a terminal. 204 | if sys.stdout.isatty(): 205 | DefaultDownloadProgress = TerminalDownloadProgress 206 | else: 207 | DefaultDownloadProgress = LoggingDownloadProgress 208 | -------------------------------------------------------------------------------- /tests/click_simulation/test_pbm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorchltr.click_simulation.pbm import simulate_perfect 3 | from pytorchltr.click_simulation.pbm import simulate_position 4 | from pytorchltr.click_simulation.pbm import simulate_nearrandom 5 | from pytest import approx 6 | 7 | 8 | def _generate_test_data(): 9 | """Generates a small batch of test data for simulations.""" 10 | rankings = torch.LongTensor([ 11 | [3, 4, 0, 2, 1], 12 | [1, 0, 2, 4, 3] 13 | ]) 14 | ys = torch.LongTensor([ 15 | [1, 0, 4, 0, 2], 16 | [4, 3, 0, 0, 0] 17 | ]) 18 | n = torch.LongTensor([5, 3]) 19 | return rankings, ys, n 20 | 21 | 22 | def _monte_carlo_simulation(rankings, ys, n, click_fn, nr=100): 23 | """Runs a click simulation repeatedly and reports averages. 24 | 25 | Arguments: 26 | rankings: The batch of rankings. 27 | ys: The batch of relevance labels. 28 | n: The nr of docs per query in the batch. 29 | click_fn: The click simulate function to call. 30 | nr: Number of simulations to run. 31 | 32 | Returns: 33 | The averaged output of `click_fn` across `nr` runs. 34 | """ 35 | clicks = simulate_perfect(rankings, ys, n) 36 | torch.manual_seed(4200) 37 | click_aggregate = torch.zeros_like(clicks[0]).to(torch.float) 38 | prop_aggregate = torch.zeros_like(clicks[1]).to(torch.float) 39 | for i in range(nr): 40 | clicks, props = click_fn(rankings, ys, n) 41 | click_aggregate += clicks.to(dtype=torch.float) 42 | prop_aggregate += props 43 | return click_aggregate / float(nr), prop_aggregate / float(nr) 44 | 45 | 46 | def test_perfect_clicks(): 47 | rankings, ys, n = _generate_test_data() 48 | clicks, props = _monte_carlo_simulation(rankings, ys, n, simulate_perfect) 49 | rel_expected = torch.FloatTensor([ 50 | [0.2, 0.0, 1.0, 0.0, 0.4], 51 | [1.0, 0.8, 0.0, 0.0, 0.0] 52 | ]) 53 | props_expected = torch.FloatTensor([ 54 | [1.0, 1.0, 1.0, 1.0, 1.0], 55 | [1.0, 1.0, 1.0, 0.0, 0.0] 56 | ]) 57 | clicks_expected = rel_expected * props_expected 58 | assert clicks_expected.numpy() == approx(clicks.numpy(), abs=0.1) 59 | assert props_expected.numpy() == approx(props.numpy(), abs=0.1) 60 | 61 | 62 | def test_perfect_clicks_cutoff_3(): 63 | rankings, ys, n = _generate_test_data() 64 | 65 | def click_fn(rankings, ys, n): 66 | return simulate_perfect(rankings, ys, n, cutoff=3) 67 | clicks, props = _monte_carlo_simulation(rankings, ys, n, click_fn) 68 | rel_expected = torch.FloatTensor([ 69 | [0.2, 0.0, 1.0, 0.0, 0.4], 70 | [1.0, 0.8, 0.0, 0.0, 0.0] 71 | ]) 72 | props_expected = torch.FloatTensor([ 73 | [1.0, 0.0, 0.0, 1.0, 1.0], 74 | [1.0, 1.0, 1.0, 0.0, 0.0] 75 | ]) 76 | clicks_expected = rel_expected * props_expected 77 | assert clicks_expected.numpy() == approx(clicks.numpy(), abs=0.1) 78 | assert props_expected.numpy() == approx(props.numpy(), abs=0.1) 79 | 80 | 81 | def test_perfect_clicks_cutoff_2(): 82 | rankings, ys, n = _generate_test_data() 83 | 84 | def click_fn(rankings, ys, n): 85 | return simulate_perfect(rankings, ys, n, cutoff=2) 86 | clicks, props = _monte_carlo_simulation(rankings, ys, n, click_fn) 87 | rel_expected = torch.FloatTensor([ 88 | [0.2, 0.0, 1.0, 0.0, 0.4], 89 | [1.0, 0.8, 0.0, 0.0, 0.0] 90 | ]) 91 | props_expected = torch.FloatTensor([ 92 | [0.0, 0.0, 0.0, 1.0, 1.0], 93 | [1.0, 1.0, 0.0, 0.0, 0.0] 94 | ]) 95 | clicks_expected = rel_expected * props_expected 96 | assert clicks_expected.numpy() == approx(clicks.numpy(), abs=0.1) 97 | assert props_expected.numpy() == approx(props.numpy(), abs=0.1) 98 | 99 | 100 | def test_position_clicks(): 101 | rankings, ys, n = _generate_test_data() 102 | clicks, props = _monte_carlo_simulation(rankings, ys, n, simulate_position) 103 | rel_expected = torch.FloatTensor([ 104 | [0.1, 0.1, 1.0, 0.1, 0.1], 105 | [1.0, 1.0, 0.1, 0.1, 0.1] 106 | ]) 107 | props_expected = torch.FloatTensor([ 108 | [1/4.0, 1/6.0, 1/5.0, 1/2.0, 1/3.0], 109 | [1/3.0, 1/2.0, 1/4.0, 0.0, 0.0] 110 | ]) 111 | clicks_expected = rel_expected * props_expected 112 | assert clicks_expected.numpy() == approx(clicks.numpy(), abs=0.1) 113 | assert props_expected.numpy() == approx(props.numpy(), abs=0.1) 114 | 115 | 116 | def test_position_clicks_eta_2(): 117 | rankings, ys, n = _generate_test_data() 118 | 119 | def click_fn(rankings, ys, n): 120 | return simulate_position(rankings, ys, n, eta=2.0) 121 | clicks, props = _monte_carlo_simulation(rankings, ys, n, click_fn) 122 | rel_expected = torch.FloatTensor([ 123 | [0.1, 0.1, 1.0, 0.1, 0.1], 124 | [1.0, 1.0, 0.1, 0.1, 0.1] 125 | ]) 126 | props_expected = torch.FloatTensor([ 127 | [1/4.0, 1/6.0, 1/5.0, 1/2.0, 1/3.0], 128 | [1/3.0, 1/2.0, 1/4.0, 0.0, 0.0] 129 | ]) ** 2.0 130 | clicks_expected = rel_expected * props_expected 131 | assert clicks_expected.numpy() == approx(clicks.numpy(), abs=0.1) 132 | assert props_expected.numpy() == approx(props.numpy(), abs=0.1) 133 | 134 | 135 | def test_position_clicks_eta_0(): 136 | rankings, ys, n = _generate_test_data() 137 | 138 | def click_fn(rankings, ys, n): 139 | return simulate_position(rankings, ys, n, eta=0.0) 140 | clicks, props = _monte_carlo_simulation(rankings, ys, n, click_fn) 141 | rel_expected = torch.FloatTensor([ 142 | [0.1, 0.1, 1.0, 0.1, 0.1], 143 | [1.0, 1.0, 0.1, 0.1, 0.1] 144 | ]) 145 | props_expected = torch.FloatTensor([ 146 | [1/4.0, 1/6.0, 1/5.0, 1/2.0, 1/3.0], 147 | [1/3.0, 1/2.0, 1/4.0, 0.0, 0.0] 148 | ]) ** 0.0 149 | props_expected[1, 3] = 0.0 150 | props_expected[1, 4] = 0.0 151 | clicks_expected = rel_expected * props_expected 152 | assert clicks_expected.numpy() == approx(clicks.numpy(), abs=0.1) 153 | assert props_expected.numpy() == approx(props.numpy(), abs=0.1) 154 | 155 | 156 | def test_position_clicks_cutoff_3(): 157 | rankings, ys, n = _generate_test_data() 158 | 159 | def click_fn(rankings, ys, n): 160 | return simulate_position(rankings, ys, n, cutoff=3) 161 | clicks, props = _monte_carlo_simulation(rankings, ys, n, click_fn) 162 | rel_expected = torch.FloatTensor([ 163 | [0.1, 0.1, 1.0, 0.1, 0.1], 164 | [1.0, 1.0, 0.1, 0.1, 0.1] 165 | ]) 166 | props_expected = torch.FloatTensor([ 167 | [1/4.0, 0.0, 0.0, 1/2.0, 1/3.0], 168 | [1/3.0, 1/2.0, 1/4.0, 0.0, 0.0] 169 | ]) 170 | clicks_expected = rel_expected * props_expected 171 | assert clicks_expected.numpy() == approx(clicks.numpy(), abs=0.1) 172 | assert props_expected.numpy() == approx(props.numpy(), abs=0.1) 173 | 174 | 175 | def test_nearrandom_clicks(): 176 | rankings, ys, n = _generate_test_data() 177 | clicks, props = _monte_carlo_simulation( 178 | rankings, ys, n, simulate_nearrandom) 179 | rel_expected = torch.FloatTensor([ 180 | [0.45, 0.4, 0.6, 0.4, 0.5], 181 | [0.6, 0.55, 0.4, 0.4, 0.4] 182 | ]) 183 | props_expected = torch.FloatTensor([ 184 | [1/4.0, 1/6.0, 1/5.0, 1/2.0, 1/3.0], 185 | [1/3.0, 1/2.0, 1/4.0, 0.0, 0.0] 186 | ]) 187 | clicks_expected = rel_expected * props_expected 188 | assert clicks_expected.numpy() == approx(clicks.numpy(), abs=0.1) 189 | assert props_expected.numpy() == approx(props.numpy(), abs=0.1) 190 | 191 | 192 | def test_nearrandom_clicks_eta_2(): 193 | rankings, ys, n = _generate_test_data() 194 | 195 | def click_fn(rankings, ys, n): 196 | return simulate_nearrandom(rankings, ys, n, eta=2.0) 197 | clicks, props = _monte_carlo_simulation(rankings, ys, n, click_fn) 198 | rel_expected = torch.FloatTensor([ 199 | [0.45, 0.4, 0.6, 0.4, 0.5], 200 | [0.6, 0.55, 0.4, 0.4, 0.4] 201 | ]) 202 | props_expected = torch.FloatTensor([ 203 | [1/4.0, 1/6.0, 1/5.0, 1/2.0, 1/3.0], 204 | [1/3.0, 1/2.0, 1/4.0, 0.0, 0.0] 205 | ]) ** 2.0 206 | clicks_expected = rel_expected * props_expected 207 | assert clicks_expected.numpy() == approx(clicks.numpy(), abs=0.1) 208 | assert props_expected.numpy() == approx(props.numpy(), abs=0.1) 209 | 210 | 211 | def test_nearrandom_clicks_eta_0(): 212 | rankings, ys, n = _generate_test_data() 213 | 214 | def click_fn(rankings, ys, n): 215 | return simulate_nearrandom(rankings, ys, n, eta=0.0) 216 | clicks, props = _monte_carlo_simulation(rankings, ys, n, click_fn) 217 | rel_expected = torch.FloatTensor([ 218 | [0.45, 0.4, 0.6, 0.4, 0.5], 219 | [0.6, 0.55, 0.4, 0.4, 0.4] 220 | ]) 221 | props_expected = torch.FloatTensor([ 222 | [1/4.0, 1/6.0, 1/5.0, 1/2.0, 1/3.0], 223 | [1/3.0, 1/2.0, 1/4.0, 0.0, 0.0] 224 | ]) ** 0.0 225 | props_expected[1, 3] = 0.0 226 | props_expected[1, 4] = 0.0 227 | clicks_expected = rel_expected * props_expected 228 | assert clicks_expected.numpy() == approx(clicks.numpy(), abs=0.1) 229 | assert props_expected.numpy() == approx(props.numpy(), abs=0.1) 230 | 231 | 232 | def test_nearrandom_clicks_cutoff_3(): 233 | rankings, ys, n = _generate_test_data() 234 | 235 | def click_fn(rankings, ys, n): 236 | return simulate_nearrandom(rankings, ys, n, cutoff=3) 237 | clicks, props = _monte_carlo_simulation(rankings, ys, n, click_fn) 238 | rel_expected = torch.FloatTensor([ 239 | [0.45, 0.4, 0.6, 0.4, 0.5], 240 | [0.6, 0.55, 0.4, 0.4, 0.4] 241 | ]) 242 | props_expected = torch.FloatTensor([ 243 | [1/4.0, 0.0, 0.0, 1/2.0, 1/3.0], 244 | [1/3.0, 1/2.0, 1/4.0, 0.0, 0.0] 245 | ]) 246 | clicks_expected = rel_expected * props_expected 247 | assert clicks_expected.numpy() == approx(clicks.numpy(), abs=0.1) 248 | assert props_expected.numpy() == approx(props.numpy(), abs=0.1) 249 | -------------------------------------------------------------------------------- /pytorchltr/loss/pairwise_lambda.py: -------------------------------------------------------------------------------- 1 | import torch as _torch 2 | from pytorchltr.utils import batch_pairs as _batch_pairs 3 | from pytorchltr.utils import rank_by_score as _rank_by_score 4 | 5 | 6 | class LambdaLoss(_torch.nn.Module): 7 | """LambdaLoss.""" 8 | def __init__(self, sigma: float = 1.0): 9 | """ 10 | Args: 11 | sigma: Steepness of the logistic curve. 12 | """ 13 | super().__init__() 14 | self.sigma = sigma 15 | 16 | def _loss_per_doc_pair(self, score_pairs: _torch.FloatTensor, 17 | rel_pairs: _torch.LongTensor, 18 | n: _torch.LongTensor) -> _torch.FloatTensor: 19 | """Computes a loss on given score pairs and relevance pairs. 20 | 21 | Args: 22 | score_pairs: A tensor of shape (batch_size, list_size, 23 | list_size, 2), where each entry (:, i, j, :) indicates a pair 24 | of scores for doc i and j. 25 | rel_pairs: A tensor of shape (batch_size, list_size, list_size, 2), 26 | where each entry (:, i, j, :) indicates the relevance 27 | for doc i and j. 28 | n: A batch of per-query number of documents (for padding purposes). 29 | 30 | Returns: 31 | A tensor of shape (batch_size, list_size, list_size) with a loss 32 | per document pair. 33 | """ 34 | raise NotImplementedError 35 | 36 | def _loss_reduction(self, 37 | loss_pairs: _torch.FloatTensor) -> _torch.FloatTensor: 38 | """Reduces the paired loss to a per sample loss. 39 | 40 | Args: 41 | loss_pairs: A tensor of shape (batch_size, list_size, list_size) 42 | where each entry indicates the loss of doc pair i and j. 43 | 44 | Returns: 45 | A tensor of shape (batch_size) indicating the loss per training 46 | sample. 47 | """ 48 | return loss_pairs.view(loss_pairs.shape[0], -1).sum(1) 49 | 50 | def forward(self, scores: _torch.FloatTensor, relevance: _torch.LongTensor, 51 | n: _torch.LongTensor) -> _torch.FloatTensor: 52 | """Computes the loss for given batch of samples. 53 | 54 | Args: 55 | scores: A batch of per-query-document scores. 56 | relevance: A batch of per-query-document relevance labels. 57 | n: A batch of per-query number of documents (for padding purposes). 58 | """ 59 | # Reshape relevance if necessary. 60 | if relevance.ndimension() == 2: 61 | relevance = relevance.reshape( 62 | (relevance.shape[0], relevance.shape[1], 1)) 63 | if scores.ndimension() == 2: 64 | scores = scores.reshape((scores.shape[0], scores.shape[1], 1)) 65 | 66 | # Compute ranking and sort scores and relevance 67 | ranking = _rank_by_score(scores, n) 68 | ranking = ranking.view((ranking.shape[0], ranking.shape[1], 1)) 69 | scores = _torch.gather(scores, 1, ranking) 70 | relevance = _torch.gather(relevance, 1, ranking) 71 | 72 | # Compute pairwise differences for scores and relevances. 73 | score_pairs = _batch_pairs(scores) 74 | rel_pairs = _batch_pairs(relevance) 75 | 76 | # Compute loss per doc pair. 77 | loss_pairs = self._loss_per_doc_pair(score_pairs, rel_pairs, n) 78 | 79 | # Mask out padded documents per query in the batch 80 | n_grid = n[:, None, None].repeat(1, score_pairs.shape[1], 81 | score_pairs.shape[2]) 82 | arange = _torch.arange(score_pairs.shape[1], 83 | device=score_pairs.device) 84 | range_grid = _torch.max(*_torch.meshgrid([arange, arange])) 85 | range_grid = range_grid[None, :, :].repeat(n.shape[0], 1, 1) 86 | loss_pairs[n_grid <= range_grid] = 0.0 87 | 88 | # Reduce final list loss from per doc pair loss to a per query loss. 89 | loss = self._loss_reduction(loss_pairs) 90 | 91 | # Return loss 92 | return loss 93 | 94 | 95 | class LambdaARPLoss1(LambdaLoss): 96 | r"""ARP Loss 1: 97 | 98 | .. math:: 99 | l(\mathbf{s}, \mathbf{y}) 100 | = -\sum_{i=1}^n \sum_{j=1}^n \log_2 101 | \left( 102 | \frac{1}{1 + e^{-\sigma (s_{\pi_i} - s_{\pi_j})}} 103 | \right)^{y_{\pi_i}} 104 | 105 | where :math:`\pi_i` is the index of the item at rank :math:`i` after 106 | sorting the scores 107 | 108 | Shape: 109 | - input scores: :math:`(N, \texttt{list_size})` 110 | - input relevance: :math:`(N, \texttt{list_size})` 111 | - input n: :math:`(N)` 112 | - output: :math:`(N)` 113 | """ 114 | def _loss_per_doc_pair(self, score_pairs, rel_pairs, n): 115 | score_diffs = score_pairs[:, :, :, 0] - score_pairs[:, :, :, 1] 116 | sigmoid = (1.0 / (1.0 + _torch.exp(-self.sigma * score_diffs))) 117 | return -(_torch.log2(sigmoid ** rel_pairs[:, :, :, 0])) 118 | 119 | 120 | class LambdaARPLoss2(LambdaLoss): 121 | r""" 122 | ARP Loss 2: 123 | 124 | .. math:: 125 | l(\mathbf{s}, \mathbf{y}) = \sum_{y_i > y_j} |y_i - y_j| \log_2 \left( 126 | 1 + e^{-\sigma(s_i - s_j)} 127 | \right) 128 | 129 | Shape: 130 | - input scores: :math:`(N, \texttt{list_size})` 131 | - input relevance: :math:`(N, \texttt{list_size})` 132 | - input n: :math:`(N)` 133 | - output: :math:`(N)` 134 | """ 135 | def _loss_per_doc_pair(self, score_pairs, rel_pairs, n): 136 | score_diffs = score_pairs[:, :, :, 0] - score_pairs[:, :, :, 1] 137 | rel_diffs = rel_pairs[:, :, :, 0] - rel_pairs[:, :, :, 1] 138 | loss = _torch.log2(1.0 + _torch.exp(-self.sigma * score_diffs)) 139 | loss[rel_diffs <= 0] = 0.0 140 | return rel_diffs * loss 141 | 142 | 143 | class LambdaNDCGLoss1(LambdaLoss): 144 | r""" 145 | NDCG Loss 1: 146 | 147 | .. math:: 148 | l(\mathbf{s}, \mathbf{y}) 149 | = -\sum_{i=1}^n \sum_{j=1}^n \log_2 150 | \left( 151 | \frac{1}{1 + e^{-\sigma (s_{\pi_i} - s_{\pi_j})}} 152 | \right)^{\frac{G_{\pi_i}}{D_i}} 153 | 154 | where :math:`\pi_i` is the index of the item at rank :math:`i` after 155 | sorting the scores and 156 | :math:`G_{\pi_i} = \frac{2^{y_{\pi_i}} - 1}{\text{maxDCG}}` and 157 | :math:`D_i = \log_2(1 + i)`. 158 | 159 | Shape: 160 | - input scores: :math:`(N, \texttt{list_size})` 161 | - input relevance: :math:`(N, \texttt{list_size})` 162 | - input n: :math:`(N)` 163 | - output: :math:`(N)` 164 | """ 165 | def _loss_per_doc_pair(self, score_pairs, rel_pairs, n): 166 | score_diffs = score_pairs[:, :, :, 0] - score_pairs[:, :, :, 1] 167 | gains = _ndcg_gains(score_pairs, rel_pairs, n)[:, :, :, 0] 168 | arange = _torch.arange(score_pairs.shape[1], 169 | device=score_pairs.device) 170 | discounts = _torch.log2(2.0 + arange) 171 | exponent = gains / discounts[None, :, None] 172 | sigmoid = (1.0 / (1.0 + _torch.exp(-self.sigma * score_diffs))) 173 | return -(_torch.log2(sigmoid ** exponent)) 174 | 175 | 176 | class LambdaNDCGLoss2(LambdaLoss): 177 | r""" 178 | NDCG Loss 2: 179 | 180 | .. math:: 181 | l(\mathbf{s}, \mathbf{y}) = \sum_{y_i > y_j} \log_2 182 | \left( 183 | \frac{1}{1 + e^{-\sigma (s_{\pi_i} - s_{\pi_j})}} 184 | \right)^{\delta_{ij} | G_{\pi_i} - G_{\pi_j} |} 185 | 186 | where :math:`\pi_i` is the index of the item at rank :math:`i` after 187 | sorting the scores and 188 | :math:`G_{\pi_i} = \frac{2^{y_{\pi_i}} - 1}{\text{maxDCG}}` and 189 | :math:`\delta_{ij} = \left|\frac{1}{D_{|i-j|}} - \frac{1}{D_{|i-j|+1}} 190 | \right|` and :math:`D_i = \log_2(1 + i)`. 191 | 192 | Shape: 193 | - input scores: :math:`(N, \texttt{list_size})` 194 | - input relevance: :math:`(N, \texttt{list_size})` 195 | - input n: :math:`(N)` 196 | - output: :math:`(N)` 197 | """ 198 | def _loss_per_doc_pair(self, score_pairs, rel_pairs, n): 199 | # Compute diffs for different parts of the loss function 200 | score_diffs = score_pairs[:, :, :, 0] - score_pairs[:, :, :, 1] 201 | rel_diffs = rel_pairs[:, :, :, 0] - rel_pairs[:, :, :, 1] 202 | gains = _ndcg_gains(score_pairs, rel_pairs, n) 203 | gain_diffs = gains[:, :, :, 0] - gains[:, :, :, 1] 204 | 205 | # Compute delta_{i, j} tensor 206 | arange = _torch.arange(score_pairs.shape[1] + 1, 207 | device=score_pairs.device) 208 | discounts = _torch.log2(2.0 + arange) 209 | idx1 = _torch.abs(arange[:-1, None] - arange[None, :-1]) 210 | idx2 = idx1 + 1 211 | delta = _torch.abs(1.0 / discounts[idx1] - 1.0 / discounts[idx2]) 212 | 213 | # Compute final loss 214 | exponent = delta[None, :, :] * _torch.abs(gain_diffs) 215 | sigmoid = (1.0 / (1.0 + _torch.exp(-self.sigma * score_diffs))) 216 | loss = _torch.log2(sigmoid ** exponent) 217 | loss[rel_diffs <= 0] = 0.0 218 | return -loss 219 | 220 | 221 | def _ndcg_gains(score_pairs: _torch.FloatTensor, rel_pairs: _torch.LongTensor, 222 | n: _torch.LongTensor, exp: bool = True) -> _torch.FloatTensor: 223 | gains = rel_pairs[:, :, :, :] 224 | if exp: 225 | gains = (2 ** gains) - 1.0 226 | max_dcg = _max_dcg(rel_pairs[:, :, 0, 0], n, exp) 227 | max_dcg[max_dcg == 0.0] = 1.0 228 | return gains / max_dcg[:, None, None, None] 229 | 230 | 231 | def _max_dcg(relevance: _torch.FloatTensor, n: _torch.LongTensor, 232 | exp: bool = True) -> _torch.FloatTensor: 233 | ranking = _rank_by_score(relevance.double(), n) 234 | arange = _torch.arange(ranking.shape[1], 235 | device=relevance.device) 236 | discounts = _torch.log2(2.0 + arange) 237 | gains = _torch.gather(relevance, 1, ranking) 238 | gains[n[:, None] <= arange[None, :]] = 0.0 239 | if exp: 240 | gains = (2 ** gains) - 1.0 241 | return _torch.sum(gains / discounts[None, :], dim=1) 242 | -------------------------------------------------------------------------------- /pytorchltr/datasets/svmrank/svmrank.py: -------------------------------------------------------------------------------- 1 | """Data loading for SVMRank-style data sets.""" 2 | from typing import Callable 3 | from typing import List 4 | from typing import Optional 5 | from typing import Union 6 | 7 | import numpy as _np 8 | import torch as _torch 9 | import logging 10 | 11 | from scipy.sparse import coo_matrix as _coo_matrix 12 | from sklearn.datasets import load_svmlight_file as _load_svmlight_file 13 | from torch.utils.data import Dataset as _Dataset 14 | from pytorchltr.datasets.list_sampler import ListSampler 15 | from pytorchltr.datasets.svmrank.parser import parse_svmrank_file 16 | 17 | 18 | class SVMRankItem: 19 | """A single item from a 20 | :obj:`pytorchltr.datasets.svmrank.SVMRankDataset`.""" 21 | def __init__(self, features: _torch.FloatTensor, 22 | relevance: _torch.LongTensor, n: int, qid: int, sparse: bool): 23 | self.features = features 24 | self.relevance = relevance 25 | self.n = n 26 | self.qid = qid 27 | self.sparse = sparse 28 | 29 | 30 | class SVMRankBatch: 31 | """A batch of items from a 32 | :obj:`pytorchltr.datasets.svmrank.SVMRankDataset`.""" 33 | def __init__(self, features: _torch.FloatTensor, 34 | relevance: _torch.LongTensor, n: _torch.LongTensor, 35 | qid: _torch.LongTensor, sparse: bool): 36 | self.features = features 37 | self.relevance = relevance 38 | self.n = n 39 | self.qid = qid 40 | self.sparse = sparse 41 | 42 | 43 | _COLLATE_RETURN_TYPE = Callable[[List[SVMRankItem]], SVMRankBatch] 44 | 45 | 46 | class SVMRankDataset(_Dataset): 47 | def __init__(self, file: str, sparse: bool = False, 48 | normalize: bool = False, filter_queries: bool = False, 49 | zero_based: Union[str, int] = "auto"): 50 | """Creates an SVMRank-style dataset from a file. 51 | 52 | Args: 53 | file: The path to load the dataset from. 54 | sparse: Whether to load the features as sparse features. 55 | normalize: Whether to perform query-level normalization (requires 56 | non-sparse features). 57 | filter_queries: Whether to filter queries that have no relevant 58 | documents associated with them. 59 | zero_based: The zero based index. 60 | """ 61 | logging.info("loading svmrank dataset from %s", file) 62 | 63 | # Load svmlight file 64 | if not sparse: 65 | # Use faster cython dense parser 66 | self._xs, self._ys, qids = parse_svmrank_file(file) 67 | else: 68 | # Use sklearn's sparse-support parser 69 | self._xs, self._ys, qids = _load_svmlight_file( 70 | file, query_id=True, zero_based=zero_based) 71 | 72 | # Compute query offsets and unique qids 73 | self._offsets = _np.hstack( 74 | [[0], _np.where(qids[1:] != qids[:-1])[0] + 1, [len(qids)]]) 75 | self._unique_qids = qids[self._offsets[:-1]] 76 | 77 | # Densify 78 | self._sparse = sparse 79 | # if not sparse: 80 | # self._xs = self._xs.A 81 | 82 | # Normalize xs 83 | if normalize: 84 | if sparse: 85 | raise NotImplementedError( 86 | "Normalization without dense features is not supported.") 87 | self._normalize() 88 | 89 | # Filter queries without any relevant documents 90 | if filter_queries: 91 | indices = [] 92 | for i, (start, end) in enumerate(zip(self._offsets[:-1], 93 | self._offsets[1:])): 94 | if _np.sum(self._ys[start:end]) > 0.0: 95 | indices.append(i) 96 | self._indices = _np.array(indices) 97 | else: 98 | self._indices = _np.arange(len(self._unique_qids)) 99 | 100 | # Compute qid map and dataset length 101 | self._qid_map = { 102 | self._unique_qids[self._indices[index]]: index 103 | for index in range(len(self._indices)) 104 | } 105 | self._n = len(self._indices) 106 | 107 | def _normalize(self): 108 | """Performs query-level feature normalization on the dataset.""" 109 | for start, end in zip(self._offsets[:-1], self._offsets[1:]): 110 | self._xs[start:end, :] -= _np.min(self._xs[start:end, :], axis=0) 111 | m = _np.max(self._xs[start:end, :], axis=0) 112 | m[m == 0.0] = 1.0 113 | self._xs[start:end, :] /= m 114 | 115 | def get_index(self, qid: int) -> int: 116 | """Returns the dataset item index for given qid (if it exists). 117 | 118 | Args: 119 | qid: The qid to look up. 120 | 121 | Returns: 122 | The corresponding the dataset index for given qid. 123 | """ 124 | return self._qid_map[qid] 125 | 126 | @staticmethod 127 | def collate_fn(list_sampler: Optional[ListSampler] = None) -> _COLLATE_RETURN_TYPE: # noqa: E501 128 | r"""Returns a collate_fn that can be used to collate batches. 129 | Args: 130 | list_sampler: Sampler to use for sampling lists of documents. 131 | """ 132 | if list_sampler is None: 133 | list_sampler = ListSampler() 134 | 135 | def _collate_fn(batch: List[SVMRankItem]) -> SVMRankBatch: 136 | # Check if batch is sparse or not 137 | sparse = batch[0].sparse 138 | 139 | # Compute list size 140 | list_size = max([list_sampler.max_list_size(b.relevance) 141 | for b in batch]) 142 | 143 | # Create output tensors from batch 144 | if sparse: 145 | out_features = [] 146 | else: 147 | out_features = _torch.zeros( 148 | (len(batch), list_size, batch[0].features.shape[1])) 149 | out_relevance = _torch.zeros( 150 | (len(batch), list_size), dtype=_torch.long) 151 | out_qid = _torch.zeros(len(batch), dtype=_torch.long) 152 | out_n = _torch.zeros(len(batch), dtype=_torch.long) 153 | 154 | # Collate the whole batch 155 | for batch_index, sample in enumerate(batch): 156 | 157 | # Generate random indices when we exceed the list_size. 158 | xs = sample.features 159 | if xs.shape[0] > list_size: 160 | rng_indices = list_sampler(sample.relevance) 161 | 162 | # Collate features 163 | if sparse: 164 | xs_coalesce = xs.coalesce() 165 | ind = xs_coalesce.indices() 166 | val = xs_coalesce.values() 167 | if xs.shape[0] > list_size: 168 | mask = [ind[0, :] == i for i in rng_indices] 169 | for i in range(len(mask)): 170 | ind[0, mask[i]] = int(i) 171 | ind = ind[:, sum(mask)] 172 | val = val[sum(mask)] 173 | ind_l = _torch.ones((1, ind.shape[1]), 174 | dtype=ind.dtype) * batch_index 175 | ind = _torch.cat([ind_l, ind], dim=0) 176 | out_features.append((ind, val)) 177 | else: 178 | if xs.shape[0] > list_size: 179 | out_features[batch_index, :, :] = xs[rng_indices, :] 180 | else: 181 | out_features[batch_index, 0:xs.shape[0], :] = xs 182 | 183 | # Collate relevance 184 | if xs.shape[0] > list_size: 185 | rel = sample.relevance[rng_indices] 186 | rel_n = len(rng_indices) 187 | out_relevance[batch_index, 0:rel_n] = rel 188 | else: 189 | rel = sample.relevance 190 | rel_n = len(sample.relevance) 191 | out_relevance[batch_index, 0:rel_n] = rel 192 | 193 | # Collate qid and n 194 | out_qid[batch_index] = int(sample.qid) 195 | out_n[batch_index] = min(int(sample.n), list_size) 196 | 197 | if sparse: 198 | ind = _torch.cat([d[0] for d in out_features], dim=1) 199 | val = _torch.cat([d[1] for d in out_features], dim=0) 200 | size = (len(batch), list_size, batch[0].features.shape[1]) 201 | out_features = _torch.sparse.FloatTensor( 202 | ind, val, _torch.Size(size)) 203 | 204 | return SVMRankBatch(out_features, out_relevance, out_n, out_qid, 205 | sparse) 206 | 207 | return _collate_fn 208 | 209 | def __getitem__(self, index: int) -> SVMRankItem: 210 | r""" 211 | Returns the item at given index. 212 | 213 | Args: 214 | index (int): The index. 215 | 216 | Returns: 217 | A :obj:`pytorchltr.datasets.svmrank.SVMRankItem` that contains 218 | features, relevance, qid, n and sparse fields. 219 | 220 | """ 221 | # Extract query features and relevance labels 222 | qid = self._unique_qids[self._indices[index]] 223 | start = self._offsets[self._indices[index]] 224 | end = self._offsets[self._indices[index] + 1] 225 | features = self._xs[start:end, :] 226 | y = _torch.LongTensor(self._ys[start:end]) 227 | n = end - start 228 | 229 | # Compute sparse or dense torch tensor 230 | if self._sparse: 231 | coo = _coo_matrix(features) 232 | ind = _torch.LongTensor(_np.vstack((coo.row, coo.col))) 233 | val = _torch.FloatTensor(coo.data) 234 | features = _torch.sparse.FloatTensor( 235 | ind, val, _torch.Size(coo.shape)) 236 | else: 237 | features = _torch.FloatTensor(features) 238 | 239 | # Return data sample 240 | return SVMRankItem(features, y, n, qid, self._sparse) 241 | 242 | def __len__(self) -> int: 243 | r""" 244 | Returns: 245 | int: 246 | The length of the dataset. 247 | """ 248 | return self._n 249 | -------------------------------------------------------------------------------- /tests/utils/test_downloader.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import os 3 | import contextlib 4 | import pytest 5 | from unittest import mock 6 | from pytorchltr.utils.downloader import Downloader 7 | from pytorchltr.utils.downloader import LoggingDownloadProgress 8 | from pytorchltr.utils.downloader import TerminalDownloadProgress 9 | from pytorchltr.utils.file import ChecksumError 10 | 11 | 12 | @contextlib.contextmanager 13 | def mock_urlopen(read, info): 14 | urlopen_fn_to_mock = "pytorchltr.utils.downloader.urlopen" 15 | with mock.patch(urlopen_fn_to_mock) as urlopen_mock: 16 | urlopen_obj = mock.MagicMock() 17 | urlopen_obj.read.side_effect = read 18 | urlopen_obj.info.return_value = info 19 | urlopen_obj.__enter__.return_value = urlopen_obj 20 | urlopen_obj.__exit__.return_value = None 21 | urlopen_mock.return_value = urlopen_obj 22 | yield 23 | 24 | 25 | def test_basic_download(): 26 | with mock_urlopen(read=[b"mocked", b"content"], info={}): 27 | with tempfile.TemporaryDirectory() as tmpdir: 28 | downloader = Downloader("http://mocked", "file.dat") 29 | downloader.download(tmpdir) 30 | with open(os.path.join(tmpdir, "file.dat"), "rb") as f: 31 | assert f.read() == b"mockedcontent" 32 | 33 | 34 | def test_empty_download(): 35 | with mock_urlopen(read=[], info={}): 36 | with tempfile.TemporaryDirectory() as tmpdir: 37 | downloader = Downloader("http://mocked", "file.dat") 38 | downloader.download(tmpdir) 39 | with open(os.path.join(tmpdir, "file.dat"), "rb") as f: 40 | assert f.read() == b"" 41 | 42 | 43 | def test_download_sha256_succeeds(): 44 | sha256 = "d6185849aea1f15d847aba8b45cf54da8f9bc5895484d2f04830941f68148864" 45 | with mock_urlopen(read=[b"mocked", b"content"], info={}): 46 | with tempfile.TemporaryDirectory() as tmpdir: 47 | downloader = Downloader( 48 | "http://mocked", "file.dat", 49 | sha256_checksum=sha256) 50 | downloader.download(tmpdir) 51 | 52 | 53 | def test_download_sha256_fails(): 54 | sha256 = "f13b7a15b26814ac2e5c8ab37ef135623ee0aed1cac5cc583e1b99fa688e0e29" 55 | with mock_urlopen(read=[b"mocked", b"content"], info={}): 56 | with tempfile.TemporaryDirectory() as tmpdir: 57 | downloader = Downloader( 58 | "http://mocked", "file.dat", 59 | sha256_checksum=sha256) 60 | with pytest.raises(ChecksumError): 61 | downloader.download(tmpdir) 62 | 63 | 64 | def test_download_only_once(): 65 | with mock_urlopen(read=[b"mocked", b"content"], info={}): 66 | with tempfile.TemporaryDirectory() as tmpdir: 67 | progress_fn = mock.MagicMock() 68 | downloader = Downloader("http://mocked", "file.dat", 69 | progress_fn=progress_fn) 70 | 71 | # Should trigger download and 4 progress updates 72 | assert progress_fn.call_count == 0 73 | downloader.download(tmpdir) 74 | assert progress_fn.call_count == 4 75 | 76 | # Should NOT trigger download and no more progress updates 77 | downloader.download(tmpdir) 78 | assert progress_fn.call_count == 4 79 | 80 | 81 | def test_download_postprocess_fn(): 82 | with mock_urlopen(read=[b"mocked", b"content"], info={}): 83 | with tempfile.TemporaryDirectory() as tmpdir: 84 | postprocess_fn = mock.MagicMock() 85 | downloader = Downloader("http://mocked", "file.dat", 86 | postprocess_fn=postprocess_fn) 87 | downloader.download(tmpdir) 88 | postprocess_fn.assert_has_calls([ 89 | mock.call(os.path.join(tmpdir, "file.dat"), tmpdir) 90 | ]) 91 | 92 | 93 | def test_download_progress_with_known_length(): 94 | with mock_urlopen(read=[b"mocked", b"content"], 95 | info={"Content-Length": "13"}): 96 | with tempfile.TemporaryDirectory() as tmpdir: 97 | progress_fn = mock.MagicMock() 98 | downloader = Downloader("http://mocked", "file.dat", 99 | progress_fn=progress_fn) 100 | downloader.download(tmpdir) 101 | progress_fn.assert_has_calls([ 102 | mock.call(0, 13, False), # Initial progress (0 bytes) 103 | mock.call(6, 13, False), # Read 6 bytes b"mocked" 104 | mock.call(13, 13, False), # Read 7 bytes b"content" 105 | mock.call(13, 13, True) # Final call should be True 106 | ]) 107 | 108 | 109 | def test_download_progress_with_unknown_length(): 110 | with mock_urlopen(read=[b"moc", b"ked", b"con", b"ten", b"t"], info={}): 111 | with tempfile.TemporaryDirectory() as tmpdir: 112 | progress_fn = mock.MagicMock() 113 | downloader = Downloader("http://mocked", "file.dat", 114 | progress_fn=progress_fn) 115 | downloader.download(tmpdir) 116 | progress_fn.assert_has_calls([ 117 | mock.call(0, None, False), # Initial progress (0 bytes) 118 | mock.call(3, None, False), # Read 3 bytes b"moc" 119 | mock.call(6, None, False), # Read 3 bytes b"ked" 120 | mock.call(9, None, False), # Read 3 bytes b"con" 121 | mock.call(12, None, False), # Read 3 bytes b"ten" 122 | mock.call(13, None, False), # Read 1 byte b"t" 123 | mock.call(13, None, True) # Final call should be True 124 | ]) 125 | 126 | 127 | def test_download_logging_progress_with_known_length(): 128 | logging_fn_to_mock = "pytorchltr.utils.progress.logging" 129 | with mock.patch(logging_fn_to_mock) as logging_mock: 130 | with mock_urlopen(read=[b"mocked", b"content"], 131 | info={"Content-Length": "13"}): 132 | with tempfile.TemporaryDirectory() as tmpdir: 133 | downloader = Downloader( 134 | "http://mocked", "file.dat", 135 | progress_fn=LoggingDownloadProgress(interval=0.0)) 136 | downloader.download(tmpdir) 137 | logging_mock.info.assert_has_calls([ 138 | mock.call("downloading 0% [0B / 13B]"), 139 | mock.call("downloading 46% [6B / 13B]"), 140 | mock.call("downloading 100% [13B / 13B]"), 141 | mock.call("finished downloading [13B]"), 142 | ]) 143 | 144 | 145 | def test_download_logging_progress_with_unknown_length(): 146 | logging_fn_to_mock = "pytorchltr.utils.progress.logging" 147 | with mock.patch(logging_fn_to_mock) as logging_mock: 148 | with mock_urlopen(read=[b"mocked", b"content"], info={}): 149 | with tempfile.TemporaryDirectory() as tmpdir: 150 | downloader = Downloader( 151 | "http://mocked", "file.dat", 152 | progress_fn=LoggingDownloadProgress(interval=0.0)) 153 | downloader.download(tmpdir) 154 | logging_mock.info.assert_has_calls([ 155 | mock.call("downloading [0B / ?]"), 156 | mock.call("downloading [6B / ?]"), 157 | mock.call("downloading [13B / ?]"), 158 | mock.call("finished downloading [13B]"), 159 | ]) 160 | 161 | 162 | def test_download_logging_progress_with_kilobytes(): 163 | logging_fn_to_mock = "pytorchltr.utils.progress.logging" 164 | with mock.patch(logging_fn_to_mock) as logging_mock: 165 | with mock_urlopen(read=[b"mocked" * 1024, b"content" * 1024], 166 | info={"Content-Length": "13312"}): 167 | with tempfile.TemporaryDirectory() as tmpdir: 168 | downloader = Downloader( 169 | "http://mocked", "file.dat", 170 | progress_fn=LoggingDownloadProgress(interval=0.0)) 171 | downloader.download(tmpdir) 172 | logging_mock.info.assert_has_calls([ 173 | mock.call("downloading 0% [0B / 13KB]"), 174 | mock.call("downloading 46% [6.0KB / 13KB]"), 175 | mock.call("downloading 100% [13KB / 13KB]"), 176 | mock.call("finished downloading [13KB]"), 177 | ]) 178 | 179 | 180 | def test_download_logging_progress_with_megabytes(): 181 | logging_fn_to_mock = "pytorchltr.utils.progress.logging" 182 | with mock.patch(logging_fn_to_mock) as logging_mock: 183 | with mock_urlopen(read=[b"moc" * 1024 * 120, 184 | b"moc" * 1024 * 748, 185 | b"ked" * 1024 * 768, 186 | b"content123" * 1024 * 300], 187 | info={"Content-Length": "8097792"}): 188 | with tempfile.TemporaryDirectory() as tmpdir: 189 | downloader = Downloader( 190 | "http://mocked", "file.dat", 191 | progress_fn=LoggingDownloadProgress(interval=0.0)) 192 | downloader.download(tmpdir) 193 | logging_mock.info.assert_has_calls([ 194 | mock.call("downloading 0% [0B / 7.7MB]"), 195 | mock.call("downloading 4% [360KB / 7.7MB]"), 196 | mock.call("downloading 32% [2.5MB / 7.7MB]"), 197 | mock.call("downloading 62% [4.8MB / 7.7MB]"), 198 | mock.call("downloading 100% [7.7MB / 7.7MB]"), 199 | mock.call("finished downloading [7.7MB]"), 200 | ]) 201 | 202 | 203 | def test_download_terminal_progress(): 204 | with mock.patch("builtins.print") as print_mock: 205 | with mock_urlopen(read=[b"mocked", b"content"], 206 | info={"Content-Length": "13"}): 207 | with tempfile.TemporaryDirectory() as tmpdir: 208 | downloader = Downloader( 209 | "http://mocked", "file.dat", 210 | progress_fn=TerminalDownloadProgress(interval=0.0)) 211 | downloader.download(tmpdir) 212 | print_mock.assert_has_calls([ 213 | mock.call("\033[Kdownloading 0% [0B / 13B]", end="\r"), 214 | mock.call("\033[Kdownloading 46% [6B / 13B]", end="\r"), 215 | mock.call("\033[Kdownloading 100% [13B / 13B]", end="\r"), 216 | mock.call("\033[Kfinished downloading [13B]", end="\n"), 217 | ]) 218 | -------------------------------------------------------------------------------- /tests/datasets/svmrank/test_svmrank.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import pickle 3 | import tempfile 4 | from unittest import mock 5 | 6 | from pytest import raises 7 | from pytest import approx 8 | from pytorchltr.datasets.svmrank.svmrank import SVMRankDataset 9 | from pytorchltr.datasets.list_sampler import UniformSampler 10 | 11 | 12 | def get_sample_dataset(*args, **kwargs): 13 | """Get sample dataset, uses the same arguments as `svmranking_dataset`.""" 14 | dataset_file = "tests/datasets/resources/dataset.txt" 15 | return SVMRankDataset(dataset_file, *args, **kwargs) 16 | 17 | 18 | @contextlib.contextmanager 19 | def mock_svmrank_dataset(package_str="pytorchltr.datasets.svmrank.svmrank"): 20 | validate_and_download_str = package_str + ".validate_and_download" 21 | super_str = "pytorchltr.datasets.svmrank.svmrank.SVMRankDataset.__init__" 22 | with tempfile.TemporaryDirectory() as tmpdir: 23 | with mock.patch(validate_and_download_str) as mock_vali: 24 | with mock.patch(super_str) as mock_super: 25 | yield tmpdir, mock_super, mock_vali 26 | 27 | 28 | def test_basic(): 29 | 30 | # Load data set. 31 | dataset = get_sample_dataset() 32 | 33 | # Check data set size. 34 | assert len(dataset) == 4 35 | 36 | # Get first sample. 37 | sample = dataset[0] 38 | x, y, q = sample.features, sample.relevance, sample.qid 39 | assert x.shape == (6, 45) 40 | assert y.shape == (6,) 41 | assert x[1, 2] == 1.0 42 | assert y[1] == 2.0 43 | assert q == 1 44 | 45 | # Get second sample. 46 | sample = dataset[1] 47 | x, y, q = sample.features, sample.relevance, sample.qid 48 | assert x.shape == (9, 45) 49 | assert y.shape == (9,) 50 | assert float(x[5, 3]) == approx(0.422507) 51 | assert y[5] == 1.0 52 | assert q == 16 53 | 54 | # Get third sample. 55 | sample = dataset[2] 56 | x, y, q = sample.features, sample.relevance, sample.qid 57 | assert x.shape == (14, 45) 58 | assert y.shape == (14,) 59 | assert float(x[12, 2]) == approx(0.461538) 60 | assert y[12] == 0.0 61 | assert q == 60 62 | 63 | # Get fourth sample. 64 | sample = dataset[3] 65 | x, y, q = sample.features, sample.relevance, sample.qid 66 | assert x.shape == (10, 45) 67 | assert y.shape == (10,) 68 | assert float(x[8, 2]) == approx(0.25) 69 | assert y[8] == 0.0 70 | assert q == 63 71 | 72 | 73 | def test_sparse(): 74 | 75 | # Load data set. 76 | dataset_sparse = get_sample_dataset(sparse=True) 77 | dataset_dense = get_sample_dataset(sparse=False) 78 | 79 | # Check data set size. 80 | assert len(dataset_dense) == len(dataset_sparse) 81 | 82 | # Check sparse and dense return same samples. 83 | for i in range(len(dataset_dense)): 84 | sample_dense = dataset_dense[i] 85 | sample_sparse = dataset_sparse[i] 86 | assert sample_sparse.qid == sample_dense.qid 87 | assert sample_sparse.n == sample_dense.n 88 | assert sample_sparse.features.to_dense().numpy() == approx( 89 | sample_dense.features.numpy()) 90 | assert sample_sparse.relevance.numpy() == approx( 91 | sample_dense.relevance.numpy()) 92 | 93 | 94 | def test_normalize(): 95 | 96 | # Load data set. 97 | dataset = get_sample_dataset(normalize=True) 98 | 99 | # Check data set size. 100 | assert len(dataset) == 4 101 | 102 | # Get first sample and assert the contents is as expected. 103 | sample = dataset[0] 104 | x, y, q = sample.features, sample.relevance, sample.qid 105 | assert x.shape == (6, 45) 106 | assert y.shape == (6,) 107 | assert q == 1 108 | 109 | assert float(x[0, 1]) == approx(1.0) 110 | assert float(x[1, 1]) == approx(0.5) 111 | assert float(x[2, 1]) == approx(0.25) 112 | assert float(x[3, 1]) == approx(0.0) 113 | assert float(x[4, 1]) == approx(0.125) 114 | assert float(x[5, 1]) == approx(0.5) 115 | 116 | assert float(x[0, 0]) == approx(0.24242424242424246) 117 | assert float(x[1, 0]) == approx(0.12121212121212122) 118 | assert float(x[2, 0]) == approx(0.060606060606060615) 119 | assert float(x[3, 0]) == approx(0.0) 120 | assert float(x[4, 0]) == approx(1.0) 121 | assert float(x[5, 0]) == approx(0.12121212121212122) 122 | 123 | 124 | def test_sparse_normalize(): 125 | 126 | # This should raise an error as it is not implemented. 127 | with raises(NotImplementedError): 128 | get_sample_dataset(sparse=True, normalize=True) 129 | 130 | 131 | def test_serialize(): 132 | 133 | # Load data set. 134 | dataset = get_sample_dataset(normalize=True) 135 | 136 | # Attempt to serialize and deserialize it. 137 | serialized = pickle.dumps(dataset) 138 | deserialized = pickle.loads(serialized) 139 | 140 | # Assert original and deserialized versions are the same. 141 | assert len(dataset) == len(deserialized) 142 | for i in range(len(dataset)): 143 | sample1 = dataset[i] 144 | x1, y1, q1 = sample1.features, sample1.relevance, sample1.qid 145 | sample2 = deserialized[i] 146 | x2, y2, q2 = sample2.features, sample2.relevance, sample2.qid 147 | assert x1.numpy() == approx(x2.numpy()) 148 | assert y1.numpy() == approx(y2.numpy()) 149 | assert q1 == q2 150 | 151 | 152 | def test_serialize_sparse(): 153 | 154 | # Load data set. 155 | dataset = get_sample_dataset(sparse=True) 156 | 157 | # Attempt to serialize and deserialize it. 158 | serialized = pickle.dumps(dataset) 159 | deserialized = pickle.loads(serialized) 160 | 161 | # Assert original and deserialized versions are the same. 162 | assert len(dataset) == len(deserialized) 163 | for i in range(len(dataset)): 164 | sample1 = dataset[i] 165 | x1, y1, q1 = sample1.features, sample1.relevance, sample1.qid 166 | sample2 = deserialized[i] 167 | x2, y2, q2 = sample2.features, sample2.relevance, sample2.qid 168 | assert x1.to_dense().numpy() == approx(x2.to_dense().numpy()) 169 | assert y1.numpy() == approx(y2.numpy()) 170 | assert q1 == q2 171 | 172 | 173 | def test_double_serialize(): 174 | 175 | # Load data set 176 | dataset = get_sample_dataset(normalize=True) 177 | 178 | # Attempt to serialize and deserialize it multiple times. 179 | s1 = pickle.dumps(dataset) 180 | d1 = pickle.loads(s1) 181 | s2 = pickle.dumps(d1) 182 | deserialized = pickle.loads(s2) 183 | 184 | # Assert original and deserialized versions are the same. 185 | assert len(dataset) == len(deserialized) 186 | for i in range(len(dataset)): 187 | sample1 = dataset[i] 188 | x1, y1, q1 = sample1.features, sample1.relevance, sample1.qid 189 | sample2 = deserialized[i] 190 | x2, y2, q2 = sample2.features, sample2.relevance, sample2.qid 191 | assert x1.numpy() == approx(x2.numpy()) 192 | assert y1.numpy() == approx(y2.numpy()) 193 | assert q1 == q2 194 | 195 | 196 | def test_collate_sparse_10(): 197 | 198 | # Load data set. 199 | dataset = get_sample_dataset(sparse=True) 200 | 201 | # Construct a batch of three samples and collate it with a maximum list 202 | # size of 10. 203 | batch = [dataset[0], dataset[1], dataset[2]] 204 | collate_fn = SVMRankDataset.collate_fn(UniformSampler(max_list_size=10)) 205 | 206 | # Assert resulting tensor shape is as expected. 207 | tensor_batch = collate_fn(batch) 208 | assert tensor_batch.features.shape == (3, 10, 45) 209 | 210 | 211 | def test_collate_dense_10(): 212 | 213 | # Load data set. 214 | dataset = get_sample_dataset(sparse=False) 215 | 216 | # Construct a batch of three samples and collate it with a maximum list 217 | # size of 10. 218 | batch = [dataset[0], dataset[1], dataset[2]] 219 | collate_fn = SVMRankDataset.collate_fn(UniformSampler(max_list_size=10)) 220 | 221 | # Assert resulting tensor shape is as expected. 222 | tensor_batch = collate_fn(batch) 223 | assert tensor_batch.features.shape == (3, 10, 45) 224 | 225 | 226 | def test_collate_sparse_3(): 227 | 228 | # Load data set. 229 | dataset = get_sample_dataset(sparse=True) 230 | 231 | # Construct a batch of three samples and collate it with a maximum list 232 | # size of 3. 233 | batch = [dataset[0], dataset[1], dataset[2]] 234 | collate_fn = SVMRankDataset.collate_fn(UniformSampler(max_list_size=3)) 235 | 236 | # Assert resulting tensor shape is as expected. 237 | tensor_batch = collate_fn(batch) 238 | assert tensor_batch.features.shape == (3, 3, 45) 239 | 240 | 241 | def test_collate_dense_3(): 242 | 243 | # Load data set. 244 | dataset = get_sample_dataset(sparse=False) 245 | 246 | # Construct a batch of three samples and collate it with a maximum list 247 | # size of 3. 248 | batch = [dataset[0], dataset[1], dataset[2]] 249 | collate_fn = SVMRankDataset.collate_fn(UniformSampler(max_list_size=3)) 250 | 251 | # Assert resulting tensor shape is as expected. 252 | tensor_batch = collate_fn(batch) 253 | assert tensor_batch.features.shape == (3, 3, 45) 254 | 255 | 256 | def test_collate_sparse_all(): 257 | 258 | # Load data set. 259 | dataset = get_sample_dataset(sparse=True) 260 | 261 | # Construct a batch of three samples and collate it with an unlimited 262 | # maximum list size. 263 | batch = [dataset[0], dataset[1], dataset[2]] 264 | collate_fn = SVMRankDataset.collate_fn(UniformSampler(max_list_size=None)) 265 | 266 | # Assert resulting tensor shape is as expected. 267 | tensor_batch = collate_fn(batch) 268 | assert tensor_batch.features.shape == (3, 14, 45) 269 | 270 | 271 | def test_collate_dense_all(): 272 | 273 | # Load data set. 274 | dataset = get_sample_dataset(sparse=False) 275 | 276 | # Construct a batch of three samples and collate it with an unlimited 277 | # maximum list size. 278 | batch = [dataset[0], dataset[1], dataset[2]] 279 | collate_fn = SVMRankDataset.collate_fn(UniformSampler(max_list_size=None)) 280 | 281 | # Assert resulting tensor shape is as expected. 282 | tensor_batch = collate_fn(batch) 283 | assert tensor_batch.features.shape == (3, 14, 45) 284 | 285 | 286 | def test_filter_queries(): 287 | # Load data set. 288 | dataset_filtered = get_sample_dataset(filter_queries=True) 289 | dataset = get_sample_dataset(filter_queries=False) 290 | assert len(dataset_filtered) != len(dataset) 291 | 292 | # Assert qid matches on non-filtered queries 293 | assert dataset_filtered[0].qid == dataset[0].qid 294 | assert dataset_filtered[1].qid == dataset[1].qid 295 | assert dataset_filtered[2].qid == dataset[3].qid 296 | 297 | 298 | def test_get_index(): 299 | # Load data set. 300 | dataset = get_sample_dataset() 301 | 302 | # Assert that get_index for each qid matches the index. 303 | for i in range(len(dataset)): 304 | assert dataset.get_index(dataset[i].qid) == i 305 | 306 | 307 | def test_get_index_filtered_queries(): 308 | # Load data set. 309 | dataset = get_sample_dataset(filter_queries=True) 310 | 311 | # Assert that get_index for each qid matches the index. 312 | for i in range(len(dataset)): 313 | assert dataset.get_index(dataset[i].qid) == i 314 | --------------------------------------------------------------------------------